Merge branch 'wtsi/13809-root_url-protocol-port-configuration'
[arvados.git] / lib / controller / proxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "context"
9         "io"
10         "net/http"
11         "net/url"
12         "time"
13
14         "git.curoverse.com/arvados.git/sdk/go/httpserver"
15 )
16
17 type proxy struct {
18         Name           string // to use in Via header
19         RequestTimeout time.Duration
20 }
21
22 // headers that shouldn't be forwarded when proxying. See
23 // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
24 var dropHeaders = map[string]bool{
25         "Connection":          true,
26         "Keep-Alive":          true,
27         "Proxy-Authenticate":  true,
28         "Proxy-Authorization": true,
29         "TE":                true,
30         "Trailer":           true,
31         "Transfer-Encoding": true,
32         "Upgrade":           true,
33 }
34
35 func (p *proxy) Do(w http.ResponseWriter, reqIn *http.Request, urlOut *url.URL, client *http.Client) {
36         // Copy headers from incoming request, then add/replace proxy
37         // headers like Via and X-Forwarded-For.
38         hdrOut := http.Header{}
39         for k, v := range reqIn.Header {
40                 if !dropHeaders[k] {
41                         hdrOut[k] = v
42                 }
43         }
44         xff := reqIn.RemoteAddr
45         if xffIn := reqIn.Header.Get("X-Forwarded-For"); xffIn != "" {
46                 xff = xffIn + "," + xff
47         }
48         hdrOut.Set("X-Forwarded-For", xff)
49         if hdrOut.Get("X-Forwarded-Proto") == "" {
50                 hdrOut.Set("X-Forwarded-Proto", reqIn.URL.Scheme)
51         }
52         hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
53
54         ctx := reqIn.Context()
55         if p.RequestTimeout > 0 {
56                 var cancel context.CancelFunc
57                 ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
58                 defer cancel()
59         }
60
61         reqOut := (&http.Request{
62                 Method: reqIn.Method,
63                 URL:    urlOut,
64                 Host:   reqIn.Host,
65                 Header: hdrOut,
66                 Body:   reqIn.Body,
67         }).WithContext(ctx)
68         resp, err := client.Do(reqOut)
69         if err != nil {
70                 httpserver.Error(w, err.Error(), http.StatusBadGateway)
71                 return
72         }
73         for k, v := range resp.Header {
74                 for _, v := range v {
75                         w.Header().Add(k, v)
76                 }
77         }
78         w.WriteHeader(resp.StatusCode)
79         n, err := io.Copy(w, resp.Body)
80         if err != nil {
81                 httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
82         }
83 }