19166: Close ssh session when exec/shell command exits.
authorTom Clegg <tom@curii.com>
Wed, 6 Jul 2022 18:44:57 +0000 (14:44 -0400)
committerTom Clegg <tom@curii.com>
Wed, 6 Jul 2022 18:44:57 +0000 (14:44 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

cmd/arvados-client/container_gateway.go
lib/crunchrun/container_gateway.go
sdk/go/arvados/container_gateway.go

index aca6c5b797fa4ec3b036ee8300ae3f4fcbe5e885..55f8c33bc70c77d31f13f16bb924ee4c2a6a1613 100644 (file)
@@ -160,7 +160,9 @@ Options:
                fmt.Fprintf(stderr, "target UUID is not a container or container request UUID: %s\n", targetUUID)
                return 1
        }
                fmt.Fprintf(stderr, "target UUID is not a container or container request UUID: %s\n", targetUUID)
                return 1
        }
-       sshconn, err := rpcconn.ContainerSSH(context.TODO(), arvados.ContainerSSHOptions{
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       sshconn, err := rpcconn.ContainerSSH(ctx, arvados.ContainerSSHOptions{
                UUID:          targetUUID,
                DetachKeys:    *detachKeys,
                LoginUsername: loginUsername,
                UUID:          targetUUID,
                DetachKeys:    *detachKeys,
                LoginUsername: loginUsername,
@@ -176,7 +178,6 @@ Options:
                return 0
        }
 
                return 0
        }
 
-       ctx, cancel := context.WithCancel(context.Background())
        go func() {
                defer cancel()
                _, err := io.Copy(stdout, sshconn.Conn)
        go func() {
                defer cancel()
                _, err := io.Copy(stdout, sshconn.Conn)
index 6fae73798cc6263e330614103196f156e149ee39..1002de7335495e8d6c42e9afa151f0346a4c7267 100644 (file)
@@ -242,18 +242,16 @@ func (gw *Gateway) runTunnel(addr string) error {
                                defer wg.Done()
                                _, err := io.Copy(gwconn, muxconn)
                                if err != nil {
                                defer wg.Done()
                                _, err := io.Copy(gwconn, muxconn)
                                if err != nil {
-                                       gw.Log.Printf("tunnel connection %d: tunnel: %s", muxconn.StreamID(), err)
+                                       gw.Log.Printf("tunnel connection %d: mux end: %s", muxconn.StreamID(), err)
                                }
                                }
-                               muxconn.Close()
                                gwconn.Close()
                        }()
                        go func() {
                                defer wg.Done()
                                _, err := io.Copy(muxconn, gwconn)
                                if err != nil {
                                gwconn.Close()
                        }()
                        go func() {
                                defer wg.Done()
                                _, err := io.Copy(muxconn, gwconn)
                                if err != nil {
-                                       gw.Log.Printf("tunnel connection %d: gateway: %s", muxconn.StreamID(), err)
+                                       gw.Log.Printf("tunnel connection %d: gateway end: %s", muxconn.StreamID(), err)
                                }
                                }
-                               gwconn.Close()
                                muxconn.Close()
                        }()
                        wg.Wait()
                                muxconn.Close()
                        }()
                        wg.Wait()
@@ -402,9 +400,11 @@ func (gw *Gateway) handleDirectTCPIP(ctx context.Context, newch ssh.NewChannel)
 func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, detachKeys, username string) {
        ch, reqs, err := newch.Accept()
        if err != nil {
 func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, detachKeys, username string) {
        ch, reqs, err := newch.Accept()
        if err != nil {
-               gw.Log.Printf("accept session channel: %s", err)
+               gw.Log.Printf("error accepting session channel: %s", err)
                return
        }
                return
        }
+       defer ch.Close()
+
        var pty0, tty0 *os.File
        // Where to send errors/messages for the client to see
        logw := io.Writer(ch.Stderr())
        var pty0, tty0 *os.File
        // Where to send errors/messages for the client to see
        logw := io.Writer(ch.Stderr())
@@ -413,10 +413,28 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta
        eol := "\n"
        // Env vars to add to child process
        termEnv := []string(nil)
        eol := "\n"
        // Env vars to add to child process
        termEnv := []string(nil)
-       for req := range reqs {
+
+       started := 0
+       wantClose := make(chan struct{})
+       for {
+               var req *ssh.Request
+               select {
+               case r, ok := <-reqs:
+                       if !ok {
+                               return
+                       }
+                       req = r
+               case <-wantClose:
+                       return
+               }
                ok := false
                switch req.Type {
                case "shell", "exec":
                ok := false
                switch req.Type {
                case "shell", "exec":
+                       if started++; started != 1 {
+                               // RFC 4254 6.5: "Only one of these
+                               // requests can succeed per channel."
+                               break
+                       }
                        ok = true
                        var payload struct {
                                Command string
                        ok = true
                        var payload struct {
                                Command string
@@ -436,7 +454,7 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta
                                }
                                defer func() {
                                        ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
                                }
                                defer func() {
                                        ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
-                                       ch.Close()
+                                       close(wantClose)
                                }()
 
                                cmd, err := gw.Target.InjectCommand(ctx, detachKeys, username, tty0 != nil, execargs)
                                }()
 
                                cmd, err := gw.Target.InjectCommand(ctx, detachKeys, username, tty0 != nil, execargs)
@@ -446,20 +464,39 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta
                                        resp.Status = 1
                                        return
                                }
                                        resp.Status = 1
                                        return
                                }
-                               cmd.Stdin = ch
-                               cmd.Stdout = ch
-                               cmd.Stderr = ch.Stderr()
                                if tty0 != nil {
                                        cmd.Stdin = tty0
                                        cmd.Stdout = tty0
                                        cmd.Stderr = tty0
                                if tty0 != nil {
                                        cmd.Stdin = tty0
                                        cmd.Stdout = tty0
                                        cmd.Stderr = tty0
-                                       var wg sync.WaitGroup
-                                       defer wg.Wait()
-                                       wg.Add(2)
-                                       go func() { io.Copy(ch, pty0); wg.Done() }()
-                                       go func() { io.Copy(pty0, ch); wg.Done() }()
+                                       go io.Copy(ch, pty0)
+                                       go io.Copy(pty0, ch)
                                        // Send our own debug messages to tty as well.
                                        logw = tty0
                                        // Send our own debug messages to tty as well.
                                        logw = tty0
+                               } else {
+                                       // StdinPipe may seem
+                                       // superfluous here, but it's
+                                       // not: it causes cmd.Run() to
+                                       // return when the subprocess
+                                       // exits. Without it, Run()
+                                       // waits for stdin to close,
+                                       // which causes "ssh ... echo
+                                       // ok" (with the client's
+                                       // stdin connected to a
+                                       // terminal or something) to
+                                       // hang.
+                                       stdin, err := cmd.StdinPipe()
+                                       if err != nil {
+                                               fmt.Fprintln(ch.Stderr(), err)
+                                               ch.CloseWrite()
+                                               resp.Status = 1
+                                               return
+                                       }
+                                       go func() {
+                                               io.Copy(stdin, ch)
+                                               stdin.Close()
+                                       }()
+                                       cmd.Stdout = ch
+                                       cmd.Stderr = ch.Stderr()
                                }
                                cmd.SysProcAttr = &syscall.SysProcAttr{
                                        Setctty: tty0 != nil,
                                }
                                cmd.SysProcAttr = &syscall.SysProcAttr{
                                        Setctty: tty0 != nil,
@@ -527,7 +564,7 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta
                        // would be a gaping security
                        // hole).
                default:
                        // would be a gaping security
                        // hole).
                default:
-                       // fmt.Fprintf(logw, "declining %q req"+eol, req.Type)
+                       fmt.Fprintf(logw, "declined request %q on ssh channel"+eol, req.Type)
                }
                if req.WantReply {
                        req.Reply(ok, nil)
                }
                if req.WantReply {
                        req.Reply(ok, nil)
index ec16ee2be9d9f345810274d252acec579eaf7ddd..897ae434e14bd0a0392d041a125a598b2c1d8b34 100644 (file)
@@ -34,8 +34,8 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque
        defer conn.Close()
 
        var bytesIn, bytesOut int64
        defer conn.Close()
 
        var bytesIn, bytesOut int64
-       var wg sync.WaitGroup
        ctx, cancel := context.WithCancel(req.Context())
        ctx, cancel := context.WithCancel(req.Context())
+       var wg sync.WaitGroup
        wg.Add(1)
        go func() {
                defer wg.Done()
        wg.Add(1)
        go func() {
                defer wg.Done()
@@ -49,7 +49,6 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque
                if err != nil {
                        ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream")
                }
                if err != nil {
                        ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream")
                }
-               conn.Close()
        }()
        wg.Add(1)
        go func() {
        }()
        wg.Add(1)
        go func() {
@@ -64,13 +63,17 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque
                if err != nil {
                        ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream")
                }
                if err != nil {
                        ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream")
                }
-               cresp.Conn.Close()
        }()
        }()
-       wg.Wait()
-       if cresp.Logger != nil {
-               cresp.Logger.WithFields(logrus.Fields{
-                       "bytesIn":  bytesIn,
-                       "bytesOut": bytesOut,
-               }).Info("closed connection")
-       }
+       <-ctx.Done()
+       go func() {
+               // Wait for both io.Copy goroutines to finish and increment
+               // their byte counters.
+               wg.Wait()
+               if cresp.Logger != nil {
+                       cresp.Logger.WithFields(logrus.Fields{
+                               "bytesIn":  bytesIn,
+                               "bytesOut": bytesOut,
+                       }).Info("closed connection")
+               }
+       }()
 }
 }