X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/8b1aca5c3415bfee3b4bc242596e1ee68ddef354..a42604972cccf8dd9c8341c260927a6c48c62b84:/sdk/go/arvados/container_gateway.go diff --git a/sdk/go/arvados/container_gateway.go b/sdk/go/arvados/container_gateway.go index 00c98d572e..897ae434e1 100644 --- a/sdk/go/arvados/container_gateway.go +++ b/sdk/go/arvados/container_gateway.go @@ -14,14 +14,17 @@ import ( "github.com/sirupsen/logrus" ) -func (sshconn ContainerSSHConnection) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Request) { + defer cresp.Conn.Close() hj, ok := w.(http.Hijacker) if !ok { http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError) return } w.Header().Set("Connection", "upgrade") - w.Header().Set("Upgrade", "ssh") + for k, v := range cresp.Header { + w.Header()[k] = v + } w.WriteHeader(http.StatusSwitchingProtocols) conn, bufrw, err := hj.Hijack() if err != nil { @@ -31,44 +34,46 @@ func (sshconn ContainerSSHConnection) ServeHTTP(w http.ResponseWriter, req *http defer conn.Close() var bytesIn, bytesOut int64 + ctx, cancel := context.WithCancel(req.Context()) var wg sync.WaitGroup - ctx, cancel := context.WithCancel(context.Background()) wg.Add(1) go func() { defer wg.Done() defer cancel() - n, err := io.CopyN(conn, sshconn.Bufrw, int64(sshconn.Bufrw.Reader.Buffered())) + n, err := io.CopyN(conn, cresp.Bufrw, int64(cresp.Bufrw.Reader.Buffered())) bytesOut += n if err == nil { - n, err = io.Copy(conn, sshconn.Conn) + n, err = io.Copy(conn, cresp.Conn) bytesOut += n } if err != nil { - ctxlog.FromContext(req.Context()).WithError(err).Error("error copying downstream") + ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream") } }() wg.Add(1) go func() { defer wg.Done() defer cancel() - n, err := io.CopyN(sshconn.Conn, bufrw, int64(bufrw.Reader.Buffered())) + n, err := io.CopyN(cresp.Conn, bufrw, int64(bufrw.Reader.Buffered())) bytesIn += n if err == nil { - n, err = io.Copy(sshconn.Conn, conn) + n, err = io.Copy(cresp.Conn, conn) bytesIn += n } if err != nil { - ctxlog.FromContext(req.Context()).WithError(err).Error("error copying upstream") + ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream") } }() <-ctx.Done() - if sshconn.Logger != nil { - go func() { - wg.Wait() - sshconn.Logger.WithFields(logrus.Fields{ + go func() { + // Wait for both io.Copy goroutines to finish and increment + // their byte counters. + wg.Wait() + if cresp.Logger != nil { + cresp.Logger.WithFields(logrus.Fields{ "bytesIn": bytesIn, "bytesOut": bytesOut, }).Info("closed connection") - }() - } + } + }() }