1b31a71a264fabf865f981f5f94eab1649847ac4
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/sirupsen/logrus"
23         "golang.org/x/crypto/ssh"
24 )
25
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
29         HostKey        ssh.Signer
30         AuthorizedKeys []ssh.PublicKey
31
32         // SetupVM, if set, is called upon creation of each new
33         // StubVM. This is the caller's opportunity to customize the
34         // VM's error rate and other behaviors.
35         SetupVM func(*StubVM)
36
37         // Bugf, if set, is called if a bug is detected in the caller
38         // or stub. Typically set to (*check.C)Errorf. If unset,
39         // logger.Warnf is called instead.
40         Bugf func(string, ...interface{})
41
42         // StubVM's fake crunch-run uses this Queue to read and update
43         // container state.
44         Queue *Queue
45
46         // Frequency of artificially introduced errors on calls to
47         // Destroy. 0=always succeed, 1=always fail.
48         ErrorRateDestroy float64
49
50         // If Create() or Instances() is called too frequently, return
51         // rate-limiting errors.
52         MinTimeBetweenCreateCalls    time.Duration
53         MinTimeBetweenInstancesCalls time.Duration
54
55         // If true, Create and Destroy calls block until Release() is
56         // called.
57         HoldCloudOps bool
58
59         instanceSets []*StubInstanceSet
60         holdCloudOps chan bool
61 }
62
63 // InstanceSet returns a new *StubInstanceSet.
64 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
65         if sd.holdCloudOps == nil {
66                 sd.holdCloudOps = make(chan bool)
67         }
68         sis := StubInstanceSet{
69                 driver:  sd,
70                 logger:  logger,
71                 servers: map[cloud.InstanceID]*StubVM{},
72         }
73         sd.instanceSets = append(sd.instanceSets, &sis)
74
75         var err error
76         if params != nil {
77                 err = json.Unmarshal(params, &sis)
78         }
79         return &sis, err
80 }
81
82 // InstanceSets returns all instances that have been created by the
83 // driver. This can be used to test a component that uses the driver
84 // but doesn't expose the InstanceSets it has created.
85 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
86         return sd.instanceSets
87 }
88
89 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
90 // are fewer than n blocked calls pending, it waits for the rest to
91 // arrive.
92 func (sd *StubDriver) ReleaseCloudOps(n int) {
93         for i := 0; i < n; i++ {
94                 <-sd.holdCloudOps
95         }
96 }
97
98 type StubInstanceSet struct {
99         driver  *StubDriver
100         logger  logrus.FieldLogger
101         servers map[cloud.InstanceID]*StubVM
102         mtx     sync.RWMutex
103         stopped bool
104
105         allowCreateCall    time.Time
106         allowInstancesCall time.Time
107         lastInstanceID     int
108 }
109
110 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
111         if sis.driver.HoldCloudOps {
112                 sis.driver.holdCloudOps <- true
113         }
114         sis.mtx.Lock()
115         defer sis.mtx.Unlock()
116         if sis.stopped {
117                 return nil, errors.New("StubInstanceSet: Create called after Stop")
118         }
119         if sis.allowCreateCall.After(time.Now()) {
120                 return nil, RateLimitError{sis.allowCreateCall}
121         }
122         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
123         ak := sis.driver.AuthorizedKeys
124         if authKey != nil {
125                 ak = append([]ssh.PublicKey{authKey}, ak...)
126         }
127         sis.lastInstanceID++
128         svm := &StubVM{
129                 sis:          sis,
130                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
131                 tags:         copyTags(tags),
132                 providerType: it.ProviderType,
133                 initCommand:  cmd,
134                 running:      map[string]stubProcess{},
135                 killing:      map[string]bool{},
136         }
137         svm.SSHService = SSHService{
138                 HostKey:        sis.driver.HostKey,
139                 AuthorizedUser: "root",
140                 AuthorizedKeys: ak,
141                 Exec:           svm.Exec,
142         }
143         if setup := sis.driver.SetupVM; setup != nil {
144                 setup(svm)
145         }
146         sis.servers[svm.id] = svm
147         return svm.Instance(), nil
148 }
149
150 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
151         sis.mtx.RLock()
152         defer sis.mtx.RUnlock()
153         if sis.allowInstancesCall.After(time.Now()) {
154                 return nil, RateLimitError{sis.allowInstancesCall}
155         }
156         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
157         var r []cloud.Instance
158         for _, ss := range sis.servers {
159                 r = append(r, ss.Instance())
160         }
161         return r, nil
162 }
163
164 func (sis *StubInstanceSet) Stop() {
165         sis.mtx.Lock()
166         defer sis.mtx.Unlock()
167         if sis.stopped {
168                 panic("Stop called twice")
169         }
170         sis.stopped = true
171 }
172
173 type RateLimitError struct{ Retry time.Time }
174
175 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
176 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
177
178 // StubVM is a fake server that runs an SSH service. It represents a
179 // VM running in a fake cloud.
180 //
181 // Note this is distinct from a stubInstance, which is a snapshot of
182 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
183 // running (and might change IP addresses, shut down, etc.)  without
184 // updating any stubInstances that have been returned to callers.
185 type StubVM struct {
186         Boot                  time.Time
187         Broken                time.Time
188         ReportBroken          time.Time
189         CrunchRunMissing      bool
190         CrunchRunCrashRate    float64
191         CrunchRunDetachDelay  time.Duration
192         ArvMountMaxExitLag    time.Duration
193         ArvMountDeadlockRate  float64
194         ExecuteContainer      func(arvados.Container) int
195         CrashRunningContainer func(arvados.Container)
196         ExtraCrunchRunArgs    string // extra args expected after "crunch-run --detach --stdin-env "
197
198         sis          *StubInstanceSet
199         id           cloud.InstanceID
200         tags         cloud.InstanceTags
201         initCommand  cloud.InitCommand
202         providerType string
203         SSHService   SSHService
204         running      map[string]stubProcess
205         killing      map[string]bool
206         lastPID      int64
207         deadlocked   string
208         sync.Mutex
209 }
210
211 type stubProcess struct {
212         pid int64
213
214         // crunch-run has exited, but arv-mount process (or something)
215         // still holds lock in /var/run/
216         exited bool
217 }
218
219 func (svm *StubVM) Instance() stubInstance {
220         svm.Lock()
221         defer svm.Unlock()
222         return stubInstance{
223                 svm:  svm,
224                 addr: svm.SSHService.Address(),
225                 // We deliberately return a cached/stale copy of the
226                 // real tags here, so that (Instance)Tags() sometimes
227                 // returns old data after a call to
228                 // (Instance)SetTags().  This is permitted by the
229                 // driver interface, and this might help remind
230                 // callers that they need to tolerate it.
231                 tags: copyTags(svm.tags),
232         }
233 }
234
235 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
236         stdinData, err := ioutil.ReadAll(stdin)
237         if err != nil {
238                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
239                 return 1
240         }
241         queue := svm.sis.driver.Queue
242         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
243         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
244                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
245                 return 1
246         }
247         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
248                 fmt.Fprintf(stderr, "cannot fork\n")
249                 return 2
250         }
251         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
252                 fmt.Fprint(stderr, "crunch-run: command not found\n")
253                 return 1
254         }
255         if strings.HasPrefix(command, "crunch-run --detach --stdin-env "+svm.ExtraCrunchRunArgs) {
256                 var stdinKV map[string]string
257                 err := json.Unmarshal(stdinData, &stdinKV)
258                 if err != nil {
259                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
260                         return 1
261                 }
262                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
263                         if stdinKV[name] == "" {
264                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
265                                 return 1
266                         }
267                 }
268                 svm.Lock()
269                 svm.lastPID++
270                 pid := svm.lastPID
271                 svm.running[uuid] = stubProcess{pid: pid}
272                 svm.Unlock()
273                 time.Sleep(svm.CrunchRunDetachDelay)
274                 fmt.Fprintf(stderr, "starting %s\n", uuid)
275                 logger := svm.sis.logger.WithFields(logrus.Fields{
276                         "Instance":      svm.id,
277                         "ContainerUUID": uuid,
278                         "PID":           pid,
279                 })
280                 logger.Printf("[test] starting crunch-run stub")
281                 go func() {
282                         var ctr arvados.Container
283                         var started, completed bool
284                         defer func() {
285                                 logger.Print("[test] exiting crunch-run stub")
286                                 svm.Lock()
287                                 defer svm.Unlock()
288                                 if svm.running[uuid].pid != pid {
289                                         bugf := svm.sis.driver.Bugf
290                                         if bugf == nil {
291                                                 bugf = logger.Warnf
292                                         }
293                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
294                                         return
295                                 }
296                                 if !completed {
297                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
298                                         if started && svm.CrashRunningContainer != nil {
299                                                 svm.CrashRunningContainer(ctr)
300                                         }
301                                 }
302                                 sproc := svm.running[uuid]
303                                 sproc.exited = true
304                                 svm.running[uuid] = sproc
305                                 svm.Unlock()
306                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
307                                 svm.Lock()
308                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
309                                         delete(svm.running, uuid)
310                                 }
311                         }()
312
313                         crashluck := math_rand.Float64()
314                         wantCrash := crashluck < svm.CrunchRunCrashRate
315                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
316
317                         ctr, ok := queue.Get(uuid)
318                         if !ok {
319                                 logger.Print("[test] container not in queue")
320                                 return
321                         }
322
323                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
324
325                         svm.Lock()
326                         killed := svm.killing[uuid]
327                         svm.Unlock()
328                         if killed || wantCrashEarly {
329                                 return
330                         }
331
332                         ctr.State = arvados.ContainerStateRunning
333                         started = queue.Notify(ctr)
334                         if !started {
335                                 ctr, _ = queue.Get(uuid)
336                                 logger.Print("[test] erroring out because state=Running update was rejected")
337                                 return
338                         }
339
340                         if wantCrash {
341                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
342                                 return
343                         }
344                         if svm.ExecuteContainer != nil {
345                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
346                         }
347                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
348                         ctr.State = arvados.ContainerStateComplete
349                         completed = queue.Notify(ctr)
350                 }()
351                 return 0
352         }
353         if command == "crunch-run --list" {
354                 svm.Lock()
355                 defer svm.Unlock()
356                 for uuid, sproc := range svm.running {
357                         if sproc.exited {
358                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
359                         } else {
360                                 fmt.Fprintf(stdout, "%s\n", uuid)
361                         }
362                 }
363                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
364                         fmt.Fprintln(stdout, "broken")
365                 }
366                 fmt.Fprintln(stdout, svm.deadlocked)
367                 return 0
368         }
369         if strings.HasPrefix(command, "crunch-run --kill ") {
370                 svm.Lock()
371                 sproc, running := svm.running[uuid]
372                 if running && !sproc.exited {
373                         svm.killing[uuid] = true
374                         svm.Unlock()
375                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
376                         svm.Lock()
377                         sproc, running = svm.running[uuid]
378                 }
379                 svm.Unlock()
380                 if running && !sproc.exited {
381                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
382                         return 1
383                 }
384                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
385                 return 0
386         }
387         if command == "true" {
388                 return 0
389         }
390         fmt.Fprintf(stderr, "%q: command not found", command)
391         return 1
392 }
393
394 type stubInstance struct {
395         svm  *StubVM
396         addr string
397         tags cloud.InstanceTags
398 }
399
400 func (si stubInstance) ID() cloud.InstanceID {
401         return si.svm.id
402 }
403
404 func (si stubInstance) Address() string {
405         return si.addr
406 }
407
408 func (si stubInstance) RemoteUser() string {
409         return si.svm.SSHService.AuthorizedUser
410 }
411
412 func (si stubInstance) Destroy() error {
413         sis := si.svm.sis
414         if sis.driver.HoldCloudOps {
415                 sis.driver.holdCloudOps <- true
416         }
417         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
418                 return errors.New("instance could not be destroyed")
419         }
420         si.svm.SSHService.Close()
421         sis.mtx.Lock()
422         defer sis.mtx.Unlock()
423         delete(sis.servers, si.svm.id)
424         return nil
425 }
426
427 func (si stubInstance) ProviderType() string {
428         return si.svm.providerType
429 }
430
431 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
432         tags = copyTags(tags)
433         svm := si.svm
434         go func() {
435                 svm.Lock()
436                 defer svm.Unlock()
437                 svm.tags = tags
438         }()
439         return nil
440 }
441
442 func (si stubInstance) Tags() cloud.InstanceTags {
443         // Return a copy to ensure a caller can't change our saved
444         // tags just by writing to the returned map.
445         return copyTags(si.tags)
446 }
447
448 func (si stubInstance) String() string {
449         return string(si.svm.id)
450 }
451
452 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
453         buf := make([]byte, 512)
454         _, err := io.ReadFull(rand.Reader, buf)
455         if err != nil {
456                 return err
457         }
458         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
459         if err != nil {
460                 return err
461         }
462         return key.Verify(buf, sig)
463 }
464
465 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
466         dst := cloud.InstanceTags{}
467         for k, v := range src {
468                 dst[k] = v
469         }
470         return dst
471 }