19146: Remove unneeded special case checks, explain the needed one.
[arvados.git] / lib / controller / proxy.go
index 0f5d43076a5b86ffec29e50735bd1b143289db43..13dfcac16abb0bb27c7b1f3d50d024436453f97c 100644 (file)
@@ -5,41 +5,55 @@
 package controller
 
 import (
-       "context"
        "io"
        "net/http"
        "net/url"
-       "time"
 
-       "git.curoverse.com/arvados.git/sdk/go/httpserver"
+       "git.arvados.org/arvados.git/sdk/go/httpserver"
 )
 
 type proxy struct {
-       Name           string // to use in Via header
-       RequestTimeout time.Duration
+       Name string // to use in Via header
+}
+
+type HTTPError struct {
+       Message string
+       Code    int
+}
+
+func (h HTTPError) Error() string {
+       return h.Message
 }
 
-// headers that shouldn't be forwarded when proxying. See
-// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
 var dropHeaders = map[string]bool{
+       // Headers that shouldn't be forwarded when proxying. See
+       // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
        "Connection":          true,
        "Keep-Alive":          true,
        "Proxy-Authenticate":  true,
        "Proxy-Authorization": true,
-       "TE":                true,
-       "Trailer":           true,
+       // (comment/space here makes gofmt1.10 agree with gofmt1.11)
+       "TE":      true,
+       "Trailer": true,
+       "Upgrade": true,
+
+       // Headers that would interfere with Go's automatic
+       // compression/decompression if we forwarded them.
+       "Accept-Encoding":   true,
+       "Content-Encoding":  true,
        "Transfer-Encoding": true,
-       "Content-Encoding":  true, // interfers with Go's automatic compression/decompression
-       "Upgrade":           true,
+
+       // Content-Length depends on encoding.
+       "Content-Length": true,
 }
 
 type ResponseFilter func(*http.Response, error) (*http.Response, error)
 
-func (p *proxy) Do(w http.ResponseWriter,
+// Forward a request to upstream service, and return response or error.
+func (p *proxy) Do(
        reqIn *http.Request,
        urlOut *url.URL,
-       client *http.Client,
-       filter ResponseFilter) {
+       client *http.Client) (*http.Response, error) {
 
        // Copy headers from incoming request, then add/replace proxy
        // headers like Via and X-Forwarded-For.
@@ -59,61 +73,33 @@ func (p *proxy) Do(w http.ResponseWriter,
        }
        hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
 
-       ctx := reqIn.Context()
-       if p.RequestTimeout > 0 {
-               var cancel context.CancelFunc
-               ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
-               defer cancel()
-       }
-
        reqOut := (&http.Request{
                Method: reqIn.Method,
                URL:    urlOut,
                Host:   reqIn.Host,
                Header: hdrOut,
                Body:   reqIn.Body,
-       }).WithContext(ctx)
-
-       resp, err := client.Do(reqOut)
-       if filter == nil && err != nil {
-               httpserver.Error(w, err.Error(), http.StatusBadGateway)
-               return
-       }
-
-       // make sure original response body gets closed
-       originalBody := resp.Body
-       if originalBody != nil {
-               defer originalBody.Close()
-       }
-
-       if filter != nil {
-               resp, err = filter(resp, err)
+       }).WithContext(reqIn.Context())
+       return client.Do(reqOut)
+}
 
-               if err != nil {
+// Copy a response (or error) to the downstream client
+func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
+       if err != nil {
+               if he, ok := err.(HTTPError); ok {
+                       httpserver.Error(w, he.Message, he.Code)
+               } else {
                        httpserver.Error(w, err.Error(), http.StatusBadGateway)
-                       return
-               }
-               if resp == nil {
-                       // filter() returned a nil response, this means suppress
-                       // writing a response, for the case where there might
-                       // be multiple response writers.
-                       return
-               }
-
-               // the filter gave us a new response body, make sure that gets closed too.
-               if resp.Body != originalBody {
-                       defer resp.Body.Close()
                }
+               return 0, nil
        }
 
+       defer resp.Body.Close()
        for k, v := range resp.Header {
                for _, v := range v {
                        w.Header().Add(k, v)
                }
        }
        w.WriteHeader(resp.StatusCode)
-       n, err := io.Copy(w, resp.Body)
-       if err != nil {
-               httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
-       }
+       return io.Copy(w, resp.Body)
 }