Merge branch '21666-provision-test-improvement'
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
index 5ca83d263c1c481bd71c968299744e2cf9b2486d..2265be6e1610015358036f515c50acea5bad5c11 100644 (file)
@@ -20,6 +20,7 @@ import (
        "git.arvados.org/arvados.git/lib/cloud"
        "git.arvados.org/arvados.git/lib/crunchrun"
        "git.arvados.org/arvados.git/sdk/go/arvados"
+       "github.com/prometheus/client_golang/prometheus"
        "github.com/sirupsen/logrus"
        "golang.org/x/crypto/ssh"
 )
@@ -33,7 +34,10 @@ type StubDriver struct {
        // SetupVM, if set, is called upon creation of each new
        // StubVM. This is the caller's opportunity to customize the
        // VM's error rate and other behaviors.
-       SetupVM func(*StubVM)
+       //
+       // If SetupVM returns an error, that error will be returned to
+       // the caller of Create(), and the new VM will be discarded.
+       SetupVM func(*StubVM) error
 
        // Bugf, if set, is called if a bug is detected in the caller
        // or stub. Typically set to (*check.C)Errorf. If unset,
@@ -54,6 +58,8 @@ type StubDriver struct {
        MinTimeBetweenCreateCalls    time.Duration
        MinTimeBetweenInstancesCalls time.Duration
 
+       QuotaMaxInstances int
+
        // If true, Create and Destroy calls block until Release() is
        // called.
        HoldCloudOps bool
@@ -63,7 +69,7 @@ type StubDriver struct {
 }
 
 // InstanceSet returns a new *StubInstanceSet.
-func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
+func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (cloud.InstanceSet, error) {
        if sd.holdCloudOps == nil {
                sd.holdCloudOps = make(chan bool)
        }
@@ -124,6 +130,9 @@ func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID,
        if math_rand.Float64() < sis.driver.ErrorRateCreate {
                return nil, fmt.Errorf("StubInstanceSet: rand < ErrorRateCreate %f", sis.driver.ErrorRateCreate)
        }
+       if max := sis.driver.QuotaMaxInstances; max > 0 && len(sis.servers) >= max {
+               return nil, QuotaError{fmt.Errorf("StubInstanceSet: reached QuotaMaxInstances %d", max)}
+       }
        sis.allowCreateCall = time.Now().Add(sis.driver.MinTimeBetweenCreateCalls)
        ak := sis.driver.AuthorizedKeys
        if authKey != nil {
@@ -146,7 +155,10 @@ func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID,
                Exec:           svm.Exec,
        }
        if setup := sis.driver.SetupVM; setup != nil {
-               setup(svm)
+               err := setup(svm)
+               if err != nil {
+                       return nil, err
+               }
        }
        sis.servers[svm.id] = svm
        return svm.Instance(), nil
@@ -189,6 +201,12 @@ type RateLimitError struct{ Retry time.Time }
 func (e RateLimitError) Error() string            { return fmt.Sprintf("rate limited until %s", e.Retry) }
 func (e RateLimitError) EarliestRetry() time.Time { return e.Retry }
 
+type CapacityError struct{ InstanceTypeSpecific bool }
+
+func (e CapacityError) Error() string                { return "insufficient capacity" }
+func (e CapacityError) IsCapacityError() bool        { return true }
+func (e CapacityError) IsInstanceTypeSpecific() bool { return e.InstanceTypeSpecific }
+
 // StubVM is a fake server that runs an SSH service. It represents a
 // VM running in a fake cloud.
 //
@@ -221,6 +239,8 @@ type StubVM struct {
        killing      map[string]bool
        lastPID      int64
        deadlocked   string
+       stubprocs    sync.WaitGroup
+       destroying   bool
        sync.Mutex
 }
 
@@ -249,6 +269,17 @@ func (svm *StubVM) Instance() stubInstance {
 }
 
 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
+       // Ensure we don't start any new stubprocs after Destroy()
+       // has started Wait()ing for stubprocs to end.
+       svm.Lock()
+       if svm.destroying {
+               svm.Unlock()
+               return 1
+       }
+       svm.stubprocs.Add(1)
+       defer svm.stubprocs.Done()
+       svm.Unlock()
+
        stdinData, err := ioutil.ReadAll(stdin)
        if err != nil {
                fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
@@ -286,7 +317,15 @@ func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader,
                pid := svm.lastPID
                svm.running[uuid] = stubProcess{pid: pid}
                svm.Unlock()
+
                time.Sleep(svm.CrunchRunDetachDelay)
+
+               svm.Lock()
+               defer svm.Unlock()
+               if svm.destroying {
+                       fmt.Fprint(stderr, "crunch-run: killed by system shutdown\n")
+                       return 9
+               }
                fmt.Fprintf(stderr, "starting %s\n", uuid)
                logger := svm.sis.logger.WithFields(logrus.Fields{
                        "Instance":      svm.id,
@@ -294,13 +333,18 @@ func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader,
                        "PID":           pid,
                })
                logger.Printf("[test] starting crunch-run stub")
+               svm.stubprocs.Add(1)
                go func() {
+                       defer svm.stubprocs.Done()
                        var ctr arvados.Container
                        var started, completed bool
                        defer func() {
                                logger.Print("[test] exiting crunch-run stub")
                                svm.Lock()
                                defer svm.Unlock()
+                               if svm.destroying {
+                                       return
+                               }
                                if svm.running[uuid].pid != pid {
                                        bugf := svm.sis.driver.Bugf
                                        if bugf == nil {
@@ -340,8 +384,10 @@ func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader,
 
                        svm.Lock()
                        killed := svm.killing[uuid]
+                       delete(svm.killing, uuid)
+                       destroying := svm.destroying
                        svm.Unlock()
-                       if killed || wantCrashEarly {
+                       if killed || wantCrashEarly || destroying {
                                return
                        }
 
@@ -433,6 +479,10 @@ func (si stubInstance) Destroy() error {
        if math_rand.Float64() < si.svm.sis.driver.ErrorRateDestroy {
                return errors.New("instance could not be destroyed")
        }
+       si.svm.Lock()
+       si.svm.destroying = true
+       si.svm.Unlock()
+       si.svm.stubprocs.Wait()
        si.svm.SSHService.Close()
        sis.mtx.Lock()
        defer sis.mtx.Unlock()
@@ -489,3 +539,9 @@ func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
 func (si stubInstance) PriceHistory(arvados.InstanceType) []cloud.InstancePrice {
        return nil
 }
+
+type QuotaError struct {
+       error
+}
+
+func (QuotaError) IsQuotaError() bool { return true }