16306: Merge branch 'master'
[arvados.git] / lib / dispatchcloud / test / ssh_service.go
index b1e4e03b12ea142e925b45fe689217a499e59bb9..31919b566df81769f791782b4cb3709b468122e8 100644 (file)
@@ -18,6 +18,8 @@ import (
        check "gopkg.in/check.v1"
 )
 
+// LoadTestKey returns a public/private ssh keypair, read from the files
+// identified by the path of the private key.
 func LoadTestKey(c *check.C, fnm string) (ssh.PublicKey, ssh.Signer) {
        rawpubkey, err := ioutil.ReadFile(fnm + ".pub")
        c.Assert(err, check.IsNil)
@@ -32,13 +34,14 @@ func LoadTestKey(c *check.C, fnm string) (ssh.PublicKey, ssh.Signer) {
 
 // An SSHExecFunc handles an "exec" session on a multiplexed SSH
 // connection.
-type SSHExecFunc func(command string, stdin io.Reader, stdout, stderr io.Writer) uint32
+type SSHExecFunc func(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
 
 // An SSHService accepts SSH connections on an available TCP port and
 // passes clients' "exec" sessions to the provided SSHExecFunc.
 type SSHService struct {
        Exec           SSHExecFunc
        HostKey        ssh.Signer
+       AuthorizedUser string
        AuthorizedKeys []ssh.PublicKey
 
        listener net.Listener
@@ -64,6 +67,11 @@ func (ss *SSHService) Address() string {
        return ln.Addr().String()
 }
 
+// RemoteUser returns the username that will be accepted.
+func (ss *SSHService) RemoteUser() string {
+       return ss.AuthorizedUser
+}
+
 // Close shuts down the server and releases resources. Established
 // connections are unaffected.
 func (ss *SSHService) Close() {
@@ -103,7 +111,7 @@ func (ss *SSHService) run() {
        }
        config.AddHostKey(ss.HostKey)
 
-       listener, err := net.Listen("tcp", ":")
+       listener, err := net.Listen("tcp", "127.0.0.1:")
        if err != nil {
                ss.err = err
                return
@@ -146,22 +154,37 @@ func (ss *SSHService) serveConn(nConn net.Conn, config *ssh.ServerConfig) {
                        log.Printf("accept channel: %s", err)
                        return
                }
-               var execReq struct {
-                       Command string
-               }
+               didExec := false
+               sessionEnv := map[string]string{}
                go func() {
                        for req := range reqs {
-                               if req.Type == "exec" && execReq.Command == "" {
+                               switch {
+                               case didExec:
+                                       // Reject anything after exec
+                                       req.Reply(false, nil)
+                               case req.Type == "exec":
+                                       var execReq struct {
+                                               Command string
+                                       }
                                        req.Reply(true, nil)
                                        ssh.Unmarshal(req.Payload, &execReq)
                                        go func() {
                                                var resp struct {
                                                        Status uint32
                                                }
-                                               resp.Status = ss.Exec(execReq.Command, ch, ch, ch.Stderr())
+                                               resp.Status = ss.Exec(sessionEnv, execReq.Command, ch, ch, ch.Stderr())
                                                ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
                                                ch.Close()
                                        }()
+                                       didExec = true
+                               case req.Type == "env":
+                                       var envReq struct {
+                                               Name  string
+                                               Value string
+                                       }
+                                       req.Reply(true, nil)
+                                       ssh.Unmarshal(req.Payload, &envReq)
+                                       sessionEnv[envReq.Name] = envReq.Value
                                }
                        }
                }()