Merge branch '20754-docker-py-upgrade'
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/lib/crunchrun"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/sirupsen/logrus"
24         "golang.org/x/crypto/ssh"
25 )
26
27 // A StubDriver implements cloud.Driver by setting up local SSH
28 // servers that do fake command executions.
29 type StubDriver struct {
30         HostKey        ssh.Signer
31         AuthorizedKeys []ssh.PublicKey
32
33         // SetupVM, if set, is called upon creation of each new
34         // StubVM. This is the caller's opportunity to customize the
35         // VM's error rate and other behaviors.
36         SetupVM func(*StubVM)
37
38         // Bugf, if set, is called if a bug is detected in the caller
39         // or stub. Typically set to (*check.C)Errorf. If unset,
40         // logger.Warnf is called instead.
41         Bugf func(string, ...interface{})
42
43         // StubVM's fake crunch-run uses this Queue to read and update
44         // container state.
45         Queue *Queue
46
47         // Frequency of artificially introduced errors on calls to
48         // Create and Destroy. 0=always succeed, 1=always fail.
49         ErrorRateCreate  float64
50         ErrorRateDestroy float64
51
52         // If Create() or Instances() is called too frequently, return
53         // rate-limiting errors.
54         MinTimeBetweenCreateCalls    time.Duration
55         MinTimeBetweenInstancesCalls time.Duration
56
57         // If true, Create and Destroy calls block until Release() is
58         // called.
59         HoldCloudOps bool
60
61         instanceSets []*StubInstanceSet
62         holdCloudOps chan bool
63 }
64
65 // InstanceSet returns a new *StubInstanceSet.
66 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
67         if sd.holdCloudOps == nil {
68                 sd.holdCloudOps = make(chan bool)
69         }
70         sis := StubInstanceSet{
71                 driver:  sd,
72                 logger:  logger,
73                 servers: map[cloud.InstanceID]*StubVM{},
74         }
75         sd.instanceSets = append(sd.instanceSets, &sis)
76
77         var err error
78         if params != nil {
79                 err = json.Unmarshal(params, &sis)
80         }
81         return &sis, err
82 }
83
84 // InstanceSets returns all instances that have been created by the
85 // driver. This can be used to test a component that uses the driver
86 // but doesn't expose the InstanceSets it has created.
87 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
88         return sd.instanceSets
89 }
90
91 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
92 // are fewer than n blocked calls pending, it waits for the rest to
93 // arrive.
94 func (sd *StubDriver) ReleaseCloudOps(n int) {
95         for i := 0; i < n; i++ {
96                 <-sd.holdCloudOps
97         }
98 }
99
100 type StubInstanceSet struct {
101         driver  *StubDriver
102         logger  logrus.FieldLogger
103         servers map[cloud.InstanceID]*StubVM
104         mtx     sync.RWMutex
105         stopped bool
106
107         allowCreateCall    time.Time
108         allowInstancesCall time.Time
109         lastInstanceID     int
110 }
111
112 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, initCommand cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
113         if sis.driver.HoldCloudOps {
114                 sis.driver.holdCloudOps <- true
115         }
116         sis.mtx.Lock()
117         defer sis.mtx.Unlock()
118         if sis.stopped {
119                 return nil, errors.New("StubInstanceSet: Create called after Stop")
120         }
121         if sis.allowCreateCall.After(time.Now()) {
122                 return nil, RateLimitError{sis.allowCreateCall}
123         }
124         if math_rand.Float64() < sis.driver.ErrorRateCreate {
125                 return nil, fmt.Errorf("StubInstanceSet: rand < ErrorRateCreate %f", sis.driver.ErrorRateCreate)
126         }
127         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
128         ak := sis.driver.AuthorizedKeys
129         if authKey != nil {
130                 ak = append([]ssh.PublicKey{authKey}, ak...)
131         }
132         sis.lastInstanceID++
133         svm := &StubVM{
134                 InitCommand:  initCommand,
135                 sis:          sis,
136                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
137                 tags:         copyTags(tags),
138                 providerType: it.ProviderType,
139                 running:      map[string]stubProcess{},
140                 killing:      map[string]bool{},
141         }
142         svm.SSHService = SSHService{
143                 HostKey:        sis.driver.HostKey,
144                 AuthorizedUser: "root",
145                 AuthorizedKeys: ak,
146                 Exec:           svm.Exec,
147         }
148         if setup := sis.driver.SetupVM; setup != nil {
149                 setup(svm)
150         }
151         sis.servers[svm.id] = svm
152         return svm.Instance(), nil
153 }
154
155 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
156         sis.mtx.RLock()
157         defer sis.mtx.RUnlock()
158         if sis.allowInstancesCall.After(time.Now()) {
159                 return nil, RateLimitError{sis.allowInstancesCall}
160         }
161         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
162         var r []cloud.Instance
163         for _, ss := range sis.servers {
164                 r = append(r, ss.Instance())
165         }
166         return r, nil
167 }
168
169 func (sis *StubInstanceSet) Stop() {
170         sis.mtx.Lock()
171         defer sis.mtx.Unlock()
172         if sis.stopped {
173                 panic("Stop called twice")
174         }
175         sis.stopped = true
176 }
177
178 func (sis *StubInstanceSet) StubVMs() (svms []*StubVM) {
179         sis.mtx.Lock()
180         defer sis.mtx.Unlock()
181         for _, vm := range sis.servers {
182                 svms = append(svms, vm)
183         }
184         return
185 }
186
187 type RateLimitError struct{ Retry time.Time }
188
189 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
190 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
191
192 // StubVM is a fake server that runs an SSH service. It represents a
193 // VM running in a fake cloud.
194 //
195 // Note this is distinct from a stubInstance, which is a snapshot of
196 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
197 // running (and might change IP addresses, shut down, etc.)  without
198 // updating any stubInstances that have been returned to callers.
199 type StubVM struct {
200         Boot                  time.Time
201         Broken                time.Time
202         ReportBroken          time.Time
203         CrunchRunMissing      bool
204         CrunchRunCrashRate    float64
205         CrunchRunDetachDelay  time.Duration
206         ArvMountMaxExitLag    time.Duration
207         ArvMountDeadlockRate  float64
208         ExecuteContainer      func(arvados.Container) int
209         CrashRunningContainer func(arvados.Container)
210         ExtraCrunchRunArgs    string // extra args expected after "crunch-run --detach --stdin-config "
211
212         // Populated by (*StubInstanceSet)Create()
213         InitCommand cloud.InitCommand
214
215         sis          *StubInstanceSet
216         id           cloud.InstanceID
217         tags         cloud.InstanceTags
218         providerType string
219         SSHService   SSHService
220         running      map[string]stubProcess
221         killing      map[string]bool
222         lastPID      int64
223         deadlocked   string
224         sync.Mutex
225 }
226
227 type stubProcess struct {
228         pid int64
229
230         // crunch-run has exited, but arv-mount process (or something)
231         // still holds lock in /var/run/
232         exited bool
233 }
234
235 func (svm *StubVM) Instance() stubInstance {
236         svm.Lock()
237         defer svm.Unlock()
238         return stubInstance{
239                 svm:  svm,
240                 addr: svm.SSHService.Address(),
241                 // We deliberately return a cached/stale copy of the
242                 // real tags here, so that (Instance)Tags() sometimes
243                 // returns old data after a call to
244                 // (Instance)SetTags().  This is permitted by the
245                 // driver interface, and this might help remind
246                 // callers that they need to tolerate it.
247                 tags: copyTags(svm.tags),
248         }
249 }
250
251 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
252         stdinData, err := ioutil.ReadAll(stdin)
253         if err != nil {
254                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
255                 return 1
256         }
257         queue := svm.sis.driver.Queue
258         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
259         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
260                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
261                 return 1
262         }
263         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
264                 fmt.Fprintf(stderr, "cannot fork\n")
265                 return 2
266         }
267         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
268                 fmt.Fprint(stderr, "crunch-run: command not found\n")
269                 return 1
270         }
271         if strings.HasPrefix(command, "crunch-run --detach --stdin-config "+svm.ExtraCrunchRunArgs) {
272                 var configData crunchrun.ConfigData
273                 err := json.Unmarshal(stdinData, &configData)
274                 if err != nil {
275                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
276                         return 1
277                 }
278                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
279                         if configData.Env[name] == "" {
280                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
281                                 return 1
282                         }
283                 }
284                 svm.Lock()
285                 svm.lastPID++
286                 pid := svm.lastPID
287                 svm.running[uuid] = stubProcess{pid: pid}
288                 svm.Unlock()
289                 time.Sleep(svm.CrunchRunDetachDelay)
290                 fmt.Fprintf(stderr, "starting %s\n", uuid)
291                 logger := svm.sis.logger.WithFields(logrus.Fields{
292                         "Instance":      svm.id,
293                         "ContainerUUID": uuid,
294                         "PID":           pid,
295                 })
296                 logger.Printf("[test] starting crunch-run stub")
297                 go func() {
298                         var ctr arvados.Container
299                         var started, completed bool
300                         defer func() {
301                                 logger.Print("[test] exiting crunch-run stub")
302                                 svm.Lock()
303                                 defer svm.Unlock()
304                                 if svm.running[uuid].pid != pid {
305                                         bugf := svm.sis.driver.Bugf
306                                         if bugf == nil {
307                                                 bugf = logger.Warnf
308                                         }
309                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
310                                         return
311                                 }
312                                 if !completed {
313                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
314                                         if started && svm.CrashRunningContainer != nil {
315                                                 svm.CrashRunningContainer(ctr)
316                                         }
317                                 }
318                                 sproc := svm.running[uuid]
319                                 sproc.exited = true
320                                 svm.running[uuid] = sproc
321                                 svm.Unlock()
322                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
323                                 svm.Lock()
324                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
325                                         delete(svm.running, uuid)
326                                 }
327                         }()
328
329                         crashluck := math_rand.Float64()
330                         wantCrash := crashluck < svm.CrunchRunCrashRate
331                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
332
333                         ctr, ok := queue.Get(uuid)
334                         if !ok {
335                                 logger.Print("[test] container not in queue")
336                                 return
337                         }
338
339                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
340
341                         svm.Lock()
342                         killed := svm.killing[uuid]
343                         svm.Unlock()
344                         if killed || wantCrashEarly {
345                                 return
346                         }
347
348                         ctr.State = arvados.ContainerStateRunning
349                         started = queue.Notify(ctr)
350                         if !started {
351                                 ctr, _ = queue.Get(uuid)
352                                 logger.Print("[test] erroring out because state=Running update was rejected")
353                                 return
354                         }
355
356                         if wantCrash {
357                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
358                                 return
359                         }
360                         if svm.ExecuteContainer != nil {
361                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
362                         }
363                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
364                         ctr.State = arvados.ContainerStateComplete
365                         completed = queue.Notify(ctr)
366                 }()
367                 return 0
368         }
369         if command == "crunch-run --list" {
370                 svm.Lock()
371                 defer svm.Unlock()
372                 for uuid, sproc := range svm.running {
373                         if sproc.exited {
374                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
375                         } else {
376                                 fmt.Fprintf(stdout, "%s\n", uuid)
377                         }
378                 }
379                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
380                         fmt.Fprintln(stdout, "broken")
381                 }
382                 fmt.Fprintln(stdout, svm.deadlocked)
383                 return 0
384         }
385         if strings.HasPrefix(command, "crunch-run --kill ") {
386                 svm.Lock()
387                 sproc, running := svm.running[uuid]
388                 if running && !sproc.exited {
389                         svm.killing[uuid] = true
390                         svm.Unlock()
391                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
392                         svm.Lock()
393                         sproc, running = svm.running[uuid]
394                 }
395                 svm.Unlock()
396                 if running && !sproc.exited {
397                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
398                         return 1
399                 }
400                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
401                 return 0
402         }
403         if command == "true" {
404                 return 0
405         }
406         fmt.Fprintf(stderr, "%q: command not found", command)
407         return 1
408 }
409
410 type stubInstance struct {
411         svm  *StubVM
412         addr string
413         tags cloud.InstanceTags
414 }
415
416 func (si stubInstance) ID() cloud.InstanceID {
417         return si.svm.id
418 }
419
420 func (si stubInstance) Address() string {
421         return si.addr
422 }
423
424 func (si stubInstance) RemoteUser() string {
425         return si.svm.SSHService.AuthorizedUser
426 }
427
428 func (si stubInstance) Destroy() error {
429         sis := si.svm.sis
430         if sis.driver.HoldCloudOps {
431                 sis.driver.holdCloudOps <- true
432         }
433         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
434                 return errors.New("instance could not be destroyed")
435         }
436         si.svm.SSHService.Close()
437         sis.mtx.Lock()
438         defer sis.mtx.Unlock()
439         delete(sis.servers, si.svm.id)
440         return nil
441 }
442
443 func (si stubInstance) ProviderType() string {
444         return si.svm.providerType
445 }
446
447 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
448         tags = copyTags(tags)
449         svm := si.svm
450         go func() {
451                 svm.Lock()
452                 defer svm.Unlock()
453                 svm.tags = tags
454         }()
455         return nil
456 }
457
458 func (si stubInstance) Tags() cloud.InstanceTags {
459         // Return a copy to ensure a caller can't change our saved
460         // tags just by writing to the returned map.
461         return copyTags(si.tags)
462 }
463
464 func (si stubInstance) String() string {
465         return string(si.svm.id)
466 }
467
468 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
469         buf := make([]byte, 512)
470         _, err := io.ReadFull(rand.Reader, buf)
471         if err != nil {
472                 return err
473         }
474         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
475         if err != nil {
476                 return err
477         }
478         return key.Verify(buf, sig)
479 }
480
481 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
482         dst := cloud.InstanceTags{}
483         for k, v := range src {
484                 dst[k] = v
485         }
486         return dst
487 }
488
489 func (si stubInstance) PriceHistory(arvados.InstanceType) []cloud.InstancePrice {
490         return nil
491 }