14262: Make sure cancel() from proxy.Do() gets called
[arvados.git] / lib / controller / proxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "context"
9         "io"
10         "net/http"
11         "net/url"
12         "time"
13
14         "git.curoverse.com/arvados.git/sdk/go/httpserver"
15 )
16
17 type proxy struct {
18         Name           string // to use in Via header
19         RequestTimeout time.Duration
20 }
21
22 type HTTPError struct {
23         Message string
24         Code    int
25 }
26
27 func (h HTTPError) Error() string {
28         return h.Message
29 }
30
31 // headers that shouldn't be forwarded when proxying. See
32 // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
33 var dropHeaders = map[string]bool{
34         "Connection":          true,
35         "Keep-Alive":          true,
36         "Proxy-Authenticate":  true,
37         "Proxy-Authorization": true,
38         "TE":                true,
39         "Trailer":           true,
40         "Transfer-Encoding": true, // *-Encoding headers interfer with Go's automatic compression/decompression
41         "Content-Encoding":  true,
42         "Accept-Encoding":   true,
43         "Upgrade":           true,
44 }
45
46 type ResponseFilter func(*http.Response, error) (*http.Response, error)
47
48 // Forward a request to upstream service, and return response or error.
49 func (p *proxy) Do(
50         reqIn *http.Request,
51         urlOut *url.URL,
52         client *http.Client) (*http.Response, context.CancelFunc, error) {
53
54         // Copy headers from incoming request, then add/replace proxy
55         // headers like Via and X-Forwarded-For.
56         hdrOut := http.Header{}
57         for k, v := range reqIn.Header {
58                 if !dropHeaders[k] {
59                         hdrOut[k] = v
60                 }
61         }
62         xff := reqIn.RemoteAddr
63         if xffIn := reqIn.Header.Get("X-Forwarded-For"); xffIn != "" {
64                 xff = xffIn + "," + xff
65         }
66         hdrOut.Set("X-Forwarded-For", xff)
67         if hdrOut.Get("X-Forwarded-Proto") == "" {
68                 hdrOut.Set("X-Forwarded-Proto", reqIn.URL.Scheme)
69         }
70         hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
71
72         ctx := reqIn.Context()
73         var cancel context.CancelFunc
74         if p.RequestTimeout > 0 {
75                 ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
76         }
77
78         reqOut := (&http.Request{
79                 Method: reqIn.Method,
80                 URL:    urlOut,
81                 Host:   reqIn.Host,
82                 Header: hdrOut,
83                 Body:   reqIn.Body,
84         }).WithContext(ctx)
85
86         resp, err := client.Do(reqOut)
87         return resp, cancel, err
88 }
89
90 // Copy a response (or error) to the downstream client
91 func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
92         if err != nil {
93                 if he, ok := err.(HTTPError); ok {
94                         httpserver.Error(w, he.Message, he.Code)
95                 } else {
96                         httpserver.Error(w, err.Error(), http.StatusBadGateway)
97                 }
98                 return 0, nil
99         }
100
101         defer resp.Body.Close()
102         for k, v := range resp.Header {
103                 for _, v := range v {
104                         w.Header().Add(k, v)
105                 }
106         }
107         w.WriteHeader(resp.StatusCode)
108         return io.Copy(w, resp.Body)
109 }