14262: Make sure cancel() from proxy.Do() gets called
authorPeter Amstutz <pamstutz@veritasgenetics.com>
Mon, 29 Oct 2018 19:36:45 +0000 (15:36 -0400)
committerPeter Amstutz <pamstutz@veritasgenetics.com>
Tue, 30 Oct 2018 18:12:02 +0000 (14:12 -0400)
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz@veritasgenetics.com>

lib/controller/fed_collections.go
lib/controller/fed_containers.go
lib/controller/fed_generic.go
lib/controller/federation.go
lib/controller/federation_test.go
lib/controller/handler.go
lib/controller/proxy.go
sdk/go/httpserver/id_generator.go

index 70dbdc3f51b54e7f7268638489663cd6d682a1a5..8a97c25c94eb35eb57cfa0e751950db03ec92fd1 100644 (file)
@@ -34,7 +34,7 @@ func rewriteSignatures(clusterID string, expectHash string,
                return resp, requestError
        }
 
-       if resp.StatusCode != 200 {
+       if resp.StatusCode != http.StatusOK {
                return resp, nil
        }
 
@@ -140,7 +140,7 @@ func filterLocalClusterResponse(resp *http.Response, requestError error) (newRes
                return resp, requestError
        }
 
-       if resp.StatusCode == 404 {
+       if resp.StatusCode == http.StatusNotFound {
                // Suppress returning this result, because we want to
                // search the federation.
                return nil, nil
@@ -174,12 +174,11 @@ func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Respo
                return nil, nil
        }
 
-       if resp.StatusCode != 200 {
+       if resp.StatusCode != http.StatusOK {
                // Suppress returning unsuccessful result.  Maybe
                // another request will find it.
-               // TODO collect and return error responses.
-               *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", httpserver.GetRequestID(resp.Header), s.remoteID, resp.Status))
-               if resp.StatusCode != 404 {
+               *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", resp.Header.Get(httpserver.HeaderRequestID), s.remoteID, resp.Status))
+               if resp.StatusCode != http.StatusNotFound {
                        // Got a non-404 error response, convert into BadGateway
                        *s.statusCode = http.StatusBadGateway
                }
@@ -236,7 +235,10 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
 
                if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
                        // request for remote collection by uuid
-                       resp, err := h.handler.remoteClusterRequest(clusterId, req)
+                       resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
+                       if cancel != nil {
+                               defer cancel()
+                       }
                        newResponse, err := rewriteSignatures(clusterId, "", resp, err)
                        h.handler.proxy.ForwardResponse(w, newResponse, err)
                        return
@@ -251,7 +253,10 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
        // Request for collection by PDH.  Search the federation.
 
        // First, query the local cluster.
-       resp, err := h.handler.localClusterRequest(req)
+       resp, localClusterRequestCancel, err := h.handler.localClusterRequest(req)
+       if localClusterRequestCancel != nil {
+               defer localClusterRequestCancel()
+       }
        newResp, err := filterLocalClusterResponse(resp, err)
        if newResp != nil || err != nil {
                h.handler.proxy.ForwardResponse(w, newResp, err)
@@ -271,7 +276,7 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
        mtx := sync.Mutex{}
        wg := sync.WaitGroup{}
        var errors []string
-       var errorCode int = 404
+       var errorCode int = http.StatusNotFound
 
        // use channel as a semaphore to limit the number of concurrent
        // requests at a time
@@ -292,7 +297,10 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
                        &sharedContext, cancelFunc, &errors, &errorCode}
                wg.Add(1)
                go func() {
-                       resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
+                       resp, cancel, err := h.handler.remoteClusterRequest(search.remoteID, req)
+                       if cancel != nil {
+                               defer cancel()
+                       }
                        newResp, err := search.filterRemoteClusterResponse(resp, err)
                        if newResp != nil || err != nil {
                                h.handler.proxy.ForwardResponse(w, newResp, err)
index ccb2401bb78250c36eea03a033370c0f91f0fa47..a3c292583f2df626f2323449f93ff3752d746a3d 100644 (file)
@@ -95,7 +95,10 @@ func remoteContainerRequestCreate(
        req.ContentLength = int64(buf.Len())
        req.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
 
-       resp, err := h.handler.remoteClusterRequest(*clusterId, req)
+       resp, cancel, err := h.handler.remoteClusterRequest(*clusterId, req)
+       if cancel != nil {
+               defer cancel()
+       }
        h.handler.proxy.ForwardResponse(w, resp, err)
        return true
 }
index 63e61e6908f8b318ead4e151bd13dee302c815d3..7d5b63d3107a66384059403d6016dc5653e03dee 100644 (file)
@@ -6,6 +6,7 @@ package controller
 
 import (
        "bytes"
+       "context"
        "encoding/json"
        "fmt"
        "io/ioutil"
@@ -65,12 +66,16 @@ func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
                rc := multiClusterQueryResponseCollector{clusterID: clusterID}
 
                var resp *http.Response
+               var cancel context.CancelFunc
                if clusterID == h.handler.Cluster.ClusterID {
-                       resp, err = h.handler.localClusterRequest(&remoteReq)
+                       resp, cancel, err = h.handler.localClusterRequest(&remoteReq)
                } else {
-                       resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
+                       resp, cancel, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
                }
                rc.collectResponse(resp, err)
+               if cancel != nil {
+                       cancel()
+               }
 
                if rc.error != nil {
                        return nil, "", rc.error
@@ -304,7 +309,10 @@ func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *h
        if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
                h.next.ServeHTTP(w, req)
        } else {
-               resp, err := h.handler.remoteClusterRequest(clusterId, req)
+               resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
+               if cancel != nil {
+                       defer cancel()
+               }
                h.handler.proxy.ForwardResponse(w, resp, err)
        }
 }
index dc0aa908c59dc22242cadf8c5c05ad6831d00e0a..0e016f301da1d70536a3cec9890eec853e5b74ae 100644 (file)
@@ -6,6 +6,7 @@ package controller
 
 import (
        "bytes"
+       "context"
        "database/sql"
        "encoding/json"
        "fmt"
@@ -28,10 +29,10 @@ var containerRequestsRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "container
 var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
 var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
 
-func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
+func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, context.CancelFunc, error) {
        remote, ok := h.Cluster.RemoteClusters[remoteID]
        if !ok {
-               return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
+               return nil, nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
        }
        scheme := remote.Scheme
        if scheme == "" {
@@ -39,7 +40,7 @@ func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*htt
        }
        saltedReq, err := h.saltAuthToken(req, remoteID)
        if err != nil {
-               return nil, err
+               return nil, nil, err
        }
        urlOut := &url.URL{
                Scheme:   scheme,
@@ -52,7 +53,7 @@ func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*htt
        if remote.Insecure {
                client = h.insecureClient
        }
-       return h.proxy.ForwardRequest(saltedReq, urlOut, client)
+       return h.proxy.Do(saltedReq, urlOut, client)
 }
 
 // Buffer request body, parse form parameters in request, and then
index 7842ad05d7366b2c62d1a77fa042ae948a1f0f02..f6bfca30213017e959f3f67739fe786f650eb515 100644 (file)
@@ -94,8 +94,8 @@ func (s *FederationSuite) SetUpTest(c *check.C) {
 func (s *FederationSuite) remoteMockHandler(w http.ResponseWriter, req *http.Request) {
        b := &bytes.Buffer{}
        io.Copy(b, req.Body)
-       req.Body = ioutil.NopCloser(b)
        req.Body.Close()
+       req.Body = ioutil.NopCloser(b)
        s.remoteMockRequests = append(s.remoteMockRequests, *req)
 }
 
index 5e9012949bece7d74144c51d88c8a25eb3fe248e..cbfaaddab4955ba66f2592e152410e10db9b5564 100644 (file)
@@ -5,6 +5,7 @@
 package controller
 
 import (
+       "context"
        "database/sql"
        "errors"
        "net"
@@ -121,10 +122,10 @@ func prepend(next http.Handler, middleware middlewareFunc) http.Handler {
        })
 }
 
-func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error) {
+func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, context.CancelFunc, error) {
        urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
        if err != nil {
-               return nil, err
+               return nil, nil, err
        }
        urlOut = &url.URL{
                Scheme:   urlOut.Scheme,
@@ -137,11 +138,14 @@ func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error)
        if insecure {
                client = h.insecureClient
        }
-       return h.proxy.ForwardRequest(req, urlOut, client)
+       return h.proxy.Do(req, urlOut, client)
 }
 
 func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
-       resp, err := h.localClusterRequest(req)
+       resp, cancel, err := h.localClusterRequest(req)
+       if cancel != nil {
+               defer cancel()
+       }
        n, err := h.proxy.ForwardResponse(w, resp, err)
        if err != nil {
                httpserver.Logger(req).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
index b7f3c4f72acf4754884313e44f10793b3bf486ee..c89b9b36ae0cddfc67a127d1e6f8bdac84ab6412 100644 (file)
@@ -45,11 +45,11 @@ var dropHeaders = map[string]bool{
 
 type ResponseFilter func(*http.Response, error) (*http.Response, error)
 
-// Forward a request to downstream service, and return response or error.
-func (p *proxy) ForwardRequest(
+// Forward a request to upstream service, and return response or error.
+func (p *proxy) Do(
        reqIn *http.Request,
        urlOut *url.URL,
-       client *http.Client) (*http.Response, error) {
+       client *http.Client) (*http.Response, context.CancelFunc, error) {
 
        // Copy headers from incoming request, then add/replace proxy
        // headers like Via and X-Forwarded-For.
@@ -70,8 +70,9 @@ func (p *proxy) ForwardRequest(
        hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
 
        ctx := reqIn.Context()
+       var cancel context.CancelFunc
        if p.RequestTimeout > 0 {
-               ctx, _ = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
+               ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
        }
 
        reqOut := (&http.Request{
@@ -82,10 +83,11 @@ func (p *proxy) ForwardRequest(
                Body:   reqIn.Body,
        }).WithContext(ctx)
 
-       return client.Do(reqOut)
+       resp, err := client.Do(reqOut)
+       return resp, cancel, err
 }
 
-// Copy a response (or error) to the upstream client
+// Copy a response (or error) to the downstream client
 func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
        if err != nil {
                if he, ok := err.(HTTPError); ok {
index 6093a8a7b720b9eb6258b4873c4f7b52964e78e7..14d89873b60f7d902a39a6b337eea78e8040d0c3 100644 (file)
@@ -12,6 +12,10 @@ import (
        "time"
 )
 
+const (
+       HeaderRequestID = "X-Request-Id"
+)
+
 // IDGenerator generates alphanumeric strings suitable for use as
 // unique IDs (a given IDGenerator will never return the same ID
 // twice).
@@ -44,16 +48,12 @@ func (g *IDGenerator) Next() string {
 func AddRequestIDs(h http.Handler) http.Handler {
        gen := &IDGenerator{Prefix: "req-"}
        return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-               if req.Header.Get("X-Request-Id") == "" {
+               if req.Header.Get(HeaderRequestID) == "" {
                        if req.Header == nil {
                                req.Header = http.Header{}
                        }
-                       req.Header.Set("X-Request-Id", gen.Next())
+                       req.Header.Set(HeaderRequestID, gen.Next())
                }
                h.ServeHTTP(w, req)
        })
 }
-
-func GetRequestID(h http.Header) string {
-       return h.Get("X-Request-Id")
-}