14807: Merge branch 'master'
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
index a2231673fcd08c2230259b915adc41a8f9a3ed58..5873e492213b86f58eaa98850c5c00c073cd2aee 100644 (file)
@@ -6,9 +6,11 @@ package test
 
 import (
        "crypto/rand"
+       "encoding/json"
        "errors"
        "fmt"
        "io"
+       "io/ioutil"
        math_rand "math/rand"
        "regexp"
        "strings"
@@ -17,7 +19,6 @@ import (
 
        "git.curoverse.com/arvados.git/lib/cloud"
        "git.curoverse.com/arvados.git/sdk/go/arvados"
-       "github.com/mitchellh/mapstructure"
        "github.com/sirupsen/logrus"
        "golang.org/x/crypto/ssh"
 )
@@ -55,16 +56,22 @@ type StubDriver struct {
 }
 
 // InstanceSet returns a new *StubInstanceSet.
-func (sd *StubDriver) InstanceSet(params map[string]interface{}, id cloud.InstanceSetID, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
+func (sd *StubDriver) InstanceSet(params json.RawMessage, id cloud.InstanceSetID, logger logrus.FieldLogger) (cloud.InstanceSet, error) {
        if sd.holdCloudOps == nil {
                sd.holdCloudOps = make(chan bool)
        }
        sis := StubInstanceSet{
                driver:  sd,
+               logger:  logger,
                servers: map[cloud.InstanceID]*StubVM{},
        }
        sd.instanceSets = append(sd.instanceSets, &sis)
-       return &sis, mapstructure.Decode(params, &sis)
+
+       var err error
+       if params != nil {
+               err = json.Unmarshal(params, &sis)
+       }
+       return &sis, err
 }
 
 // InstanceSets returns all instances that have been created by the
@@ -85,6 +92,7 @@ func (sd *StubDriver) ReleaseCloudOps(n int) {
 
 type StubInstanceSet struct {
        driver  *StubDriver
+       logger  logrus.FieldLogger
        servers map[cloud.InstanceID]*StubVM
        mtx     sync.RWMutex
        stopped bool
@@ -93,7 +101,7 @@ type StubInstanceSet struct {
        allowInstancesCall time.Time
 }
 
-func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, authKey ssh.PublicKey) (cloud.Instance, error) {
+func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, cmd cloud.InitCommand, authKey ssh.PublicKey) (cloud.Instance, error) {
        if sis.driver.HoldCloudOps {
                sis.driver.holdCloudOps <- true
        }
@@ -117,9 +125,11 @@ func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID,
                id:           cloud.InstanceID(fmt.Sprintf("stub-%s-%x", it.ProviderType, math_rand.Int63())),
                tags:         copyTags(tags),
                providerType: it.ProviderType,
+               initCommand:  cmd,
        }
        svm.SSHService = SSHService{
                HostKey:        sis.driver.HostKey,
+               AuthorizedUser: "root",
                AuthorizedKeys: ak,
                Exec:           svm.Exec,
        }
@@ -177,6 +187,7 @@ type StubVM struct {
        sis          *StubInstanceSet
        id           cloud.InstanceID
        tags         cloud.InstanceTags
+       initCommand  cloud.InitCommand
        providerType string
        SSHService   SSHService
        running      map[string]bool
@@ -200,6 +211,11 @@ func (svm *StubVM) Instance() stubInstance {
 }
 
 func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
+       stdinData, err := ioutil.ReadAll(stdin)
+       if err != nil {
+               fmt.Fprintf(stderr, "error reading stdin: %s\n", err)
+               return 1
+       }
        queue := svm.sis.driver.Queue
        uuid := regexp.MustCompile(`.{5}-dz642-.{15}`).FindString(command)
        if eta := svm.Boot.Sub(time.Now()); eta > 0 {
@@ -214,10 +230,16 @@ func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader,
                fmt.Fprint(stderr, "crunch-run: command not found\n")
                return 1
        }
-       if strings.HasPrefix(command, "crunch-run --detach ") {
+       if strings.HasPrefix(command, "crunch-run --detach --stdin-env ") {
+               var stdinKV map[string]string
+               err := json.Unmarshal(stdinData, &stdinKV)
+               if err != nil {
+                       fmt.Fprintf(stderr, "unmarshal stdin: %s (stdin was: %q)\n", err, stdinData)
+                       return 1
+               }
                for _, name := range []string{"ARVADOS_API_HOST", "ARVADOS_API_TOKEN"} {
-                       if env[name] == "" {
-                               fmt.Fprintf(stderr, "%s missing from environment %q\n", name, env)
+                       if stdinKV[name] == "" {
+                               fmt.Fprintf(stderr, "%s env var missing from stdin %q\n", name, stdin)
                                return 1
                        }
                }
@@ -229,7 +251,7 @@ func (svm *StubVM) Exec(env map[string]string, command string, stdin io.Reader,
                svm.Unlock()
                time.Sleep(svm.CrunchRunDetachDelay)
                fmt.Fprintf(stderr, "starting %s\n", uuid)
-               logger := logrus.WithFields(logrus.Fields{
+               logger := svm.sis.logger.WithFields(logrus.Fields{
                        "Instance":      svm.id,
                        "ContainerUUID": uuid,
                })
@@ -314,6 +336,10 @@ func (si stubInstance) Address() string {
        return si.addr
 }
 
+func (si stubInstance) RemoteUser() string {
+       return si.svm.SSHService.AuthorizedUser
+}
+
 func (si stubInstance) Destroy() error {
        sis := si.svm.sis
        if sis.driver.HoldCloudOps {