Merge branch '15370-loopback-dispatchcloud'
[arvados.git] / lib / cloud / loopback / loopback.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package loopback
6
7 import (
8         "bytes"
9         "crypto/rand"
10         "crypto/rsa"
11         "encoding/json"
12         "errors"
13         "io"
14         "os/exec"
15         "os/user"
16         "strings"
17         "sync"
18         "syscall"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/test"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/sirupsen/logrus"
24         "golang.org/x/crypto/ssh"
25 )
26
27 // Driver is the loopback implementation of the cloud.Driver interface.
28 var Driver = cloud.DriverFunc(newInstanceSet)
29
30 var (
31         errUnimplemented = errors.New("function not implemented by loopback driver")
32         errQuota         = quotaError("loopback driver is always at quota")
33 )
34
35 type quotaError string
36
37 func (e quotaError) IsQuotaError() bool { return true }
38 func (e quotaError) Error() string      { return string(e) }
39
40 type instanceSet struct {
41         instanceSetID cloud.InstanceSetID
42         logger        logrus.FieldLogger
43         instances     []*instance
44         mtx           sync.Mutex
45 }
46
47 func newInstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
48         is := &instanceSet{
49                 instanceSetID: instanceSetID,
50                 logger:        logger,
51         }
52         return is, nil
53 }
54
55 func (is *instanceSet) Create(it arvados.InstanceType, _ cloud.ImageID, tags cloud.InstanceTags, _ cloud.InitCommand, pubkey ssh.PublicKey) (cloud.Instance, error) {
56         is.mtx.Lock()
57         defer is.mtx.Unlock()
58         if len(is.instances) > 0 {
59                 return nil, errQuota
60         }
61         u, err := user.Current()
62         if err != nil {
63                 return nil, err
64         }
65         hostRSAKey, err := rsa.GenerateKey(rand.Reader, 1024)
66         if err != nil {
67                 return nil, err
68         }
69         hostKey, err := ssh.NewSignerFromKey(hostRSAKey)
70         if err != nil {
71                 return nil, err
72         }
73         hostPubKey, err := ssh.NewPublicKey(hostRSAKey.Public())
74         if err != nil {
75                 return nil, err
76         }
77         inst := &instance{
78                 is:           is,
79                 instanceType: it,
80                 adminUser:    u.Username,
81                 tags:         tags,
82                 hostPubKey:   hostPubKey,
83                 sshService: test.SSHService{
84                         HostKey:        hostKey,
85                         AuthorizedUser: u.Username,
86                         AuthorizedKeys: []ssh.PublicKey{pubkey},
87                 },
88         }
89         inst.sshService.Exec = inst.sshExecFunc
90         go inst.sshService.Start()
91         is.instances = []*instance{inst}
92         return inst, nil
93 }
94
95 func (is *instanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
96         is.mtx.Lock()
97         defer is.mtx.Unlock()
98         var ret []cloud.Instance
99         for _, inst := range is.instances {
100                 ret = append(ret, inst)
101         }
102         return ret, nil
103 }
104
105 func (is *instanceSet) Stop() {
106         is.mtx.Lock()
107         defer is.mtx.Unlock()
108         for _, inst := range is.instances {
109                 inst.sshService.Close()
110         }
111 }
112
113 type instance struct {
114         is           *instanceSet
115         instanceType arvados.InstanceType
116         adminUser    string
117         tags         cloud.InstanceTags
118         hostPubKey   ssh.PublicKey
119         sshService   test.SSHService
120 }
121
122 func (i *instance) ID() cloud.InstanceID     { return cloud.InstanceID(i.instanceType.ProviderType) }
123 func (i *instance) String() string           { return i.instanceType.ProviderType }
124 func (i *instance) ProviderType() string     { return i.instanceType.ProviderType }
125 func (i *instance) Address() string          { return i.sshService.Address() }
126 func (i *instance) RemoteUser() string       { return i.adminUser }
127 func (i *instance) Tags() cloud.InstanceTags { return i.tags }
128 func (i *instance) SetTags(tags cloud.InstanceTags) error {
129         i.tags = tags
130         return nil
131 }
132 func (i *instance) Destroy() error {
133         i.is.mtx.Lock()
134         defer i.is.mtx.Unlock()
135         i.is.instances = i.is.instances[:0]
136         return nil
137 }
138 func (i *instance) VerifyHostKey(pubkey ssh.PublicKey, _ *ssh.Client) error {
139         if !bytes.Equal(pubkey.Marshal(), i.hostPubKey.Marshal()) {
140                 return errors.New("host key mismatch")
141         }
142         return nil
143 }
144 func (i *instance) sshExecFunc(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
145         cmd := exec.Command("sh", "-c", strings.TrimPrefix(command, "sudo "))
146         cmd.Stdin = stdin
147         cmd.Stdout = stdout
148         cmd.Stderr = stderr
149         for k, v := range env {
150                 cmd.Env = append(cmd.Env, k+"="+v)
151         }
152         // Prevent child process from using our tty.
153         cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
154         err := cmd.Run()
155         if err == nil {
156                 return 0
157         } else if err, ok := err.(*exec.ExitError); !ok {
158                 return 1
159         } else if code := err.ExitCode(); code < 0 {
160                 return 1
161         } else {
162                 return uint32(code)
163         }
164 }