Merge branch '16723-kill-vs-requeue'
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/sirupsen/logrus"
23         "golang.org/x/crypto/ssh"
24 )
25
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
29         HostKey        ssh.Signer
30         AuthorizedKeys []ssh.PublicKey
31
32         // SetupVM, if set, is called upon creation of each new
33         // StubVM. This is the caller's opportunity to customize the
34         // VM's error rate and other behaviors.
35         SetupVM func(*StubVM)
36
37         // Bugf, if set, is called if a bug is detected in the caller
38         // or stub. Typically set to (*check.C)Errorf. If unset,
39         // logger.Warnf is called instead.
40         Bugf func(string, ...interface{})
41
42         // StubVM's fake crunch-run uses this Queue to read and update
43         // container state.
44         Queue *Queue
45
46         // Frequency of artificially introduced errors on calls to
47         // Destroy. 0=always succeed, 1=always fail.
48         ErrorRateDestroy float64
49
50         // If Create() or Instances() is called too frequently, return
51         // rate-limiting errors.
52         MinTimeBetweenCreateCalls    time.Duration
53         MinTimeBetweenInstancesCalls time.Duration
54
55         // If true, Create and Destroy calls block until Release() is
56         // called.
57         HoldCloudOps bool
58
59         instanceSets []*StubInstanceSet
60         holdCloudOps chan bool
61 }
62
63 // InstanceSet returns a new *StubInstanceSet.
64 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
65         if sd.holdCloudOps == nil {
66                 sd.holdCloudOps = make(chan bool)
67         }
68         sis := StubInstanceSet{
69                 driver:  sd,
70                 logger:  logger,
71                 servers: map[cloud.InstanceID]*StubVM{},
72         }
73         sd.instanceSets = append(sd.instanceSets, &sis)
74
75         var err error
76         if params != nil {
77                 err = json.Unmarshal(params, &sis)
78         }
79         return &sis, err
80 }
81
82 // InstanceSets returns all instances that have been created by the
83 // driver. This can be used to test a component that uses the driver
84 // but doesn't expose the InstanceSets it has created.
85 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
86         return sd.instanceSets
87 }
88
89 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
90 // are fewer than n blocked calls pending, it waits for the rest to
91 // arrive.
92 func (sd *StubDriver) ReleaseCloudOps(n int) {
93         for i := 0; i < n; i++ {
94                 <-sd.holdCloudOps
95         }
96 }
97
98 type StubInstanceSet struct {
99         driver  *StubDriver
100         logger  logrus.FieldLogger
101         servers map[cloud.InstanceID]*StubVM
102         mtx     sync.RWMutex
103         stopped bool
104
105         allowCreateCall    time.Time
106         allowInstancesCall time.Time
107         lastInstanceID     int
108 }
109
110 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
111         if sis.driver.HoldCloudOps {
112                 sis.driver.holdCloudOps <- true
113         }
114         sis.mtx.Lock()
115         defer sis.mtx.Unlock()
116         if sis.stopped {
117                 return nil, errors.New("StubInstanceSet: Create called after Stop")
118         }
119         if sis.allowCreateCall.After(time.Now()) {
120                 return nil, RateLimitError{sis.allowCreateCall}
121         }
122         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
123         ak := sis.driver.AuthorizedKeys
124         if authKey != nil {
125                 ak = append([]ssh.PublicKey{authKey}, ak...)
126         }
127         sis.lastInstanceID++
128         svm := &StubVM{
129                 sis:          sis,
130                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
131                 tags:         copyTags(tags),
132                 providerType: it.ProviderType,
133                 initCommand:  cmd,
134                 running:      map[string]int64{},
135                 killing:      map[string]bool{},
136         }
137         svm.SSHService = SSHService{
138                 HostKey:        sis.driver.HostKey,
139                 AuthorizedUser: "root",
140                 AuthorizedKeys: ak,
141                 Exec:           svm.Exec,
142         }
143         if setup := sis.driver.SetupVM; setup != nil {
144                 setup(svm)
145         }
146         sis.servers[svm.id] = svm
147         return svm.Instance(), nil
148 }
149
150 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
151         sis.mtx.RLock()
152         defer sis.mtx.RUnlock()
153         if sis.allowInstancesCall.After(time.Now()) {
154                 return nil, RateLimitError{sis.allowInstancesCall}
155         }
156         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
157         var r []cloud.Instance
158         for _, ss := range sis.servers {
159                 r = append(r, ss.Instance())
160         }
161         return r, nil
162 }
163
164 func (sis *StubInstanceSet) Stop() {
165         sis.mtx.Lock()
166         defer sis.mtx.Unlock()
167         if sis.stopped {
168                 panic("Stop called twice")
169         }
170         sis.stopped = true
171 }
172
173 type RateLimitError struct{ Retry time.Time }
174
175 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
176 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
177
178 // StubVM is a fake server that runs an SSH service. It represents a
179 // VM running in a fake cloud.
180 //
181 // Note this is distinct from a stubInstance, which is a snapshot of
182 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
183 // running (and might change IP addresses, shut down, etc.)  without
184 // updating any stubInstances that have been returned to callers.
185 type StubVM struct {
186         Boot                  time.Time
187         Broken                time.Time
188         ReportBroken          time.Time
189         CrunchRunMissing      bool
190         CrunchRunCrashRate    float64
191         CrunchRunDetachDelay  time.Duration
192         ExecuteContainer      func(arvados.Container) int
193         CrashRunningContainer func(arvados.Container)
194
195         sis          *StubInstanceSet
196         id           cloud.InstanceID
197         tags         cloud.InstanceTags
198         initCommand  cloud.InitCommand
199         providerType string
200         SSHService   SSHService
201         running      map[string]int64
202         killing      map[string]bool
203         lastPID      int64
204         sync.Mutex
205 }
206
207 func (svm *StubVM) Instance() stubInstance {
208         svm.Lock()
209         defer svm.Unlock()
210         return stubInstance{
211                 svm:  svm,
212                 addr: svm.SSHService.Address(),
213                 // We deliberately return a cached/stale copy of the
214                 // real tags here, so that (Instance)Tags() sometimes
215                 // returns old data after a call to
216                 // (Instance)SetTags().  This is permitted by the
217                 // driver interface, and this might help remind
218                 // callers that they need to tolerate it.
219                 tags: copyTags(svm.tags),
220         }
221 }
222
223 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
224         stdinData, err := ioutil.ReadAll(stdin)
225         if err != nil {
226                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
227                 return 1
228         }
229         queue := svm.sis.driver.Queue
230         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
231         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
232                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
233                 return 1
234         }
235         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
236                 fmt.Fprintf(stderr, "cannot fork\n")
237                 return 2
238         }
239         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
240                 fmt.Fprint(stderr, "crunch-run: command not found\n")
241                 return 1
242         }
243         if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
244                 var stdinKV map[string]string
245                 err := json.Unmarshal(stdinData, &stdinKV)
246                 if err != nil {
247                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
248                         return 1
249                 }
250                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
251                         if stdinKV[name] == "" {
252                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
253                                 return 1
254                         }
255                 }
256                 svm.Lock()
257                 svm.lastPID++
258                 pid := svm.lastPID
259                 svm.running[uuid] = pid
260                 svm.Unlock()
261                 time.Sleep(svm.CrunchRunDetachDelay)
262                 fmt.Fprintf(stderr, "starting %s\n", uuid)
263                 logger := svm.sis.logger.WithFields(logrus.Fields{
264                         "Instance":      svm.id,
265                         "ContainerUUID": uuid,
266                         "PID":           pid,
267                 })
268                 logger.Printf("[test] starting crunch-run stub")
269                 go func() {
270                         var ctr arvados.Container
271                         var started, completed bool
272                         defer func() {
273                                 logger.Print("[test] exiting crunch-run stub")
274                                 svm.Lock()
275                                 defer svm.Unlock()
276                                 if svm.running[uuid] != pid {
277                                         if !completed {
278                                                 bugf := svm.sis.driver.Bugf
279                                                 if bugf == nil {
280                                                         bugf = logger.Warnf
281                                                 }
282                                                 bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s]==%d", pid, uuid, svm.running[uuid])
283                                         }
284                                 } else {
285                                         delete(svm.running, uuid)
286                                 }
287                                 if !completed {
288                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
289                                         if started && svm.CrashRunningContainer != nil {
290                                                 svm.CrashRunningContainer(ctr)
291                                         }
292                                 }
293                         }()
294
295                         crashluck := math_rand.Float64()
296                         wantCrash := crashluck < svm.CrunchRunCrashRate
297                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
298
299                         ctr, ok := queue.Get(uuid)
300                         if !ok {
301                                 logger.Print("[test] container not in queue")
302                                 return
303                         }
304
305                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
306
307                         svm.Lock()
308                         killed := svm.running[uuid] != pid
309                         svm.Unlock()
310                         if killed || wantCrashEarly {
311                                 return
312                         }
313
314                         ctr.State = arvados.ContainerStateRunning
315                         started = queue.Notify(ctr)
316                         if !started {
317                                 ctr, _ = queue.Get(uuid)
318                                 logger.Print("[test] erroring out because state=Running update was rejected")
319                                 return
320                         }
321
322                         if wantCrash {
323                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
324                                 return
325                         }
326                         if svm.ExecuteContainer != nil {
327                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
328                         }
329                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
330                         ctr.State = arvados.ContainerStateComplete
331                         completed = queue.Notify(ctr)
332                 }()
333                 return 0
334         }
335         if command == "crunch-run --list" {
336                 svm.Lock()
337                 defer svm.Unlock()
338                 for uuid := range svm.running {
339                         fmt.Fprintf(stdout, "%s\n", uuid)
340                 }
341                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
342                         fmt.Fprintln(stdout, "broken")
343                 }
344                 return 0
345         }
346         if strings.HasPrefix(command, "crunch-run --kill ") {
347                 svm.Lock()
348                 pid, running := svm.running[uuid]
349                 if running && !svm.killing[uuid] {
350                         svm.killing[uuid] = true
351                         go func() {
352                                 time.Sleep(time.Duration(math_rand.Float64()*30) * time.Millisecond)
353                                 svm.Lock()
354                                 defer svm.Unlock()
355                                 if svm.running[uuid] == pid {
356                                         // Kill only if the running entry
357                                         // hasn't since been killed and
358                                         // replaced with a different one.
359                                         delete(svm.running, uuid)
360                                 }
361                                 delete(svm.killing, uuid)
362                         }()
363                         svm.Unlock()
364                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
365                         svm.Lock()
366                         _, running = svm.running[uuid]
367                 }
368                 svm.Unlock()
369                 if running {
370                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
371                         return 1
372                 }
373                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
374                 return 0
375         }
376         if command == "true" {
377                 return 0
378         }
379         fmt.Fprintf(stderr, "%q: command not found", command)
380         return 1
381 }
382
383 type stubInstance struct {
384         svm  *StubVM
385         addr string
386         tags cloud.InstanceTags
387 }
388
389 func (si stubInstance) ID() cloud.InstanceID {
390         return si.svm.id
391 }
392
393 func (si stubInstance) Address() string {
394         return si.addr
395 }
396
397 func (si stubInstance) RemoteUser() string {
398         return si.svm.SSHService.AuthorizedUser
399 }
400
401 func (si stubInstance) Destroy() error {
402         sis := si.svm.sis
403         if sis.driver.HoldCloudOps {
404                 sis.driver.holdCloudOps <- true
405         }
406         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
407                 return errors.New("instance could not be destroyed")
408         }
409         si.svm.SSHService.Close()
410         sis.mtx.Lock()
411         defer sis.mtx.Unlock()
412         delete(sis.servers, si.svm.id)
413         return nil
414 }
415
416 func (si stubInstance) ProviderType() string {
417         return si.svm.providerType
418 }
419
420 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
421         tags = copyTags(tags)
422         svm := si.svm
423         go func() {
424                 svm.Lock()
425                 defer svm.Unlock()
426                 svm.tags = tags
427         }()
428         return nil
429 }
430
431 func (si stubInstance) Tags() cloud.InstanceTags {
432         // Return a copy to ensure a caller can't change our saved
433         // tags just by writing to the returned map.
434         return copyTags(si.tags)
435 }
436
437 func (si stubInstance) String() string {
438         return string(si.svm.id)
439 }
440
441 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
442         buf := make([]byte, 512)
443         _, err := io.ReadFull(rand.Reader, buf)
444         if err != nil {
445                 return err
446         }
447         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
448         if err != nil {
449                 return err
450         }
451         return key.Verify(buf, sig)
452 }
453
454 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
455         dst := cloud.InstanceTags{}
456         for k, v := range src {
457                 dst[k] = v
458         }
459         return dst
460 }