Merge branch '15964-fix-docs' refs #15964
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/sirupsen/logrus"
23         "golang.org/x/crypto/ssh"
24 )
25
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
29         HostKey        ssh.Signer
30         AuthorizedKeys []ssh.PublicKey
31
32         // SetupVM, if set, is called upon creation of each new
33         // StubVM. This is the caller's opportunity to customize the
34         // VM's error rate and other behaviors.
35         SetupVM func(*StubVM)
36
37         // Bugf, if set, is called if a bug is detected in the caller
38         // or stub. Typically set to (*check.C)Errorf. If unset,
39         // logger.Warnf is called instead.
40         Bugf func(string, ...interface{})
41
42         // StubVM's fake crunch-run uses this Queue to read and update
43         // container state.
44         Queue *Queue
45
46         // Frequency of artificially introduced errors on calls to
47         // Destroy. 0=always succeed, 1=always fail.
48         ErrorRateDestroy float64
49
50         // If Create() or Instances() is called too frequently, return
51         // rate-limiting errors.
52         MinTimeBetweenCreateCalls    time.Duration
53         MinTimeBetweenInstancesCalls time.Duration
54
55         // If true, Create and Destroy calls block until Release() is
56         // called.
57         HoldCloudOps bool
58
59         instanceSets []*StubInstanceSet
60         holdCloudOps chan bool
61 }
62
63 // InstanceSet returns a new *StubInstanceSet.
64 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
65         if sd.holdCloudOps == nil {
66                 sd.holdCloudOps = make(chan bool)
67         }
68         sis := StubInstanceSet{
69                 driver:  sd,
70                 logger:  logger,
71                 servers: map[cloud.InstanceID]*StubVM{},
72         }
73         sd.instanceSets = append(sd.instanceSets, &sis)
74
75         var err error
76         if params != nil {
77                 err = json.Unmarshal(params, &sis)
78         }
79         return &sis, err
80 }
81
82 // InstanceSets returns all instances that have been created by the
83 // driver. This can be used to test a component that uses the driver
84 // but doesn't expose the InstanceSets it has created.
85 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
86         return sd.instanceSets
87 }
88
89 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
90 // are fewer than n blocked calls pending, it waits for the rest to
91 // arrive.
92 func (sd *StubDriver) ReleaseCloudOps(n int) {
93         for i := 0; i < n; i++ {
94                 <-sd.holdCloudOps
95         }
96 }
97
98 type StubInstanceSet struct {
99         driver  *StubDriver
100         logger  logrus.FieldLogger
101         servers map[cloud.InstanceID]*StubVM
102         mtx     sync.RWMutex
103         stopped bool
104
105         allowCreateCall    time.Time
106         allowInstancesCall time.Time
107         lastInstanceID     int
108 }
109
110 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
111         if sis.driver.HoldCloudOps {
112                 sis.driver.holdCloudOps <- true
113         }
114         sis.mtx.Lock()
115         defer sis.mtx.Unlock()
116         if sis.stopped {
117                 return nil, errors.New("StubInstanceSet: Create called after Stop")
118         }
119         if sis.allowCreateCall.After(time.Now()) {
120                 return nil, RateLimitError{sis.allowCreateCall}
121         } else {
122                 sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
123         }
124
125         ak := sis.driver.AuthorizedKeys
126         if authKey != nil {
127                 ak = append([]ssh.PublicKey{authKey}, ak...)
128         }
129         sis.lastInstanceID++
130         svm := &StubVM{
131                 sis:          sis,
132                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
133                 tags:         copyTags(tags),
134                 providerType: it.ProviderType,
135                 initCommand:  cmd,
136                 running:      map[string]int64{},
137                 killing:      map[string]bool{},
138         }
139         svm.SSHService = SSHService{
140                 HostKey:        sis.driver.HostKey,
141                 AuthorizedUser: "root",
142                 AuthorizedKeys: ak,
143                 Exec:           svm.Exec,
144         }
145         if setup := sis.driver.SetupVM; setup != nil {
146                 setup(svm)
147         }
148         sis.servers[svm.id] = svm
149         return svm.Instance(), nil
150 }
151
152 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
153         sis.mtx.RLock()
154         defer sis.mtx.RUnlock()
155         if sis.allowInstancesCall.After(time.Now()) {
156                 return nil, RateLimitError{sis.allowInstancesCall}
157         } else {
158                 sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
159         }
160         var r []cloud.Instance
161         for _, ss := range sis.servers {
162                 r = append(r, ss.Instance())
163         }
164         return r, nil
165 }
166
167 func (sis *StubInstanceSet) Stop() {
168         sis.mtx.Lock()
169         defer sis.mtx.Unlock()
170         if sis.stopped {
171                 panic("Stop called twice")
172         }
173         sis.stopped = true
174 }
175
176 type RateLimitError struct{ Retry time.Time }
177
178 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
179 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
180
181 // StubVM is a fake server that runs an SSH service. It represents a
182 // VM running in a fake cloud.
183 //
184 // Note this is distinct from a stubInstance, which is a snapshot of
185 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
186 // running (and might change IP addresses, shut down, etc.)  without
187 // updating any stubInstances that have been returned to callers.
188 type StubVM struct {
189         Boot                  time.Time
190         Broken                time.Time
191         ReportBroken          time.Time
192         CrunchRunMissing      bool
193         CrunchRunCrashRate    float64
194         CrunchRunDetachDelay  time.Duration
195         ExecuteContainer      func(arvados.Container) int
196         CrashRunningContainer func(arvados.Container)
197
198         sis          *StubInstanceSet
199         id           cloud.InstanceID
200         tags         cloud.InstanceTags
201         initCommand  cloud.InitCommand
202         providerType string
203         SSHService   SSHService
204         running      map[string]int64
205         killing      map[string]bool
206         lastPID      int64
207         sync.Mutex
208 }
209
210 func (svm *StubVM) Instance() stubInstance {
211         svm.Lock()
212         defer svm.Unlock()
213         return stubInstance{
214                 svm:  svm,
215                 addr: svm.SSHService.Address(),
216                 // We deliberately return a cached/stale copy of the
217                 // real tags here, so that (Instance)Tags() sometimes
218                 // returns old data after a call to
219                 // (Instance)SetTags().  This is permitted by the
220                 // driver interface, and this might help remind
221                 // callers that they need to tolerate it.
222                 tags: copyTags(svm.tags),
223         }
224 }
225
226 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
227         stdinData, err := ioutil.ReadAll(stdin)
228         if err != nil {
229                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
230                 return 1
231         }
232         queue := svm.sis.driver.Queue
233         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
234         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
235                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
236                 return 1
237         }
238         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
239                 fmt.Fprintf(stderr, "cannot fork\n")
240                 return 2
241         }
242         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
243                 fmt.Fprint(stderr, "crunch-run: command not found\n")
244                 return 1
245         }
246         if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
247                 var stdinKV map[string]string
248                 err := json.Unmarshal(stdinData, &stdinKV)
249                 if err != nil {
250                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
251                         return 1
252                 }
253                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
254                         if stdinKV[name] == "" {
255                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
256                                 return 1
257                         }
258                 }
259                 svm.Lock()
260                 svm.lastPID++
261                 pid := svm.lastPID
262                 svm.running[uuid] = pid
263                 svm.Unlock()
264                 time.Sleep(svm.CrunchRunDetachDelay)
265                 fmt.Fprintf(stderr, "starting %s\n", uuid)
266                 logger := svm.sis.logger.WithFields(logrus.Fields{
267                         "Instance":      svm.id,
268                         "ContainerUUID": uuid,
269                         "PID":           pid,
270                 })
271                 logger.Printf("[test] starting crunch-run stub")
272                 go func() {
273                         var ctr arvados.Container
274                         var started, completed bool
275                         defer func() {
276                                 logger.Print("[test] exiting crunch-run stub")
277                                 svm.Lock()
278                                 defer svm.Unlock()
279                                 if svm.running[uuid] != pid {
280                                         if !completed {
281                                                 bugf := svm.sis.driver.Bugf
282                                                 if bugf == nil {
283                                                         bugf = logger.Warnf
284                                                 }
285                                                 bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s]==%d", pid, uuid, svm.running[uuid])
286                                         }
287                                 } else {
288                                         delete(svm.running, uuid)
289                                 }
290                                 if !completed {
291                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
292                                         if started && svm.CrashRunningContainer != nil {
293                                                 svm.CrashRunningContainer(ctr)
294                                         }
295                                 }
296                         }()
297
298                         crashluck := math_rand.Float64()
299                         wantCrash := crashluck < svm.CrunchRunCrashRate
300                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
301
302                         ctr, ok := queue.Get(uuid)
303                         if !ok {
304                                 logger.Print("[test] container not in queue")
305                                 return
306                         }
307
308                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
309
310                         svm.Lock()
311                         killed := svm.running[uuid] != pid
312                         svm.Unlock()
313                         if killed || wantCrashEarly {
314                                 return
315                         }
316
317                         ctr.State = arvados.ContainerStateRunning
318                         started = queue.Notify(ctr)
319                         if !started {
320                                 ctr, _ = queue.Get(uuid)
321                                 logger.Print("[test] erroring out because state=Running update was rejected")
322                                 return
323                         }
324
325                         if wantCrash {
326                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
327                                 return
328                         }
329                         if svm.ExecuteContainer != nil {
330                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
331                         }
332                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
333                         ctr.State = arvados.ContainerStateComplete
334                         completed = queue.Notify(ctr)
335                 }()
336                 return 0
337         }
338         if command == "crunch-run --list" {
339                 svm.Lock()
340                 defer svm.Unlock()
341                 for uuid := range svm.running {
342                         fmt.Fprintf(stdout, "%s\n", uuid)
343                 }
344                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
345                         fmt.Fprintln(stdout, "broken")
346                 }
347                 return 0
348         }
349         if strings.HasPrefix(command, "crunch-run --kill ") {
350                 svm.Lock()
351                 pid, running := svm.running[uuid]
352                 if running && !svm.killing[uuid] {
353                         svm.killing[uuid] = true
354                         go func() {
355                                 time.Sleep(time.Duration(math_rand.Float64()*30) * time.Millisecond)
356                                 svm.Lock()
357                                 defer svm.Unlock()
358                                 if svm.running[uuid] == pid {
359                                         // Kill only if the running entry
360                                         // hasn't since been killed and
361                                         // replaced with a different one.
362                                         delete(svm.running, uuid)
363                                 }
364                                 delete(svm.killing, uuid)
365                         }()
366                         svm.Unlock()
367                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
368                         svm.Lock()
369                         _, running = svm.running[uuid]
370                 }
371                 svm.Unlock()
372                 if running {
373                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
374                         return 1
375                 } else {
376                         fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
377                         return 0
378                 }
379         }
380         if command == "true" {
381                 return 0
382         }
383         fmt.Fprintf(stderr, "%q: command not found", command)
384         return 1
385 }
386
387 type stubInstance struct {
388         svm  *StubVM
389         addr string
390         tags cloud.InstanceTags
391 }
392
393 func (si stubInstance) ID() cloud.InstanceID {
394         return si.svm.id
395 }
396
397 func (si stubInstance) Address() string {
398         return si.addr
399 }
400
401 func (si stubInstance) RemoteUser() string {
402         return si.svm.SSHService.AuthorizedUser
403 }
404
405 func (si stubInstance) Destroy() error {
406         sis := si.svm.sis
407         if sis.driver.HoldCloudOps {
408                 sis.driver.holdCloudOps <- true
409         }
410         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
411                 return errors.New("instance could not be destroyed")
412         }
413         si.svm.SSHService.Close()
414         sis.mtx.Lock()
415         defer sis.mtx.Unlock()
416         delete(sis.servers, si.svm.id)
417         return nil
418 }
419
420 func (si stubInstance) ProviderType() string {
421         return si.svm.providerType
422 }
423
424 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
425         tags = copyTags(tags)
426         svm := si.svm
427         go func() {
428                 svm.Lock()
429                 defer svm.Unlock()
430                 svm.tags = tags
431         }()
432         return nil
433 }
434
435 func (si stubInstance) Tags() cloud.InstanceTags {
436         // Return a copy to ensure a caller can't change our saved
437         // tags just by writing to the returned map.
438         return copyTags(si.tags)
439 }
440
441 func (si stubInstance) String() string {
442         return string(si.svm.id)
443 }
444
445 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
446         buf := make([]byte, 512)
447         _, err := io.ReadFull(rand.Reader, buf)
448         if err != nil {
449                 return err
450         }
451         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
452         if err != nil {
453                 return err
454         }
455         return key.Verify(buf, sig)
456 }
457
458 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
459         dst := cloud.InstanceTags{}
460         for k, v := range src {
461                 dst[k] = v
462         }
463         return dst
464 }