1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
20 "git.arvados.org/arvados.git/lib/cloud"
21 "git.arvados.org/arvados.git/sdk/go/arvados"
22 "github.com/sirupsen/logrus"
23 "golang.org/x/crypto/ssh"
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
30 AuthorizedKeys []ssh.PublicKey
32 // SetupVM, if set, is called upon creation of each new
33 // StubVM. This is the caller's opportunity to customize the
34 // VM's error rate and other behaviors.
37 // Bugf, if set, is called if a bug is detected in the caller
38 // or stub. Typically set to (*check.C)Errorf. If unset,
39 // logger.Warnf is called instead.
40 Bugf func(string, ...interface{})
42 // StubVM's fake crunch-run uses this Queue to read and update
46 // Frequency of artificially introduced errors on calls to
47 // Destroy. 0=always succeed, 1=always fail.
48 ErrorRateDestroy float64
50 // If Create() or Instances() is called too frequently, return
51 // rate-limiting errors.
52 MinTimeBetweenCreateCalls time.Duration
53 MinTimeBetweenInstancesCalls time.Duration
55 // If true, Create and Destroy calls block until Release() is
59 instanceSets []*StubInstanceSet
60 holdCloudOps chan bool
63 // InstanceSet returns a new *StubInstanceSet.
64 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
65 if sd.holdCloudOps == nil {
66 sd.holdCloudOps = make(chan bool)
68 sis := StubInstanceSet{
71 servers: map[cloud.InstanceID]*StubVM{},
73 sd.instanceSets = append(sd.instanceSets, &sis)
77 err = json.Unmarshal(params, &sis)
82 // InstanceSets returns all instances that have been created by the
83 // driver. This can be used to test a component that uses the driver
84 // but doesn't expose the InstanceSets it has created.
85 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
86 return sd.instanceSets
89 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
90 // are fewer than n blocked calls pending, it waits for the rest to
92 func (sd *StubDriver) ReleaseCloudOps(n int) {
93 for i := 0; i < n; i++ {
98 type StubInstanceSet struct {
100 logger logrus.FieldLogger
101 servers map[cloud.InstanceID]*StubVM
105 allowCreateCall time.Time
106 allowInstancesCall time.Time
110 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
111 if sis.driver.HoldCloudOps {
112 sis.driver.holdCloudOps <- true
115 defer sis.mtx.Unlock()
117 return nil, errors.New("StubInstanceSet: Create called after Stop")
119 if sis.allowCreateCall.After(time.Now()) {
120 return nil, RateLimitError{sis.allowCreateCall}
122 sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
125 ak := sis.driver.AuthorizedKeys
127 ak = append([]ssh.PublicKey{authKey}, ak...)
132 id: cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
133 tags: copyTags(tags),
134 providerType: it.ProviderType,
136 running: map[string]int64{},
137 killing: map[string]bool{},
139 svm.SSHService = SSHService{
140 HostKey: sis.driver.HostKey,
141 AuthorizedUser: "root",
145 if setup := sis.driver.SetupVM; setup != nil {
148 sis.servers[svm.id] = svm
149 return svm.Instance(), nil
152 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
154 defer sis.mtx.RUnlock()
155 if sis.allowInstancesCall.After(time.Now()) {
156 return nil, RateLimitError{sis.allowInstancesCall}
158 sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
160 var r []cloud.Instance
161 for _, ss := range sis.servers {
162 r = append(r, ss.Instance())
167 func (sis *StubInstanceSet) Stop() {
169 defer sis.mtx.Unlock()
171 panic("Stop called twice")
176 type RateLimitError struct{ Retry time.Time }
178 func (e RateLimitError) Error() string { return fmt.Sprintf("rate limited until %s", e.Retry) }
179 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
181 // StubVM is a fake server that runs an SSH service. It represents a
182 // VM running in a fake cloud.
184 // Note this is distinct from a stubInstance, which is a snapshot of
185 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
186 // running (and might change IP addresses, shut down, etc.) without
187 // updating any stubInstances that have been returned to callers.
191 ReportBroken time.Time
192 CrunchRunMissing bool
193 CrunchRunCrashRate float64
194 CrunchRunDetachDelay time.Duration
195 ExecuteContainer func(arvados.Container) int
196 CrashRunningContainer func(arvados.Container)
200 tags cloud.InstanceTags
201 initCommand cloud.InitCommand
203 SSHService SSHService
204 running map[string]int64
205 killing map[string]bool
210 func (svm *StubVM) Instance() stubInstance {
215 addr: svm.SSHService.Address(),
216 // We deliberately return a cached/stale copy of the
217 // real tags here, so that (Instance)Tags() sometimes
218 // returns old data after a call to
219 // (Instance)SetTags(). This is permitted by the
220 // driver interface, and this might help remind
221 // callers that they need to tolerate it.
222 tags: copyTags(svm.tags),
226 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
227 stdinData, err := ioutil.ReadAll(stdin)
229 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
232 queue := svm.sis.driver.Queue
233 uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
234 if eta := svm.Boot.Sub(time.Now()); eta > 0 {
235 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
238 if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
239 fmt.Fprintf(stderr, "cannot fork\n")
242 if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
243 fmt.Fprint(stderr, "crunch-run: command not found\n")
246 if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
247 var stdinKV map[string]string
248 err := json.Unmarshal(stdinData, &stdinKV)
250 fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
253 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
254 if stdinKV[name] == "" {
255 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
262 svm.running[uuid] = pid
264 time.Sleep(svm.CrunchRunDetachDelay)
265 fmt.Fprintf(stderr, "starting %s\n", uuid)
266 logger := svm.sis.logger.WithFields(logrus.Fields{
268 "ContainerUUID": uuid,
271 logger.Printf("[test] starting crunch-run stub")
273 var ctr arvados.Container
274 var started, completed bool
276 logger.Print("[test] exiting crunch-run stub")
279 if svm.running[uuid] != pid {
281 bugf := svm.sis.driver.Bugf
285 bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s]==%d", pid, uuid, svm.running[uuid])
288 delete(svm.running, uuid)
291 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
292 if started && svm.CrashRunningContainer != nil {
293 svm.CrashRunningContainer(ctr)
298 crashluck := math_rand.Float64()
299 wantCrash := crashluck < svm.CrunchRunCrashRate
300 wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
302 ctr, ok := queue.Get(uuid)
304 logger.Print("[test] container not in queue")
308 time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
311 killed := svm.running[uuid] != pid
313 if killed || wantCrashEarly {
317 ctr.State = arvados.ContainerStateRunning
318 started = queue.Notify(ctr)
320 ctr, _ = queue.Get(uuid)
321 logger.Print("[test] erroring out because state=Running update was rejected")
326 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
329 if svm.ExecuteContainer != nil {
330 ctr.ExitCode = svm.ExecuteContainer(ctr)
332 logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
333 ctr.State = arvados.ContainerStateComplete
334 completed = queue.Notify(ctr)
338 if command == "crunch-run --list" {
341 for uuid := range svm.running {
342 fmt.Fprintf(stdout, "%s\n", uuid)
344 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
345 fmt.Fprintln(stdout, "broken")
349 if strings.HasPrefix(command, "crunch-run --kill ") {
351 pid, running := svm.running[uuid]
352 if running && !svm.killing[uuid] {
353 svm.killing[uuid] = true
355 time.Sleep(time.Duration(math_rand.Float64()*30) * time.Millisecond)
358 if svm.running[uuid] == pid {
359 // Kill only if the running entry
360 // hasn't since been killed and
361 // replaced with a different one.
362 delete(svm.running, uuid)
364 delete(svm.killing, uuid)
367 time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
369 _, running = svm.running[uuid]
373 fmt.Fprintf(stderr, "%s: container is running\n", uuid)
376 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
380 if command == "true" {
383 fmt.Fprintf(stderr, "%q: command not found", command)
387 type stubInstance struct {
390 tags cloud.InstanceTags
393 func (si stubInstance) ID() cloud.InstanceID {
397 func (si stubInstance) Address() string {
401 func (si stubInstance) RemoteUser() string {
402 return si.svm.SSHService.AuthorizedUser
405 func (si stubInstance) Destroy() error {
407 if sis.driver.HoldCloudOps {
408 sis.driver.holdCloudOps <- true
410 if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
411 return errors.New("instance could not be destroyed")
413 si.svm.SSHService.Close()
415 defer sis.mtx.Unlock()
416 delete(sis.servers, si.svm.id)
420 func (si stubInstance) ProviderType() string {
421 return si.svm.providerType
424 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
425 tags = copyTags(tags)
435 func (si stubInstance) Tags() cloud.InstanceTags {
436 // Return a copy to ensure a caller can't change our saved
437 // tags just by writing to the returned map.
438 return copyTags(si.tags)
441 func (si stubInstance) String() string {
442 return string(si.svm.id)
445 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
446 buf := make([]byte, 512)
447 _, err := io.ReadFull(rand.Reader, buf)
451 sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
455 return key.Verify(buf, sig)
458 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
459 dst := cloud.InstanceTags{}
460 for k, v := range src {