19166: Close ssh session when exec/shell command exits.
[arvados.git] / sdk / go / arvados / container_gateway.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "context"
9         "io"
10         "net/http"
11         "sync"
12
13         "git.arvados.org/arvados.git/sdk/go/ctxlog"
14         "github.com/sirupsen/logrus"
15 )
16
17 func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Request) {
18         defer cresp.Conn.Close()
19         hj, ok := w.(http.Hijacker)
20         if !ok {
21                 http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError)
22                 return
23         }
24         w.Header().Set("Connection", "upgrade")
25         for k, v := range cresp.Header {
26                 w.Header()[k] = v
27         }
28         w.WriteHeader(http.StatusSwitchingProtocols)
29         conn, bufrw, err := hj.Hijack()
30         if err != nil {
31                 ctxlog.FromContext(req.Context()).WithError(err).Error("error hijacking ResponseWriter")
32                 return
33         }
34         defer conn.Close()
35
36         var bytesIn, bytesOut int64
37         ctx, cancel := context.WithCancel(req.Context())
38         var wg sync.WaitGroup
39         wg.Add(1)
40         go func() {
41                 defer wg.Done()
42                 defer cancel()
43                 n, err := io.CopyN(conn, cresp.Bufrw, int64(cresp.Bufrw.Reader.Buffered()))
44                 bytesOut += n
45                 if err == nil {
46                         n, err = io.Copy(conn, cresp.Conn)
47                         bytesOut += n
48                 }
49                 if err != nil {
50                         ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream")
51                 }
52         }()
53         wg.Add(1)
54         go func() {
55                 defer wg.Done()
56                 defer cancel()
57                 n, err := io.CopyN(cresp.Conn, bufrw, int64(bufrw.Reader.Buffered()))
58                 bytesIn += n
59                 if err == nil {
60                         n, err = io.Copy(cresp.Conn, conn)
61                         bytesIn += n
62                 }
63                 if err != nil {
64                         ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream")
65                 }
66         }()
67         <-ctx.Done()
68         go func() {
69                 // Wait for both io.Copy goroutines to finish and increment
70                 // their byte counters.
71                 wg.Wait()
72                 if cresp.Logger != nil {
73                         cresp.Logger.WithFields(logrus.Fields{
74                                 "bytesIn":  bytesIn,
75                                 "bytesOut": bytesOut,
76                         }).Info("closed connection")
77                 }
78         }()
79 }