1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
20 "git.arvados.org/arvados.git/lib/cloud"
21 "git.arvados.org/arvados.git/sdk/go/arvados"
22 "github.com/sirupsen/logrus"
23 "golang.org/x/crypto/ssh"
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
30 AuthorizedKeys []ssh.PublicKey
32 // SetupVM, if set, is called upon creation of each new
33 // StubVM. This is the caller's opportunity to customize the
34 // VM's error rate and other behaviors.
37 // Bugf, if set, is called if a bug is detected in the caller
38 // or stub. Typically set to (*check.C)Errorf. If unset,
39 // logger.Warnf is called instead.
40 Bugf func(string, ...interface{})
42 // StubVM's fake crunch-run uses this Queue to read and update
46 // Frequency of artificially introduced errors on calls to
47 // Destroy. 0=always succeed, 1=always fail.
48 ErrorRateDestroy float64
50 // If Create() or Instances() is called too frequently, return
51 // rate-limiting errors.
52 MinTimeBetweenCreateCalls time.Duration
53 MinTimeBetweenInstancesCalls time.Duration
55 // If true, Create and Destroy calls block until Release() is
59 instanceSets []*StubInstanceSet
60 holdCloudOps chan bool
63 // InstanceSet returns a new *StubInstanceSet.
64 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
65 if sd.holdCloudOps == nil {
66 sd.holdCloudOps = make(chan bool)
68 sis := StubInstanceSet{
71 servers: map[cloud.InstanceID]*StubVM{},
73 sd.instanceSets = append(sd.instanceSets, &sis)
77 err = json.Unmarshal(params, &sis)
82 // InstanceSets returns all instances that have been created by the
83 // driver. This can be used to test a component that uses the driver
84 // but doesn't expose the InstanceSets it has created.
85 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
86 return sd.instanceSets
89 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
90 // are fewer than n blocked calls pending, it waits for the rest to
92 func (sd *StubDriver) ReleaseCloudOps(n int) {
93 for i := 0; i < n; i++ {
98 type StubInstanceSet struct {
100 logger logrus.FieldLogger
101 servers map[cloud.InstanceID]*StubVM
105 allowCreateCall time.Time
106 allowInstancesCall time.Time
110 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
111 if sis.driver.HoldCloudOps {
112 sis.driver.holdCloudOps <- true
115 defer sis.mtx.Unlock()
117 return nil, errors.New("StubInstanceSet: Create called after Stop")
119 if sis.allowCreateCall.After(time.Now()) {
120 return nil, RateLimitError{sis.allowCreateCall}
122 sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
123 ak := sis.driver.AuthorizedKeys
125 ak = append([]ssh.PublicKey{authKey}, ak...)
130 id: cloud.InstanceID(fmt.Sprintf("inst%d,%s", sis.lastInstanceID, it.ProviderType)),
131 tags: copyTags(tags),
132 providerType: it.ProviderType,
134 running: map[string]int64{},
135 killing: map[string]bool{},
137 svm.SSHService = SSHService{
138 HostKey: sis.driver.HostKey,
139 AuthorizedUser: "root",
143 if setup := sis.driver.SetupVM; setup != nil {
146 sis.servers[svm.id] = svm
147 return svm.Instance(), nil
150 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
152 defer sis.mtx.RUnlock()
153 if sis.allowInstancesCall.After(time.Now()) {
154 return nil, RateLimitError{sis.allowInstancesCall}
156 sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
157 var r []cloud.Instance
158 for _, ss := range sis.servers {
159 r = append(r, ss.Instance())
164 func (sis *StubInstanceSet) Stop() {
166 defer sis.mtx.Unlock()
168 panic("Stop called twice")
173 type RateLimitError struct{ Retry time.Time }
175 func (e RateLimitError) Error() string { return fmt.Sprintf("rate limited until %s", e.Retry) }
176 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
178 // StubVM is a fake server that runs an SSH service. It represents a
179 // VM running in a fake cloud.
181 // Note this is distinct from a stubInstance, which is a snapshot of
182 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
183 // running (and might change IP addresses, shut down, etc.) without
184 // updating any stubInstances that have been returned to callers.
188 ReportBroken time.Time
189 CrunchRunMissing bool
190 CrunchRunCrashRate float64
191 CrunchRunDetachDelay time.Duration
192 ExecuteContainer func(arvados.Container) int
193 CrashRunningContainer func(arvados.Container)
197 tags cloud.InstanceTags
198 initCommand cloud.InitCommand
200 SSHService SSHService
201 running map[string]int64
202 killing map[string]bool
207 func (svm *StubVM) Instance() stubInstance {
212 addr: svm.SSHService.Address(),
213 // We deliberately return a cached/stale copy of the
214 // real tags here, so that (Instance)Tags() sometimes
215 // returns old data after a call to
216 // (Instance)SetTags(). This is permitted by the
217 // driver interface, and this might help remind
218 // callers that they need to tolerate it.
219 tags: copyTags(svm.tags),
223 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
224 stdinData, err := ioutil.ReadAll(stdin)
226 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
229 queue := svm.sis.driver.Queue
230 uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
231 if eta := svm.Boot.Sub(time.Now()); eta > 0 {
232 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
235 if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
236 fmt.Fprintf(stderr, "cannot fork\n")
239 if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
240 fmt.Fprint(stderr, "crunch-run: command not found\n")
243 if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
244 var stdinKV map[string]string
245 err := json.Unmarshal(stdinData, &stdinKV)
247 fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
250 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
251 if stdinKV[name] == "" {
252 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
259 svm.running[uuid] = pid
261 time.Sleep(svm.CrunchRunDetachDelay)
262 fmt.Fprintf(stderr, "starting %s\n", uuid)
263 logger := svm.sis.logger.WithFields(logrus.Fields{
265 "ContainerUUID": uuid,
268 logger.Printf("[test] starting crunch-run stub")
270 var ctr arvados.Container
271 var started, completed bool
273 logger.Print("[test] exiting crunch-run stub")
276 if svm.running[uuid] != pid {
277 bugf := svm.sis.driver.Bugf
281 bugf("[test] StubDriver bug or caller bug: pid %d exiting, running[%s]==%d", pid, uuid, svm.running[uuid])
283 delete(svm.running, uuid)
286 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
287 if started && svm.CrashRunningContainer != nil {
288 svm.CrashRunningContainer(ctr)
293 crashluck := math_rand.Float64()
294 wantCrash := crashluck < svm.CrunchRunCrashRate
295 wantCrashEarly := crashluck < svm.CrunchRunCrashRate/2
297 ctr, ok := queue.Get(uuid)
299 logger.Print("[test] container not in queue")
303 time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
306 killed := svm.killing[uuid]
308 if killed || wantCrashEarly {
312 ctr.State = arvados.ContainerStateRunning
313 started = queue.Notify(ctr)
315 ctr, _ = queue.Get(uuid)
316 logger.Print("[test] erroring out because state=Running update was rejected")
321 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
324 if svm.ExecuteContainer != nil {
325 ctr.ExitCode = svm.ExecuteContainer(ctr)
327 logger.WithField("ExitCode", ctr.ExitCode).Print("[test] completing container")
328 ctr.State = arvados.ContainerStateComplete
329 completed = queue.Notify(ctr)
333 if command == "crunch-run --list" {
336 for uuid := range svm.running {
337 fmt.Fprintf(stdout, "%s\n", uuid)
339 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
340 fmt.Fprintln(stdout, "broken")
344 if strings.HasPrefix(command, "crunch-run --kill ") {
346 _, running := svm.running[uuid]
348 svm.killing[uuid] = true
350 time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
352 _, running = svm.running[uuid]
356 fmt.Fprintf(stderr, "%s: container is running\n", uuid)
359 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
362 if command == "true" {
365 fmt.Fprintf(stderr, "%q: command not found", command)
369 type stubInstance struct {
372 tags cloud.InstanceTags
375 func (si stubInstance) ID() cloud.InstanceID {
379 func (si stubInstance) Address() string {
383 func (si stubInstance) RemoteUser() string {
384 return si.svm.SSHService.AuthorizedUser
387 func (si stubInstance) Destroy() error {
389 if sis.driver.HoldCloudOps {
390 sis.driver.holdCloudOps <- true
392 if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
393 return errors.New("instance could not be destroyed")
395 si.svm.SSHService.Close()
397 defer sis.mtx.Unlock()
398 delete(sis.servers, si.svm.id)
402 func (si stubInstance) ProviderType() string {
403 return si.svm.providerType
406 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
407 tags = copyTags(tags)
417 func (si stubInstance) Tags() cloud.InstanceTags {
418 // Return a copy to ensure a caller can't change our saved
419 // tags just by writing to the returned map.
420 return copyTags(si.tags)
423 func (si stubInstance) String() string {
424 return string(si.svm.id)
427 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
428 buf := make([]byte, 512)
429 _, err := io.ReadFull(rand.Reader, buf)
433 sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
437 return key.Verify(buf, sig)
440 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
441 dst := cloud.InstanceTags{}
442 for k, v := range src {