13619: MultiClusterQuery passes test
[arvados.git] / lib / controller / proxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "context"
9         "io"
10         "net/http"
11         "net/url"
12         "time"
13
14         "git.curoverse.com/arvados.git/sdk/go/httpserver"
15 )
16
17 type proxy struct {
18         Name           string // to use in Via header
19         RequestTimeout time.Duration
20 }
21
22 // headers that shouldn't be forwarded when proxying. See
23 // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
24 var dropHeaders = map[string]bool{
25         "Connection":          true,
26         "Keep-Alive":          true,
27         "Proxy-Authenticate":  true,
28         "Proxy-Authorization": true,
29         "TE":                true,
30         "Trailer":           true,
31         "Transfer-Encoding": true, // *-Encoding headers interfer with Go's automatic compression/decompression
32         "Content-Encoding":  true,
33         "Accept-Encoding":   true,
34         "Upgrade":           true,
35 }
36
37 type ResponseFilter func(*http.Response, error) (*http.Response, error)
38
39 // Do sends a request, passes the result to the filter (if provided)
40 // and then if the result is not suppressed by the filter, sends the
41 // request to the ResponseWriter.  Returns true if a response was written,
42 // false if not.
43 func (p *proxy) Do(w http.ResponseWriter,
44         reqIn *http.Request,
45         urlOut *url.URL,
46         client *http.Client,
47         filter ResponseFilter) bool {
48
49         // Copy headers from incoming request, then add/replace proxy
50         // headers like Via and X-Forwarded-For.
51         hdrOut := http.Header{}
52         for k, v := range reqIn.Header {
53                 if !dropHeaders[k] {
54                         hdrOut[k] = v
55                 }
56         }
57         xff := reqIn.RemoteAddr
58         if xffIn := reqIn.Header.Get("X-Forwarded-For"); xffIn != "" {
59                 xff = xffIn + "," + xff
60         }
61         hdrOut.Set("X-Forwarded-For", xff)
62         if hdrOut.Get("X-Forwarded-Proto") == "" {
63                 hdrOut.Set("X-Forwarded-Proto", reqIn.URL.Scheme)
64         }
65         hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
66
67         ctx := reqIn.Context()
68         if p.RequestTimeout > 0 {
69                 var cancel context.CancelFunc
70                 ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
71                 defer cancel()
72         }
73
74         reqOut := (&http.Request{
75                 Method: reqIn.Method,
76                 URL:    urlOut,
77                 Host:   reqIn.Host,
78                 Header: hdrOut,
79                 Body:   reqIn.Body,
80         }).WithContext(ctx)
81
82         resp, err := client.Do(reqOut)
83         if filter == nil && err != nil {
84                 httpserver.Error(w, err.Error(), http.StatusBadGateway)
85                 return true
86         }
87
88         // make sure original response body gets closed
89         var originalBody io.ReadCloser
90         if resp != nil {
91                 originalBody = resp.Body
92                 if originalBody != nil {
93                         defer originalBody.Close()
94                 }
95         }
96
97         if filter != nil {
98                 resp, err = filter(resp, err)
99
100                 if err != nil {
101                         httpserver.Error(w, err.Error(), http.StatusBadGateway)
102                         return true
103                 }
104                 if resp == nil {
105                         // filter() returned a nil response, this means suppress
106                         // writing a response, for the case where there might
107                         // be multiple response writers.
108                         return false
109                 }
110
111                 // the filter gave us a new response body, make sure that gets closed too.
112                 if resp.Body != originalBody {
113                         defer resp.Body.Close()
114                 }
115         }
116
117         for k, v := range resp.Header {
118                 for _, v := range v {
119                         w.Header().Add(k, v)
120                 }
121         }
122         w.WriteHeader(resp.StatusCode)
123         n, err := io.Copy(w, resp.Body)
124         if err != nil {
125                 httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
126         }
127         return true
128 }