20755: Merge branch 'main' into 20755-ec2-multiple-subnets
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/lib/crunchrun"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/prometheus/client_golang/prometheus"
24         "github.com/sirupsen/logrus"
25         "golang.org/x/crypto/ssh"
26 )
27
28 // A StubDriver implements cloud.Driver by setting up local SSH
29 // servers that do fake command executions.
30 type StubDriver struct {
31         HostKey        ssh.Signer
32         AuthorizedKeys []ssh.PublicKey
33
34         // SetupVM, if set, is called upon creation of each new
35         // StubVM. This is the caller's opportunity to customize the
36         // VM's error rate and other behaviors.
37         SetupVM func(*StubVM)
38
39         // Bugf, if set, is called if a bug is detected in the caller
40         // or stub. Typically set to (*check.C)Errorf. If unset,
41         // logger.Warnf is called instead.
42         Bugf func(string, ...interface{})
43
44         // StubVM's fake crunch-run uses this Queue to read and update
45         // container state.
46         Queue *Queue
47
48         // Frequency of artificially introduced errors on calls to
49         // Create and Destroy. 0=always succeed, 1=always fail.
50         ErrorRateCreate  float64
51         ErrorRateDestroy float64
52
53         // If Create() or Instances() is called too frequently, return
54         // rate-limiting errors.
55         MinTimeBetweenCreateCalls    time.Duration
56         MinTimeBetweenInstancesCalls time.Duration
57
58         // If true, Create and Destroy calls block until Release() is
59         // called.
60         HoldCloudOps bool
61
62         instanceSets []*StubInstanceSet
63         holdCloudOps chan bool
64 }
65
66 // InstanceSet returns a new *StubInstanceSet.
67 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (cloud.InstanceSet, error) {
68         if sd.holdCloudOps == nil {
69                 sd.holdCloudOps = make(chan bool)
70         }
71         sis := StubInstanceSet{
72                 driver:  sd,
73                 logger:  logger,
74                 servers: map[cloud.InstanceID]*StubVM{},
75         }
76         sd.instanceSets = append(sd.instanceSets, &sis)
77
78         var err error
79         if params != nil {
80                 err = json.Unmarshal(params, &sis)
81         }
82         return &sis, err
83 }
84
85 // InstanceSets returns all instances that have been created by the
86 // driver. This can be used to test a component that uses the driver
87 // but doesn't expose the InstanceSets it has created.
88 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
89         return sd.instanceSets
90 }
91
92 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
93 // are fewer than n blocked calls pending, it waits for the rest to
94 // arrive.
95 func (sd *StubDriver) ReleaseCloudOps(n int) {
96         for i := 0; i < n; i++ {
97                 <-sd.holdCloudOps
98         }
99 }
100
101 type StubInstanceSet struct {
102         driver  *StubDriver
103         logger  logrus.FieldLogger
104         servers map[cloud.InstanceID]*StubVM
105         mtx     sync.RWMutex
106         stopped bool
107
108         allowCreateCall    time.Time
109         allowInstancesCall time.Time
110         lastInstanceID     int
111 }
112
113 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, initCommand cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
114         if sis.driver.HoldCloudOps {
115                 sis.driver.holdCloudOps <- true
116         }
117         sis.mtx.Lock()
118         defer sis.mtx.Unlock()
119         if sis.stopped {
120                 return nil, errors.New("StubInstanceSet: Create called after Stop")
121         }
122         if sis.allowCreateCall.After(time.Now()) {
123                 return nil, RateLimitError{sis.allowCreateCall}
124         }
125         if math_rand.Float64() < sis.driver.ErrorRateCreate {
126                 return nil, fmt.Errorf("StubInstanceSet: rand < ErrorRateCreate %f", sis.driver.ErrorRateCreate)
127         }
128         sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
129         ak := sis.driver.AuthorizedKeys
130         if authKey != nil {
131                 ak = append([]ssh.PublicKey{authKey}, ak...)
132         }
133         sis.lastInstanceID++
134         svm := &StubVM{
135                 InitCommand:  initCommand,
136                 sis:          sis,
137                 id:           cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
138                 tags:         copyTags(tags),
139                 providerType: it.ProviderType,
140                 running:      map[string]stubProcess{},
141                 killing:      map[string]bool{},
142         }
143         svm.SSHService = SSHService{
144                 HostKey:        sis.driver.HostKey,
145                 AuthorizedUser: "root",
146                 AuthorizedKeys: ak,
147                 Exec:           svm.Exec,
148         }
149         if setup := sis.driver.SetupVM; setup != nil {
150                 setup(svm)
151         }
152         sis.servers[svm.id] = svm
153         return svm.Instance(), nil
154 }
155
156 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
157         sis.mtx.RLock()
158         defer sis.mtx.RUnlock()
159         if sis.allowInstancesCall.After(time.Now()) {
160                 return nil, RateLimitError{sis.allowInstancesCall}
161         }
162         sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
163         var r []cloud.Instance
164         for _, ss := range sis.servers {
165                 r = append(r, ss.Instance())
166         }
167         return r, nil
168 }
169
170 func (sis *StubInstanceSet) Stop() {
171         sis.mtx.Lock()
172         defer sis.mtx.Unlock()
173         if sis.stopped {
174                 panic("Stop called twice")
175         }
176         sis.stopped = true
177 }
178
179 func (sis *StubInstanceSet) StubVMs() (svms []*StubVM) {
180         sis.mtx.Lock()
181         defer sis.mtx.Unlock()
182         for _, vm := range sis.servers {
183                 svms = append(svms, vm)
184         }
185         return
186 }
187
188 type RateLimitError struct{ Retry time.Time }
189
190 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
191 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
192
193 // StubVM is a fake server that runs an SSH service. It represents a
194 // VM running in a fake cloud.
195 //
196 // Note this is distinct from a stubInstance, which is a snapshot of
197 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
198 // running (and might change IP addresses, shut down, etc.)  without
199 // updating any stubInstances that have been returned to callers.
200 type StubVM struct {
201         Boot                  time.Time
202         Broken                time.Time
203         ReportBroken          time.Time
204         CrunchRunMissing      bool
205         CrunchRunCrashRate    float64
206         CrunchRunDetachDelay  time.Duration
207         ArvMountMaxExitLag    time.Duration
208         ArvMountDeadlockRate  float64
209         ExecuteContainer      func(arvados.Container) int
210         CrashRunningContainer func(arvados.Container)
211         ExtraCrunchRunArgs    string // extra args expected after "crunch-run --detach --stdin-config "
212
213         // Populated by (*StubInstanceSet)Create()
214         InitCommand cloud.InitCommand
215
216         sis          *StubInstanceSet
217         id           cloud.InstanceID
218         tags         cloud.InstanceTags
219         providerType string
220         SSHService   SSHService
221         running      map[string]stubProcess
222         killing      map[string]bool
223         lastPID      int64
224         deadlocked   string
225         sync.Mutex
226 }
227
228 type stubProcess struct {
229         pid int64
230
231         // crunch-run has exited, but arv-mount process (or something)
232         // still holds lock in /var/run/
233         exited bool
234 }
235
236 func (svm *StubVM) Instance() stubInstance {
237         svm.Lock()
238         defer svm.Unlock()
239         return stubInstance{
240                 svm:  svm,
241                 addr: svm.SSHService.Address(),
242                 // We deliberately return a cached/stale copy of the
243                 // real tags here, so that (Instance)Tags() sometimes
244                 // returns old data after a call to
245                 // (Instance)SetTags().  This is permitted by the
246                 // driver interface, and this might help remind
247                 // callers that they need to tolerate it.
248                 tags: copyTags(svm.tags),
249         }
250 }
251
252 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
253         stdinData, err := ioutil.ReadAll(stdin)
254         if err != nil {
255                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
256                 return 1
257         }
258         queue := svm.sis.driver.Queue
259         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
260         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
261                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
262                 return 1
263         }
264         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
265                 fmt.Fprintf(stderr, "cannot fork\n")
266                 return 2
267         }
268         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
269                 fmt.Fprint(stderr, "crunch-run: command not found\n")
270                 return 1
271         }
272         if strings.HasPrefix(command, "crunch-run --detach --stdin-config "+svm.ExtraCrunchRunArgs) {
273                 var configData crunchrun.ConfigData
274                 err := json.Unmarshal(stdinData, &configData)
275                 if err != nil {
276                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
277                         return 1
278                 }
279                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
280                         if configData.Env[name] == "" {
281                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
282                                 return 1
283                         }
284                 }
285                 svm.Lock()
286                 svm.lastPID++
287                 pid := svm.lastPID
288                 svm.running[uuid] = stubProcess{pid: pid}
289                 svm.Unlock()
290                 time.Sleep(svm.CrunchRunDetachDelay)
291                 fmt.Fprintf(stderr, "starting %s\n", uuid)
292                 logger := svm.sis.logger.WithFields(logrus.Fields{
293                         "Instance":      svm.id,
294                         "ContainerUUID": uuid,
295                         "PID":           pid,
296                 })
297                 logger.Printf("[test] starting crunch-run stub")
298                 go func() {
299                         var ctr arvados.Container
300                         var started, completed bool
301                         defer func() {
302                                 logger.Print("[test] exiting crunch-run stub")
303                                 svm.Lock()
304                                 defer svm.Unlock()
305                                 if svm.running[uuid].pid != pid {
306                                         bugf := svm.sis.driver.Bugf
307                                         if bugf == nil {
308                                                 bugf = logger.Warnf
309                                         }
310                                         bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s].pid==%d", pid, uuid, svm.running[uuid].pid)
311                                         return
312                                 }
313                                 if !completed {
314                                         logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
315                                         if started && svm.CrashRunningContainer != nil {
316                                                 svm.CrashRunningContainer(ctr)
317                                         }
318                                 }
319                                 sproc := svm.running[uuid]
320                                 sproc.exited = true
321                                 svm.running[uuid] = sproc
322                                 svm.Unlock()
323                                 time.Sleep(svm.ArvMountMaxExitLag * time.Duration(math_rand.Float64()))
324                                 svm.Lock()
325                                 if math_rand.Float64() >= svm.ArvMountDeadlockRate {
326                                         delete(svm.running, uuid)
327                                 }
328                         }()
329
330                         crashluck := math_rand.Float64()
331                         wantCrash := crashluck < svm.CrunchRunCrashRate
332                         wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
333
334                         ctr, ok := queue.Get(uuid)
335                         if !ok {
336                                 logger.Print("[test] container not in queue")
337                                 return
338                         }
339
340                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
341
342                         svm.Lock()
343                         killed := svm.killing[uuid]
344                         svm.Unlock()
345                         if killed || wantCrashEarly {
346                                 return
347                         }
348
349                         ctr.State = arvados.ContainerStateRunning
350                         started = queue.Notify(ctr)
351                         if !started {
352                                 ctr, _ = queue.Get(uuid)
353                                 logger.Print("[test] erroring out because state=Running update was rejected")
354                                 return
355                         }
356
357                         if wantCrash {
358                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
359                                 return
360                         }
361                         if svm.ExecuteContainer != nil {
362                                 ctr.ExitCode = svm.ExecuteContainer(ctr)
363                         }
364                         logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
365                         ctr.State = arvados.ContainerStateComplete
366                         completed = queue.Notify(ctr)
367                 }()
368                 return 0
369         }
370         if command == "crunch-run --list" {
371                 svm.Lock()
372                 defer svm.Unlock()
373                 for uuid, sproc := range svm.running {
374                         if sproc.exited {
375                                 fmt.Fprintf(stdout, "%s stale\n", uuid)
376                         } else {
377                                 fmt.Fprintf(stdout, "%s\n", uuid)
378                         }
379                 }
380                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
381                         fmt.Fprintln(stdout, "broken")
382                 }
383                 fmt.Fprintln(stdout, svm.deadlocked)
384                 return 0
385         }
386         if strings.HasPrefix(command, "crunch-run --kill ") {
387                 svm.Lock()
388                 sproc, running := svm.running[uuid]
389                 if running && !sproc.exited {
390                         svm.killing[uuid] = true
391                         svm.Unlock()
392                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
393                         svm.Lock()
394                         sproc, running = svm.running[uuid]
395                 }
396                 svm.Unlock()
397                 if running && !sproc.exited {
398                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
399                         return 1
400                 }
401                 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
402                 return 0
403         }
404         if command == "true" {
405                 return 0
406         }
407         fmt.Fprintf(stderr, "%q: command not found", command)
408         return 1
409 }
410
411 type stubInstance struct {
412         svm  *StubVM
413         addr string
414         tags cloud.InstanceTags
415 }
416
417 func (si stubInstance) ID() cloud.InstanceID {
418         return si.svm.id
419 }
420
421 func (si stubInstance) Address() string {
422         return si.addr
423 }
424
425 func (si stubInstance) RemoteUser() string {
426         return si.svm.SSHService.AuthorizedUser
427 }
428
429 func (si stubInstance) Destroy() error {
430         sis := si.svm.sis
431         if sis.driver.HoldCloudOps {
432                 sis.driver.holdCloudOps <- true
433         }
434         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
435                 return errors.New("instance could not be destroyed")
436         }
437         si.svm.SSHService.Close()
438         sis.mtx.Lock()
439         defer sis.mtx.Unlock()
440         delete(sis.servers, si.svm.id)
441         return nil
442 }
443
444 func (si stubInstance) ProviderType() string {
445         return si.svm.providerType
446 }
447
448 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
449         tags = copyTags(tags)
450         svm := si.svm
451         go func() {
452                 svm.Lock()
453                 defer svm.Unlock()
454                 svm.tags = tags
455         }()
456         return nil
457 }
458
459 func (si stubInstance) Tags() cloud.InstanceTags {
460         // Return a copy to ensure a caller can't change our saved
461         // tags just by writing to the returned map.
462         return copyTags(si.tags)
463 }
464
465 func (si stubInstance) String() string {
466         return string(si.svm.id)
467 }
468
469 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
470         buf := make([]byte, 512)
471         _, err := io.ReadFull(rand.Reader, buf)
472         if err != nil {
473                 return err
474         }
475         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
476         if err != nil {
477                 return err
478         }
479         return key.Verify(buf, sig)
480 }
481
482 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
483         dst := cloud.InstanceTags{}
484         for k, v := range src {
485                 dst[k] = v
486         }
487         return dst
488 }
489
490 func (si stubInstance) PriceHistory(arvados.InstanceType) []cloud.InstancePrice {
491         return nil
492 }