4d32cf221ce49461e092a834ad192460bc37a49d
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/sirupsen/logrus"
23         "golang.org/x/crypto/ssh"
24 )
25
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
29         HostKey        ssh.Signer
30         AuthorizedKeys []ssh.PublicKey
31
32         // SetupVM, if set, is called upon creation of each new
33         // StubVM. This is the caller's opportunity to customize the
34         // VM's error rate and other behaviors.
35         SetupVM func(*StubVM)
36
37         // Bugf, if set, is called if a bug is detected in the caller
38         // or stub. Typically set to (*check.C)Errorf. If unset,
39         // logger.Warnf is called instead.
40         Bugf func(string, ...interface{})
41
42         // StubVM's fake crunch-run uses this Queue to read and update
43         // container state.
44         Queue *Queue
45
46         // Frequency of artificially introduced errors on calls to
47         // Destroy. 0=always succeed, 1=always fail.
48         ErrorRateDestroy float64
49
50         // If Create() or Instances() is called too frequently, return
51         // rate-limiting errors.
52         MinTimeBetweenCreateCalls    time.Duration
53         MinTimeBetweenInstancesCalls time.Duration
54
55         // If true, Create and Destroy calls block until Release() is
56         // called.
57         HoldCloudOps bool
58
59         instanceSets []*StubInstanceSet
60         holdCloudOps chan bool
61 }
62
63 // InstanceSet returns a new *StubInstanceSet.
64 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
65         if sd.holdCloudOps == nil {
66                 sd.holdCloudOps = make(chan bool)
67         }
68         sis := StubInstanceSet{
69                 driver:  sd,
70                 logger:  logger,
71                 servers: map[cloud.InstanceID]*StubVM{},
72         }
73         sd.instanceSets = append(sd.instanceSets, &sis)
74
75         var err error
76         if params != nil {
77                 err = json.Unmarshal(params, &sis)
78         }
79         return &sis, err
80 }
81
82 // InstanceSets returns all instances that have been created by the
83 // driver. This can be used to test a component that uses the driver
84 // but doesn't expose the InstanceSets it has created.
85 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
86         return sd.instanceSets
87 }
88
89 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
90 // are fewer than n blocked calls pending, it waits for the rest to
91 // arrive.
92 func (sd *StubDriver) ReleaseCloudOps(n int) {
93         for i := 0; i < n; i++ {
94                 <-sd.holdCloudOps
95         }
96 }
97
98 type StubInstanceSet struct {
99         driver  *StubDriver
100         logger  logrus.FieldLogger
101         servers map[cloud.InstanceID]*StubVM
102         mtx     sync.RWMutex
103         stopped bool
104
105         allowCreateCall    time.Time
106         allowInstancesCall time.Time
107         lastInstanceID     int
108 }
109
110 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
111         if sis.driver.HoldCloudOps {
112                 sis.driver.holdCloudOps <- true
113         }
114         sis.mtx.Lock()
115         defer sis.mtx.Unlock()
116         if sis.stopped {
117                 return nil, errors.New("StubInstanceSet: Create called after Stop")
118         }
119         if sis.allowCreateCall.After(time.Now()) {
120                 return nil, RateLimitError{sis.allowCreateCall}
121         }
122         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
123         ak := sis.driver.AuthorizedKeys
124         if authKey != nil {
125                 ak = append([]ssh.PublicKey{authKey}, ak...)
126         }
127         sis.lastInstanceID++
128         svm := &StubVM{
129                 sis:          sis,
130                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
131                 tags:         copyTags(tags),
132                 providerType: it.ProviderType,
133                 initCommand:  cmd,
134                 running:      map[string]stubProcess{},
135                 killing:      map[string]bool{},
136         }
137         svm.SSHService = SSHService{
138                 HostKey:        sis.driver.HostKey,
139                 AuthorizedUser: "root",
140                 AuthorizedKeys: ak,
141                 Exec:           svm.Exec,
142         }
143         if setup := sis.driver.SetupVM; setup != nil {
144                 setup(svm)
145         }
146         sis.servers[svm.id] = svm
147         return svm.Instance(), nil
148 }
149
150 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
151         sis.mtx.RLock()
152         defer sis.mtx.RUnlock()
153         if sis.allowInstancesCall.After(time.Now()) {
154                 return nil, RateLimitError{sis.allowInstancesCall}
155         }
156         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
157         var r []cloud.Instance
158         for _, ss := range sis.servers {
159                 r = append(r, ss.Instance())
160         }
161         return r, nil
162 }
163
164 func (sis *StubInstanceSet) Stop() {
165         sis.mtx.Lock()
166         defer sis.mtx.Unlock()
167         if sis.stopped {
168                 panic("Stop called twice")
169         }
170         sis.stopped = true
171 }
172
173 type RateLimitError struct{ Retry time.Time }
174
175 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
176 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
177
178 // StubVM is a fake server that runs an SSH service. It represents a
179 // VM running in a fake cloud.
180 //
181 // Note this is distinct from a stubInstance, which is a snapshot of
182 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
183 // running (and might change IP addresses, shut down, etc.)  without
184 // updating any stubInstances that have been returned to callers.
185 type StubVM struct {
186         Boot                  time.Time
187         Broken                time.Time
188         ReportBroken          time.Time
189         CrunchRunMissing      bool
190         CrunchRunCrashRate    float64
191         CrunchRunDetachDelay  time.Duration
192         ArvMountMaxExitLag    time.Duration
193         ArvMountDeadlockRate  float64
194         ExecuteContainer      func(arvados.Container) int
195         CrashRunningContainer func(arvados.Container)
196
197         sis          *StubInstanceSet
198         id           cloud.InstanceID
199         tags         cloud.InstanceTags
200         initCommand  cloud.InitCommand
201         providerType string
202         SSHService   SSHService
203         running      map[string]stubProcess
204         killing      map[string]bool
205         lastPID      int64
206         deadlocked   string
207         sync.Mutex
208 }
209
210 type stubProcess struct {
211         pid int64
212
213         // crunch-run has exited, but arv-mount process (or something)
214         // still holds lock in /var/run/
215         exited bool
216 }
217
218 func (svm *StubVM) Instance() stubInstance {
219         svm.Lock()
220         defer svm.Unlock()
221         return stubInstance{
222                 svm:  svm,
223                 addr: svm.SSHService.Address(),
224                 // We deliberately return a cached/stale copy of the
225                 // real tags here, so that (Instance)Tags() sometimes
226                 // returns old data after a call to
227                 // (Instance)SetTags().  This is permitted by the
228                 // driver interface, and this might help remind
229                 // callers that they need to tolerate it.
230                 tags: copyTags(svm.tags),
231         }
232 }
233
234 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
235         stdinData, err := ioutil.ReadAll(stdin)
236         if err != nil {
237                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
238                 return 1
239         }
240         queue := svm.sis.driver.Queue
241         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
242         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
243                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
244                 return 1
245         }
246         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
247                 fmt.Fprintf(stderr, "cannot fork\n")
248                 return 2
249         }
250         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
251                 fmt.Fprint(stderr, "crunch-run: command not found\n")
252                 return 1
253         }
254         if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
255                 var stdinKV map[string]string
256                 err := json.Unmarshal(stdinData, &stdinKV)
257                 if err != nil {
258                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
259                         return 1
260                 }
261                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
262                         if stdinKV[name] == "" {
263                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
264                                 return 1
265                         }
266                 }
267                 svm.Lock()
268                 svm.lastPID++
269                 pid := svm.lastPID
270                 svm.running[uuid] = stubProcess{pid: pid}
271                 svm.Unlock()
272                 time.Sleep(svm.CrunchRunDetachDelay)
273                 fmt.Fprintf(stderr, "starting %s\n", uuid)
274                 logger := svm.sis.logger.WithFields(logrus.Fields{
275                         "Instance":      svm.id,
276                         "ContainerUUID": uuid,
277                         "PID":           pid,
278                 })
279                 logger.Printf("[test] starting crunch-run stub")
280                 go func() {
281                         var ctr arvados.Container
282                         var started, completed bool
283                         defer func() {
284                                 logger.Print("[test] exiting crunch-run stub")
285                                 svm.Lock()
286                                 defer svm.Unlock()
287                                 if svm.running[uuid].pid != pid {
288                                         bugf := svm.sis.driver.Bugf
289                                         if bugf == nil {
290                                                 bugf = logger.Warnf
291                                         }
292                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
293                                         return
294                                 }
295                                 if !completed {
296                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
297                                         if started && svm.CrashRunningContainer != nil {
298                                                 svm.CrashRunningContainer(ctr)
299                                         }
300                                 }
301                                 sproc := svm.running[uuid]
302                                 sproc.exited = true
303                                 svm.running[uuid] = sproc
304                                 svm.Unlock()
305                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
306                                 svm.Lock()
307                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
308                                         delete(svm.running, uuid)
309                                 }
310                         }()
311
312                         crashluck := math_rand.Float64()
313                         wantCrash := crashluck < svm.CrunchRunCrashRate
314                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
315
316                         ctr, ok := queue.Get(uuid)
317                         if !ok {
318                                 logger.Print("[test] container not in queue")
319                                 return
320                         }
321
322                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
323
324                         svm.Lock()
325                         killed := svm.killing[uuid]
326                         svm.Unlock()
327                         if killed || wantCrashEarly {
328                                 return
329                         }
330
331                         ctr.State = arvados.ContainerStateRunning
332                         started = queue.Notify(ctr)
333                         if !started {
334                                 ctr, _ = queue.Get(uuid)
335                                 logger.Print("[test] erroring out because state=Running update was rejected")
336                                 return
337                         }
338
339                         if wantCrash {
340                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
341                                 return
342                         }
343                         if svm.ExecuteContainer != nil {
344                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
345                         }
346                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
347                         ctr.State = arvados.ContainerStateComplete
348                         completed = queue.Notify(ctr)
349                 }()
350                 return 0
351         }
352         if command == "crunch-run --list" {
353                 svm.Lock()
354                 defer svm.Unlock()
355                 for uuid, sproc := range svm.running {
356                         if sproc.exited {
357                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
358                         } else {
359                                 fmt.Fprintf(stdout, "%s\n", uuid)
360                         }
361                 }
362                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
363                         fmt.Fprintln(stdout, "broken")
364                 }
365                 fmt.Fprintln(stdout, svm.deadlocked)
366                 return 0
367         }
368         if strings.HasPrefix(command, "crunch-run --kill ") {
369                 svm.Lock()
370                 sproc, running := svm.running[uuid]
371                 if running && !sproc.exited {
372                         svm.killing[uuid] = true
373                         svm.Unlock()
374                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
375                         svm.Lock()
376                         sproc, running = svm.running[uuid]
377                 }
378                 svm.Unlock()
379                 if running && !sproc.exited {
380                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
381                         return 1
382                 }
383                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
384                 return 0
385         }
386         if command == "true" {
387                 return 0
388         }
389         fmt.Fprintf(stderr, "%q: command not found", command)
390         return 1
391 }
392
393 type stubInstance struct {
394         svm  *StubVM
395         addr string
396         tags cloud.InstanceTags
397 }
398
399 func (si stubInstance) ID() cloud.InstanceID {
400         return si.svm.id
401 }
402
403 func (si stubInstance) Address() string {
404         return si.addr
405 }
406
407 func (si stubInstance) RemoteUser() string {
408         return si.svm.SSHService.AuthorizedUser
409 }
410
411 func (si stubInstance) Destroy() error {
412         sis := si.svm.sis
413         if sis.driver.HoldCloudOps {
414                 sis.driver.holdCloudOps <- true
415         }
416         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
417                 return errors.New("instance could not be destroyed")
418         }
419         si.svm.SSHService.Close()
420         sis.mtx.Lock()
421         defer sis.mtx.Unlock()
422         delete(sis.servers, si.svm.id)
423         return nil
424 }
425
426 func (si stubInstance) ProviderType() string {
427         return si.svm.providerType
428 }
429
430 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
431         tags = copyTags(tags)
432         svm := si.svm
433         go func() {
434                 svm.Lock()
435                 defer svm.Unlock()
436                 svm.tags = tags
437         }()
438         return nil
439 }
440
441 func (si stubInstance) Tags() cloud.InstanceTags {
442         // Return a copy to ensure a caller can't change our saved
443         // tags just by writing to the returned map.
444         return copyTags(si.tags)
445 }
446
447 func (si stubInstance) String() string {
448         return string(si.svm.id)
449 }
450
451 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
452         buf := make([]byte, 512)
453         _, err := io.ReadFull(rand.Reader, buf)
454         if err != nil {
455                 return err
456         }
457         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
458         if err != nil {
459                 return err
460         }
461         return key.Verify(buf, sig)
462 }
463
464 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
465         dst := cloud.InstanceTags{}
466         for k, v := range src {
467                 dst[k] = v
468         }
469         return dst
470 }