ce33fb3105a218a537fe8b0c28cf122041b962a5
[arvados.git] / sdk / go / arvados / container_gateway.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "context"
9         "io"
10         "net/http"
11         "sync"
12
13         "git.arvados.org/arvados.git/sdk/go/ctxlog"
14         "github.com/sirupsen/logrus"
15 )
16
17 func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Request) {
18         hj, ok := w.(http.Hijacker)
19         if !ok {
20                 http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError)
21                 return
22         }
23         w.Header().Set("Connection", "upgrade")
24         for k, v := range cresp.Header {
25                 w.Header()[k] = v
26         }
27         w.WriteHeader(http.StatusSwitchingProtocols)
28         conn, bufrw, err := hj.Hijack()
29         if err != nil {
30                 ctxlog.FromContext(req.Context()).WithError(err).Error("error hijacking ResponseWriter")
31                 return
32         }
33         defer conn.Close()
34
35         var bytesIn, bytesOut int64
36         var wg sync.WaitGroup
37         ctx, cancel := context.WithCancel(req.Context())
38         wg.Add(1)
39         go func() {
40                 defer wg.Done()
41                 defer cancel()
42                 n, err := io.CopyN(conn, cresp.Bufrw, int64(cresp.Bufrw.Reader.Buffered()))
43                 bytesOut += n
44                 if err == nil {
45                         n, err = io.Copy(conn, cresp.Conn)
46                         bytesOut += n
47                 }
48                 if err != nil {
49                         ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream")
50                 }
51         }()
52         wg.Add(1)
53         go func() {
54                 defer wg.Done()
55                 defer cancel()
56                 n, err := io.CopyN(cresp.Conn, bufrw, int64(bufrw.Reader.Buffered()))
57                 bytesIn += n
58                 if err == nil {
59                         n, err = io.Copy(cresp.Conn, conn)
60                         bytesIn += n
61                 }
62                 if err != nil {
63                         ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream")
64                 }
65         }()
66         wg.Wait()
67         if cresp.Logger != nil {
68                 cresp.Logger.WithFields(logrus.Fields{
69                         "bytesIn":  bytesIn,
70                         "bytesOut": bytesOut,
71                 }).Info("closed connection")
72         }
73 }