15954: Fix process start/stop race.
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         math_rand "math/rand"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/sirupsen/logrus"
23         "golang.org/x/crypto/ssh"
24 )
25
26 // A StubDriver implements cloud.Driver by setting up local SSH
27 // servers that do fake command executions.
28 type StubDriver struct {
29         HostKey        ssh.Signer
30         AuthorizedKeys []ssh.PublicKey
31
32         // SetupVM, if set, is called upon creation of each new
33         // StubVM. This is the caller's opportunity to customize the
34         // VM's error rate and other behaviors.
35         SetupVM func(*StubVM)
36
37         // StubVM's fake crunch-run uses this Queue to read and update
38         // container state.
39         Queue *Queue
40
41         // Frequency of artificially introduced errors on calls to
42         // Destroy. 0=always succeed, 1=always fail.
43         ErrorRateDestroy float64
44
45         // If Create() or Instances() is called too frequently, return
46         // rate-limiting errors.
47         MinTimeBetweenCreateCalls    time.Duration
48         MinTimeBetweenInstancesCalls time.Duration
49
50         // If true, Create and Destroy calls block until Release() is
51         // called.
52         HoldCloudOps bool
53
54         instanceSets []*StubInstanceSet
55         holdCloudOps chan bool
56 }
57
58 // InstanceSet returns a new *StubInstanceSet.
59 func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
60         if sd.holdCloudOps == nil {
61                 sd.holdCloudOps = make(chan bool)
62         }
63         sis := StubInstanceSet{
64                 driver:  sd,
65                 logger:  logger,
66                 servers: map[cloud.InstanceID]*StubVM{},
67         }
68         sd.instanceSets = append(sd.instanceSets, &sis)
69
70         var err error
71         if params != nil {
72                 err = json.Unmarshal(params, &sis)
73         }
74         return &sis, err
75 }
76
77 // InstanceSets returns all instances that have been created by the
78 // driver. This can be used to test a component that uses the driver
79 // but doesn't expose the InstanceSets it has created.
80 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
81         return sd.instanceSets
82 }
83
84 // ReleaseCloudOps releases n pending Create/Destroy calls. If there
85 // are fewer than n blocked calls pending, it waits for the rest to
86 // arrive.
87 func (sd *StubDriver) ReleaseCloudOps(n int) {
88         for i := 0; i < n; i++ {
89                 <-sd.holdCloudOps
90         }
91 }
92
93 type StubInstanceSet struct {
94         driver  *StubDriver
95         logger  logrus.FieldLogger
96         servers map[cloud.InstanceID]*StubVM
97         mtx     sync.RWMutex
98         stopped bool
99
100         allowCreateCall    time.Time
101         allowInstancesCall time.Time
102 }
103
104 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
105         if sis.driver.HoldCloudOps {
106                 sis.driver.holdCloudOps <- true
107         }
108         sis.mtx.Lock()
109         defer sis.mtx.Unlock()
110         if sis.stopped {
111                 return nil, errors.New("StubInstanceSet: Create called after Stop")
112         }
113         if sis.allowCreateCall.After(time.Now()) {
114                 return nil, RateLimitError{sis.allowCreateCall}
115         } else {
116                 sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
117         }
118
119         ak := sis.driver.AuthorizedKeys
120         if authKey != nil {
121                 ak = append([]ssh.PublicKey{authKey}, ak...)
122         }
123         svm := &StubVM{
124                 sis:          sis,
125                 id:           cloud.InstanceID(fmt.Sprintf("stub-%s-%x", it.ProviderType, math_rand.Int63())),
126                 tags:         copyTags(tags),
127                 providerType: it.ProviderType,
128                 initCommand:  cmd,
129                 running:      map[string]int64{},
130                 killing:      map[string]bool{},
131         }
132         svm.SSHService = SSHService{
133                 HostKey:        sis.driver.HostKey,
134                 AuthorizedUser: "root",
135                 AuthorizedKeys: ak,
136                 Exec:           svm.Exec,
137         }
138         if setup := sis.driver.SetupVM; setup != nil {
139                 setup(svm)
140         }
141         sis.servers[svm.id] = svm
142         return svm.Instance(), nil
143 }
144
145 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
146         sis.mtx.RLock()
147         defer sis.mtx.RUnlock()
148         if sis.allowInstancesCall.After(time.Now()) {
149                 return nil, RateLimitError{sis.allowInstancesCall}
150         } else {
151                 sis.allowInstancesCall = time.Now().Add(sis.driver.MinTimeBetweenInstancesCalls)
152         }
153         var r []cloud.Instance
154         for _, ss := range sis.servers {
155                 r = append(r, ss.Instance())
156         }
157         return r, nil
158 }
159
160 func (sis *StubInstanceSet) Stop() {
161         sis.mtx.Lock()
162         defer sis.mtx.Unlock()
163         if sis.stopped {
164                 panic("Stop called twice")
165         }
166         sis.stopped = true
167 }
168
169 type RateLimitError struct{ Retry time.Time }
170
171 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
172 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
173
174 // StubVM is a fake server that runs an SSH service. It represents a
175 // VM running in a fake cloud.
176 //
177 // Note this is distinct from a stubInstance, which is a snapshot of
178 // the VM's metadata. Like a VM in a real cloud, a StubVM keeps
179 // running (and might change IP addresses, shut down, etc.)  without
180 // updating any stubInstances that have been returned to callers.
181 type StubVM struct {
182         Boot                  time.Time
183         Broken                time.Time
184         ReportBroken          time.Time
185         CrunchRunMissing      bool
186         CrunchRunCrashRate    float64
187         CrunchRunDetachDelay  time.Duration
188         ExecuteContainer      func(arvados.Container) int
189         CrashRunningContainer func(arvados.Container)
190
191         sis          *StubInstanceSet
192         id           cloud.InstanceID
193         tags         cloud.InstanceTags
194         initCommand  cloud.InitCommand
195         providerType string
196         SSHService   SSHService
197         running      map[string]int64
198         killing      map[string]bool
199         lastPID      int64
200         sync.Mutex
201 }
202
203 func (svm *StubVM) Instance() stubInstance {
204         svm.Lock()
205         defer svm.Unlock()
206         return stubInstance{
207                 svm:  svm,
208                 addr: svm.SSHService.Address(),
209                 // We deliberately return a cached/stale copy of the
210                 // real tags here, so that (Instance)Tags() sometimes
211                 // returns old data after a call to
212                 // (Instance)SetTags().  This is permitted by the
213                 // driver interface, and this might help remind
214                 // callers that they need to tolerate it.
215                 tags: copyTags(svm.tags),
216         }
217 }
218
219 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
220         stdinData, err := ioutil.ReadAll(stdin)
221         if err != nil {
222                 fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
223                 return 1
224         }
225         queue := svm.sis.driver.Queue
226         uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
227         if eta := svm.Boot.Sub(time.Now()); eta > 0 {
228                 fmt.Fprintf(stderr, "stub is booting, ETA %s\n", eta)
229                 return 1
230         }
231         if !svm.Broken.IsZero() && svm.Broken.Before(time.Now()) {
232                 fmt.Fprintf(stderr, "cannot fork\n")
233                 return 2
234         }
235         if svm.CrunchRunMissing && strings.Contains(command, "crunch-run") {
236                 fmt.Fprint(stderr, "crunch-run: command not found\n")
237                 return 1
238         }
239         if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
240                 var stdinKV map[string]string
241                 err := json.Unmarshal(stdinData, &stdinKV)
242                 if err != nil {
243                         fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
244                         return 1
245                 }
246                 for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
247                         if stdinKV[name] == "" {
248                                 fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdinData)
249                                 return 1
250                         }
251                 }
252                 svm.Lock()
253                 svm.lastPID++
254                 pid := svm.lastPID
255                 svm.running[uuid] = pid
256                 svm.Unlock()
257                 time.Sleep(svm.CrunchRunDetachDelay)
258                 fmt.Fprintf(stderr, "starting %s\n", uuid)
259                 logger := svm.sis.logger.WithFields(logrus.Fields{
260                         "Instance":      svm.id,
261                         "ContainerUUID": uuid,
262                         "PID":           pid,
263                 })
264                 logger.Printf("[test] starting crunch-run stub")
265                 go func() {
266                         crashluck := math_rand.Float64()
267                         ctr, ok := queue.Get(uuid)
268                         if !ok {
269                                 logger.Print("[test] container not in queue")
270                                 return
271                         }
272
273                         defer func() {
274                                 if ctr.State == arvados.ContainerStateRunning && svm.CrashRunningContainer != nil {
275                                         svm.CrashRunningContainer(ctr)
276                                 }
277                         }()
278
279                         if crashluck > svm.CrunchRunCrashRate/2 {
280                                 time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
281                                 ctr.State = arvados.ContainerStateRunning
282                                 if !queue.Notify(ctr) {
283                                         ctr, _ = queue.Get(uuid)
284                                         logger.Print("[test] erroring out because state=Running update was rejected")
285                                         return
286                                 }
287                         }
288
289                         time.Sleep(time.Duration(math_rand.Float64()*20) * time.Millisecond)
290
291                         svm.Lock()
292                         defer svm.Unlock()
293                         if svm.running[uuid] != pid {
294                                 logger.Print("[test] container was killed")
295                                 return
296                         }
297                         delete(svm.running, uuid)
298
299                         if crashluck < svm.CrunchRunCrashRate {
300                                 logger.WithField("State", ctr.State).Print("[test] crashing crunch-run stub")
301                         } else {
302                                 if svm.ExecuteContainer != nil {
303                                         ctr.ExitCode = svm.ExecuteContainer(ctr)
304                                 }
305                                 logger.WithField("ExitCode", ctr.ExitCode).Print("[test] exiting crunch-run stub")
306                                 ctr.State = arvados.ContainerStateComplete
307                                 go queue.Notify(ctr)
308                         }
309                 }()
310                 return 0
311         }
312         if command == "crunch-run --list" {
313                 svm.Lock()
314                 defer svm.Unlock()
315                 for uuid := range svm.running {
316                         fmt.Fprintf(stdout, "%s\n", uuid)
317                 }
318                 if !svm.ReportBroken.IsZero() && svm.ReportBroken.Before(time.Now()) {
319                         fmt.Fprintln(stdout, "broken")
320                 }
321                 return 0
322         }
323         if strings.HasPrefix(command, "crunch-run --kill ") {
324                 svm.Lock()
325                 pid, running := svm.running[uuid]
326                 if running && !svm.killing[uuid] {
327                         svm.killing[uuid] = true
328                         go func() {
329                                 time.Sleep(time.Duration(math_rand.Float64()*30) * time.Millisecond)
330                                 svm.Lock()
331                                 defer svm.Unlock()
332                                 if svm.running[uuid] == pid {
333                                         // Kill only if the running entry
334                                         // hasn't since been killed and
335                                         // replaced with a different one.
336                                         delete(svm.running, uuid)
337                                 }
338                                 delete(svm.killing, uuid)
339                         }()
340                         svm.Unlock()
341                         time.Sleep(time.Duration(math_rand.Float64()*2) * time.Millisecond)
342                         svm.Lock()
343                         _, running = svm.running[uuid]
344                 }
345                 svm.Unlock()
346                 if running {
347                         fmt.Fprintf(stderr, "%s: container is running\n", uuid)
348                         return 1
349                 } else {
350                         fmt.Fprintf(stderr, "%s: container is not running\n", uuid)
351                         return 0
352                 }
353         }
354         if command == "true" {
355                 return 0
356         }
357         fmt.Fprintf(stderr, "%q: command not found", command)
358         return 1
359 }
360
361 type stubInstance struct {
362         svm  *StubVM
363         addr string
364         tags cloud.InstanceTags
365 }
366
367 func (si stubInstance) ID() cloud.InstanceID {
368         return si.svm.id
369 }
370
371 func (si stubInstance) Address() string {
372         return si.addr
373 }
374
375 func (si stubInstance) RemoteUser() string {
376         return si.svm.SSHService.AuthorizedUser
377 }
378
379 func (si stubInstance) Destroy() error {
380         sis := si.svm.sis
381         if sis.driver.HoldCloudOps {
382                 sis.driver.holdCloudOps <- true
383         }
384         if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
385                 return errors.New("instance could not be destroyed")
386         }
387         si.svm.SSHService.Close()
388         sis.mtx.Lock()
389         defer sis.mtx.Unlock()
390         delete(sis.servers, si.svm.id)
391         return nil
392 }
393
394 func (si stubInstance) ProviderType() string {
395         return si.svm.providerType
396 }
397
398 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
399         tags = copyTags(tags)
400         svm := si.svm
401         go func() {
402                 svm.Lock()
403                 defer svm.Unlock()
404                 svm.tags = tags
405         }()
406         return nil
407 }
408
409 func (si stubInstance) Tags() cloud.InstanceTags {
410         // Return a copy to ensure a caller can't change our saved
411         // tags just by writing to the returned map.
412         return copyTags(si.tags)
413 }
414
415 func (si stubInstance) String() string {
416         return string(si.svm.id)
417 }
418
419 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
420         buf := make([]byte, 512)
421         _, err := io.ReadFull(rand.Reader, buf)
422         if err != nil {
423                 return err
424         }
425         sig, err := si.svm.sis.driver.HostKey.Sign(rand.Reader, buf)
426         if err != nil {
427                 return err
428         }
429         return key.Verify(buf, sig)
430 }
431
432 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
433         dst := cloud.InstanceTags{}
434         for k, v := range src {
435                 dst[k] = v
436         }
437         return dst
438 }