14262: Refactoring proxy
authorPeter Amstutz <pamstutz@veritasgenetics.com>
Thu, 18 Oct 2018 20:08:28 +0000 (16:08 -0400)
committerPeter Amstutz <pamstutz@veritasgenetics.com>
Tue, 30 Oct 2018 18:12:02 +0000 (14:12 -0400)
Split proxy.Do() into ForwardRequest() and ForwardResponse().

Inversion of control eliminates need for "filter" callback, since the
caller can now modify the response in between the calls to
ForwardRequest() and ForwardResponse().

Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz@veritasgenetics.com>

lib/controller/federation.go
lib/controller/handler.go
lib/controller/proxy.go

index f303655747d72ba5a10e11e455152ef566cf797f..c5089fa23512907b26fb499a3fbe17c61e2c6762 100644 (file)
@@ -44,17 +44,10 @@ type collectionFederatedRequestHandler struct {
        handler *Handler
 }
 
-func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
+func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
        remote, ok := h.Cluster.RemoteClusters[remoteID]
        if !ok {
-               err := fmt.Errorf("no proxy available for cluster %v", remoteID)
-               if filter != nil {
-                       _, err = filter(nil, err)
-               }
-               if err != nil {
-                       httpserver.Error(w, err.Error(), http.StatusNotFound)
-               }
-               return
+               return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
        }
        scheme := remote.Scheme
        if scheme == "" {
@@ -62,13 +55,7 @@ func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, r
        }
        saltedReq, err := h.saltAuthToken(req, remoteID)
        if err != nil {
-               if filter != nil {
-                       _, err = filter(nil, err)
-               }
-               if err != nil {
-                       httpserver.Error(w, err.Error(), http.StatusBadRequest)
-               }
-               return
+               return nil, err
        }
        urlOut := &url.URL{
                Scheme:   scheme,
@@ -81,7 +68,7 @@ func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, r
        if remote.Insecure {
                client = h.insecureClient
        }
-       h.proxy.Do(w, saltedReq, urlOut, client, filter)
+       return h.proxy.ForwardRequest(saltedReq, urlOut, client)
 }
 
 // Buffer request body, parse form parameters in request, and then
@@ -179,13 +166,14 @@ func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
 
                rc := multiClusterQueryResponseCollector{clusterID: clusterID}
 
+               var resp *http.Response
                if clusterID == h.handler.Cluster.ClusterID {
-                       h.handler.localClusterRequest(w, &remoteReq,
-                               rc.collectResponse)
+                       resp, err = h.handler.localClusterRequest(&remoteReq)
                } else {
-                       h.handler.remoteClusterRequest(clusterID, w, &remoteReq,
-                               rc.collectResponse)
+                       resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
                }
+               rc.collectResponse(resp, err)
+
                if rc.error != nil {
                        return nil, "", rc.error
                }
@@ -412,16 +400,14 @@ func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *h
        if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
                h.next.ServeHTTP(w, req)
        } else {
-               h.handler.remoteClusterRequest(clusterId, w, req, nil)
+               resp, err := h.handler.remoteClusterRequest(clusterId, req)
+               h.handler.proxy.ForwardResponse(w, resp, err)
        }
 }
 
-type rewriteSignaturesClusterId struct {
-       clusterID  string
-       expectHash string
-}
+func rewriteSignatures(clusterID string, expectHash string,
+       resp *http.Response, requestError error) (newResponse *http.Response, err error) {
 
-func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
        if requestError != nil {
                return resp, requestError
        }
@@ -471,7 +457,7 @@ func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requ
                        m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
                        if m != nil {
                                // Rewrite the block signature to be a remote signature
-                               _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], rw.clusterID, m[5][2:], m[8])
+                               _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
                                if err != nil {
                                        return nil, fmt.Errorf("Error updating manifest: %v", err)
                                }
@@ -499,17 +485,17 @@ func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requ
 
        // Check that expected hash is consistent with
        // portable_data_hash field of the returned record
-       if rw.expectHash == "" {
-               rw.expectHash = col.PortableDataHash
-       } else if rw.expectHash != col.PortableDataHash {
-               return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", rw.expectHash, col.PortableDataHash)
+       if expectHash == "" {
+               expectHash = col.PortableDataHash
+       } else if expectHash != col.PortableDataHash {
+               return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
        }
 
        // Certify that the computed hash of the manifest_text matches our expectation
        sum := hasher.Sum(nil)
        computedHash := fmt.Sprintf("%x+%v", sum, sz)
-       if computedHash != rw.expectHash {
-               return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, rw.expectHash)
+       if computedHash != expectHash {
+               return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
        }
 
        col.ManifestText = updatedManifest.String()
@@ -585,7 +571,7 @@ func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Respo
        // also have made it to this point, and we don't want a
        // slow response holding the lock to block a faster response
        // that is waiting on the lock.
-       newResponse, err = rewriteSignaturesClusterId{s.remoteID, s.pdh}.rewriteSignatures(resp, nil)
+       newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
 
        s.mtx.Lock()
 
@@ -628,8 +614,9 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
 
                if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
                        // request for remote collection by uuid
-                       h.handler.remoteClusterRequest(clusterId, w, req,
-                               rewriteSignaturesClusterId{clusterId, ""}.rewriteSignatures)
+                       resp, err := h.handler.remoteClusterRequest(clusterId, req)
+                       newResponse, err := rewriteSignatures(clusterId, "", resp, err)
+                       h.handler.proxy.ForwardResponse(w, newResponse, err)
                        return
                }
                // not a collection UUID request, or it is a request
