17800: Fixes bug.
[arvados.git] / sdk / go / arvados / container_gateway.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "context"
9         "io"
10         "net/http"
11         "sync"
12
13         "git.arvados.org/arvados.git/sdk/go/ctxlog"
14         "github.com/sirupsen/logrus"
15 )
16
17 func (sshconn ContainerSSHConnection) ServeHTTP(w http.ResponseWriter, req *http.Request) {
18         hj, ok := w.(http.Hijacker)
19         if !ok {
20                 http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError)
21                 return
22         }
23         w.Header().Set("Connection", "upgrade")
24         w.Header().Set("Upgrade", "ssh")
25         w.WriteHeader(http.StatusSwitchingProtocols)
26         conn, bufrw, err := hj.Hijack()
27         if err != nil {
28                 ctxlog.FromContext(req.Context()).WithError(err).Error("error hijacking ResponseWriter")
29                 return
30         }
31         defer conn.Close()
32
33         var bytesIn, bytesOut int64
34         var wg sync.WaitGroup
35         ctx, cancel := context.WithCancel(context.Background())
36         wg.Add(1)
37         go func() {
38                 defer wg.Done()
39                 defer cancel()
40                 n, err := io.CopyN(conn, sshconn.Bufrw, int64(sshconn.Bufrw.Reader.Buffered()))
41                 bytesOut += n
42                 if err == nil {
43                         n, err = io.Copy(conn, sshconn.Conn)
44                         bytesOut += n
45                 }
46                 if err != nil {
47                         ctxlog.FromContext(req.Context()).WithError(err).Error("error copying downstream")
48                 }
49         }()
50         wg.Add(1)
51         go func() {
52                 defer wg.Done()
53                 defer cancel()
54                 n, err := io.CopyN(sshconn.Conn, bufrw, int64(bufrw.Reader.Buffered()))
55                 bytesIn += n
56                 if err == nil {
57                         n, err = io.Copy(sshconn.Conn, conn)
58                         bytesIn += n
59                 }
60                 if err != nil {
61                         ctxlog.FromContext(req.Context()).WithError(err).Error("error copying upstream")
62                 }
63         }()
64         <-ctx.Done()
65         if sshconn.Logger != nil {
66                 go func() {
67                         wg.Wait()
68                         sshconn.Logger.WithFields(logrus.Fields{
69                                 "bytesIn":  bytesIn,
70                                 "bytesOut": bytesOut,
71                         }).Info("closed connection")
72                 }()
73         }
74 }