Fix 2.4.2 upgrade notes formatting refs #19330
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/lib/crunchrun"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/sirupsen/logrus"
24         "golang.org/x/crypto/ssh"
25 )
26
27 // A StubDriver implements cloud.Driver by setting up local SSH
28 // servers that do fake command executions.
29 type StubDriver struct {
30         HostKey        ssh.Signer
31         AuthorizedKeys []ssh.PublicKey
32
33         // SetupVM, if set, is called upon creation of each new
34         // StubVM. This is the caller's opportunity to customize the
35         // VM's error rate and other behaviors.
36         SetupVM func(*StubVM)
37
38         // Bugf, if set, is called if a bug is detected in the caller
39         // or stub. Typically set to (*check.C)Errorf. If unset,
40         // logger.Warnf is called instead.
41         Bugf func(string, ...interface{})
42
43         // StubVM's fake crunch-run uses this Queue to read and update
44         // container state.
45         Queue *Queue
46
47         // Frequency of artificially introduced errors on calls to
48         // Destroy. 0=always succeed, 1=always fail.
49         ErrorRateDestroy float64
50
51         // If Create() or Instances() is called too frequently, return
52         // rate-limiting errors.
53         MinTimeBetweenCreateCalls    time.Duration
54         MinTimeBetweenInstancesCalls time.Duration
55
56         // If true, Create and Destroy calls block until Release() is
57         // called.
58         HoldCloudOps bool
59
60         instanceSets []*StubInstanceSet
61         holdCloudOps chan bool
62 }
63
64 // InstanceSet returns a new *StubInstanceSet.
65 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
66         if sd.holdCloudOps == nil {
67                 sd.holdCloudOps = make(chan bool)
68         }
69         sis := StubInstanceSet{
70                 driver:  sd,
71                 logger:  logger,
72                 servers: map[cloud.InstanceID]*StubVM{},
73         }
74         sd.instanceSets = append(sd.instanceSets, &sis)
75
76         var err error
77         if params != nil {
78                 err = json.Unmarshal(params, &sis)
79         }
80         return &sis, err
81 }
82
83 // InstanceSets returns all instances that have been created by the
84 // driver. This can be used to test a component that uses the driver
85 // but doesn't expose the InstanceSets it has created.
86 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
87         return sd.instanceSets
88 }
89
90 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
91 // are fewer than n blocked calls pending, it waits for the rest to
92 // arrive.
93 func (sd *StubDriver) ReleaseCloudOps(n int) {
94         for i := 0; i < n; i++ {
95                 <-sd.holdCloudOps
96         }
97 }
98
99 type StubInstanceSet struct {
100         driver  *StubDriver
101         logger  logrus.FieldLogger
102         servers map[cloud.InstanceID]*StubVM
103         mtx     sync.RWMutex
104         stopped bool
105
106         allowCreateCall    time.Time
107         allowInstancesCall time.Time
108         lastInstanceID     int
109 }
110
111 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
112         if sis.driver.HoldCloudOps {
113                 sis.driver.holdCloudOps <- true
114         }
115         sis.mtx.Lock()
116         defer sis.mtx.Unlock()
117         if sis.stopped {
118                 return nil, errors.New("StubInstanceSet: Create called after Stop")
119         }
120         if sis.allowCreateCall.After(time.Now()) {
121                 return nil, RateLimitError{sis.allowCreateCall}
122         }
123         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
124         ak := sis.driver.AuthorizedKeys
125         if authKey != nil {
126                 ak = append([]ssh.PublicKey{authKey}, ak...)
127         }
128         sis.lastInstanceID++
129         svm := &StubVM{
130                 sis:          sis,
131                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
132                 tags:         copyTags(tags),
133                 providerType: it.ProviderType,
134                 initCommand:  cmd,
135                 running:      map[string]stubProcess{},
136                 killing:      map[string]bool{},
137         }
138         svm.SSHService = SSHService{
139                 HostKey:        sis.driver.HostKey,
140                 AuthorizedUser: "root",
141                 AuthorizedKeys: ak,
142                 Exec:           svm.Exec,
143         }
144         if setup := sis.driver.SetupVM; setup != nil {
145                 setup(svm)
146         }
147         sis.servers[svm.id] = svm
148         return svm.Instance(), nil
149 }
150
151 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
152         sis.mtx.RLock()
153         defer sis.mtx.RUnlock()
154         if sis.allowInstancesCall.After(time.Now()) {
155                 return nil, RateLimitError{sis.allowInstancesCall}
156         }
157         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
158         var r []cloud.Instance
159         for _, ss := range sis.servers {
160                 r = append(r, ss.Instance())
161         }
162         return r, nil
163 }
164
165 func (sis *StubInstanceSet) Stop() {
166         sis.mtx.Lock()
167         defer sis.mtx.Unlock()
168         if sis.stopped {
169                 panic("Stop called twice")
170         }
171         sis.stopped = true
172 }
173
174 type RateLimitError struct{ Retry time.Time }
175
176 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
177 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
178
179 // StubVM is a fake server that runs an SSH service. It represents a
180 // VM running in a fake cloud.
181 //
182 // Note this is distinct from a stubInstance, which is a snapshot of
183 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
184 // running (and might change IP addresses, shut down, etc.)  without
185 // updating any stubInstances that have been returned to callers.
186 type StubVM struct {
187         Boot                  time.Time
188         Broken                time.Time
189         ReportBroken          time.Time
190         CrunchRunMissing      bool
191         CrunchRunCrashRate    float64
192         CrunchRunDetachDelay  time.Duration
193         ArvMountMaxExitLag    time.Duration
194         ArvMountDeadlockRate  float64
195         ExecuteContainer      func(arvados.Container) int
196         CrashRunningContainer func(arvados.Container)
197         ExtraCrunchRunArgs    string // extra args expected after "crunch-run --detach --stdin-config "
198
199         sis          *StubInstanceSet
200         id           cloud.InstanceID
201         tags         cloud.InstanceTags
202         initCommand  cloud.InitCommand
203         providerType string
204         SSHService   SSHService
205         running      map[string]stubProcess
206         killing      map[string]bool
207         lastPID      int64
208         deadlocked   string
209         sync.Mutex
210 }
211
212 type stubProcess struct {
213         pid int64
214
215         // crunch-run has exited, but arv-mount process (or something)
216         // still holds lock in /var/run/
217         exited bool
218 }
219
220 func (svm *StubVM) Instance() stubInstance {
221         svm.Lock()
222         defer svm.Unlock()
223         return stubInstance{
224                 svm:  svm,
225                 addr: svm.SSHService.Address(),
226                 // We deliberately return a cached/stale copy of the
227                 // real tags here, so that (Instance)Tags() sometimes
228                 // returns old data after a call to
229                 // (Instance)SetTags().  This is permitted by the
230                 // driver interface, and this might help remind
231                 // callers that they need to tolerate it.
232                 tags: copyTags(svm.tags),
233         }
234 }
235
236 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
237         stdinData, err := ioutil.ReadAll(stdin)
238         if err != nil {
239                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
240                 return 1
241         }
242         queue := svm.sis.driver.Queue
243         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
244         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
245                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
246                 return 1
247         }
248         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
249                 fmt.Fprintf(stderr, "cannot fork\n")
250                 return 2
251         }
252         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
253                 fmt.Fprint(stderr, "crunch-run: command not found\n")
254                 return 1
255         }
256         if strings.HasPrefix(command, "crunch-run --detach --stdin-config "+svm.ExtraCrunchRunArgs) {
257                 var configData crunchrun.ConfigData
258                 err := json.Unmarshal(stdinData, &configData)
259                 if err != nil {
260                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
261                         return 1
262                 }
263                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
264                         if configData.Env[name] == "" {
265                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
266                                 return 1
267                         }
268                 }
269                 svm.Lock()
270                 svm.lastPID++
271                 pid := svm.lastPID
272                 svm.running[uuid] = stubProcess{pid: pid}
273                 svm.Unlock()
274                 time.Sleep(svm.CrunchRunDetachDelay)
275                 fmt.Fprintf(stderr, "starting %s\n", uuid)
276                 logger := svm.sis.logger.WithFields(logrus.Fields{
277                         "Instance":      svm.id,
278                         "ContainerUUID": uuid,
279                         "PID":           pid,
280                 })
281                 logger.Printf("[test] starting crunch-run stub")
282                 go func() {
283                         var ctr arvados.Container
284                         var started, completed bool
285                         defer func() {
286                                 logger.Print("[test] exiting crunch-run stub")
287                                 svm.Lock()
288                                 defer svm.Unlock()
289                                 if svm.running[uuid].pid != pid {
290                                         bugf := svm.sis.driver.Bugf
291                                         if bugf == nil {
292                                                 bugf = logger.Warnf
293                                         }
294                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
295                                         return
296                                 }
297                                 if !completed {
298                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
299                                         if started && svm.CrashRunningContainer != nil {
300                                                 svm.CrashRunningContainer(ctr)
301                                         }
302                                 }
303                                 sproc := svm.running[uuid]
304                                 sproc.exited = true
305                                 svm.running[uuid] = sproc
306                                 svm.Unlock()
307                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
308                                 svm.Lock()
309                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
310                                         delete(svm.running, uuid)
311                                 }
312                         }()
313
314                         crashluck := math_rand.Float64()
315                         wantCrash := crashluck < svm.CrunchRunCrashRate
316                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
317
318                         ctr, ok := queue.Get(uuid)
319                         if !ok {
320                                 logger.Print("[test] container not in queue")
321                                 return
322                         }
323
324                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
325
326                         svm.Lock()
327                         killed := svm.killing[uuid]
328                         svm.Unlock()
329                         if killed || wantCrashEarly {
330                                 return
331                         }
332
333                         ctr.State = arvados.ContainerStateRunning
334                         started = queue.Notify(ctr)
335                         if !started {
336                                 ctr, _ = queue.Get(uuid)
337                                 logger.Print("[test] erroring out because state=Running update was rejected")
338                                 return
339                         }
340
341                         if wantCrash {
342                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
343                                 return
344                         }
345                         if svm.ExecuteContainer != nil {
346                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
347                         }
348                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
349                         ctr.State = arvados.ContainerStateComplete
350                         completed = queue.Notify(ctr)
351                 }()
352                 return 0
353         }
354         if command == "crunch-run --list" {
355                 svm.Lock()
356                 defer svm.Unlock()
357                 for uuid, sproc := range svm.running {
358                         if sproc.exited {
359                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
360                         } else {
361                                 fmt.Fprintf(stdout, "%s\n", uuid)
362                         }
363                 }
364                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
365                         fmt.Fprintln(stdout, "broken")
366                 }
367                 fmt.Fprintln(stdout, svm.deadlocked)
368                 return 0
369         }
370         if strings.HasPrefix(command, "crunch-run --kill ") {
371                 svm.Lock()
372                 sproc, running := svm.running[uuid]
373                 if running && !sproc.exited {
374                         svm.killing[uuid] = true
375                         svm.Unlock()
376                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
377                         svm.Lock()
378                         sproc, running = svm.running[uuid]
379                 }
380                 svm.Unlock()
381                 if running && !sproc.exited {
382                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
383                         return 1
384                 }
385                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
386                 return 0
387         }
388         if command == "true" {
389                 return 0
390         }
391         fmt.Fprintf(stderr, "%q: command not found", command)
392         return 1
393 }
394
395 type stubInstance struct {
396         svm  *StubVM
397         addr string
398         tags cloud.InstanceTags
399 }
400
401 func (si stubInstance) ID() cloud.InstanceID {
402         return si.svm.id
403 }
404
405 func (si stubInstance) Address() string {
406         return si.addr
407 }
408
409 func (si stubInstance) RemoteUser() string {
410         return si.svm.SSHService.AuthorizedUser
411 }
412
413 func (si stubInstance) Destroy() error {
414         sis := si.svm.sis
415         if sis.driver.HoldCloudOps {
416                 sis.driver.holdCloudOps <- true
417         }
418         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
419                 return errors.New("instance could not be destroyed")
420         }
421         si.svm.SSHService.Close()
422         sis.mtx.Lock()
423         defer sis.mtx.Unlock()
424         delete(sis.servers, si.svm.id)
425         return nil
426 }
427
428 func (si stubInstance) ProviderType() string {
429         return si.svm.providerType
430 }
431
432 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
433         tags = copyTags(tags)
434         svm := si.svm
435         go func() {
436                 svm.Lock()
437                 defer svm.Unlock()
438                 svm.tags = tags
439         }()
440         return nil
441 }
442
443 func (si stubInstance) Tags() cloud.InstanceTags {
444         // Return a copy to ensure a caller can't change our saved
445         // tags just by writing to the returned map.
446         return copyTags(si.tags)
447 }
448
449 func (si stubInstance) String() string {
450         return string(si.svm.id)
451 }
452
453 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
454         buf := make([]byte, 512)
455         _, err := io.ReadFull(rand.Reader, buf)
456         if err != nil {
457                 return err
458         }
459         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
460         if err != nil {
461                 return err
462         }
463         return key.Verify(buf, sig)
464 }
465
466 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
467         dst := cloud.InstanceTags{}
468         for k, v := range src {
469                 dst[k] = v
470         }
471         return dst
472 }