Merge branch '17074-optimize-itemsavailable' into main. Closes #17074
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/lib/crunchrun"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/prometheus/client_golang/prometheus"
24         "github.com/sirupsen/logrus"
25         "golang.org/x/crypto/ssh"
26 )
27
28 // A StubDriver implements cloud.Driver by setting up local SSH
29 // servers that do fake command executions.
30 type StubDriver struct {
31         HostKey        ssh.Signer
32         AuthorizedKeys []ssh.PublicKey
33
34         // SetupVM, if set, is called upon creation of each new
35         // StubVM. This is the caller's opportunity to customize the
36         // VM's error rate and other behaviors.
37         //
38         // If SetupVM returns an error, that error will be returned to
39         // the caller of Create(), and the new VM will be discarded.
40         SetupVM func(*StubVM) error
41
42         // Bugf, if set, is called if a bug is detected in the caller
43         // or stub. Typically set to (*check.C)Errorf. If unset,
44         // logger.Warnf is called instead.
45         Bugf func(string, ...interface{})
46
47         // StubVM's fake crunch-run uses this Queue to read and update
48         // container state.
49         Queue *Queue
50
51         // Frequency of artificially introduced errors on calls to
52         // Create and Destroy. 0=always succeed, 1=always fail.
53         ErrorRateCreate  float64
54         ErrorRateDestroy float64
55
56         // If Create() or Instances() is called too frequently, return
57         // rate-limiting errors.
58         MinTimeBetweenCreateCalls    time.Duration
59         MinTimeBetweenInstancesCalls time.Duration
60
61         QuotaMaxInstances int
62
63         // If true, Create and Destroy calls block until Release() is
64         // called.
65         HoldCloudOps bool
66
67         instanceSets []*StubInstanceSet
68         holdCloudOps chan bool
69 }
70
71 // InstanceSet returns a new *StubInstanceSet.
72 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (cloud.InstanceSet, error) {
73         if sd.holdCloudOps == nil {
74                 sd.holdCloudOps = make(chan bool)
75         }
76         sis := StubInstanceSet{
77                 driver:  sd,
78                 logger:  logger,
79                 servers: map[cloud.InstanceID]*StubVM{},
80         }
81         sd.instanceSets = append(sd.instanceSets, &sis)
82
83         var err error
84         if params != nil {
85                 err = json.Unmarshal(params, &sis)
86         }
87         return &sis, err
88 }
89
90 // InstanceSets returns all instances that have been created by the
91 // driver. This can be used to test a component that uses the driver
92 // but doesn't expose the InstanceSets it has created.
93 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
94         return sd.instanceSets
95 }
96
97 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
98 // are fewer than n blocked calls pending, it waits for the rest to
99 // arrive.
100 func (sd *StubDriver) ReleaseCloudOps(n int) {
101         for i := 0; i < n; i++ {
102                 <-sd.holdCloudOps
103         }
104 }
105
106 type StubInstanceSet struct {
107         driver  *StubDriver
108         logger  logrus.FieldLogger
109         servers map[cloud.InstanceID]*StubVM
110         mtx     sync.RWMutex
111         stopped bool
112
113         allowCreateCall    time.Time
114         allowInstancesCall time.Time
115         lastInstanceID     int
116 }
117
118 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, initCommand cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
119         if sis.driver.HoldCloudOps {
120                 sis.driver.holdCloudOps <- true
121         }
122         sis.mtx.Lock()
123         defer sis.mtx.Unlock()
124         if sis.stopped {
125                 return nil, errors.New("StubInstanceSet: Create called after Stop")
126         }
127         if sis.allowCreateCall.After(time.Now()) {
128                 return nil, RateLimitError{sis.allowCreateCall}
129         }
130         if math_rand.Float64() < sis.driver.ErrorRateCreate {
131                 return nil, fmt.Errorf("StubInstanceSet: rand < ErrorRateCreate %f", sis.driver.ErrorRateCreate)
132         }
133         if max := sis.driver.QuotaMaxInstances; max > 0 && len(sis.servers) >= max {
134                 return nil, QuotaError{fmt.Errorf("StubInstanceSet: reached QuotaMaxInstances %d", max)}
135         }
136         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
137         ak := sis.driver.AuthorizedKeys
138         if authKey != nil {
139                 ak = append([]ssh.PublicKey{authKey}, ak...)
140         }
141         sis.lastInstanceID++
142         svm := &StubVM{
143                 InitCommand:  initCommand,
144                 sis:          sis,
145                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
146                 tags:         copyTags(tags),
147                 providerType: it.ProviderType,
148                 running:      map[string]stubProcess{},
149                 killing:      map[string]bool{},
150         }
151         svm.SSHService = SSHService{
152                 HostKey:        sis.driver.HostKey,
153                 AuthorizedUser: "root",
154                 AuthorizedKeys: ak,
155                 Exec:           svm.Exec,
156         }
157         if setup := sis.driver.SetupVM; setup != nil {
158                 err := setup(svm)
159                 if err != nil {
160                         return nil, err
161                 }
162         }
163         sis.servers[svm.id] = svm
164         return svm.Instance(), nil
165 }
166
167 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
168         sis.mtx.RLock()
169         defer sis.mtx.RUnlock()
170         if sis.allowInstancesCall.After(time.Now()) {
171                 return nil, RateLimitError{sis.allowInstancesCall}
172         }
173         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
174         var r []cloud.Instance
175         for _, ss := range sis.servers {
176                 r = append(r, ss.Instance())
177         }
178         return r, nil
179 }
180
181 // InstanceQuotaGroup returns the first character of the given
182 // instance's ProviderType.  Use ProviderTypes like "a1", "a2", "b1",
183 // "b2" to test instance quota group behaviors.
184 func (sis *StubInstanceSet) InstanceQuotaGroup(it arvados.InstanceType) cloud.InstanceQuotaGroup {
185         return cloud.InstanceQuotaGroup(it.ProviderType[:1])
186 }
187
188 func (sis *StubInstanceSet) Stop() {
189         sis.mtx.Lock()
190         defer sis.mtx.Unlock()
191         if sis.stopped {
192                 panic("Stop called twice")
193         }
194         sis.stopped = true
195 }
196
197 func (sis *StubInstanceSet) StubVMs() (svms []*StubVM) {
198         sis.mtx.Lock()
199         defer sis.mtx.Unlock()
200         for _, vm := range sis.servers {
201                 svms = append(svms, vm)
202         }
203         return
204 }
205
206 type RateLimitError struct{ Retry time.Time }
207
208 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
209 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
210
211 var _ = cloud.RateLimitError(RateLimitError{}) // assert the interface is satisfied
212
213 type CapacityError struct {
214         InstanceTypeSpecific       bool
215         InstanceQuotaGroupSpecific bool
216 }
217
218 func (e CapacityError) Error() string                      { return "insufficient capacity" }
219 func (e CapacityError) IsCapacityError() bool              { return true }
220 func (e CapacityError) IsInstanceTypeSpecific() bool       { return e.InstanceTypeSpecific }
221 func (e CapacityError) IsInstanceQuotaGroupSpecific() bool { return e.InstanceQuotaGroupSpecific }
222
223 var _ = cloud.CapacityError(CapacityError{}) // assert the interface is satisfied
224
225 // StubVM is a fake server that runs an SSH service. It represents a
226 // VM running in a fake cloud.
227 //
228 // Note this is distinct from a stubInstance, which is a snapshot of
229 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
230 // running (and might change IP addresses, shut down, etc.)  without
231 // updating any stubInstances that have been returned to callers.
232 type StubVM struct {
233         Boot                  time.Time
234         Broken                time.Time
235         ReportBroken          time.Time
236         CrunchRunMissing      bool
237         CrunchRunCrashRate    float64
238         CrunchRunDetachDelay  time.Duration
239         ArvMountMaxExitLag    time.Duration
240         ArvMountDeadlockRate  float64
241         ExecuteContainer      func(arvados.Container) int
242         CrashRunningContainer func(arvados.Container)
243         ExtraCrunchRunArgs    string // extra args expected after "crunch-run --detach --stdin-config "
244
245         // Populated by (*StubInstanceSet)Create()
246         InitCommand cloud.InitCommand
247
248         sis          *StubInstanceSet
249         id           cloud.InstanceID
250         tags         cloud.InstanceTags
251         providerType string
252         SSHService   SSHService
253         running      map[string]stubProcess
254         killing      map[string]bool
255         lastPID      int64
256         deadlocked   string
257         stubprocs    sync.WaitGroup
258         destroying   bool
259         sync.Mutex
260 }
261
262 type stubProcess struct {
263         pid int64
264
265         // crunch-run has exited, but arv-mount process (or something)
266         // still holds lock in /var/run/
267         exited bool
268 }
269
270 func (svm *StubVM) Instance() stubInstance {
271         svm.Lock()
272         defer svm.Unlock()
273         return stubInstance{
274                 svm:  svm,
275                 addr: svm.SSHService.Address(),
276                 // We deliberately return a cached/stale copy of the
277                 // real tags here, so that (Instance)Tags() sometimes
278                 // returns old data after a call to
279                 // (Instance)SetTags().  This is permitted by the
280                 // driver interface, and this might help remind
281                 // callers that they need to tolerate it.
282                 tags: copyTags(svm.tags),
283         }
284 }
285
286 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
287         // Ensure we don't start any new stubprocs after Destroy()
288         // has started Wait()ing for stubprocs to end.
289         svm.Lock()
290         if svm.destroying {
291                 svm.Unlock()
292                 return 1
293         }
294         svm.stubprocs.Add(1)
295         defer svm.stubprocs.Done()
296         svm.Unlock()
297
298         stdinData, err := ioutil.ReadAll(stdin)
299         if err != nil {
300                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
301                 return 1
302         }
303         queue := svm.sis.driver.Queue
304         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
305         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
306                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
307                 return 1
308         }
309         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
310                 fmt.Fprintf(stderr, "cannot fork\n")
311                 return 2
312         }
313         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
314                 fmt.Fprint(stderr, "crunch-run: command not found\n")
315                 return 1
316         }
317         if strings.HasPrefix(command, "crunch-run --detach --stdin-config "+svm.ExtraCrunchRunArgs) {
318                 var configData crunchrun.ConfigData
319                 err := json.Unmarshal(stdinData, &configData)
320                 if err != nil {
321                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
322                         return 1
323                 }
324                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
325                         if configData.Env[name] == "" {
326                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
327                                 return 1
328                         }
329                 }
330                 svm.Lock()
331                 svm.lastPID++
332                 pid := svm.lastPID
333                 svm.running[uuid] = stubProcess{pid: pid}
334                 svm.Unlock()
335
336                 time.Sleep(svm.CrunchRunDetachDelay)
337
338                 svm.Lock()
339                 defer svm.Unlock()
340                 if svm.destroying {
341                         fmt.Fprint(stderr, "crunch-run: killed by system shutdown\n")
342                         return 9
343                 }
344                 fmt.Fprintf(stderr, "starting %s\n", uuid)
345                 logger := svm.sis.logger.WithFields(logrus.Fields{
346                         "Instance":      svm.id,
347                         "ContainerUUID": uuid,
348                         "PID":           pid,
349                 })
350                 logger.Printf("[test] starting crunch-run stub")
351                 svm.stubprocs.Add(1)
352                 go func() {
353                         defer svm.stubprocs.Done()
354                         var ctr arvados.Container
355                         var started, completed bool
356                         defer func() {
357                                 logger.Print("[test] exiting crunch-run stub")
358                                 svm.Lock()
359                                 defer svm.Unlock()
360                                 if svm.destroying {
361                                         return
362                                 }
363                                 if svm.running[uuid].pid != pid {
364                                         bugf := svm.sis.driver.Bugf
365                                         if bugf == nil {
366                                                 bugf = logger.Warnf
367                                         }
368                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
369                                         return
370                                 }
371                                 if !completed {
372                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
373                                         if started && svm.CrashRunningContainer != nil {
374                                                 svm.CrashRunningContainer(ctr)
375                                         }
376                                 }
377                                 sproc := svm.running[uuid]
378                                 sproc.exited = true
379                                 svm.running[uuid] = sproc
380                                 svm.Unlock()
381                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
382                                 svm.Lock()
383                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
384                                         delete(svm.running, uuid)
385                                 }
386                         }()
387
388                         crashluck := math_rand.Float64()
389                         wantCrash := crashluck < svm.CrunchRunCrashRate
390                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
391
392                         ctr, ok := queue.Get(uuid)
393                         if !ok {
394                                 logger.Print("[test] container not in queue")
395                                 return
396                         }
397
398                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
399
400                         svm.Lock()
401                         killed := svm.killing[uuid]
402                         delete(svm.killing, uuid)
403                         destroying := svm.destroying
404                         svm.Unlock()
405                         if killed || wantCrashEarly || destroying {
406                                 return
407                         }
408
409                         ctr.State = arvados.ContainerStateRunning
410                         started = queue.Notify(ctr)
411                         if !started {
412                                 ctr, _ = queue.Get(uuid)
413                                 logger.Print("[test] erroring out because state=Running update was rejected")
414                                 return
415                         }
416
417                         if wantCrash {
418                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
419                                 return
420                         }
421                         if svm.ExecuteContainer != nil {
422                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
423                         }
424                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
425                         ctr.State = arvados.ContainerStateComplete
426                         completed = queue.Notify(ctr)
427                 }()
428                 return 0
429         }
430         if command == "crunch-run --list" {
431                 svm.Lock()
432                 defer svm.Unlock()
433                 for uuid, sproc := range svm.running {
434                         if sproc.exited {
435                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
436                         } else {
437                                 fmt.Fprintf(stdout, "%s\n", uuid)
438                         }
439                 }
440                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
441                         fmt.Fprintln(stdout, "broken")
442                 }
443                 fmt.Fprintln(stdout, svm.deadlocked)
444                 return 0
445         }
446         if strings.HasPrefix(command, "crunch-run --kill ") {
447                 svm.Lock()
448                 sproc, running := svm.running[uuid]
449                 if running && !sproc.exited {
450                         svm.killing[uuid] = true
451                         svm.Unlock()
452                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
453                         svm.Lock()
454                         sproc, running = svm.running[uuid]
455                 }
456                 svm.Unlock()
457                 if running && !sproc.exited {
458                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
459                         return 1
460                 }
461                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
462                 return 0
463         }
464         if command == "true" {
465                 return 0
466         }
467         fmt.Fprintf(stderr, "%q: command not found", command)
468         return 1
469 }
470
471 type stubInstance struct {
472         svm  *StubVM
473         addr string
474         tags cloud.InstanceTags
475 }
476
477 func (si stubInstance) ID() cloud.InstanceID {
478         return si.svm.id
479 }
480
481 func (si stubInstance) Address() string {
482         return si.addr
483 }
484
485 func (si stubInstance) RemoteUser() string {
486         return si.svm.SSHService.AuthorizedUser
487 }
488
489 func (si stubInstance) Destroy() error {
490         sis := si.svm.sis
491         if sis.driver.HoldCloudOps {
492                 sis.driver.holdCloudOps <- true
493         }
494         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
495                 return errors.New("instance could not be destroyed")
496         }
497         si.svm.Lock()
498         si.svm.destroying = true
499         si.svm.Unlock()
500         si.svm.stubprocs.Wait()
501         si.svm.SSHService.Close()
502         sis.mtx.Lock()
503         defer sis.mtx.Unlock()
504         delete(sis.servers, si.svm.id)
505         return nil
506 }
507
508 func (si stubInstance) ProviderType() string {
509         return si.svm.providerType
510 }
511
512 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
513         tags = copyTags(tags)
514         svm := si.svm
515         go func() {
516                 svm.Lock()
517                 defer svm.Unlock()
518                 svm.tags = tags
519         }()
520         return nil
521 }
522
523 func (si stubInstance) Tags() cloud.InstanceTags {
524         // Return a copy to ensure a caller can't change our saved
525         // tags just by writing to the returned map.
526         return copyTags(si.tags)
527 }
528
529 func (si stubInstance) String() string {
530         return string(si.svm.id)
531 }
532
533 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
534         buf := make([]byte, 512)
535         _, err := io.ReadFull(rand.Reader, buf)
536         if err != nil {
537                 return err
538         }
539         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
540         if err != nil {
541                 return err
542         }
543         return key.Verify(buf, sig)
544 }
545
546 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
547         dst := cloud.InstanceTags{}
548         for k, v := range src {
549                 dst[k] = v
550         }
551         return dst
552 }
553
554 func (si stubInstance) PriceHistory(arvados.InstanceType) []cloud.InstancePrice {
555         return nil
556 }
557
558 type QuotaError struct {
559         error
560 }
561
562 func (QuotaError) IsQuotaError() bool { return true }