Merge branch 'main' from workbench2.git
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/lib/crunchrun"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/prometheus/client_golang/prometheus"
24         "github.com/sirupsen/logrus"
25         "golang.org/x/crypto/ssh"
26 )
27
28 // A StubDriver implements cloud.Driver by setting up local SSH
29 // servers that do fake command executions.
30 type StubDriver struct {
31         HostKey        ssh.Signer
32         AuthorizedKeys []ssh.PublicKey
33
34         // SetupVM, if set, is called upon creation of each new
35         // StubVM. This is the caller's opportunity to customize the
36         // VM's error rate and other behaviors.
37         SetupVM func(*StubVM)
38
39         // Bugf, if set, is called if a bug is detected in the caller
40         // or stub. Typically set to (*check.C)Errorf. If unset,
41         // logger.Warnf is called instead.
42         Bugf func(string, ...interface{})
43
44         // StubVM's fake crunch-run uses this Queue to read and update
45         // container state.
46         Queue *Queue
47
48         // Frequency of artificially introduced errors on calls to
49         // Create and Destroy. 0=always succeed, 1=always fail.
50         ErrorRateCreate  float64
51         ErrorRateDestroy float64
52
53         // If Create() or Instances() is called too frequently, return
54         // rate-limiting errors.
55         MinTimeBetweenCreateCalls    time.Duration
56         MinTimeBetweenInstancesCalls time.Duration
57
58         QuotaMaxInstances int
59
60         // If true, Create and Destroy calls block until Release() is
61         // called.
62         HoldCloudOps bool
63
64         instanceSets []*StubInstanceSet
65         holdCloudOps chan bool
66 }
67
68 // InstanceSet returns a new *StubInstanceSet.
69 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (cloud.InstanceSet, error) {
70         if sd.holdCloudOps == nil {
71                 sd.holdCloudOps = make(chan bool)
72         }
73         sis := StubInstanceSet{
74                 driver:  sd,
75                 logger:  logger,
76                 servers: map[cloud.InstanceID]*StubVM{},
77         }
78         sd.instanceSets = append(sd.instanceSets, &sis)
79
80         var err error
81         if params != nil {
82                 err = json.Unmarshal(params, &sis)
83         }
84         return &sis, err
85 }
86
87 // InstanceSets returns all instances that have been created by the
88 // driver. This can be used to test a component that uses the driver
89 // but doesn't expose the InstanceSets it has created.
90 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
91         return sd.instanceSets
92 }
93
94 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
95 // are fewer than n blocked calls pending, it waits for the rest to
96 // arrive.
97 func (sd *StubDriver) ReleaseCloudOps(n int) {
98         for i := 0; i < n; i++ {
99                 <-sd.holdCloudOps
100         }
101 }
102
103 type StubInstanceSet struct {
104         driver  *StubDriver
105         logger  logrus.FieldLogger
106         servers map[cloud.InstanceID]*StubVM
107         mtx     sync.RWMutex
108         stopped bool
109
110         allowCreateCall    time.Time
111         allowInstancesCall time.Time
112         lastInstanceID     int
113 }
114
115 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, initCommand cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
116         if sis.driver.HoldCloudOps {
117                 sis.driver.holdCloudOps <- true
118         }
119         sis.mtx.Lock()
120         defer sis.mtx.Unlock()
121         if sis.stopped {
122                 return nil, errors.New("StubInstanceSet: Create called after Stop")
123         }
124         if sis.allowCreateCall.After(time.Now()) {
125                 return nil, RateLimitError{sis.allowCreateCall}
126         }
127         if math_rand.Float64() < sis.driver.ErrorRateCreate {
128                 return nil, fmt.Errorf("StubInstanceSet: rand < ErrorRateCreate %f", sis.driver.ErrorRateCreate)
129         }
130         if max := sis.driver.QuotaMaxInstances; max > 0 && len(sis.servers) >= max {
131                 return nil, QuotaError{fmt.Errorf("StubInstanceSet: reached QuotaMaxInstances %d", max)}
132         }
133         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
134         ak := sis.driver.AuthorizedKeys
135         if authKey != nil {
136                 ak = append([]ssh.PublicKey{authKey}, ak...)
137         }
138         sis.lastInstanceID++
139         svm := &StubVM{
140                 InitCommand:  initCommand,
141                 sis:          sis,
142                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
143                 tags:         copyTags(tags),
144                 providerType: it.ProviderType,
145                 running:      map[string]stubProcess{},
146                 killing:      map[string]bool{},
147         }
148         svm.SSHService = SSHService{
149                 HostKey:        sis.driver.HostKey,
150                 AuthorizedUser: "root",
151                 AuthorizedKeys: ak,
152                 Exec:           svm.Exec,
153         }
154         if setup := sis.driver.SetupVM; setup != nil {
155                 setup(svm)
156         }
157         sis.servers[svm.id] = svm
158         return svm.Instance(), nil
159 }
160
161 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
162         sis.mtx.RLock()
163         defer sis.mtx.RUnlock()
164         if sis.allowInstancesCall.After(time.Now()) {
165                 return nil, RateLimitError{sis.allowInstancesCall}
166         }
167         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
168         var r []cloud.Instance
169         for _, ss := range sis.servers {
170                 r = append(r, ss.Instance())
171         }
172         return r, nil
173 }
174
175 func (sis *StubInstanceSet) Stop() {
176         sis.mtx.Lock()
177         defer sis.mtx.Unlock()
178         if sis.stopped {
179                 panic("Stop called twice")
180         }
181         sis.stopped = true
182 }
183
184 func (sis *StubInstanceSet) StubVMs() (svms []*StubVM) {
185         sis.mtx.Lock()
186         defer sis.mtx.Unlock()
187         for _, vm := range sis.servers {
188                 svms = append(svms, vm)
189         }
190         return
191 }
192
193 type RateLimitError struct{ Retry time.Time }
194
195 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
196 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
197
198 // StubVM is a fake server that runs an SSH service. It represents a
199 // VM running in a fake cloud.
200 //
201 // Note this is distinct from a stubInstance, which is a snapshot of
202 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
203 // running (and might change IP addresses, shut down, etc.)  without
204 // updating any stubInstances that have been returned to callers.
205 type StubVM struct {
206         Boot                  time.Time
207         Broken                time.Time
208         ReportBroken          time.Time
209         CrunchRunMissing      bool
210         CrunchRunCrashRate    float64
211         CrunchRunDetachDelay  time.Duration
212         ArvMountMaxExitLag    time.Duration
213         ArvMountDeadlockRate  float64
214         ExecuteContainer      func(arvados.Container) int
215         CrashRunningContainer func(arvados.Container)
216         ExtraCrunchRunArgs    string // extra args expected after "crunch-run --detach --stdin-config "
217
218         // Populated by (*StubInstanceSet)Create()
219         InitCommand cloud.InitCommand
220
221         sis          *StubInstanceSet
222         id           cloud.InstanceID
223         tags         cloud.InstanceTags
224         providerType string
225         SSHService   SSHService
226         running      map[string]stubProcess
227         killing      map[string]bool
228         lastPID      int64
229         deadlocked   string
230         sync.Mutex
231 }
232
233 type stubProcess struct {
234         pid int64
235
236         // crunch-run has exited, but arv-mount process (or something)
237         // still holds lock in /var/run/
238         exited bool
239 }
240
241 func (svm *StubVM) Instance() stubInstance {
242         svm.Lock()
243         defer svm.Unlock()
244         return stubInstance{
245                 svm:  svm,
246                 addr: svm.SSHService.Address(),
247                 // We deliberately return a cached/stale copy of the
248                 // real tags here, so that (Instance)Tags() sometimes
249                 // returns old data after a call to
250                 // (Instance)SetTags().  This is permitted by the
251                 // driver interface, and this might help remind
252                 // callers that they need to tolerate it.
253                 tags: copyTags(svm.tags),
254         }
255 }
256
257 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
258         stdinData, err := ioutil.ReadAll(stdin)
259         if err != nil {
260                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
261                 return 1
262         }
263         queue := svm.sis.driver.Queue
264         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
265         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
266                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
267                 return 1
268         }
269         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
270                 fmt.Fprintf(stderr, "cannot fork\n")
271                 return 2
272         }
273         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
274                 fmt.Fprint(stderr, "crunch-run: command not found\n")
275                 return 1
276         }
277         if strings.HasPrefix(command, "crunch-run --detach --stdin-config "+svm.ExtraCrunchRunArgs) {
278                 var configData crunchrun.ConfigData
279                 err := json.Unmarshal(stdinData, &configData)
280                 if err != nil {
281                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
282                         return 1
283                 }
284                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
285                         if configData.Env[name] == "" {
286                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
287                                 return 1
288                         }
289                 }
290                 svm.Lock()
291                 svm.lastPID++
292                 pid := svm.lastPID
293                 svm.running[uuid] = stubProcess{pid: pid}
294                 svm.Unlock()
295                 time.Sleep(svm.CrunchRunDetachDelay)
296                 fmt.Fprintf(stderr, "starting %s\n", uuid)
297                 logger := svm.sis.logger.WithFields(logrus.Fields{
298                         "Instance":      svm.id,
299                         "ContainerUUID": uuid,
300                         "PID":           pid,
301                 })
302                 logger.Printf("[test] starting crunch-run stub")
303                 go func() {
304                         var ctr arvados.Container
305                         var started, completed bool
306                         defer func() {
307                                 logger.Print("[test] exiting crunch-run stub")
308                                 svm.Lock()
309                                 defer svm.Unlock()
310                                 if svm.running[uuid].pid != pid {
311                                         bugf := svm.sis.driver.Bugf
312                                         if bugf == nil {
313                                                 bugf = logger.Warnf
314                                         }
315                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
316                                         return
317                                 }
318                                 if !completed {
319                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
320                                         if started && svm.CrashRunningContainer != nil {
321                                                 svm.CrashRunningContainer(ctr)
322                                         }
323                                 }
324                                 sproc := svm.running[uuid]
325                                 sproc.exited = true
326                                 svm.running[uuid] = sproc
327                                 svm.Unlock()
328                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
329                                 svm.Lock()
330                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
331                                         delete(svm.running, uuid)
332                                 }
333                         }()
334
335                         crashluck := math_rand.Float64()
336                         wantCrash := crashluck < svm.CrunchRunCrashRate
337                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
338
339                         ctr, ok := queue.Get(uuid)
340                         if !ok {
341                                 logger.Print("[test] container not in queue")
342                                 return
343                         }
344
345                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
346
347                         svm.Lock()
348                         killed := svm.killing[uuid]
349                         svm.Unlock()
350                         if killed || wantCrashEarly {
351                                 return
352                         }
353
354                         ctr.State = arvados.ContainerStateRunning
355                         started = queue.Notify(ctr)
356                         if !started {
357                                 ctr, _ = queue.Get(uuid)
358                                 logger.Print("[test] erroring out because state=Running update was rejected")
359                                 return
360                         }
361
362                         if wantCrash {
363                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
364                                 return
365                         }
366                         if svm.ExecuteContainer != nil {
367                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
368                         }
369                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
370                         ctr.State = arvados.ContainerStateComplete
371                         completed = queue.Notify(ctr)
372                 }()
373                 return 0
374         }
375         if command == "crunch-run --list" {
376                 svm.Lock()
377                 defer svm.Unlock()
378                 for uuid, sproc := range svm.running {
379                         if sproc.exited {
380                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
381                         } else {
382                                 fmt.Fprintf(stdout, "%s\n", uuid)
383                         }
384                 }
385                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
386                         fmt.Fprintln(stdout, "broken")
387                 }
388                 fmt.Fprintln(stdout, svm.deadlocked)
389                 return 0
390         }
391         if strings.HasPrefix(command, "crunch-run --kill ") {
392                 svm.Lock()
393                 sproc, running := svm.running[uuid]
394                 if running && !sproc.exited {
395                         svm.killing[uuid] = true
396                         svm.Unlock()
397                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
398                         svm.Lock()
399                         sproc, running = svm.running[uuid]
400                 }
401                 svm.Unlock()
402                 if running && !sproc.exited {
403                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
404                         return 1
405                 }
406                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
407                 return 0
408         }
409         if command == "true" {
410                 return 0
411         }
412         fmt.Fprintf(stderr, "%q: command not found", command)
413         return 1
414 }
415
416 type stubInstance struct {
417         svm  *StubVM
418         addr string
419         tags cloud.InstanceTags
420 }
421
422 func (si stubInstance) ID() cloud.InstanceID {
423         return si.svm.id
424 }
425
426 func (si stubInstance) Address() string {
427         return si.addr
428 }
429
430 func (si stubInstance) RemoteUser() string {
431         return si.svm.SSHService.AuthorizedUser
432 }
433
434 func (si stubInstance) Destroy() error {
435         sis := si.svm.sis
436         if sis.driver.HoldCloudOps {
437                 sis.driver.holdCloudOps <- true
438         }
439         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
440                 return errors.New("instance could not be destroyed")
441         }
442         si.svm.SSHService.Close()
443         sis.mtx.Lock()
444         defer sis.mtx.Unlock()
445         delete(sis.servers, si.svm.id)
446         return nil
447 }
448
449 func (si stubInstance) ProviderType() string {
450         return si.svm.providerType
451 }
452
453 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
454         tags = copyTags(tags)
455         svm := si.svm
456         go func() {
457                 svm.Lock()
458                 defer svm.Unlock()
459                 svm.tags = tags
460         }()
461         return nil
462 }
463
464 func (si stubInstance) Tags() cloud.InstanceTags {
465         // Return a copy to ensure a caller can't change our saved
466         // tags just by writing to the returned map.
467         return copyTags(si.tags)
468 }
469
470 func (si stubInstance) String() string {
471         return string(si.svm.id)
472 }
473
474 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
475         buf := make([]byte, 512)
476         _, err := io.ReadFull(rand.Reader, buf)
477         if err != nil {
478                 return err
479         }
480         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
481         if err != nil {
482                 return err
483         }
484         return key.Verify(buf, sig)
485 }
486
487 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
488         dst := cloud.InstanceTags{}
489         for k, v := range src {
490                 dst[k] = v
491         }
492         return dst
493 }
494
495 func (si stubInstance) PriceHistory(arvados.InstanceType) []cloud.InstancePrice {
496         return nil
497 }
498
499 type QuotaError struct {
500         error
501 }
502
503 func (QuotaError) IsQuotaError() bool { return true }