@@ -642,7 +629,10 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
        // Request for collection by PDH.  Search the federation.
 
        // First, query the local cluster.
-       if h.handler.localClusterRequest(w, req, filterLocalClusterResponse) {
+       resp, err := h.handler.localClusterRequest(req)
+       newResp, err := filterLocalClusterResponse(resp, err)
+       if newResp != nil || err != nil {
+               h.handler.proxy.ForwardResponse(w, newResp, err)
                return
        }
 
@@ -680,7 +670,11 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
                        &sharedContext, cancelFunc, &errors, &errorCode}
                wg.Add(1)
                go func() {
-                       h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse)
+                       resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
+                       newResp, err := search.filterRemoteClusterResponse(resp, err)
+                       if newResp != nil || err != nil {
+                               h.handler.proxy.ForwardResponse(w, newResp, err)
+                       }
                        wg.Done()
                        <-sem
                }()
index 0c31815cba21f2869e7ae4ddf73c880bf4d0a5c8..5e9012949bece7d74144c51d88c8a25eb3fe248e 100644 (file)
@@ -121,14 +121,10 @@ func prepend(next http.Handler, middleware middlewareFunc) http.Handler {
        })
 }
 
-// localClusterRequest sets up a request so it can be proxied to the
-// local API server using proxy.Do().  Returns true if a response was
-// written, false if not.
-func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request, filter ResponseFilter) bool {
+func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error) {
        urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
        if err != nil {
-               httpserver.Error(w, err.Error(), http.StatusInternalServerError)
-               return true
+               return nil, err
        }
        urlOut = &url.URL{
                Scheme:   urlOut.Scheme,
@@ -141,12 +137,14 @@ func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request,
        if insecure {
                client = h.insecureClient
        }
-       return h.proxy.Do(w, req, urlOut, client, filter)
+       return h.proxy.ForwardRequest(req, urlOut, client)
 }
 
 func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
-       if !h.localClusterRequest(w, req, nil) && next != nil {
-               next.ServeHTTP(w, req)
+       resp, err := h.localClusterRequest(req)
+       n, err := h.proxy.ForwardResponse(w, resp, err)
+       if err != nil {
+               httpserver.Logger(req).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
        }
 }
 
index 951cb9d25fe24ba74a5697d54187847cfc84ae1a..9aecdc1b2c6bd9cb81877eed5c4ede52eee3ba21 100644 (file)
@@ -19,6 +19,15 @@ type proxy struct {
        RequestTimeout time.Duration
 }
 
+type HTTPError struct {
+       Message string
+       Code    int
+}
+
+func (h HTTPError) Error() string {
+       return h.Message
+}
+
 // headers that shouldn't be forwarded when proxying. See
 // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
 var dropHeaders = map[string]bool{
@@ -36,15 +45,11 @@ var dropHeaders = map[string]bool{
 
 type ResponseFilter func(*http.Response, error) (*http.Response, error)
 
-// Do sends a request, passes the result to the filter (if provided)
-// and then if the result is not suppressed by the filter, sends the
-// request to the ResponseWriter.  Returns true if a response was written,
-// false if not.
-func (p *proxy) Do(w http.ResponseWriter,
+// Forward a request to downstream service, and return response or error.
+func (p *proxy) ForwardRequest(
        reqIn *http.Request,
        urlOut *url.URL,
-       client *http.Client,
-       filter ResponseFilter) bool {
+       client *http.Client) (*http.Response, error) {
 
        // Copy headers from incoming request, then add/replace proxy
        // headers like Via and X-Forwarded-For.
@@ -79,50 +84,26 @@ func (p *proxy) Do(w http.ResponseWriter,
                Body:   reqIn.Body,
        }).WithContext(ctx)
 
-       resp, err := client.Do(reqOut)
-       if filter == nil && err != nil {
-               httpserver.Error(w, err.Error(), http.StatusBadGateway)
-               return true
-       }
-
-       // make sure original response body gets closed
-       var originalBody io.ReadCloser
-       if resp != nil {
-               originalBody = resp.Body
-               if originalBody != nil {
-                       defer originalBody.Close()
-               }
-       }
-
-       if filter != nil {
-               resp, err = filter(resp, err)
+       return client.Do(reqOut)
+}
 
-               if err != nil {
+// Copy a response (or error) to the upstream client
+func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
+       if err != nil {
+               if he, ok := err.(HTTPError); ok {
+                       httpserver.Error(w, he.Message, he.Code)
+               } else {
                        httpserver.Error(w, err.Error(), http.StatusBadGateway)
-                       return true
-               }
-               if resp == nil {
-                       // filter() returned a nil response, this means suppress
-                       // writing a response, for the case where there might
-                       // be multiple response writers.
-                       return false
-               }
-
-               // the filter gave us a new response body, make sure that gets closed too.
-               if resp.Body != originalBody {
-                       defer resp.Body.Close()
                }
+               return 0, nil
        }
 
+       defer resp.Body.Close()
        for k, v := range resp.Header {
                for _, v := range v {
                        w.Header().Add(k, v)
                }
        }
        w.WriteHeader(resp.StatusCode)
-       n, err := io.Copy(w, resp.Body)
-       if err != nil {
-               httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
-       }
-       return true
+       return io.Copy(w, resp.Body)
 }