Merge branch 'main' into 21224-project-details
[arvados.git] / lib / cloud / loopback / loopback.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package loopback
6
7 import (
8         "bytes"
9         "crypto/rand"
10         "crypto/rsa"
11         "encoding/json"
12         "errors"
13         "io"
14         "os"
15         "os/exec"
16         "os/user"
17         "strings"
18         "sync"
19         "syscall"
20
21         "git.arvados.org/arvados.git/lib/cloud"
22         "git.arvados.org/arvados.git/lib/dispatchcloud/test"
23         "git.arvados.org/arvados.git/sdk/go/arvados"
24         "github.com/prometheus/client_golang/prometheus"
25         "github.com/sirupsen/logrus"
26         "golang.org/x/crypto/ssh"
27 )
28
29 // Driver is the loopback implementation of the cloud.Driver interface.
30 var Driver = cloud.DriverFunc(newInstanceSet)
31
32 var (
33         errUnimplemented = errors.New("function not implemented by loopback driver")
34         errQuota         = quotaError("loopback driver is always at quota")
35 )
36
37 type quotaError string
38
39 func (e quotaError) IsQuotaError() bool { return true }
40 func (e quotaError) Error() string      { return string(e) }
41
42 type instanceSet struct {
43         instanceSetID cloud.InstanceSetID
44         logger        logrus.FieldLogger
45         instances     []*instance
46         mtx           sync.Mutex
47 }
48
49 func newInstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (cloud.InstanceSet, error) {
50         is := &instanceSet{
51                 instanceSetID: instanceSetID,
52                 logger:        logger,
53         }
54         return is, nil
55 }
56
57 func (is *instanceSet) Create(it arvados.InstanceType, _ cloud.ImageID, tags cloud.InstanceTags, _ cloud.InitCommand, pubkey ssh.PublicKey) (cloud.Instance, error) {
58         is.mtx.Lock()
59         defer is.mtx.Unlock()
60         if len(is.instances) > 0 {
61                 return nil, errQuota
62         }
63         // A crunch-run process running in a previous instance may
64         // have marked the node as broken. In the loopback scenario a
65         // destroy+create cycle doesn't fix whatever was broken -- but
66         // nothing else will either, so the best we can do is remove
67         // the "broken" flag and try again.
68         if err := os.Remove("/var/lock/crunch-run-broken"); err == nil {
69                 is.logger.Info("removed /var/lock/crunch-run-broken")
70         } else if !errors.Is(err, os.ErrNotExist) {
71                 return nil, err
72         }
73         u, err := user.Current()
74         if err != nil {
75                 return nil, err
76         }
77         hostRSAKey, err := rsa.GenerateKey(rand.Reader, 1024)
78         if err != nil {
79                 return nil, err
80         }
81         hostKey, err := ssh.NewSignerFromKey(hostRSAKey)
82         if err != nil {
83                 return nil, err
84         }
85         hostPubKey, err := ssh.NewPublicKey(hostRSAKey.Public())
86         if err != nil {
87                 return nil, err
88         }
89         inst := &instance{
90                 is:           is,
91                 instanceType: it,
92                 adminUser:    u.Username,
93                 tags:         tags,
94                 hostPubKey:   hostPubKey,
95                 sshService: test.SSHService{
96                         HostKey:        hostKey,
97                         AuthorizedUser: u.Username,
98                         AuthorizedKeys: []ssh.PublicKey{pubkey},
99                 },
100         }
101         inst.sshService.Exec = inst.sshExecFunc
102         go inst.sshService.Start()
103         is.instances = []*instance{inst}
104         return inst, nil
105 }
106
107 func (is *instanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
108         is.mtx.Lock()
109         defer is.mtx.Unlock()
110         var ret []cloud.Instance
111         for _, inst := range is.instances {
112                 ret = append(ret, inst)
113         }
114         return ret, nil
115 }
116
117 func (is *instanceSet) Stop() {
118         is.mtx.Lock()
119         defer is.mtx.Unlock()
120         for _, inst := range is.instances {
121                 inst.sshService.Close()
122         }
123 }
124
125 type instance struct {
126         is           *instanceSet
127         instanceType arvados.InstanceType
128         adminUser    string
129         tags         cloud.InstanceTags
130         hostPubKey   ssh.PublicKey
131         sshService   test.SSHService
132 }
133
134 func (i *instance) ID() cloud.InstanceID                                    { return cloud.InstanceID(i.instanceType.ProviderType) }
135 func (i *instance) String() string                                          { return i.instanceType.ProviderType }
136 func (i *instance) ProviderType() string                                    { return i.instanceType.ProviderType }
137 func (i *instance) Address() string                                         { return i.sshService.Address() }
138 func (i *instance) PriceHistory(arvados.InstanceType) []cloud.InstancePrice { return nil }
139 func (i *instance) RemoteUser() string                                      { return i.adminUser }
140 func (i *instance) Tags() cloud.InstanceTags                                { return i.tags }
141 func (i *instance) SetTags(tags cloud.InstanceTags) error {
142         i.tags = tags
143         return nil
144 }
145 func (i *instance) Destroy() error {
146         i.is.mtx.Lock()
147         defer i.is.mtx.Unlock()
148         i.is.instances = i.is.instances[:0]
149         return nil
150 }
151 func (i *instance) VerifyHostKey(pubkey ssh.PublicKey, _ *ssh.Client) error {
152         if !bytes.Equal(pubkey.Marshal(), i.hostPubKey.Marshal()) {
153                 return errors.New("host key mismatch")
154         }
155         return nil
156 }
157 func (i *instance) sshExecFunc(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
158         cmd := exec.Command("sh", "-c", strings.TrimPrefix(command, "sudo "))
159         cmd.Stdin = stdin
160         cmd.Stdout = stdout
161         cmd.Stderr = stderr
162         for k, v := range env {
163                 cmd.Env = append(cmd.Env, k+"="+v)
164         }
165         // Prevent child process from using our tty.
166         cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
167         err := cmd.Run()
168         if err == nil {
169                 return 0
170         } else if err, ok := err.(*exec.ExitError); !ok {
171                 return 1
172         } else if code := err.ExitCode(); code < 0 {
173                 return 1
174         } else {
175                 return uint32(code)
176         }
177 }