X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/56e130608f8977d20b21c54f6ab8973d71e045a0..135bada0fe08de2b678ede684d43a155c4351ed3:/sdk/go/arvados/container_gateway.go diff --git a/sdk/go/arvados/container_gateway.go b/sdk/go/arvados/container_gateway.go index 07f8c0793c..897ae434e1 100644 --- a/sdk/go/arvados/container_gateway.go +++ b/sdk/go/arvados/container_gateway.go @@ -8,18 +8,23 @@ import ( "context" "io" "net/http" + "sync" "git.arvados.org/arvados.git/sdk/go/ctxlog" + "github.com/sirupsen/logrus" ) -func (sshconn ContainerSSHConnection) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Request) { + defer cresp.Conn.Close() hj, ok := w.(http.Hijacker) if !ok { http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError) return } w.Header().Set("Connection", "upgrade") - w.Header().Set("Upgrade", "ssh") + for k, v := range cresp.Header { + w.Header()[k] = v + } w.WriteHeader(http.StatusSwitchingProtocols) conn, bufrw, err := hj.Hijack() if err != nil { @@ -28,26 +33,47 @@ func (sshconn ContainerSSHConnection) ServeHTTP(w http.ResponseWriter, req *http } defer conn.Close() - ctx, cancel := context.WithCancel(context.Background()) + var bytesIn, bytesOut int64 + ctx, cancel := context.WithCancel(req.Context()) + var wg sync.WaitGroup + wg.Add(1) go func() { + defer wg.Done() defer cancel() - _, err := io.CopyN(conn, sshconn.Bufrw, int64(sshconn.Bufrw.Reader.Buffered())) + n, err := io.CopyN(conn, cresp.Bufrw, int64(cresp.Bufrw.Reader.Buffered())) + bytesOut += n if err == nil { - _, err = io.Copy(conn, sshconn.Conn) + n, err = io.Copy(conn, cresp.Conn) + bytesOut += n } if err != nil { - ctxlog.FromContext(req.Context()).WithError(err).Error("error copying downstream") + ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream") } }() + wg.Add(1) go func() { + defer wg.Done() defer cancel() - _, err := io.CopyN(sshconn.Conn, bufrw, int64(bufrw.Reader.Buffered())) + n, err := io.CopyN(cresp.Conn, bufrw, int64(bufrw.Reader.Buffered())) + bytesIn += n if err == nil { - _, err = io.Copy(sshconn.Conn, conn) + n, err = io.Copy(cresp.Conn, conn) + bytesIn += n } if err != nil { - ctxlog.FromContext(req.Context()).WithError(err).Error("error copying upstream") + ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream") } }() <-ctx.Done() + go func() { + // Wait for both io.Copy goroutines to finish and increment + // their byte counters. + wg.Wait() + if cresp.Logger != nil { + cresp.Logger.WithFields(logrus.Fields{ + "bytesIn": bytesIn, + "bytesOut": bytesOut, + }).Info("closed connection") + } + }() }