8465: Set proper container flags to provide data on stdin.
[arvados.git] / services / crunch-run / crunchrun.go
index a05f61a858c04321f41ff7c8a3faf97a87618431..6406c1742001b7bdde263ba31342ce2574c18029 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "bytes"
        "context"
        "encoding/json"
        "errors"
@@ -361,6 +362,13 @@ func (runner *ContainerRunner) SetupMounts() (err error) {
                        }
                }
 
+               if bind == "stdin" {
+                       // Is it a "collection" mount kind?
+                       if mnt.Kind != "collection" && mnt.Kind != "json" {
+                               return fmt.Errorf("Unsupported mount kind '%s' for stdin. Only 'collection' or 'json' are supported.", mnt.Kind)
+                       }
+               }
+
                if bind == "/etc/arvados/ca-certificates.crt" {
                        needCertMount = false
                }
@@ -372,7 +380,7 @@ func (runner *ContainerRunner) SetupMounts() (err error) {
                }
 
                switch {
-               case mnt.Kind == "collection":
+               case mnt.Kind == "collection" && bind != "stdin":
                        var src string
                        if mnt.UUID != "" && mnt.PortableDataHash != "" {
                                return fmt.Errorf("Cannot specify both 'uuid' and 'portable_data_hash' for a collection mount")
@@ -663,8 +671,38 @@ func (runner *ContainerRunner) AttachStreams() (err error) {
 
        runner.CrunchLog.Print("Attaching container streams")
 
+       // If stdin mount is provided, attach it to the docker container
+       var stdinRdr keepclient.Reader
+       var stdinJson []byte
+       if stdinMnt, ok := runner.Container.Mounts["stdin"]; ok {
+               if stdinMnt.Kind == "collection" {
+                       var stdinColl arvados.Collection
+                       collId := stdinMnt.UUID
+                       if collId == "" {
+                               collId = stdinMnt.PortableDataHash
+                       }
+                       err = runner.ArvClient.Get("collections", collId, nil, &stdinColl)
+                       if err != nil {
+                               return fmt.Errorf("While getting stding collection: %v", err)
+                       }
+
+                       stdinRdr, err = runner.Kc.ManifestFileReader(manifest.Manifest{Text: stdinColl.ManifestText}, stdinMnt.Path)
+                       if os.IsNotExist(err) {
+                               return fmt.Errorf("stdin collection path not found: %v", stdinMnt.Path)
+                       } else if err != nil {
+                               return fmt.Errorf("While getting stdin collection path %v: %v", stdinMnt.Path, err)
+                       }
+               } else if stdinMnt.Kind == "json" {
+                       stdinJson, err = json.Marshal(stdinMnt.Content)
+                       if err != nil {
+                               return fmt.Errorf("While encoding stdin json data: %v", err)
+                       }
+               }
+       }
+
+       stdinUsed := stdinRdr != nil || len(stdinJson) != 0
        response, err := runner.Docker.ContainerAttach(context.TODO(), runner.ContainerID,
-               dockertypes.ContainerAttachOptions{Stream: true, Stdout: true, Stderr: true})
+               dockertypes.ContainerAttachOptions{Stream: true, Stdin: stdinUsed, Stdout: true, Stderr: true})
        if err != nil {
                return fmt.Errorf("While attaching container stdout/stderr streams: %v", err)
        }
@@ -698,6 +736,27 @@ func (runner *ContainerRunner) AttachStreams() (err error) {
        }
        runner.Stderr = NewThrottledLogger(runner.NewLogWriter("stderr"))
 
+       if stdinRdr != nil {
+               go func() {
+                       _, err := io.Copy(response.Conn, stdinRdr)
+                       if err != nil {
+                               runner.CrunchLog.Print("While writing stdin collection to docker container %q", err)
+                               runner.stop()
+                       }
+                       stdinRdr.Close()
+                       response.CloseWrite()
+               }()
+       } else if len(stdinJson) != 0 {
+               go func() {
+                       _, err := io.Copy(response.Conn, bytes.NewReader(stdinJson))
+                       if err != nil {
+                               runner.CrunchLog.Print("While writing stdin json to docker container %q", err)
+                               runner.stop()
+                       }
+                       response.CloseWrite()
+               }()
+       }
+
        go runner.ProcessDockerAttach(response.Reader)
 
        return nil
@@ -743,6 +802,13 @@ func (runner *ContainerRunner) CreateContainer() error {
                }
        }
 
+       _, stdinUsed := runner.Container.Mounts["stdin"]
+       runner.ContainerConfig.OpenStdin = stdinUsed
+       runner.ContainerConfig.StdinOnce = stdinUsed
+       runner.ContainerConfig.AttachStdin = stdinUsed
+       runner.ContainerConfig.AttachStdout = true
+       runner.ContainerConfig.AttachStderr = true
+
        createdBody, err := runner.Docker.ContainerCreate(context.TODO(), &runner.ContainerConfig, &runner.HostConfig, nil, runner.Container.UUID)
        if err != nil {
                return fmt.Errorf("While creating container: %v", err)