X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/f1f74069850d8c5e987ef7d7fc246735ff94d58d..45828d11792f00d732b7d8e667db8b45b7a0f8b4:/lib/crunchrun/container_gateway.go diff --git a/lib/crunchrun/container_gateway.go b/lib/crunchrun/container_gateway.go index 02df06cf21..3cb93fc746 100644 --- a/lib/crunchrun/container_gateway.go +++ b/lib/crunchrun/container_gateway.go @@ -168,6 +168,10 @@ func (gw *Gateway) Start() error { if err != nil { return err } + go func() { + err := srv.Wait() + gw.Log.Printf("gateway server stopped: %s", err) + }() // Get the port number we are listening on (extPort might be // "0" or a port name, in which case this will be different). _, listenPort, err := net.SplitHostPort(srv.Addr) @@ -184,6 +188,7 @@ func (gw *Gateway) Start() error { // non-tunnel connections aren't available; and PORT is the // port number we are listening on. gw.Address = net.JoinHostPort(extHost, listenPort) + gw.Log.Printf("gateway server listening at %s", gw.Address) if gw.ArvadosClient != nil { go gw.maintainTunnel(gw.Address) } @@ -218,16 +223,16 @@ func (gw *Gateway) runTunnel(addr string) error { gw.UpdateTunnelURL(url) } for { - muxconn, err := mux.Accept() + muxconn, err := mux.AcceptStream() if err != nil { return err } - gw.Log.Printf("receiving connection from tunnel, remoteAddr %s", muxconn.RemoteAddr().String()) + gw.Log.Printf("tunnel connection %d started", muxconn.StreamID()) go func() { defer muxconn.Close() gwconn, err := net.Dial("tcp", addr) if err != nil { - gw.Log.Printf("error connecting to %s on behalf of tunnel connection: %s", addr, err) + gw.Log.Printf("tunnel connection %d: error connecting to %s: %s", muxconn.StreamID(), addr, err) return } defer gwconn.Close() @@ -235,13 +240,22 @@ func (gw *Gateway) runTunnel(addr string) error { wg.Add(2) go func() { defer wg.Done() - io.Copy(gwconn, muxconn) + _, err := io.Copy(gwconn, muxconn) + if err != nil { + gw.Log.Printf("tunnel connection %d: mux end: %s", muxconn.StreamID(), err) + } + gwconn.Close() }() go func() { defer wg.Done() - io.Copy(muxconn, gwconn) + _, err := io.Copy(muxconn, gwconn) + if err != nil { + gw.Log.Printf("tunnel connection %d: gateway end: %s", muxconn.StreamID(), err) + } + muxconn.Close() }() wg.Wait() + gw.Log.Printf("tunnel connection %d finished", muxconn.StreamID()) }() } } @@ -386,9 +400,11 @@ func (gw *Gateway) handleDirectTCPIP(ctx context.Context, newch ssh.NewChannel) func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, detachKeys, username string) { ch, reqs, err := newch.Accept() if err != nil { - gw.Log.Printf("accept session channel: %s", err) + gw.Log.Printf("error accepting session channel: %s", err) return } + defer ch.Close() + var pty0, tty0 *os.File // Where to send errors/messages for the client to see logw := io.Writer(ch.Stderr()) @@ -397,10 +413,28 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta eol := "\n" // Env vars to add to child process termEnv := []string(nil) - for req := range reqs { + + started := 0 + wantClose := make(chan struct{}) + for { + var req *ssh.Request + select { + case r, ok := <-reqs: + if !ok { + return + } + req = r + case <-wantClose: + return + } ok := false switch req.Type { case "shell", "exec": + if started++; started != 1 { + // RFC 4254 6.5: "Only one of these + // requests can succeed per channel." + break + } ok = true var payload struct { Command string @@ -420,7 +454,7 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta } defer func() { ch.SendRequest("exit-status", false, ssh.Marshal(&resp)) - ch.Close() + close(wantClose) }() cmd, err := gw.Target.InjectCommand(ctx, detachKeys, username, tty0 != nil, execargs) @@ -430,20 +464,39 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta resp.Status = 1 return } - cmd.Stdin = ch - cmd.Stdout = ch - cmd.Stderr = ch.Stderr() if tty0 != nil { cmd.Stdin = tty0 cmd.Stdout = tty0 cmd.Stderr = tty0 - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(2) - go func() { io.Copy(ch, pty0); wg.Done() }() - go func() { io.Copy(pty0, ch); wg.Done() }() + go io.Copy(ch, pty0) + go io.Copy(pty0, ch) // Send our own debug messages to tty as well. logw = tty0 + } else { + // StdinPipe may seem + // superfluous here, but it's + // not: it causes cmd.Run() to + // return when the subprocess + // exits. Without it, Run() + // waits for stdin to close, + // which causes "ssh ... echo + // ok" (with the client's + // stdin connected to a + // terminal or something) to + // hang. + stdin, err := cmd.StdinPipe() + if err != nil { + fmt.Fprintln(ch.Stderr(), err) + ch.CloseWrite() + resp.Status = 1 + return + } + go func() { + io.Copy(stdin, ch) + stdin.Close() + }() + cmd.Stdout = ch + cmd.Stderr = ch.Stderr() } cmd.SysProcAttr = &syscall.SysProcAttr{ Setctty: tty0 != nil, @@ -511,7 +564,7 @@ func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, deta // would be a gaping security // hole). default: - // fmt.Fprintf(logw, "declining %q req"+eol, req.Type) + // fmt.Fprintf(logw, "declined request %q on ssh channel"+eol, req.Type) } if req.WantReply { req.Reply(ok, nil)