17344: Reset "broken node" flag on loopback instance create/reset.
[arvados.git] / lib / cloud / loopback / loopback.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package loopback
6
7 import (
8         "bytes"
9         "crypto/rand"
10         "crypto/rsa"
11         "encoding/json"
12         "errors"
13         "io"
14         "os"
15         "os/exec"
16         "os/user"
17         "strings"
18         "sync"
19         "syscall"
20
21         "git.arvados.org/arvados.git/lib/cloud"
22         "git.arvados.org/arvados.git/lib/dispatchcloud/test"
23         "git.arvados.org/arvados.git/sdk/go/arvados"
24         "github.com/sirupsen/logrus"
25         "golang.org/x/crypto/ssh"
26 )
27
28 // Driver is the loopback implementation of the cloud.Driver interface.
29 var Driver = cloud.DriverFunc(newInstanceSet)
30
31 var (
32         errUnimplemented = errors.New("function not implemented by loopback driver")
33         errQuota         = quotaError("loopback driver is always at quota")
34 )
35
36 type quotaError string
37
38 func (e quotaError) IsQuotaError() bool { return true }
39 func (e quotaError) Error() string      { return string(e) }
40
41 type instanceSet struct {
42         instanceSetID cloud.InstanceSetID
43         logger        logrus.FieldLogger
44         instances     []*instance
45         mtx           sync.Mutex
46 }
47
48 func newInstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
49         is := &instanceSet{
50                 instanceSetID: instanceSetID,
51                 logger:        logger,
52         }
53         return is, nil
54 }
55
56 func (is *instanceSet) Create(it arvados.InstanceType, _ cloud.ImageID, tags cloud.InstanceTags, _ cloud.InitCommand, pubkey ssh.PublicKey) (cloud.Instance, error) {
57         is.mtx.Lock()
58         defer is.mtx.Unlock()
59         if len(is.instances) > 0 {
60                 return nil, errQuota
61         }
62         // A crunch-run process running in a previous instance may
63         // have marked the node as broken. In the loopback scenario a
64         // destroy+create cycle doesn't fix whatever was broken -- but
65         // nothing else will either, so the best we can do is remove
66         // the "broken" flag and try again.
67         if err := os.Remove("/var/lock/crunch-run-broken"); err != nil && !errors.Is(err, os.ErrNotExist) {
68                 return nil, err
69         }
70         u, err := user.Current()
71         if err != nil {
72                 return nil, err
73         }
74         hostRSAKey, err := rsa.GenerateKey(rand.Reader, 1024)
75         if err != nil {
76                 return nil, err
77         }
78         hostKey, err := ssh.NewSignerFromKey(hostRSAKey)
79         if err != nil {
80                 return nil, err
81         }
82         hostPubKey, err := ssh.NewPublicKey(hostRSAKey.Public())
83         if err != nil {
84                 return nil, err
85         }
86         inst := &instance{
87                 is:           is,
88                 instanceType: it,
89                 adminUser:    u.Username,
90                 tags:         tags,
91                 hostPubKey:   hostPubKey,
92                 sshService: test.SSHService{
93                         HostKey:        hostKey,
94                         AuthorizedUser: u.Username,
95                         AuthorizedKeys: []ssh.PublicKey{pubkey},
96                 },
97         }
98         inst.sshService.Exec = inst.sshExecFunc
99         go inst.sshService.Start()
100         is.instances = []*instance{inst}
101         return inst, nil
102 }
103
104 func (is *instanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
105         is.mtx.Lock()
106         defer is.mtx.Unlock()
107         var ret []cloud.Instance
108         for _, inst := range is.instances {
109                 ret = append(ret, inst)
110         }
111         return ret, nil
112 }
113
114 func (is *instanceSet) Stop() {
115         is.mtx.Lock()
116         defer is.mtx.Unlock()
117         for _, inst := range is.instances {
118                 inst.sshService.Close()
119         }
120 }
121
122 type instance struct {
123         is           *instanceSet
124         instanceType arvados.InstanceType
125         adminUser    string
126         tags         cloud.InstanceTags
127         hostPubKey   ssh.PublicKey
128         sshService   test.SSHService
129 }
130
131 func (i *instance) ID() cloud.InstanceID     { return cloud.InstanceID(i.instanceType.ProviderType) }
132 func (i *instance) String() string           { return i.instanceType.ProviderType }
133 func (i *instance) ProviderType() string     { return i.instanceType.ProviderType }
134 func (i *instance) Address() string          { return i.sshService.Address() }
135 func (i *instance) RemoteUser() string       { return i.adminUser }
136 func (i *instance) Tags() cloud.InstanceTags { return i.tags }
137 func (i *instance) SetTags(tags cloud.InstanceTags) error {
138         i.tags = tags
139         return nil
140 }
141 func (i *instance) Destroy() error {
142         i.is.mtx.Lock()
143         defer i.is.mtx.Unlock()
144         i.is.instances = i.is.instances[:0]
145         return nil
146 }
147 func (i *instance) VerifyHostKey(pubkey ssh.PublicKey, _ *ssh.Client) error {
148         if !bytes.Equal(pubkey.Marshal(), i.hostPubKey.Marshal()) {
149                 return errors.New("host key mismatch")
150         }
151         return nil
152 }
153 func (i *instance) sshExecFunc(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
154         cmd := exec.Command("sh", "-c", strings.TrimPrefix(command, "sudo "))
155         cmd.Stdin = stdin
156         cmd.Stdout = stdout
157         cmd.Stderr = stderr
158         for k, v := range env {
159                 cmd.Env = append(cmd.Env, k+"="+v)
160         }
161         // Prevent child process from using our tty.
162         cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
163         err := cmd.Run()
164         if err == nil {
165                 return 0
166         } else if err, ok := err.(*exec.ExitError); !ok {
167                 return 1
168         } else if code := err.ExitCode(); code < 0 {
169                 return 1
170         } else {
171                 return uint32(code)
172         }
173 }