Merge branch '14325-dispatch-cloud'
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "errors"
10         "fmt"
11         "io"
12         math_rand "math/rand"
13         "regexp"
14         "strings"
15         "sync"
16         "time"
17
18         "git.curoverse.com/arvados.git/lib/cloud"
19         "git.curoverse.com/arvados.git/sdk/go/arvados"
20         "github.com/mitchellh/mapstructure"
21         "github.com/sirupsen/logrus"
22         "golang.org/x/crypto/ssh"
23 )
24
25 // A StubDriver implements cloud.Driver by setting up local SSH
26 // servers that do fake command executions.
27 type StubDriver struct {
28         HostKey        ssh.Signer
29         AuthorizedKeys []ssh.PublicKey
30
31         // SetupVM, if set, is called upon creation of each new
32         // StubVM. This is the caller's opportunity to customize the
33         // VM's error rate and other behaviors.
34         SetupVM func(*StubVM)
35
36         // StubVM's fake crunch-run uses this Queue to read and update
37         // container state.
38         Queue *Queue
39
40         // Frequency of artificially introduced errors on calls to
41         // Destroy. 0=always succeed, 1=always fail.
42         ErrorRateDestroy float64
43
44         // If Create() or Instances() is called too frequently, return
45         // rate-limiting errors.
46         MinTimeBetweenCreateCalls    time.Duration
47         MinTimeBetweenInstancesCalls time.Duration
48
49         // If true, Create and Destroy calls block until Release() is
50         // called.
51         HoldCloudOps bool
52
53         instanceSets []*StubInstanceSet
54         holdCloudOps chan bool
55 }
56
57 // InstanceSet returns a new *StubInstanceSet.
58 func (sd *StubDriver) InstanceSet(params map[string]interface{}, id cloud.InstanceSetID, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
59         if sd.holdCloudOps == nil {
60                 sd.holdCloudOps = make(chan bool)
61         }
62         sis := StubInstanceSet{
63                 driver:  sd,
64                 servers: map[cloud.InstanceID]*StubVM{},
65         }
66         sd.instanceSets = append(sd.instanceSets, &sis)
67         return &sis, mapstructure.Decode(params, &sis)
68 }
69
70 // InstanceSets returns all instances that have been created by the
71 // driver. This can be used to test a component that uses the driver
72 // but doesn't expose the InstanceSets it has created.
73 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
74         return sd.instanceSets
75 }
76
77 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
78 // are fewer than n blocked calls pending, it waits for the rest to
79 // arrive.
80 func (sd *StubDriver) ReleaseCloudOps(n int) {
81         for i := 0; i < n; i++ {
82                 <-sd.holdCloudOps
83         }
84 }
85
86 type StubInstanceSet struct {
87         driver  *StubDriver
88         servers map[cloud.InstanceID]*StubVM
89         mtx     sync.RWMutex
90         stopped bool
91
92         allowCreateCall    time.Time
93         allowInstancesCall time.Time
94 }
95
96 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, authKey ssh.PublicKey) (cloud.Instance, error) {
97         if sis.driver.HoldCloudOps {
98                 sis.driver.holdCloudOps <- true
99         }
100         sis.mtx.Lock()
101         defer sis.mtx.Unlock()
102         if sis.stopped {
103                 return nil, errors.New("StubInstanceSet: Create called after Stop")
104         }
105         if sis.allowCreateCall.After(time.Now()) {
106                 return nil, RateLimitError{sis.allowCreateCall}
107         } else {
108                 sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
109         }
110
111         ak := sis.driver.AuthorizedKeys
112         if authKey != nil {
113                 ak = append([]ssh.PublicKey{authKey}, ak...)
114         }
115         svm := &StubVM{
116                 sis:          sis,
117                 id:           cloud.InstanceID(fmt.Sprintf("stub-%s-%x", it.ProviderType, math_rand.Int63())),
118                 tags:         copyTags(tags),
119                 providerType: it.ProviderType,
120         }
121         svm.SSHService = SSHService{
122                 HostKey:        sis.driver.HostKey,
123                 AuthorizedKeys: ak,
124                 Exec:           svm.Exec,
125         }
126         if setup := sis.driver.SetupVM; setup != nil {
127                 setup(svm)
128         }
129         sis.servers[svm.id] = svm
130         return svm.Instance(), nil
131 }
132
133 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
134         sis.mtx.RLock()
135         defer sis.mtx.RUnlock()
136         if sis.allowInstancesCall.After(time.Now()) {
137                 return nil, RateLimitError{sis.allowInstancesCall}
138         } else {
139                 sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
140         }
141         var r []cloud.Instance
142         for _, ss := range sis.servers {
143                 r = append(r, ss.Instance())
144         }
145         return r, nil
146 }
147
148 func (sis *StubInstanceSet) Stop() {
149         sis.mtx.Lock()
150         defer sis.mtx.Unlock()
151         if sis.stopped {
152                 panic("Stop called twice")
153         }
154         sis.stopped = true
155 }
156
157 type RateLimitError struct{ Retry time.Time }
158
159 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
160 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
161
162 // StubVM is a fake server that runs an SSH service. It represents a
163 // VM running in a fake cloud.
164 //
165 // Note this is distinct from a stubInstance, which is a snapshot of
166 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
167 // running (and might change IP addresses, shut down, etc.)  without
168 // updating any stubInstances that have been returned to callers.
169 type StubVM struct {
170         Boot                 time.Time
171         Broken               time.Time
172         CrunchRunMissing     bool
173         CrunchRunCrashRate   float64
174         CrunchRunDetachDelay time.Duration
175         ExecuteContainer     func(arvados.Container) int
176
177         sis          *StubInstanceSet
178         id           cloud.InstanceID
179         tags         cloud.InstanceTags
180         providerType string
181         SSHService   SSHService
182         running      map[string]bool
183         sync.Mutex
184 }
185
186 func (svm *StubVM) Instance() stubInstance {
187         svm.Lock()
188         defer svm.Unlock()
189         return stubInstance{
190                 svm:  svm,
191                 addr: svm.SSHService.Address(),
192                 // We deliberately return a cached/stale copy of the
193                 // real tags here, so that (Instance)Tags() sometimes
194                 // returns old data after a call to
195                 // (Instance)SetTags().  This is permitted by the
196                 // driver interface, and this might help remind
197                 // callers that they need to tolerate it.
198                 tags: copyTags(svm.tags),
199         }
200 }
201
202 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
203         queue := svm.sis.driver.Queue
204         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
205         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
206                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
207                 return 1
208         }
209         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
210                 fmt.Fprintf(stderr, "cannot fork\n")
211                 return 2
212         }
213         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
214                 fmt.Fprint(stderr, "crunch-run: command not found\n")
215                 return 1
216         }
217         if strings.HasPrefix(command, "crunch-run --detach ") {
218                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
219                         if env[name] == "" {
220                                 fmt.Fprintf(stderr, "%s missing from environment %q\n", name, env)
221                                 return 1
222                         }
223                 }
224                 svm.Lock()
225                 if svm.running == nil {
226                         svm.running = map[string]bool{}
227                 }
228                 svm.running[uuid] = true
229                 svm.Unlock()
230                 time.Sleep(svm.CrunchRunDetachDelay)
231                 fmt.Fprintf(stderr, "starting %s\n", uuid)
232                 logger := logrus.WithFields(logrus.Fields{
233                         "Instance":      svm.id,
234                         "ContainerUUID": uuid,
235                 })
236                 logger.Printf("[test] starting crunch-run stub")
237                 go func() {
238                         crashluck := math_rand.Float64()
239                         ctr, ok := queue.Get(uuid)
240                         if !ok {
241                                 logger.Print("[test] container not in queue")
242                                 return
243                         }
244                         if crashluck > svm.CrunchRunCrashRate/2 {
245                                 time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
246                                 ctr.State = arvados.ContainerStateRunning
247                                 queue.Notify(ctr)
248                         }
249
250                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
251                         svm.Lock()
252                         _, running := svm.running[uuid]
253                         svm.Unlock()
254                         if !running {
255                                 logger.Print("[test] container was killed")
256                                 return
257                         }
258                         if svm.ExecuteContainer != nil {
259                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
260                         }
261                         // TODO: Check whether the stub instance has
262                         // been destroyed, and if so, don't call
263                         // queue.Notify. Then "container finished
264                         // twice" can be classified as a bug.
265                         if crashluck < svm.CrunchRunCrashRate {
266                                 logger.Print("[test] crashing crunch-run stub")
267                         } else {
268                                 ctr.State = arvados.ContainerStateComplete
269                                 queue.Notify(ctr)
270                         }
271                         logger.Print("[test] exiting crunch-run stub")
272                         svm.Lock()
273                         defer svm.Unlock()
274                         delete(svm.running, uuid)
275                 }()
276                 return 0
277         }
278         if command == "crunch-run --list" {
279                 svm.Lock()
280                 defer svm.Unlock()
281                 for uuid := range svm.running {
282                         fmt.Fprintf(stdout, "%s\n", uuid)
283                 }
284                 return 0
285         }
286         if strings.HasPrefix(command, "crunch-run --kill ") {
287                 svm.Lock()
288                 defer svm.Unlock()
289                 if svm.running[uuid] {
290                         delete(svm.running, uuid)
291                 } else {
292                         fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
293                 }
294                 return 0
295         }
296         if command == "true" {
297                 return 0
298         }
299         fmt.Fprintf(stderr, "%q: command not found", command)
300         return 1
301 }
302
303 type stubInstance struct {
304         svm  *StubVM
305         addr string
306         tags cloud.InstanceTags
307 }
308
309 func (si stubInstance) ID() cloud.InstanceID {
310         return si.svm.id
311 }
312
313 func (si stubInstance) Address() string {
314         return si.addr
315 }
316
317 func (si stubInstance) Destroy() error {
318         sis := si.svm.sis
319         if sis.driver.HoldCloudOps {
320                 sis.driver.holdCloudOps <- true
321         }
322         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
323                 return errors.New("instance could not be destroyed")
324         }
325         si.svm.SSHService.Close()
326         sis.mtx.Lock()
327         defer sis.mtx.Unlock()
328         delete(sis.servers, si.svm.id)
329         return nil
330 }
331
332 func (si stubInstance) ProviderType() string {
333         return si.svm.providerType
334 }
335
336 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
337         tags = copyTags(tags)
338         svm := si.svm
339         go func() {
340                 svm.Lock()
341                 defer svm.Unlock()
342                 svm.tags = tags
343         }()
344         return nil
345 }
346
347 func (si stubInstance) Tags() cloud.InstanceTags {
348         return si.tags
349 }
350
351 func (si stubInstance) String() string {
352         return string(si.svm.id)
353 }
354
355 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
356         buf := make([]byte, 512)
357         _, err := io.ReadFull(rand.Reader, buf)
358         if err != nil {
359                 return err
360         }
361         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
362         if err != nil {
363                 return err
364         }
365         return key.Verify(buf, sig)
366 }
367
368 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
369         dst := cloud.InstanceTags{}
370         for k, v := range src {
371                 dst[k] = v
372         }
373         return dst
374 }