ec16ee2be9d9f345810274d252acec579eaf7ddd
[arvados.git] / sdk / go / arvados / container_gateway.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "context"
9         "io"
10         "net/http"
11         "sync"
12
13         "git.arvados.org/arvados.git/sdk/go/ctxlog"
14         "github.com/sirupsen/logrus"
15 )
16
17 func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Request) {
18         defer cresp.Conn.Close()
19         hj, ok := w.(http.Hijacker)
20         if !ok {
21                 http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError)
22                 return
23         }
24         w.Header().Set("Connection", "upgrade")
25         for k, v := range cresp.Header {
26                 w.Header()[k] = v
27         }
28         w.WriteHeader(http.StatusSwitchingProtocols)
29         conn, bufrw, err := hj.Hijack()
30         if err != nil {
31                 ctxlog.FromContext(req.Context()).WithError(err).Error("error hijacking ResponseWriter")
32                 return
33         }
34         defer conn.Close()
35
36         var bytesIn, bytesOut int64
37         var wg sync.WaitGroup
38         ctx, cancel := context.WithCancel(req.Context())
39         wg.Add(1)
40         go func() {
41                 defer wg.Done()
42                 defer cancel()
43                 n, err := io.CopyN(conn, cresp.Bufrw, int64(cresp.Bufrw.Reader.Buffered()))
44                 bytesOut += n
45                 if err == nil {
46                         n, err = io.Copy(conn, cresp.Conn)
47                         bytesOut += n
48                 }
49                 if err != nil {
50                         ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream")
51                 }
52                 conn.Close()
53         }()
54         wg.Add(1)
55         go func() {
56                 defer wg.Done()
57                 defer cancel()
58                 n, err := io.CopyN(cresp.Conn, bufrw, int64(bufrw.Reader.Buffered()))
59                 bytesIn += n
60                 if err == nil {
61                         n, err = io.Copy(cresp.Conn, conn)
62                         bytesIn += n
63                 }
64                 if err != nil {
65                         ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream")
66                 }
67                 cresp.Conn.Close()
68         }()
69         wg.Wait()
70         if cresp.Logger != nil {
71                 cresp.Logger.WithFields(logrus.Fields{
72                         "bytesIn":  bytesIn,
73                         "bytesOut": bytesOut,
74                 }).Info("closed connection")
75         }
76 }