Merge branch '20520-instance-init-command'
[arvados.git] / lib / cloud / cloudtest / tester.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package cloudtest
6
7 import (
8         "crypto/rand"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "time"
13
14         "git.arvados.org/arvados.git/lib/cloud"
15         "git.arvados.org/arvados.git/lib/dispatchcloud/sshexecutor"
16         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
17         "git.arvados.org/arvados.git/sdk/go/arvados"
18         "github.com/sirupsen/logrus"
19         "golang.org/x/crypto/ssh"
20 )
21
22 var (
23         errTestInstanceNotFound = errors.New("test instance missing from cloud provider's list")
24 )
25
26 // A tester does a sequence of operations to test a cloud driver and
27 // configuration. Run() should be called only once, after assigning
28 // suitable values to public fields.
29 type tester struct {
30         Logger              logrus.FieldLogger
31         Tags                cloud.SharedResourceTags
32         TagKeyPrefix        string
33         SetID               cloud.InstanceSetID
34         DestroyExisting     bool
35         ProbeInterval       time.Duration
36         SyncInterval        time.Duration
37         TimeoutBooting      time.Duration
38         Driver              cloud.Driver
39         DriverParameters    json.RawMessage
40         InstanceType        arvados.InstanceType
41         ImageID             cloud.ImageID
42         SSHKey              ssh.Signer
43         SSHPort             string
44         BootProbeCommand    string
45         InstanceInitCommand cloud.InitCommand
46         ShellCommand        string
47         PauseBeforeDestroy  func()
48
49         is              cloud.InstanceSet
50         testInstance    *worker.TagVerifier
51         secret          string
52         executor        *sshexecutor.Executor
53         showedLoginInfo bool
54
55         failed bool
56 }
57
58 // Run the test suite as specified, clean up as needed, and return
59 // true (everything is OK) or false (something went wrong).
60 func (t *tester) Run() bool {
61         // This flag gets set when we encounter a non-fatal error, so
62         // we can continue doing more tests but remember to return
63         // false (failure) at the end.
64         deferredError := false
65
66         var err error
67         t.is, err = t.Driver.InstanceSet(t.DriverParameters, t.SetID, t.Tags, t.Logger)
68         if err != nil {
69                 t.Logger.WithError(err).Info("error initializing driver")
70                 return false
71         }
72
73         for {
74                 // Don't send the driver any filters when getting the
75                 // initial instance list. This way we can log an
76                 // instance count (N=...)  that includes all instances
77                 // in this service account, even if they don't have
78                 // the same InstanceSetID.
79                 insts, err := t.getInstances(nil)
80                 if err != nil {
81                         t.Logger.WithError(err).Info("error getting list of instances")
82                         return false
83                 }
84
85                 foundExisting := false
86                 for _, i := range insts {
87                         if i.Tags()[t.TagKeyPrefix+"InstanceSetID"] != string(t.SetID) {
88                                 continue
89                         }
90                         lgr := t.Logger.WithFields(logrus.Fields{
91                                 "Instance":      i.ID(),
92                                 "InstanceSetID": t.SetID,
93                         })
94                         foundExisting = true
95                         if t.DestroyExisting {
96                                 lgr.Info("destroying existing instance with our InstanceSetID")
97                                 t0 := time.Now()
98                                 err := i.Destroy()
99                                 lgr := lgr.WithField("Duration", time.Since(t0))
100                                 if err != nil {
101                                         lgr.WithError(err).Error("error destroying existing instance")
102                                 } else {
103                                         lgr.Info("Destroy() call succeeded")
104                                 }
105                         } else {
106                                 lgr.Error("found existing instance with our InstanceSetID")
107                         }
108                 }
109                 if !foundExisting {
110                         break
111                 } else if t.DestroyExisting {
112                         t.sleepSyncInterval()
113                 } else {
114                         t.Logger.Error("cannot continue with existing instances -- clean up manually, use -destroy-existing=true, or choose a different -instance-set-id")
115                         return false
116                 }
117         }
118
119         t.secret = randomHex(40)
120
121         tags := cloud.InstanceTags{}
122         for k, v := range t.Tags {
123                 tags[k] = v
124         }
125         tags[t.TagKeyPrefix+"InstanceSetID"] = string(t.SetID)
126         tags[t.TagKeyPrefix+"InstanceSecret"] = t.secret
127
128         defer t.destroyTestInstance()
129
130         bootDeadline := time.Now().Add(t.TimeoutBooting)
131         initCommand := worker.TagVerifier{Instance: nil, Secret: t.secret, ReportVerified: nil}.InitCommand() + "\n" + t.InstanceInitCommand
132
133         t.Logger.WithFields(logrus.Fields{
134                 "InstanceType":         t.InstanceType.Name,
135                 "ProviderInstanceType": t.InstanceType.ProviderType,
136                 "ImageID":              t.ImageID,
137                 "Tags":                 tags,
138                 "InitCommand":          initCommand,
139         }).Info("creating instance")
140         t0 := time.Now()
141         inst, err := t.is.Create(t.InstanceType, t.ImageID, tags, initCommand, t.SSHKey.PublicKey())
142         lgrC := t.Logger.WithField("Duration", time.Since(t0))
143         if err != nil {
144                 // Create() might have failed due to a bug or network
145                 // error even though the creation was successful, so
146                 // it's safer to wait a bit for an instance to appear.
147                 deferredError = true
148                 lgrC.WithError(err).Error("error creating test instance")
149                 t.Logger.WithField("Deadline", bootDeadline).Info("waiting for instance to appear anyway, in case the Create response was incorrect")
150                 for err = t.refreshTestInstance(); err != nil; err = t.refreshTestInstance() {
151                         if time.Now().After(bootDeadline) {
152                                 t.Logger.Error("timed out")
153                                 return false
154                         }
155                         t.sleepSyncInterval()
156                 }
157                 t.Logger.WithField("Instance", t.testInstance.ID()).Info("new instance appeared")
158                 t.showLoginInfo()
159         } else {
160                 // Create() succeeded. Make sure the new instance
161                 // appears right away in the Instances() list.
162                 lgrC.WithField("Instance", inst.ID()).Info("created instance")
163                 t.testInstance = &worker.TagVerifier{Instance: inst, Secret: t.secret, ReportVerified: nil}
164                 t.showLoginInfo()
165                 err = t.refreshTestInstance()
166                 if err == errTestInstanceNotFound {
167                         t.Logger.WithError(err).Error("cloud/driver Create succeeded, but instance is not in list")
168                         deferredError = true
169                 } else if err != nil {
170                         t.Logger.WithError(err).Error("error getting list of instances")
171                         return false
172                 }
173         }
174
175         if !t.checkTags() {
176                 // checkTags() already logged the errors
177                 deferredError = true
178         }
179
180         if !t.waitForBoot(bootDeadline) {
181                 deferredError = true
182         }
183
184         if t.ShellCommand != "" {
185                 err = t.runShellCommand(t.ShellCommand)
186                 if err != nil {
187                         t.Logger.WithError(err).Error("shell command failed")
188                         deferredError = true
189                 }
190         }
191
192         if fn := t.PauseBeforeDestroy; fn != nil {
193                 t.showLoginInfo()
194                 fn()
195         }
196
197         return !deferredError
198 }
199
200 // If the test instance has an address, log an "ssh user@host" command
201 // line that the operator can paste into another terminal, and set
202 // t.showedLoginInfo.
203 //
204 // If the test instance doesn't have an address yet, do nothing.
205 func (t *tester) showLoginInfo() {
206         t.updateExecutor()
207         host, port := t.executor.TargetHostPort()
208         if host == "" {
209                 return
210         }
211         user := t.testInstance.RemoteUser()
212         t.Logger.WithField("Command", fmt.Sprintf("ssh -p%s %s@%s", port, user, host)).Info("showing login information")
213         t.showedLoginInfo = true
214 }
215
216 // Get the latest instance list from the driver. If our test instance
217 // is found, assign it to t.testIntance.
218 func (t *tester) refreshTestInstance() error {
219         insts, err := t.getInstances(cloud.InstanceTags{t.TagKeyPrefix + "InstanceSetID": string(t.SetID)})
220         if err != nil {
221                 return err
222         }
223         for _, i := range insts {
224                 if t.testInstance == nil {
225                         // Filter by InstanceSetID tag value
226                         if i.Tags()[t.TagKeyPrefix+"InstanceSetID"] != string(t.SetID) {
227                                 continue
228                         }
229                 } else {
230                         // Filter by instance ID
231                         if i.ID() != t.testInstance.ID() {
232                                 continue
233                         }
234                 }
235                 t.Logger.WithFields(logrus.Fields{
236                         "Instance": i.ID(),
237                         "Address":  i.Address(),
238                 }).Info("found our instance in returned list")
239                 t.testInstance = &worker.TagVerifier{Instance: i, Secret: t.secret, ReportVerified: nil}
240                 if !t.showedLoginInfo {
241                         t.showLoginInfo()
242                 }
243                 return nil
244         }
245         return errTestInstanceNotFound
246 }
247
248 // Get the list of instances, passing the given tags to the cloud
249 // driver to filter results.
250 //
251 // Return only the instances that have our InstanceSetID tag.
252 func (t *tester) getInstances(tags cloud.InstanceTags) ([]cloud.Instance, error) {
253         var ret []cloud.Instance
254         t.Logger.WithField("FilterTags", tags).Info("getting instance list")
255         t0 := time.Now()
256         insts, err := t.is.Instances(tags)
257         if err != nil {
258                 return nil, err
259         }
260         t.Logger.WithFields(logrus.Fields{
261                 "Duration": time.Since(t0),
262                 "N":        len(insts),
263         }).Info("got instance list")
264         for _, i := range insts {
265                 if i.Tags()[t.TagKeyPrefix+"InstanceSetID"] == string(t.SetID) {
266                         ret = append(ret, i)
267                 }
268         }
269         return ret, nil
270 }
271
272 // Check that t.testInstance has every tag in t.Tags. If not, log an
273 // error and return false.
274 func (t *tester) checkTags() bool {
275         ok := true
276         for k, v := range t.Tags {
277                 if got := t.testInstance.Tags()[k]; got != v {
278                         ok = false
279                         t.Logger.WithFields(logrus.Fields{
280                                 "Key":           k,
281                                 "ExpectedValue": v,
282                                 "GotValue":      got,
283                         }).Error("tag is missing from test instance")
284                 }
285         }
286         if ok {
287                 t.Logger.Info("all expected tags are present")
288         }
289         return ok
290 }
291
292 // Run t.BootProbeCommand on t.testInstance until it succeeds or the
293 // deadline arrives.
294 func (t *tester) waitForBoot(deadline time.Time) bool {
295         for time.Now().Before(deadline) {
296                 err := t.runShellCommand(t.BootProbeCommand)
297                 if err == nil {
298                         return true
299                 }
300                 t.sleepProbeInterval()
301                 t.refreshTestInstance()
302         }
303         t.Logger.Error("timed out")
304         return false
305 }
306
307 // Create t.executor and/or update its target to t.testInstance's
308 // current address.
309 func (t *tester) updateExecutor() {
310         if t.executor == nil {
311                 t.executor = sshexecutor.New(t.testInstance)
312                 t.executor.SetTargetPort(t.SSHPort)
313                 t.executor.SetSigners(t.SSHKey)
314         } else {
315                 t.executor.SetTarget(t.testInstance)
316         }
317 }
318
319 func (t *tester) runShellCommand(cmd string) error {
320         t.updateExecutor()
321         t.Logger.WithFields(logrus.Fields{
322                 "Command": cmd,
323         }).Info("executing remote command")
324         t0 := time.Now()
325         stdout, stderr, err := t.executor.Execute(nil, cmd, nil)
326         lgr := t.Logger.WithFields(logrus.Fields{
327                 "Duration": time.Since(t0),
328                 "Command":  cmd,
329                 "stdout":   string(stdout),
330                 "stderr":   string(stderr),
331         })
332         if err != nil {
333                 lgr.WithError(err).Info("remote command failed")
334         } else {
335                 lgr.Info("remote command succeeded")
336         }
337         return err
338 }
339
340 // currently, this tries forever until it can return true (success).
341 func (t *tester) destroyTestInstance() bool {
342         if t.testInstance == nil {
343                 return true
344         }
345         for {
346                 lgr := t.Logger.WithField("Instance", t.testInstance.ID())
347                 lgr.Info("destroying instance")
348                 t0 := time.Now()
349
350                 err := t.testInstance.Destroy()
351                 lgrDur := lgr.WithField("Duration", time.Since(t0))
352                 if err != nil {
353                         lgrDur.WithError(err).Error("error destroying instance")
354                 } else {
355                         lgrDur.Info("destroyed instance")
356                 }
357
358                 err = t.refreshTestInstance()
359                 if err == errTestInstanceNotFound {
360                         lgr.Info("instance no longer appears in list")
361                         t.testInstance = nil
362                         return true
363                 } else if err == nil {
364                         lgr.Info("instance still exists after calling Destroy")
365                         t.sleepSyncInterval()
366                         continue
367                 } else {
368                         t.Logger.WithError(err).Error("error getting list of instances")
369                         continue
370                 }
371         }
372 }
373
374 func (t *tester) sleepSyncInterval() {
375         t.Logger.WithField("Duration", t.SyncInterval).Info("waiting SyncInterval")
376         time.Sleep(t.SyncInterval)
377 }
378
379 func (t *tester) sleepProbeInterval() {
380         t.Logger.WithField("Duration", t.ProbeInterval).Info("waiting ProbeInterval")
381         time.Sleep(t.ProbeInterval)
382 }
383
384 // Return a random string of n hexadecimal digits (n*4 random bits). n
385 // must be even.
386 func randomHex(n int) string {
387         buf := make([]byte, n/2)
388         _, err := rand.Read(buf)
389         if err != nil {
390                 panic(err)
391         }
392         return fmt.Sprintf("%x", buf)
393 }