20520: Add test for custom init command.
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/lib/crunchrun"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/sirupsen/logrus"
24         "golang.org/x/crypto/ssh"
25 )
26
27 // A StubDriver implements cloud.Driver by setting up local SSH
28 // servers that do fake command executions.
29 type StubDriver struct {
30         HostKey        ssh.Signer
31         AuthorizedKeys []ssh.PublicKey
32
33         // SetupVM, if set, is called upon creation of each new
34         // StubVM. This is the caller's opportunity to customize the
35         // VM's error rate and other behaviors.
36         SetupVM func(*StubVM)
37
38         // Bugf, if set, is called if a bug is detected in the caller
39         // or stub. Typically set to (*check.C)Errorf. If unset,
40         // logger.Warnf is called instead.
41         Bugf func(string, ...interface{})
42
43         // StubVM's fake crunch-run uses this Queue to read and update
44         // container state.
45         Queue *Queue
46
47         // Frequency of artificially introduced errors on calls to
48         // Destroy. 0=always succeed, 1=always fail.
49         ErrorRateDestroy float64
50
51         // If Create() or Instances() is called too frequently, return
52         // rate-limiting errors.
53         MinTimeBetweenCreateCalls    time.Duration
54         MinTimeBetweenInstancesCalls time.Duration
55
56         // If true, Create and Destroy calls block until Release() is
57         // called.
58         HoldCloudOps bool
59
60         instanceSets []*StubInstanceSet
61         holdCloudOps chan bool
62 }
63
64 // InstanceSet returns a new *StubInstanceSet.
65 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
66         if sd.holdCloudOps == nil {
67                 sd.holdCloudOps = make(chan bool)
68         }
69         sis := StubInstanceSet{
70                 driver:  sd,
71                 logger:  logger,
72                 servers: map[cloud.InstanceID]*StubVM{},
73         }
74         sd.instanceSets = append(sd.instanceSets, &sis)
75
76         var err error
77         if params != nil {
78                 err = json.Unmarshal(params, &sis)
79         }
80         return &sis, err
81 }
82
83 // InstanceSets returns all instances that have been created by the
84 // driver. This can be used to test a component that uses the driver
85 // but doesn't expose the InstanceSets it has created.
86 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
87         return sd.instanceSets
88 }
89
90 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
91 // are fewer than n blocked calls pending, it waits for the rest to
92 // arrive.
93 func (sd *StubDriver) ReleaseCloudOps(n int) {
94         for i := 0; i < n; i++ {
95                 <-sd.holdCloudOps
96         }
97 }
98
99 type StubInstanceSet struct {
100         driver  *StubDriver
101         logger  logrus.FieldLogger
102         servers map[cloud.InstanceID]*StubVM
103         mtx     sync.RWMutex
104         stopped bool
105
106         allowCreateCall    time.Time
107         allowInstancesCall time.Time
108         lastInstanceID     int
109 }
110
111 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, initCommand cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
112         if sis.driver.HoldCloudOps {
113                 sis.driver.holdCloudOps <- true
114         }
115         sis.mtx.Lock()
116         defer sis.mtx.Unlock()
117         if sis.stopped {
118                 return nil, errors.New("StubInstanceSet: Create called after Stop")
119         }
120         if sis.allowCreateCall.After(time.Now()) {
121                 return nil, RateLimitError{sis.allowCreateCall}
122         }
123         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
124         ak := sis.driver.AuthorizedKeys
125         if authKey != nil {
126                 ak = append([]ssh.PublicKey{authKey}, ak...)
127         }
128         sis.lastInstanceID++
129         svm := &StubVM{
130                 InitCommand:  initCommand,
131                 sis:          sis,
132                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
133                 tags:         copyTags(tags),
134                 providerType: it.ProviderType,
135                 running:      map[string]stubProcess{},
136                 killing:      map[string]bool{},
137         }
138         svm.SSHService = SSHService{
139                 HostKey:        sis.driver.HostKey,
140                 AuthorizedUser: "root",
141                 AuthorizedKeys: ak,
142                 Exec:           svm.Exec,
143         }
144         if setup := sis.driver.SetupVM; setup != nil {
145                 setup(svm)
146         }
147         sis.servers[svm.id] = svm
148         return svm.Instance(), nil
149 }
150
151 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
152         sis.mtx.RLock()
153         defer sis.mtx.RUnlock()
154         if sis.allowInstancesCall.After(time.Now()) {
155                 return nil, RateLimitError{sis.allowInstancesCall}
156         }
157         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
158         var r []cloud.Instance
159         for _, ss := range sis.servers {
160                 r = append(r, ss.Instance())
161         }
162         return r, nil
163 }
164
165 func (sis *StubInstanceSet) Stop() {
166         sis.mtx.Lock()
167         defer sis.mtx.Unlock()
168         if sis.stopped {
169                 panic("Stop called twice")
170         }
171         sis.stopped = true
172 }
173
174 func (sis *StubInstanceSet) StubVMs() (svms []*StubVM) {
175         sis.mtx.Lock()
176         defer sis.mtx.Unlock()
177         for _, vm := range sis.servers {
178                 svms = append(svms, vm)
179         }
180         return
181 }
182
183 type RateLimitError struct{ Retry time.Time }
184
185 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
186 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
187
188 // StubVM is a fake server that runs an SSH service. It represents a
189 // VM running in a fake cloud.
190 //
191 // Note this is distinct from a stubInstance, which is a snapshot of
192 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
193 // running (and might change IP addresses, shut down, etc.)  without
194 // updating any stubInstances that have been returned to callers.
195 type StubVM struct {
196         Boot                  time.Time
197         Broken                time.Time
198         ReportBroken          time.Time
199         CrunchRunMissing      bool
200         CrunchRunCrashRate    float64
201         CrunchRunDetachDelay  time.Duration
202         ArvMountMaxExitLag    time.Duration
203         ArvMountDeadlockRate  float64
204         ExecuteContainer      func(arvados.Container) int
205         CrashRunningContainer func(arvados.Container)
206         ExtraCrunchRunArgs    string // extra args expected after "crunch-run --detach --stdin-config "
207
208         // Populated by (*StubInstanceSet)Create()
209         InitCommand cloud.InitCommand
210
211         sis          *StubInstanceSet
212         id           cloud.InstanceID
213         tags         cloud.InstanceTags
214         providerType string
215         SSHService   SSHService
216         running      map[string]stubProcess
217         killing      map[string]bool
218         lastPID      int64
219         deadlocked   string
220         sync.Mutex
221 }
222
223 type stubProcess struct {
224         pid int64
225
226         // crunch-run has exited, but arv-mount process (or something)
227         // still holds lock in /var/run/
228         exited bool
229 }
230
231 func (svm *StubVM) Instance() stubInstance {
232         svm.Lock()
233         defer svm.Unlock()
234         return stubInstance{
235                 svm:  svm,
236                 addr: svm.SSHService.Address(),
237                 // We deliberately return a cached/stale copy of the
238                 // real tags here, so that (Instance)Tags() sometimes
239                 // returns old data after a call to
240                 // (Instance)SetTags().  This is permitted by the
241                 // driver interface, and this might help remind
242                 // callers that they need to tolerate it.
243                 tags: copyTags(svm.tags),
244         }
245 }
246
247 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
248         stdinData, err := ioutil.ReadAll(stdin)
249         if err != nil {
250                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
251                 return 1
252         }
253         queue := svm.sis.driver.Queue
254         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
255         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
256                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
257                 return 1
258         }
259         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
260                 fmt.Fprintf(stderr, "cannot fork\n")
261                 return 2
262         }
263         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
264                 fmt.Fprint(stderr, "crunch-run: command not found\n")
265                 return 1
266         }
267         if strings.HasPrefix(command, "crunch-run --detach --stdin-config "+svm.ExtraCrunchRunArgs) {
268                 var configData crunchrun.ConfigData
269                 err := json.Unmarshal(stdinData, &configData)
270                 if err != nil {
271                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
272                         return 1
273                 }
274                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
275                         if configData.Env[name] == "" {
276                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
277                                 return 1
278                         }
279                 }
280                 svm.Lock()
281                 svm.lastPID++
282                 pid := svm.lastPID
283                 svm.running[uuid] = stubProcess{pid: pid}
284                 svm.Unlock()
285                 time.Sleep(svm.CrunchRunDetachDelay)
286                 fmt.Fprintf(stderr, "starting %s\n", uuid)
287                 logger := svm.sis.logger.WithFields(logrus.Fields{
288                         "Instance":      svm.id,
289                         "ContainerUUID": uuid,
290                         "PID":           pid,
291                 })
292                 logger.Printf("[test] starting crunch-run stub")
293                 go func() {
294                         var ctr arvados.Container
295                         var started, completed bool
296                         defer func() {
297                                 logger.Print("[test] exiting crunch-run stub")
298                                 svm.Lock()
299                                 defer svm.Unlock()
300                                 if svm.running[uuid].pid != pid {
301                                         bugf := svm.sis.driver.Bugf
302                                         if bugf == nil {
303                                                 bugf = logger.Warnf
304                                         }
305                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
306                                         return
307                                 }
308                                 if !completed {
309                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
310                                         if started && svm.CrashRunningContainer != nil {
311                                                 svm.CrashRunningContainer(ctr)
312                                         }
313                                 }
314                                 sproc := svm.running[uuid]
315                                 sproc.exited = true
316                                 svm.running[uuid] = sproc
317                                 svm.Unlock()
318                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
319                                 svm.Lock()
320                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
321                                         delete(svm.running, uuid)
322                                 }
323                         }()
324
325                         crashluck := math_rand.Float64()
326                         wantCrash := crashluck < svm.CrunchRunCrashRate
327                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
328
329                         ctr, ok := queue.Get(uuid)
330                         if !ok {
331                                 logger.Print("[test] container not in queue")
332                                 return
333                         }
334
335                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
336
337                         svm.Lock()
338                         killed := svm.killing[uuid]
339                         svm.Unlock()
340                         if killed || wantCrashEarly {
341                                 return
342                         }
343
344                         ctr.State = arvados.ContainerStateRunning
345                         started = queue.Notify(ctr)
346                         if !started {
347                                 ctr, _ = queue.Get(uuid)
348                                 logger.Print("[test] erroring out because state=Running update was rejected")
349                                 return
350                         }
351
352                         if wantCrash {
353                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
354                                 return
355                         }
356                         if svm.ExecuteContainer != nil {
357                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
358                         }
359                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
360                         ctr.State = arvados.ContainerStateComplete
361                         completed = queue.Notify(ctr)
362                 }()
363                 return 0
364         }
365         if command == "crunch-run --list" {
366                 svm.Lock()
367                 defer svm.Unlock()
368                 for uuid, sproc := range svm.running {
369                         if sproc.exited {
370                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
371                         } else {
372                                 fmt.Fprintf(stdout, "%s\n", uuid)
373                         }
374                 }
375                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
376                         fmt.Fprintln(stdout, "broken")
377                 }
378                 fmt.Fprintln(stdout, svm.deadlocked)
379                 return 0
380         }
381         if strings.HasPrefix(command, "crunch-run --kill ") {
382                 svm.Lock()
383                 sproc, running := svm.running[uuid]
384                 if running && !sproc.exited {
385                         svm.killing[uuid] = true
386                         svm.Unlock()
387                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
388                         svm.Lock()
389                         sproc, running = svm.running[uuid]
390                 }
391                 svm.Unlock()
392                 if running && !sproc.exited {
393                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
394                         return 1
395                 }
396                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
397                 return 0
398         }
399         if command == "true" {
400                 return 0
401         }
402         fmt.Fprintf(stderr, "%q: command not found", command)
403         return 1
404 }
405
406 type stubInstance struct {
407         svm  *StubVM
408         addr string
409         tags cloud.InstanceTags
410 }
411
412 func (si stubInstance) ID() cloud.InstanceID {
413         return si.svm.id
414 }
415
416 func (si stubInstance) Address() string {
417         return si.addr
418 }
419
420 func (si stubInstance) RemoteUser() string {
421         return si.svm.SSHService.AuthorizedUser
422 }
423
424 func (si stubInstance) Destroy() error {
425         sis := si.svm.sis
426         if sis.driver.HoldCloudOps {
427                 sis.driver.holdCloudOps <- true
428         }
429         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
430                 return errors.New("instance could not be destroyed")
431         }
432         si.svm.SSHService.Close()
433         sis.mtx.Lock()
434         defer sis.mtx.Unlock()
435         delete(sis.servers, si.svm.id)
436         return nil
437 }
438
439 func (si stubInstance) ProviderType() string {
440         return si.svm.providerType
441 }
442
443 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
444         tags = copyTags(tags)
445         svm := si.svm
446         go func() {
447                 svm.Lock()
448                 defer svm.Unlock()
449                 svm.tags = tags
450         }()
451         return nil
452 }
453
454 func (si stubInstance) Tags() cloud.InstanceTags {
455         // Return a copy to ensure a caller can't change our saved
456         // tags just by writing to the returned map.
457         return copyTags(si.tags)
458 }
459
460 func (si stubInstance) String() string {
461         return string(si.svm.id)
462 }
463
464 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
465         buf := make([]byte, 512)
466         _, err := io.ReadFull(rand.Reader, buf)
467         if err != nil {
468                 return err
469         }
470         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
471         if err != nil {
472                 return err
473         }
474         return key.Verify(buf, sig)
475 }
476
477 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
478         dst := cloud.InstanceTags{}
479         for k, v := range src {
480                 dst[k] = v
481         }
482         return dst
483 }
484
485 func (si stubInstance) PriceHistory(arvados.InstanceType) []cloud.InstancePrice {
486         return nil
487 }