1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
20 "git.curoverse.com/arvados.git/lib/cloud"
21 "git.curoverse.com/arvados.git/sdk/go/arvados"
22 "github.com/sirupsen/logrus"
23 "golang.org/x/crypto/ssh"
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
30 AuthorizedKeys []ssh.PublicKey
32 // SetupVM, if set, is called upon creation of each new
33 // StubVM. This is the caller's opportunity to customize the
34 // VM's error rate and other behaviors.
37 // StubVM's fake crunch-run uses this Queue to read and update
41 // Frequency of artificially introduced errors on calls to
42 // Destroy. 0=always succeed, 1=always fail.
43 ErrorRateDestroy float64
45 // If Create() or Instances() is called too frequently, return
46 // rate-limiting errors.
47 MinTimeBetweenCreateCalls time.Duration
48 MinTimeBetweenInstancesCalls time.Duration
50 // If true, Create and Destroy calls block until Release() is
54 instanceSets []*StubInstanceSet
55 holdCloudOps chan bool
58 // InstanceSet returns a new *StubInstanceSet.
59 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
60 if sd.holdCloudOps == nil {
61 sd.holdCloudOps = make(chan bool)
63 sis := StubInstanceSet{
66 servers: map[cloud.InstanceID]*StubVM{},
68 sd.instanceSets = append(sd.instanceSets, &sis)
72 err = json.Unmarshal(params, &sis)
77 // InstanceSets returns all instances that have been created by the
78 // driver. This can be used to test a component that uses the driver
79 // but doesn't expose the InstanceSets it has created.
80 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
81 return sd.instanceSets
84 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
85 // are fewer than n blocked calls pending, it waits for the rest to
87 func (sd *StubDriver) ReleaseCloudOps(n int) {
88 for i := 0; i < n; i++ {
93 type StubInstanceSet struct {
95 logger logrus.FieldLogger
96 servers map[cloud.InstanceID]*StubVM
100 allowCreateCall time.Time
101 allowInstancesCall time.Time
104 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
105 if sis.driver.HoldCloudOps {
106 sis.driver.holdCloudOps <- true
109 defer sis.mtx.Unlock()
111 return nil, errors.New("StubInstanceSet: Create called after Stop")
113 if sis.allowCreateCall.After(time.Now()) {
114 return nil, RateLimitError{sis.allowCreateCall}
116 sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
119 ak := sis.driver.AuthorizedKeys
121 ak = append([]ssh.PublicKey{authKey}, ak...)
125 id: cloud.InstanceID(fmt.Sprintf("stub-%s-%x", it.ProviderType, math_rand.Int63())),
126 tags: copyTags(tags),
127 providerType: it.ProviderType,
129 running: map[string]int64{},
130 killing: map[string]bool{},
132 svm.SSHService = SSHService{
133 HostKey: sis.driver.HostKey,
134 AuthorizedUser: "root",
138 if setup := sis.driver.SetupVM; setup != nil {
141 sis.servers[svm.id] = svm
142 return svm.Instance(), nil
145 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
147 defer sis.mtx.RUnlock()
148 if sis.allowInstancesCall.After(time.Now()) {
149 return nil, RateLimitError{sis.allowInstancesCall}
151 sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
153 var r []cloud.Instance
154 for _, ss := range sis.servers {
155 r = append(r, ss.Instance())
160 func (sis *StubInstanceSet) Stop() {
162 defer sis.mtx.Unlock()
164 panic("Stop called twice")
169 type RateLimitError struct{ Retry time.Time }
171 func (e RateLimitError) Error() string { return fmt.Sprintf("rate limited until %s", e.Retry) }
172 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
174 // StubVM is a fake server that runs an SSH service. It represents a
175 // VM running in a fake cloud.
177 // Note this is distinct from a stubInstance, which is a snapshot of
178 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
179 // running (and might change IP addresses, shut down, etc.) without
180 // updating any stubInstances that have been returned to callers.
184 CrunchRunMissing bool
185 CrunchRunCrashRate float64
186 CrunchRunDetachDelay time.Duration
187 ExecuteContainer func(arvados.Container) int
188 CrashRunningContainer func(arvados.Container)
192 tags cloud.InstanceTags
193 initCommand cloud.InitCommand
195 SSHService SSHService
196 running map[string]int64
197 killing map[string]bool
202 func (svm *StubVM) Instance() stubInstance {
207 addr: svm.SSHService.Address(),
208 // We deliberately return a cached/stale copy of the
209 // real tags here, so that (Instance)Tags() sometimes
210 // returns old data after a call to
211 // (Instance)SetTags(). This is permitted by the
212 // driver interface, and this might help remind
213 // callers that they need to tolerate it.
214 tags: copyTags(svm.tags),
218 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
219 stdinData, err := ioutil.ReadAll(stdin)
221 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
224 queue := svm.sis.driver.Queue
225 uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
226 if eta := svm.Boot.Sub(time.Now()); eta > 0 {
227 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
230 if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
231 fmt.Fprintf(stderr, "cannot fork\n")
234 if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
235 fmt.Fprint(stderr, "crunch-run: command not found\n")
238 if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
239 var stdinKV map[string]string
240 err := json.Unmarshal(stdinData, &stdinKV)
242 fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
245 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
246 if stdinKV[name] == "" {
247 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdin)
254 svm.running[uuid] = pid
256 time.Sleep(svm.CrunchRunDetachDelay)
257 fmt.Fprintf(stderr, "starting %s\n", uuid)
258 logger := svm.sis.logger.WithFields(logrus.Fields{
260 "ContainerUUID": uuid,
263 logger.Printf("[test] starting crunch-run stub")
265 crashluck := math_rand.Float64()
266 ctr, ok := queue.Get(uuid)
268 logger.Print("[test] container not in queue")
273 if ctr.State == arvados.ContainerStateRunning && svm.CrashRunningContainer != nil {
274 svm.CrashRunningContainer(ctr)
278 if crashluck > svm.CrunchRunCrashRate/2 {
279 time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
280 ctr.State = arvados.ContainerStateRunning
281 if !queue.Notify(ctr) {
282 ctr, _ = queue.Get(uuid)
283 logger.Print("[test] erroring out because state=Running update was rejected")
288 time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
292 if svm.running[uuid] != pid {
293 logger.Print("[test] container was killed")
296 delete(svm.running, uuid)
298 if crashluck < svm.CrunchRunCrashRate {
299 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
301 if svm.ExecuteContainer != nil {
302 ctr.ExitCode = svm.ExecuteContainer(ctr)
304 logger.WithField("ExitCode", ctr.ExitCode).Print("[test] exiting crunch-run stub")
305 ctr.State = arvados.ContainerStateComplete
311 if command == "crunch-run --list" {
314 for uuid := range svm.running {
315 fmt.Fprintf(stdout, "%s\n", uuid)
319 if strings.HasPrefix(command, "crunch-run --kill ") {
321 pid, running := svm.running[uuid]
322 if running && !svm.killing[uuid] {
323 svm.killing[uuid] = true
325 time.Sleep(time.Duration(math_rand.Float64()*30) * time.Millisecond)
328 if svm.running[uuid] == pid {
329 // Kill only if the running entry
330 // hasn't since been killed and
331 // replaced with a different one.
332 delete(svm.running, uuid)
334 delete(svm.killing, uuid)
337 time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
339 _, running = svm.running[uuid]
343 fmt.Fprintf(stderr, "%s: container is running\n", uuid)
346 fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
350 if command == "true" {
353 fmt.Fprintf(stderr, "%q: command not found", command)
357 type stubInstance struct {
360 tags cloud.InstanceTags
363 func (si stubInstance) ID() cloud.InstanceID {
367 func (si stubInstance) Address() string {
371 func (si stubInstance) RemoteUser() string {
372 return si.svm.SSHService.AuthorizedUser
375 func (si stubInstance) Destroy() error {
377 if sis.driver.HoldCloudOps {
378 sis.driver.holdCloudOps <- true
380 if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
381 return errors.New("instance could not be destroyed")
383 si.svm.SSHService.Close()
385 defer sis.mtx.Unlock()
386 delete(sis.servers, si.svm.id)
390 func (si stubInstance) ProviderType() string {
391 return si.svm.providerType
394 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
395 tags = copyTags(tags)
405 func (si stubInstance) Tags() cloud.InstanceTags {
406 // Return a copy to ensure a caller can't change our saved
407 // tags just by writing to the returned map.
408 return copyTags(si.tags)
411 func (si stubInstance) String() string {
412 return string(si.svm.id)
415 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
416 buf := make([]byte, 512)
417 _, err := io.ReadFull(rand.Reader, buf)
421 sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
425 return key.Verify(buf, sig)
428 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
429 dst := cloud.InstanceTags{}
430 for k, v := range src {