1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
20 "git.arvados.org/arvados.git/lib/cloud"
21 "git.arvados.org/arvados.git/sdk/go/arvados"
22 "github.com/sirupsen/logrus"
23 "golang.org/x/crypto/ssh"
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
30 AuthorizedKeys []ssh.PublicKey
32 // SetupVM, if set, is called upon creation of each new
33 // StubVM. This is the caller's opportunity to customize the
34 // VM's error rate and other behaviors.
37 // Bugf, if set, is called if a bug is detected in the caller
38 // or stub. Typically set to (*check.C)Errorf. If unset,
39 // logger.Warnf is called instead.
40 Bugf func(string, ...interface{})
42 // StubVM's fake crunch-run uses this Queue to read and update
46 // Frequency of artificially introduced errors on calls to
47 // Destroy. 0=always succeed, 1=always fail.
48 ErrorRateDestroy float64
50 // If Create() or Instances() is called too frequently, return
51 // rate-limiting errors.
52 MinTimeBetweenCreateCalls time.Duration
53 MinTimeBetweenInstancesCalls time.Duration
55 // If true, Create and Destroy calls block until Release() is
59 instanceSets []*StubInstanceSet
60 holdCloudOps chan bool
63 // InstanceSet returns a new *StubInstanceSet.
64 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
65 if sd.holdCloudOps == nil {
66 sd.holdCloudOps = make(chan bool)
68 sis := StubInstanceSet{
71 servers: map[cloud.InstanceID]*StubVM{},
73 sd.instanceSets = append(sd.instanceSets, &sis)
77 err = json.Unmarshal(params, &sis)
82 // InstanceSets returns all instances that have been created by the
83 // driver. This can be used to test a component that uses the driver
84 // but doesn't expose the InstanceSets it has created.
85 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
86 return sd.instanceSets
89 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
90 // are fewer than n blocked calls pending, it waits for the rest to
92 func (sd *StubDriver) ReleaseCloudOps(n int) {
93 for i := 0; i < n; i++ {
98 type StubInstanceSet struct {
100 logger logrus.FieldLogger
101 servers map[cloud.InstanceID]*StubVM
105 allowCreateCall time.Time
106 allowInstancesCall time.Time
110 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
111 if sis.driver.HoldCloudOps {
112 sis.driver.holdCloudOps <- true
115 defer sis.mtx.Unlock()
117 return nil, errors.New("StubInstanceSet: Create called after Stop")
119 if sis.allowCreateCall.After(time.Now()) {
120 return nil, RateLimitError{sis.allowCreateCall}
122 sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
123 ak := sis.driver.AuthorizedKeys
125 ak = append([]ssh.PublicKey{authKey}, ak...)
130 id: cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
131 tags: copyTags(tags),
132 providerType: it.ProviderType,
134 running: map[string]int64{},
135 killing: map[string]bool{},
137 svm.SSHService = SSHService{
138 HostKey: sis.driver.HostKey,
139 AuthorizedUser: "root",
143 if setup := sis.driver.SetupVM; setup != nil {
146 sis.servers[svm.id] = svm
147 return svm.Instance(), nil
150 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
152 defer sis.mtx.RUnlock()
153 if sis.allowInstancesCall.After(time.Now()) {
154 return nil, RateLimitError{sis.allowInstancesCall}
156 sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
157 var r []cloud.Instance
158 for _, ss := range sis.servers {
159 r = append(r, ss.Instance())
164 func (sis *StubInstanceSet) Stop() {
166 defer sis.mtx.Unlock()
168 panic("Stop called twice")
173 type RateLimitError struct{ Retry time.Time }
175 func (e RateLimitError) Error() string { return fmt.Sprintf("rate limited until %s", e.Retry) }
176 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
178 // StubVM is a fake server that runs an SSH service. It represents a
179 // VM running in a fake cloud.
181 // Note this is distinct from a stubInstance, which is a snapshot of
182 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
183 // running (and might change IP addresses, shut down, etc.) without
184 // updating any stubInstances that have been returned to callers.
188 ReportBroken time.Time
189 CrunchRunMissing bool
190 CrunchRunCrashRate float64
191 CrunchRunDetachDelay time.Duration
192 ExecuteContainer func(arvados.Container) int
193 CrashRunningContainer func(arvados.Container)
197 tags cloud.InstanceTags
198 initCommand cloud.InitCommand
200 SSHService SSHService
201 running map[string]int64
202 killing map[string]bool
207 func (svm *StubVM) Instance() stubInstance {
212 addr: svm.SSHService.Address(),
213 // We deliberately return a cached/stale copy of the
214 // real tags here, so that (Instance)Tags() sometimes
215 // returns old data after a call to
216 // (Instance)SetTags(). This is permitted by the
217 // driver interface, and this might help remind
218 // callers that they need to tolerate it.
219 tags: copyTags(svm.tags),
223 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
224 stdinData, err := ioutil.ReadAll(stdin)
226 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
229 queue := svm.sis.driver.Queue
230 uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
231 if eta := svm.Boot.Sub(time.Now()); eta > 0 {
232 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
235 if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
236 fmt.Fprintf(stderr, "cannot fork\n")
239 if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
240 fmt.Fprint(stderr, "crunch-run: command not found\n")
243 if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
244 var stdinKV map[string]string
245 err := json.Unmarshal(stdinData, &stdinKV)
247 fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
250 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
251 if stdinKV[name] == "" {
252 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
259 svm.running[uuid] = pid
261 time.Sleep(svm.CrunchRunDetachDelay)
262 fmt.Fprintf(stderr, "starting %s\n", uuid)
263 logger := svm.sis.logger.WithFields(logrus.Fields{
265 "ContainerUUID": uuid,
268 logger.Printf("[test] starting crunch-run stub")
270 var ctr arvados.Container
271 var started, completed bool
273 logger.Print("[test] exiting crunch-run stub")
276 if svm.running[uuid] != pid {
278 bugf := svm.sis.driver.Bugf
282 bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s]==%d", pid, uuid, svm.running[uuid])
285 delete(svm.running, uuid)
288 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
289 if started && svm.CrashRunningContainer != nil {
290 svm.CrashRunningContainer(ctr)
295 crashluck := math_rand.Float64()
296 wantCrash := crashluck < svm.CrunchRunCrashRate
297 wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
299 ctr, ok := queue.Get(uuid)
301 logger.Print("[test] container not in queue")
305 time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
308 killed := svm.running[uuid] != pid
310 if killed || wantCrashEarly {
314 ctr.State = arvados.ContainerStateRunning
315 started = queue.Notify(ctr)
317 ctr, _ = queue.Get(uuid)
318 logger.Print("[test] erroring out because state=Running update was rejected")
323 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
326 if svm.ExecuteContainer != nil {
327 ctr.ExitCode = svm.ExecuteContainer(ctr)
329 logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
330 ctr.State = arvados.ContainerStateComplete
331 completed = queue.Notify(ctr)
335 if command == "crunch-run --list" {
338 for uuid := range svm.running {
339 fmt.Fprintf(stdout, "%s\n", uuid)
341 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
342 fmt.Fprintln(stdout, "broken")
346 if strings.HasPrefix(command, "crunch-run --kill ") {
348 pid, running := svm.running[uuid]
349 if running && !svm.killing[uuid] {
350 svm.killing[uuid] = true
352 time.Sleep(time.Duration(math_rand.Float64()*30) * time.Millisecond)
355 if svm.running[uuid] == pid {
356 // Kill only if the running entry
357 // hasn't since been killed and
358 // replaced with a different one.
359 delete(svm.running, uuid)
361 delete(svm.killing, uuid)
364 time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
366 _, running = svm.running[uuid]
370 fmt.Fprintf(stderr, "%s: container is running\n", uuid)
373 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
376 if command == "true" {
379 fmt.Fprintf(stderr, "%q: command not found", command)
383 type stubInstance struct {
386 tags cloud.InstanceTags
389 func (si stubInstance) ID() cloud.InstanceID {
393 func (si stubInstance) Address() string {
397 func (si stubInstance) RemoteUser() string {
398 return si.svm.SSHService.AuthorizedUser
401 func (si stubInstance) Destroy() error {
403 if sis.driver.HoldCloudOps {
404 sis.driver.holdCloudOps <- true
406 if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
407 return errors.New("instance could not be destroyed")
409 si.svm.SSHService.Close()
411 defer sis.mtx.Unlock()
412 delete(sis.servers, si.svm.id)
416 func (si stubInstance) ProviderType() string {
417 return si.svm.providerType
420 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
421 tags = copyTags(tags)
431 func (si stubInstance) Tags() cloud.InstanceTags {
432 // Return a copy to ensure a caller can't change our saved
433 // tags just by writing to the returned map.
434 return copyTags(si.tags)
437 func (si stubInstance) String() string {
438 return string(si.svm.id)
441 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
442 buf := make([]byte, 512)
443 _, err := io.ReadFull(rand.Reader, buf)
447 sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
451 return key.Verify(buf, sig)
454 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
455 dst := cloud.InstanceTags{}
456 for k, v := range src {