17170: Fixup gateway auth secret.
[arvados.git] / lib / controller / fed_collections.go
index 8a97c25c94eb35eb57cfa0e751950db03ec92fd1..a0a123129fdacdae34bf8b216d3e6b766a6f5889 100644 (file)
@@ -17,16 +17,11 @@ import (
        "strings"
        "sync"
 
-       "git.curoverse.com/arvados.git/sdk/go/arvados"
-       "git.curoverse.com/arvados.git/sdk/go/httpserver"
-       "git.curoverse.com/arvados.git/sdk/go/keepclient"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/httpserver"
+       "git.arvados.org/arvados.git/sdk/go/keepclient"
 )
 
-type collectionFederatedRequestHandler struct {
-       next    http.Handler
-       handler *Handler
-}
-
 func rewriteSignatures(clusterID string, expectHash string,
        resp *http.Response, requestError error) (newResponse *http.Response, err error) {
 
@@ -159,162 +154,159 @@ type searchRemoteClusterForPDH struct {
        statusCode    *int
 }
 
-func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-       s.mtx.Lock()
-       defer s.mtx.Unlock()
+func fetchRemoteCollectionByUUID(
+       h *genericFederatedRequestHandler,
+       effectiveMethod string,
+       clusterID *string,
+       uuid string,
+       remainder string,
+       w http.ResponseWriter,
+       req *http.Request) bool {
 
-       if *s.sentResponse {
-               // Another request already returned a response
-               return nil, nil
-       }
-
-       if requestError != nil {
-               *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
-               // Record the error and suppress response
-               return nil, nil
+       if effectiveMethod != "GET" {
+               // Only handle GET requests right now
+               return false
        }
 
-       if resp.StatusCode != http.StatusOK {
-               // Suppress returning unsuccessful result.  Maybe
-               // another request will find it.
-               *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", resp.Header.Get(httpserver.HeaderRequestID), s.remoteID, resp.Status))
-               if resp.StatusCode != http.StatusNotFound {
-                       // Got a non-404 error response, convert into BadGateway
-                       *s.statusCode = http.StatusBadGateway
+       if uuid != "" {
+               // Collection UUID GET request
+               *clusterID = uuid[0:5]
+               if *clusterID != "" && *clusterID != h.handler.Cluster.ClusterID {
+                       // request for remote collection by uuid
+                       resp, err := h.handler.remoteClusterRequest(*clusterID, req)
+                       newResponse, err := rewriteSignatures(*clusterID, "", resp, err)
+                       h.handler.proxy.ForwardResponse(w, newResponse, err)
+                       return true
                }
-               return nil, nil
-       }
-
-       s.mtx.Unlock()
-
-       // This reads the response body.  We don't want to hold the
-       // lock while doing this because other remote requests could
-       // also have made it to this point, and we don't want a
-       // slow response holding the lock to block a faster response
-       // that is waiting on the lock.
-       newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
-
-       s.mtx.Lock()
-
-       if *s.sentResponse {
-               // Another request already returned a response
-               return nil, nil
-       }
-
-       if err != nil {
-               // Suppress returning unsuccessful result.  Maybe
-               // another request will be successful.
-               *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
-               return nil, nil
        }
 
-       // We have a successful response.  Suppress/cancel all the
-       // other requests/responses.
-       *s.sentResponse = true
-       s.cancelFunc()
-
-       return newResponse, nil
+       return false
 }
 
-func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-       if req.Method != "GET" {
+func fetchRemoteCollectionByPDH(
+       h *genericFederatedRequestHandler,
+       effectiveMethod string,
+       clusterID *string,
+       uuid string,
+       remainder string,
+       w http.ResponseWriter,
+       req *http.Request) bool {
+
+       if effectiveMethod != "GET" {
                // Only handle GET requests right now
-               h.next.ServeHTTP(w, req)
-               return
+               return false
        }
 
-       m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
+       m := collectionsByPDHRe.FindStringSubmatch(req.URL.Path)
        if len(m) != 2 {
-               // Not a collection PDH GET request
-               m = collectionRe.FindStringSubmatch(req.URL.Path)
-               clusterId := ""
-
-               if len(m) > 0 {
-                       clusterId = m[2]
-               }
-
-               if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
-                       // request for remote collection by uuid
-                       resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
-                       if cancel != nil {
-                               defer cancel()
-                       }
-                       newResponse, err := rewriteSignatures(clusterId, "", resp, err)
-                       h.handler.proxy.ForwardResponse(w, newResponse, err)
-                       return
-               }
-               // not a collection UUID request, or it is a request
-               // for a local UUID, either way, continue down the
-               // handler stack.
-               h.next.ServeHTTP(w, req)
-               return
+               return false
        }
 
        // Request for collection by PDH.  Search the federation.
 
        // First, query the local cluster.
-       resp, localClusterRequestCancel, err := h.handler.localClusterRequest(req)
-       if localClusterRequestCancel != nil {
-               defer localClusterRequestCancel()
-       }
+       resp, err := h.handler.localClusterRequest(req)
        newResp, err := filterLocalClusterResponse(resp, err)
        if newResp != nil || err != nil {
                h.handler.proxy.ForwardResponse(w, newResp, err)
-               return
+               return true
        }
 
-       sharedContext, cancelFunc := context.WithCancel(req.Context())
-       defer cancelFunc()
-       req = req.WithContext(sharedContext)
-
        // Create a goroutine for each cluster in the
        // RemoteClusters map.  The first valid result gets
        // returned to the client.  When that happens, all
-       // other outstanding requests are cancelled or
-       // suppressed.
-       sentResponse := false
-       mtx := sync.Mutex{}
+       // other outstanding requests are cancelled
+       sharedContext, cancelFunc := context.WithCancel(req.Context())
+       defer cancelFunc()
+
+       req = req.WithContext(sharedContext)
        wg := sync.WaitGroup{}
-       var errors []string
-       var errorCode int = http.StatusNotFound
+       pdh := m[1]
+       success := make(chan *http.Response)
+       errorChan := make(chan error, len(h.handler.Cluster.RemoteClusters))
+
+       acquire, release := semaphore(h.handler.Cluster.API.MaxRequestAmplification)
 
-       // use channel as a semaphore to limit the number of concurrent
-       // requests at a time
-       sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
-       defer close(sem)
        for remoteID := range h.handler.Cluster.RemoteClusters {
                if remoteID == h.handler.Cluster.ClusterID {
                        // No need to query local cluster again
                        continue
                }
-               // blocks until it can put a value into the
-               // channel (which has a max queue capacity)
-               sem <- true
-               if sentResponse {
-                       break
+               if remoteID == "*" {
+                       // This isn't a real remote cluster: it just sets defaults for unlisted remotes.
+                       continue
                }
-               search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
-                       &sharedContext, cancelFunc, &errors, &errorCode}
+
                wg.Add(1)
-               go func() {
-                       resp, cancel, err := h.handler.remoteClusterRequest(search.remoteID, req)
-                       if cancel != nil {
-                               defer cancel()
+               go func(remote string) {
+                       defer wg.Done()
+                       acquire()
+                       defer release()
+                       select {
+                       case <-sharedContext.Done():
+                               return
+                       default:
                        }
-                       newResp, err := search.filterRemoteClusterResponse(resp, err)
-                       if newResp != nil || err != nil {
-                               h.handler.proxy.ForwardResponse(w, newResp, err)
+
+                       resp, err := h.handler.remoteClusterRequest(remote, req)
+                       wasSuccess := false
+                       defer func() {
+                               if resp != nil && !wasSuccess {
+                                       resp.Body.Close()
+                               }
+                       }()
+                       if err != nil {
+                               errorChan <- err
+                               return
+                       }
+                       if resp.StatusCode != http.StatusOK {
+                               errorChan <- HTTPError{resp.Status, resp.StatusCode}
+                               return
+                       }
+                       select {
+                       case <-sharedContext.Done():
+                               return
+                       default:
                        }
-                       wg.Done()
-                       <-sem
-               }()
-       }
-       wg.Wait()
 
-       if sentResponse {
-               return
+                       newResponse, err := rewriteSignatures(remote, pdh, resp, nil)
+                       if err != nil {
+                               errorChan <- err
+                               return
+                       }
+                       select {
+                       case <-sharedContext.Done():
+                       case success <- newResponse:
+                               wasSuccess = true
+                       }
+               }(remoteID)
+       }
+       go func() {
+               wg.Wait()
+               cancelFunc()
+       }()
+
+       errorCode := http.StatusNotFound
+
+       for {
+               select {
+               case newResp = <-success:
+                       h.handler.proxy.ForwardResponse(w, newResp, nil)
+                       return true
+               case <-sharedContext.Done():
+                       var errors []string
+                       for len(errorChan) > 0 {
+                               err := <-errorChan
+                               if httperr, ok := err.(HTTPError); !ok || httperr.Code != http.StatusNotFound {
+                                       errorCode = http.StatusBadGateway
+                               }
+                               errors = append(errors, err.Error())
+                       }
+                       httpserver.Errors(w, errors, errorCode)
+                       return true
+               }
        }
 
-       // No successful responses, so return the error
-       httpserver.Errors(w, errors, errorCode)
+       // shouldn't ever get here
+       return true
 }