From: Tom Clegg Date: Wed, 6 Jul 2022 18:44:57 +0000 (-0400) Subject: 19166: Close ssh session when exec/shell command exits. X-Git-Tag: 2.5.0~106^2~4 X-Git-Url: https://git.arvados.org/arvados.git/commitdiff_plain/a42604972cccf8dd9c8341c260927a6c48c62b84 19166: Close ssh session when exec/shell command exits. Arvados-DCO-1.1-Signed-off-by: Tom Clegg --- diff --git a/cmd/arvados-client/container_gateway.go b/cmd/arvados-client/container_gateway.go index aca6c5b797..55f8c33bc7 100644 --- a/cmd/arvados-client/container_gateway.go +++ b/cmd/arvados-client/container_gateway.go @@ -160,7 +160,9 @@ Options: fmt.Fprintf(stderr, "target UUID is not a container or container request UUID: %s\n", targetUUID) return 1 } - sshconn, err := rpcconn.ContainerSSH(context.TODO(), arvados.ContainerSSHOptions{ + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sshconn, err := rpcconn.ContainerSSH(ctx, arvados.ContainerSSHOptions{ UUID: targetUUID, DetachKeys: *detachKeys, LoginUsername: loginUsername, @@ -176,7 +178,6 @@ Options: return 0 } - ctx, cancel := context.WithCancel(context.Background()) go func() { defer cancel() _, err := io.Copy(stdout, sshconn.Conn) diff --git a/lib/crunchrun/container_gateway.go b/lib/crunchrun/container_gateway.go index 6fae73798c..1002de7335 100644 --- a/lib/crunchrun/container_gateway.go +++ b/lib/crunchrun/container_gateway.go @@ -242,18 +242,16 @@ func (gw *Gateway) runTunnel(addr string) error { defer wg.Done() _, err := io.Copy(gwconn, muxconn) if err != nil { - gw.Log.Printf("tunnel connection %d: tunnel: %s", muxconn.StreamID(), err) + gw.Log.Printf("tunnel connection %d: mux end: %s", muxconn.StreamID(), err) } - muxconn.Close() gwconn.Close() }() go func() { defer wg.Done() _, err := io.Copy(muxconn, gwconn) if err != nil { - gw.Log.Printf("tunnel connection %d: gateway: %s", muxconn.StreamID(), err) + gw.Log.Printf("tunnel connection %d: gateway end: %s", muxconn.StreamID(), err) } - gwconn.Close() muxconn.Close() }() wg.Wait() @@ -402,9 +400,11 @@ func (gw *Gateway) handleDirectTCPIP(ctx context.Context, newch ssh.NewChannel) func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, detachKeys, username string) { ch, reqs, err := newch.Accept() if err != nil { - gw.Log.Printf("accept session channel: %s", err) + gw.Log.Printf("error accepting session channel: %s", err) return } + defer ch.Close() + var pty0, tty0 *os.File // Where to send errors/messages for the client to see logw := io.Writer(ch.Stderr()) @@ -413,10 +413,28 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta eol := "\n" // Env vars to add to child process termEnv := []string(nil) - for req := range reqs { + + started := 0 + wantClose := make(chan struct{}) + for { + var req *ssh.Request + select { + case r, ok := <-reqs: + if !ok { + return + } + req = r + case <-wantClose: + return + } ok := false switch req.Type { case "shell", "exec": + if started++; started != 1 { + // RFC 4254 6.5: "Only one of these + // requests can succeed per channel." + break + } ok = true var payload struct { Command string @@ -436,7 +454,7 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta } defer func() { ch.SendRequest("exit-status", false, ssh.Marshal(&resp)) - ch.Close() + close(wantClose) }() cmd, err := gw.Target.InjectCommand(ctx, detachKeys, username, tty0 != nil, execargs) @@ -446,20 +464,39 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta resp.Status = 1 return } - cmd.Stdin = ch - cmd.Stdout = ch - cmd.Stderr = ch.Stderr() if tty0 != nil { cmd.Stdin = tty0 cmd.Stdout = tty0 cmd.Stderr = tty0 - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(2) - go func() { io.Copy(ch, pty0); wg.Done() }() - go func() { io.Copy(pty0, ch); wg.Done() }() + go io.Copy(ch, pty0) + go io.Copy(pty0, ch) // Send our own debug messages to tty as well. logw = tty0 + } else { + // StdinPipe may seem + // superfluous here, but it's + // not: it causes cmd.Run() to + // return when the subprocess + // exits. Without it, Run() + // waits for stdin to close, + // which causes "ssh ... echo + // ok" (with the client's + // stdin connected to a + // terminal or something) to + // hang. + stdin, err := cmd.StdinPipe() + if err != nil { + fmt.Fprintln(ch.Stderr(), err) + ch.CloseWrite() + resp.Status = 1 + return + } + go func() { + io.Copy(stdin, ch) + stdin.Close() + }() + cmd.Stdout = ch + cmd.Stderr = ch.Stderr() } cmd.SysProcAttr = &syscall.SysProcAttr{ Setctty: tty0 != nil, @@ -527,7 +564,7 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta // would be a gaping security // hole). default: - // fmt.Fprintf(logw, "declining %q req"+eol, req.Type) + fmt.Fprintf(logw, "declined request %q on ssh channel"+eol, req.Type) } if req.WantReply { req.Reply(ok, nil) diff --git a/sdk/go/arvados/container_gateway.go b/sdk/go/arvados/container_gateway.go index ec16ee2be9..897ae434e1 100644 --- a/sdk/go/arvados/container_gateway.go +++ b/sdk/go/arvados/container_gateway.go @@ -34,8 +34,8 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque defer conn.Close() var bytesIn, bytesOut int64 - var wg sync.WaitGroup ctx, cancel := context.WithCancel(req.Context()) + var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() @@ -49,7 +49,6 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque if err != nil { ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream") } - conn.Close() }() wg.Add(1) go func() { @@ -64,13 +63,17 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque if err != nil { ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream") } - cresp.Conn.Close() }() - wg.Wait() - if cresp.Logger != nil { - cresp.Logger.WithFields(logrus.Fields{ - "bytesIn": bytesIn, - "bytesOut": bytesOut, - }).Info("closed connection") - } + <-ctx.Done() + go func() { + // Wait for both io.Copy goroutines to finish and increment + // their byte counters. + wg.Wait() + if cresp.Logger != nil { + cresp.Logger.WithFields(logrus.Fields{ + "bytesIn": bytesIn, + "bytesOut": bytesOut, + }).Info("closed connection") + } + }() }