14262: Move the context deadline to the top of the handler stack
authorPeter Amstutz <pamstutz@veritasgenetics.com>
Thu, 1 Nov 2018 14:19:18 +0000 (10:19 -0400)
committerPeter Amstutz <pamstutz@veritasgenetics.com>
Thu, 1 Nov 2018 14:19:42 +0000 (10:19 -0400)
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz@veritasgenetics.com>

lib/controller/fed_collections.go
lib/controller/fed_containers.go
lib/controller/fed_generic.go
lib/controller/federation.go
lib/controller/federation_test.go
lib/controller/handler.go
lib/controller/proxy.go

index 88b0f95a0267f3979e7b5cfff1f56fbcf3dc32e2..b9cd20582951505fe7b43c07c490be9507535720 100644 (file)
@@ -178,10 +178,7 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
 
                if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
                        // request for remote collection by uuid
-                       resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
-                       if cancel != nil {
-                               defer cancel()
-                       }
+                       resp, err := h.handler.remoteClusterRequest(clusterId, req)
                        newResponse, err := rewriteSignatures(clusterId, "", resp, err)
                        h.handler.proxy.ForwardResponse(w, newResponse, err)
                        return
@@ -196,10 +193,7 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
        // Request for collection by PDH.  Search the federation.
 
        // First, query the local cluster.
-       resp, localClusterRequestCancel, err := h.handler.localClusterRequest(req)
-       if localClusterRequestCancel != nil {
-               defer localClusterRequestCancel()
-       }
+       resp, err := h.handler.localClusterRequest(req)
        newResp, err := filterLocalClusterResponse(resp, err)
        if newResp != nil || err != nil {
                h.handler.proxy.ForwardResponse(w, newResp, err)
@@ -244,19 +238,13 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
                        default:
                        }
 
-                       resp, _, err := h.handler.remoteClusterRequest(remote, req)
+                       resp, err := h.handler.remoteClusterRequest(remote, req)
                        wasSuccess := false
                        defer func() {
                                if resp != nil && !wasSuccess {
                                        resp.Body.Close()
                                }
                        }()
