20846: Merge branch 'main' into 20846-ruby3
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/lib/crunchrun"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/prometheus/client_golang/prometheus"
24         "github.com/sirupsen/logrus"
25         "golang.org/x/crypto/ssh"
26 )
27
28 // A StubDriver implements cloud.Driver by setting up local SSH
29 // servers that do fake command executions.
30 type StubDriver struct {
31         HostKey        ssh.Signer
32         AuthorizedKeys []ssh.PublicKey
33
34         // SetupVM, if set, is called upon creation of each new
35         // StubVM. This is the caller's opportunity to customize the
36         // VM's error rate and other behaviors.
37         //
38         // If SetupVM returns an error, that error will be returned to
39         // the caller of Create(), and the new VM will be discarded.
40         SetupVM func(*StubVM) error
41
42         // Bugf, if set, is called if a bug is detected in the caller
43         // or stub. Typically set to (*check.C)Errorf. If unset,
44         // logger.Warnf is called instead.
45         Bugf func(string, ...interface{})
46
47         // StubVM's fake crunch-run uses this Queue to read and update
48         // container state.
49         Queue *Queue
50
51         // Frequency of artificially introduced errors on calls to
52         // Create and Destroy. 0=always succeed, 1=always fail.
53         ErrorRateCreate  float64
54         ErrorRateDestroy float64
55
56         // If Create() or Instances() is called too frequently, return
57         // rate-limiting errors.
58         MinTimeBetweenCreateCalls    time.Duration
59         MinTimeBetweenInstancesCalls time.Duration
60
61         QuotaMaxInstances int
62
63         // If true, Create and Destroy calls block until Release() is
64         // called.
65         HoldCloudOps bool
66
67         instanceSets []*StubInstanceSet
68         holdCloudOps chan bool
69 }
70
71 // InstanceSet returns a new *StubInstanceSet.
72 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (cloud.InstanceSet, error) {
73         if sd.holdCloudOps == nil {
74                 sd.holdCloudOps = make(chan bool)
75         }
76         sis := StubInstanceSet{
77                 driver:  sd,
78                 logger:  logger,
79                 servers: map[cloud.InstanceID]*StubVM{},
80         }
81         sd.instanceSets = append(sd.instanceSets, &sis)
82
83         var err error
84         if params != nil {
85                 err = json.Unmarshal(params, &sis)
86         }
87         return &sis, err
88 }
89
90 // InstanceSets returns all instances that have been created by the
91 // driver. This can be used to test a component that uses the driver
92 // but doesn't expose the InstanceSets it has created.
93 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
94         return sd.instanceSets
95 }
96
97 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
98 // are fewer than n blocked calls pending, it waits for the rest to
99 // arrive.
100 func (sd *StubDriver) ReleaseCloudOps(n int) {
101         for i := 0; i < n; i++ {
102                 <-sd.holdCloudOps
103         }
104 }
105
106 type StubInstanceSet struct {
107         driver  *StubDriver
108         logger  logrus.FieldLogger
109         servers map[cloud.InstanceID]*StubVM
110         mtx     sync.RWMutex
111         stopped bool
112
113         allowCreateCall    time.Time
114         allowInstancesCall time.Time
115         lastInstanceID     int
116 }
117
118 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, initCommand cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
119         if sis.driver.HoldCloudOps {
120                 sis.driver.holdCloudOps <- true
121         }
122         sis.mtx.Lock()
123         defer sis.mtx.Unlock()
124         if sis.stopped {
125                 return nil, errors.New("StubInstanceSet: Create called after Stop")
126         }
127         if sis.allowCreateCall.After(time.Now()) {
128                 return nil, RateLimitError{sis.allowCreateCall}
129         }
130         if math_rand.Float64() < sis.driver.ErrorRateCreate {
131                 return nil, fmt.Errorf("StubInstanceSet: rand < ErrorRateCreate %f", sis.driver.ErrorRateCreate)
132         }
133         if max := sis.driver.QuotaMaxInstances; max > 0 && len(sis.servers) >= max {
134                 return nil, QuotaError{fmt.Errorf("StubInstanceSet: reached QuotaMaxInstances %d", max)}
135         }
136         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
137         ak := sis.driver.AuthorizedKeys
138         if authKey != nil {
139                 ak = append([]ssh.PublicKey{authKey}, ak...)
140         }
141         sis.lastInstanceID++
142         svm := &StubVM{
143                 InitCommand:  initCommand,
144                 sis:          sis,
145                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
146                 tags:         copyTags(tags),
147                 providerType: it.ProviderType,
148                 running:      map[string]stubProcess{},
149                 killing:      map[string]bool{},
150         }
151         svm.SSHService = SSHService{
152                 HostKey:        sis.driver.HostKey,
153                 AuthorizedUser: "root",
154                 AuthorizedKeys: ak,
155                 Exec:           svm.Exec,
156         }
157         if setup := sis.driver.SetupVM; setup != nil {
158                 err := setup(svm)
159                 if err != nil {
160                         return nil, err
161                 }
162         }
163         sis.servers[svm.id] = svm
164         return svm.Instance(), nil
165 }
166
167 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
168         sis.mtx.RLock()
169         defer sis.mtx.RUnlock()
170         if sis.allowInstancesCall.After(time.Now()) {
171                 return nil, RateLimitError{sis.allowInstancesCall}
172         }
173         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
174         var r []cloud.Instance
175         for _, ss := range sis.servers {
176                 r = append(r, ss.Instance())
177         }
178         return r, nil
179 }
180
181 func (sis *StubInstanceSet) Stop() {
182         sis.mtx.Lock()
183         defer sis.mtx.Unlock()
184         if sis.stopped {
185                 panic("Stop called twice")
186         }
187         sis.stopped = true
188 }
189
190 func (sis *StubInstanceSet) StubVMs() (svms []*StubVM) {
191         sis.mtx.Lock()
192         defer sis.mtx.Unlock()
193         for _, vm := range sis.servers {
194                 svms = append(svms, vm)
195         }
196         return
197 }
198
199 type RateLimitError struct{ Retry time.Time }
200
201 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
202 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
203
204 type CapacityError struct{ InstanceTypeSpecific bool }
205
206 func (e CapacityError) Error() string                { return "insufficient capacity" }
207 func (e CapacityError) IsCapacityError() bool        { return true }
208 func (e CapacityError) IsInstanceTypeSpecific() bool { return e.InstanceTypeSpecific }
209
210 // StubVM is a fake server that runs an SSH service. It represents a
211 // VM running in a fake cloud.
212 //
213 // Note this is distinct from a stubInstance, which is a snapshot of
214 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
215 // running (and might change IP addresses, shut down, etc.)  without
216 // updating any stubInstances that have been returned to callers.
217 type StubVM struct {
218         Boot                  time.Time
219         Broken                time.Time
220         ReportBroken          time.Time
221         CrunchRunMissing      bool
222         CrunchRunCrashRate    float64
223         CrunchRunDetachDelay  time.Duration
224         ArvMountMaxExitLag    time.Duration
225         ArvMountDeadlockRate  float64
226         ExecuteContainer      func(arvados.Container) int
227         CrashRunningContainer func(arvados.Container)
228         ExtraCrunchRunArgs    string // extra args expected after "crunch-run --detach --stdin-config "
229
230         // Populated by (*StubInstanceSet)Create()
231         InitCommand cloud.InitCommand
232
233         sis          *StubInstanceSet
234         id           cloud.InstanceID
235         tags         cloud.InstanceTags
236         providerType string
237         SSHService   SSHService
238         running      map[string]stubProcess
239         killing      map[string]bool
240         lastPID      int64
241         deadlocked   string
242         sync.Mutex
243 }
244
245 type stubProcess struct {
246         pid int64
247
248         // crunch-run has exited, but arv-mount process (or something)
249         // still holds lock in /var/run/
250         exited bool
251 }
252
253 func (svm *StubVM) Instance() stubInstance {
254         svm.Lock()
255         defer svm.Unlock()
256         return stubInstance{
257                 svm:  svm,
258                 addr: svm.SSHService.Address(),
259                 // We deliberately return a cached/stale copy of the
260                 // real tags here, so that (Instance)Tags() sometimes
261                 // returns old data after a call to
262                 // (Instance)SetTags().  This is permitted by the
263                 // driver interface, and this might help remind
264                 // callers that they need to tolerate it.
265                 tags: copyTags(svm.tags),
266         }
267 }
268
269 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
270         stdinData, err := ioutil.ReadAll(stdin)
271         if err != nil {
272                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
273                 return 1
274         }
275         queue := svm.sis.driver.Queue
276         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
277         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
278                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
279                 return 1
280         }
281         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
282                 fmt.Fprintf(stderr, "cannot fork\n")
283                 return 2
284         }
285         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
286                 fmt.Fprint(stderr, "crunch-run: command not found\n")
287                 return 1
288         }
289         if strings.HasPrefix(command, "crunch-run --detach --stdin-config "+svm.ExtraCrunchRunArgs) {
290                 var configData crunchrun.ConfigData
291                 err := json.Unmarshal(stdinData, &configData)
292                 if err != nil {
293                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
294                         return 1
295                 }
296                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
297                         if configData.Env[name] == "" {
298                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
299                                 return 1
300                         }
301                 }
302                 svm.Lock()
303                 svm.lastPID++
304                 pid := svm.lastPID
305                 svm.running[uuid] = stubProcess{pid: pid}
306                 svm.Unlock()
307                 time.Sleep(svm.CrunchRunDetachDelay)
308                 fmt.Fprintf(stderr, "starting %s\n", uuid)
309                 logger := svm.sis.logger.WithFields(logrus.Fields{
310                         "Instance":      svm.id,
311                         "ContainerUUID": uuid,
312                         "PID":           pid,
313                 })
314                 logger.Printf("[test] starting crunch-run stub")
315                 go func() {
316                         var ctr arvados.Container
317                         var started, completed bool
318                         defer func() {
319                                 logger.Print("[test] exiting crunch-run stub")
320                                 svm.Lock()
321                                 defer svm.Unlock()
322                                 if svm.running[uuid].pid != pid {
323                                         bugf := svm.sis.driver.Bugf
324                                         if bugf == nil {
325                                                 bugf = logger.Warnf
326                                         }
327                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
328                                         return
329                                 }
330                                 if !completed {
331                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
332                                         if started && svm.CrashRunningContainer != nil {
333                                                 svm.CrashRunningContainer(ctr)
334                                         }
335                                 }
336                                 sproc := svm.running[uuid]
337                                 sproc.exited = true
338                                 svm.running[uuid] = sproc
339                                 svm.Unlock()
340                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
341                                 svm.Lock()
342                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
343                                         delete(svm.running, uuid)
344                                 }
345                         }()
346
347                         crashluck := math_rand.Float64()
348                         wantCrash := crashluck < svm.CrunchRunCrashRate
349                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
350
351                         ctr, ok := queue.Get(uuid)
352                         if !ok {
353                                 logger.Print("[test] container not in queue")
354                                 return
355                         }
356
357                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
358
359                         svm.Lock()
360                         killed := svm.killing[uuid]
361                         svm.Unlock()
362                         if killed || wantCrashEarly {
363                                 return
364                         }
365
366                         ctr.State = arvados.ContainerStateRunning
367                         started = queue.Notify(ctr)
368                         if !started {
369                                 ctr, _ = queue.Get(uuid)
370                                 logger.Print("[test] erroring out because state=Running update was rejected")
371                                 return
372                         }
373
374                         if wantCrash {
375                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
376                                 return
377                         }
378                         if svm.ExecuteContainer != nil {
379                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
380                         }
381                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
382                         ctr.State = arvados.ContainerStateComplete
383                         completed = queue.Notify(ctr)
384                 }()
385                 return 0
386         }
387         if command == "crunch-run --list" {
388                 svm.Lock()
389                 defer svm.Unlock()
390                 for uuid, sproc := range svm.running {
391                         if sproc.exited {
392                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
393                         } else {
394                                 fmt.Fprintf(stdout, "%s\n", uuid)
395                         }
396                 }
397                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
398                         fmt.Fprintln(stdout, "broken")
399                 }
400                 fmt.Fprintln(stdout, svm.deadlocked)
401                 return 0
402         }
403         if strings.HasPrefix(command, "crunch-run --kill ") {
404                 svm.Lock()
405                 sproc, running := svm.running[uuid]
406                 if running && !sproc.exited {
407                         svm.killing[uuid] = true
408                         svm.Unlock()
409                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
410                         svm.Lock()
411                         sproc, running = svm.running[uuid]
412                 }
413                 svm.Unlock()
414                 if running && !sproc.exited {
415                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
416                         return 1
417                 }
418                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
419                 return 0
420         }
421         if command == "true" {
422                 return 0
423         }
424         fmt.Fprintf(stderr, "%q: command not found", command)
425         return 1
426 }
427
428 type stubInstance struct {
429         svm  *StubVM
430         addr string
431         tags cloud.InstanceTags
432 }
433
434 func (si stubInstance) ID() cloud.InstanceID {
435         return si.svm.id
436 }
437
438 func (si stubInstance) Address() string {
439         return si.addr
440 }
441
442 func (si stubInstance) RemoteUser() string {
443         return si.svm.SSHService.AuthorizedUser
444 }
445
446 func (si stubInstance) Destroy() error {
447         sis := si.svm.sis
448         if sis.driver.HoldCloudOps {
449                 sis.driver.holdCloudOps <- true
450         }
451         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
452                 return errors.New("instance could not be destroyed")
453         }
454         si.svm.SSHService.Close()
455         sis.mtx.Lock()
456         defer sis.mtx.Unlock()
457         delete(sis.servers, si.svm.id)
458         return nil
459 }
460
461 func (si stubInstance) ProviderType() string {
462         return si.svm.providerType
463 }
464
465 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
466         tags = copyTags(tags)
467         svm := si.svm
468         go func() {
469                 svm.Lock()
470                 defer svm.Unlock()
471                 svm.tags = tags
472         }()
473         return nil
474 }
475
476 func (si stubInstance) Tags() cloud.InstanceTags {
477         // Return a copy to ensure a caller can't change our saved
478         // tags just by writing to the returned map.
479         return copyTags(si.tags)
480 }
481
482 func (si stubInstance) String() string {
483         return string(si.svm.id)
484 }
485
486 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
487         buf := make([]byte, 512)
488         _, err := io.ReadFull(rand.Reader, buf)
489         if err != nil {
490                 return err
491         }
492         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
493         if err != nil {
494                 return err
495         }
496         return key.Verify(buf, sig)
497 }
498
499 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
500         dst := cloud.InstanceTags{}
501         for k, v := range src {
502                 dst[k] = v
503         }
504         return dst
505 }
506
507 func (si stubInstance) PriceHistory(arvados.InstanceType) []cloud.InstancePrice {
508         return nil
509 }
510
511 type QuotaError struct {
512         error
513 }
514
515 func (QuotaError) IsQuotaError() bool { return true }