Merge branch 'master' into 14874-protected-collection-properties
[arvados.git] / lib / controller / fed_collections.go
index 024af83c2b36cdd254ec1248ed3df6e3a71e02e0..07daf2f90ef28b3199e856c93134aa5b6975fab3 100644 (file)
@@ -22,11 +22,6 @@ import (
        "git.curoverse.com/arvados.git/sdk/go/keepclient"
 )
 
-type collectionFederatedRequestHandler struct {
-       next    http.Handler
-       handler *Handler
-}
-
 func rewriteSignatures(clusterID string, expectHash string,
        resp *http.Response, requestError error) (newResponse *http.Response, err error) {
 
@@ -159,35 +154,52 @@ type searchRemoteClusterForPDH struct {
        statusCode    *int
 }
 
-func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-       if req.Method != "GET" {
+func fetchRemoteCollectionByUUID(
+       h *genericFederatedRequestHandler,
+       effectiveMethod string,
+       clusterId *string,
+       uuid string,
+       remainder string,
+       w http.ResponseWriter,
+       req *http.Request) bool {
+
+       if effectiveMethod != "GET" {
                // Only handle GET requests right now
-               h.next.ServeHTTP(w, req)
-               return
+               return false
        }
 
-       m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
-       if len(m) != 2 {
-               // Not a collection PDH GET request
-               m = collectionRe.FindStringSubmatch(req.URL.Path)
-               clusterId := ""
-
-               if len(m) > 0 {
-                       clusterId = m[2]
-               }
-
-               if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
+       if uuid != "" {
+               // Collection UUID GET request
+               *clusterId = uuid[0:5]
+               if *clusterId != "" && *clusterId != h.handler.Cluster.ClusterID {
                        // request for remote collection by uuid
-                       resp, err := h.handler.remoteClusterRequest(clusterId, req)
-                       newResponse, err := rewriteSignatures(clusterId, "", resp, err)
+                       resp, err := h.handler.remoteClusterRequest(*clusterId, req)
+                       newResponse, err := rewriteSignatures(*clusterId, "", resp, err)
                        h.handler.proxy.ForwardResponse(w, newResponse, err)
-                       return
+                       return true
                }
-               // not a collection UUID request, or it is a request
-               // for a local UUID, either way, continue down the
-               // handler stack.
-               h.next.ServeHTTP(w, req)
-               return
+       }
+
+       return false
+}
+
+func fetchRemoteCollectionByPDH(
+       h *genericFederatedRequestHandler,
+       effectiveMethod string,
+       clusterId *string,
+       uuid string,
+       remainder string,
+       w http.ResponseWriter,
+       req *http.Request) bool {
+
+       if effectiveMethod != "GET" {
+               // Only handle GET requests right now
+               return false
+       }
+
+       m := collectionsByPDHRe.FindStringSubmatch(req.URL.Path)
+       if len(m) != 2 {
+               return false
        }
 
        // Request for collection by PDH.  Search the federation.
@@ -197,7 +209,7 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
        newResp, err := filterLocalClusterResponse(resp, err)
        if newResp != nil || err != nil {
                h.handler.proxy.ForwardResponse(w, newResp, err)
-               return
+               return true
        }
 
        // Create a goroutine for each cluster in the
@@ -205,17 +217,15 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
        // returned to the client.  When that happens, all
        // other outstanding requests are cancelled
        sharedContext, cancelFunc := context.WithCancel(req.Context())
+       defer cancelFunc()
+
        req = req.WithContext(sharedContext)
        wg := sync.WaitGroup{}
        pdh := m[1]
        success := make(chan *http.Response)
        errorChan := make(chan error, len(h.handler.Cluster.RemoteClusters))
 
-       // use channel as a semaphore to limit the number of concurrent
-       // requests at a time
-       sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
-
-       defer cancelFunc()
+       acquire, release := semaphore(h.handler.Cluster.API.MaxRequestAmplification)
 
        for remoteID := range h.handler.Cluster.RemoteClusters {
                if remoteID == h.handler.Cluster.ClusterID {
@@ -226,9 +236,8 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
                wg.Add(1)
                go func(remote string) {
                        defer wg.Done()
-                       // blocks until it can put a value into the
-                       // channel (which has a max queue capacity)
-                       sem <- true
+                       acquire()
+                       defer release()
                        select {
                        case <-sharedContext.Done():
                                return
@@ -266,7 +275,6 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
                        case success <- newResponse:
                                wasSuccess = true
                        }
-                       <-sem
                }(remoteID)
        }
        go func() {
@@ -280,7 +288,7 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
                select {
                case newResp = <-success:
                        h.handler.proxy.ForwardResponse(w, newResp, nil)
-                       return
+                       return true
                case <-sharedContext.Done():
                        var errors []string
                        for len(errorChan) > 0 {
@@ -293,7 +301,10 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
                                errors = append(errors, err.Error())
                        }
                        httpserver.Errors(w, errors, errorCode)
-                       return
+                       return true
                }
        }
+
+       // shouldn't ever get here
+       return true
 }