-                       // Don't need to do anything with the cancel
-                       // function returned by remoteClusterRequest
-                       // because the context inherits from
-                       // sharedContext, so when sharedContext is
-                       // cancelled it should cancel that one as
-                       // well.
                        if err != nil {
                                errorChan <- err
                                return
index fc627d3fafeb7000e5e3d78eb0efed257b92abe3..e4c80a32cc16bd36d3b45a84077500e186e30304 100644 (file)
@@ -9,7 +9,6 @@ import (
        "encoding/json"
        "fmt"
        "io/ioutil"
-       "log"
        "net/http"
 
        "git.curoverse.com/arvados.git/sdk/go/auth"
@@ -64,8 +63,6 @@ func remoteContainerRequestCreate(
 
        // If runtime_token is not set, create a new token
        if _, ok := containerRequest["runtime_token"]; !ok {
-               log.Printf("ok %v", ok)
-
                // First make sure supplied token is valid.
                creds := auth.NewCredentials()
                creds.LoadTokensFromHTTPRequest(req)
@@ -98,10 +95,7 @@ func remoteContainerRequestCreate(
        req.ContentLength = int64(buf.Len())
        req.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
 
-       resp, cancel, err := h.handler.remoteClusterRequest(*clusterId, req)
-       if cancel != nil {
-               defer cancel()
-       }
+       resp, err := h.handler.remoteClusterRequest(*clusterId, req)
        h.handler.proxy.ForwardResponse(w, resp, err)
        return true
 }
index 7d5b63d3107a66384059403d6016dc5653e03dee..63e61e6908f8b318ead4e151bd13dee302c815d3 100644 (file)
@@ -6,7 +6,6 @@ package controller
 
 import (
        "bytes"
-       "context"
        "encoding/json"
        "fmt"
        "io/ioutil"
@@ -66,16 +65,12 @@ func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
                rc := multiClusterQueryResponseCollector{clusterID: clusterID}
 
                var resp *http.Response
-               var cancel context.CancelFunc
                if clusterID == h.handler.Cluster.ClusterID {
-                       resp, cancel, err = h.handler.localClusterRequest(&remoteReq)
+                       resp, err = h.handler.localClusterRequest(&remoteReq)
                } else {
-                       resp, cancel, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
+                       resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
                }
                rc.collectResponse(resp, err)
-               if cancel != nil {
-                       cancel()
-               }
 
                if rc.error != nil {
                        return nil, "", rc.error
@@ -309,10 +304,7 @@ func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *h
        if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
                h.next.ServeHTTP(w, req)
        } else {
-               resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
-               if cancel != nil {
-                       defer cancel()
-               }
+               resp, err := h.handler.remoteClusterRequest(clusterId, req)
                h.handler.proxy.ForwardResponse(w, resp, err)
        }
 }
index 0e016f301da1d70536a3cec9890eec853e5b74ae..e08a1c16742a6d5ea9b251d2906b24f6d5b00e61 100644 (file)
@@ -6,7 +6,6 @@ package controller
 
 import (
        "bytes"
-       "context"
        "database/sql"
        "encoding/json"
        "fmt"
@@ -29,10 +28,10 @@ var containerRequestsRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "container
 var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
 var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
 
-func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, context.CancelFunc, error) {
+func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
        remote, ok := h.Cluster.RemoteClusters[remoteID]
        if !ok {
-               return nil, nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
+               return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
        }
        scheme := remote.Scheme
        if scheme == "" {
@@ -40,7 +39,7 @@ func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*htt
        }
        saltedReq, err := h.saltAuthToken(req, remoteID)
        if err != nil {
-               return nil, nil, err
+               return nil, err
        }
        urlOut := &url.URL{
                Scheme:   scheme,
index f6bfca30213017e959f3f67739fe786f650eb515..da640071c523bc388af98fa3214d7328cf715359 100644 (file)
@@ -594,6 +594,15 @@ func (s *FederationSuite) TestCreateRemoteContainerRequestCheckRuntimeToken(c *c
 `))
        req.Header.Set("Authorization", "Bearer "+arvadostest.ActiveToken)
        req.Header.Set("Content-type", "application/json")
+
+       np := arvados.NodeProfile{
+               Controller: arvados.SystemServiceInstance{Listen: ":"},
+               RailsAPI: arvados.SystemServiceInstance{Listen: os.Getenv("ARVADOS_TEST_API_HOST"),
+                       TLS: true, Insecure: true}}
+       s.testHandler.Cluster.ClusterID = "zzzzz"
+       s.testHandler.Cluster.NodeProfiles["*"] = np
+       s.testHandler.NodeProfile = &np
+
        resp := s.testRequest(req)
        c.Check(resp.StatusCode, check.Equals, http.StatusOK)
        var cr struct {
index cbfaaddab4955ba66f2592e152410e10db9b5564..295dde7ca42821b1c8f904eec42ac7e7764812fa 100644 (file)
@@ -50,6 +50,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
                        req.URL.Path = strings.Replace(req.URL.Path, "//", "/", -1)
                }
        }
+       if h.Cluster.HTTPRequestTimeout > 0 {
+               ctx, cancel := context.WithDeadline(req.Context(), time.Now().Add(time.Duration(h.Cluster.HTTPRequestTimeout)))
+               req = req.WithContext(ctx)
+               defer cancel()
+       }
+
        h.handlerStack.ServeHTTP(w, req)
 }
 
@@ -84,8 +90,7 @@ func (h *Handler) setup() {
        h.insecureClient = &ic
 
        h.proxy = &proxy{
-               Name:           "arvados-controller",
-               RequestTimeout: time.Duration(h.Cluster.HTTPRequestTimeout),
+               Name: "arvados-controller",
        }
 }
 
@@ -122,10 +127,10 @@ func prepend(next http.Handler, middleware middlewareFunc) http.Handler {
        })
 }
 
-func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, context.CancelFunc, error) {
+func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error) {
        urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
        if err != nil {
-               return nil, nil, err
+               return nil, err
        }
        urlOut = &url.URL{
                Scheme:   urlOut.Scheme,
@@ -142,10 +147,7 @@ func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, contex
 }
 
 func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
-       resp, cancel, err := h.localClusterRequest(req)
-       if cancel != nil {
-               defer cancel()
-       }
+       resp, err := h.localClusterRequest(req)
        n, err := h.proxy.ForwardResponse(w, resp, err)
        if err != nil {
                httpserver.Logger(req).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
index c89b9b36ae0cddfc67a127d1e6f8bdac84ab6412..c01c152352e6b8f101179bf38add3b0574a00c5d 100644 (file)
@@ -5,18 +5,15 @@
 package controller
 
 import (
-       "context"
        "io"
        "net/http"
        "net/url"
-       "time"
 
        "git.curoverse.com/arvados.git/sdk/go/httpserver"
 )
 
 type proxy struct {
-       Name           string // to use in Via header
-       RequestTimeout time.Duration
+       Name string // to use in Via header
 }
 
 type HTTPError struct {
@@ -49,7 +46,7 @@ type ResponseFilter func(*http.Response, error) (*http.Response, error)
 func (p *proxy) Do(
        reqIn *http.Request,
        urlOut *url.URL,
-       client *http.Client) (*http.Response, context.CancelFunc, error) {
+       client *http.Client) (*http.Response, error) {
 
        // Copy headers from incoming request, then add/replace proxy
        // headers like Via and X-Forwarded-For.
@@ -69,22 +66,16 @@ func (p *proxy) Do(
        }
        hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
 
-       ctx := reqIn.Context()
-       var cancel context.CancelFunc
-       if p.RequestTimeout > 0 {
-               ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
-       }
-
        reqOut := (&http.Request{
                Method: reqIn.Method,
                URL:    urlOut,
                Host:   reqIn.Host,
                Header: hdrOut,
                Body:   reqIn.Body,
-       }).WithContext(ctx)
+       }).WithContext(reqIn.Context())
 
        resp, err := client.Do(reqOut)
-       return resp, cancel, err
+       return resp, err
 }
 
 // Copy a response (or error) to the downstream client