X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/703179225b04309485c0a1cefb794df6c919e84f..508f13840841afc5938f7210a999ff58f002b29d:/lib/controller/fed_collections.go diff --git a/lib/controller/fed_collections.go b/lib/controller/fed_collections.go index 8a97c25c94..07daf2f90e 100644 --- a/lib/controller/fed_collections.go +++ b/lib/controller/fed_collections.go @@ -22,11 +22,6 @@ import ( "git.curoverse.com/arvados.git/sdk/go/keepclient" ) -type collectionFederatedRequestHandler struct { - next http.Handler - handler *Handler -} - func rewriteSignatures(clusterID string, expectHash string, resp *http.Response, requestError error) (newResponse *http.Response, err error) { @@ -159,162 +154,157 @@ type searchRemoteClusterForPDH struct { statusCode *int } -func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) { - s.mtx.Lock() - defer s.mtx.Unlock() +func fetchRemoteCollectionByUUID( + h *genericFederatedRequestHandler, + effectiveMethod string, + clusterId *string, + uuid string, + remainder string, + w http.ResponseWriter, + req *http.Request) bool { - if *s.sentResponse { - // Another request already returned a response - return nil, nil - } - - if requestError != nil { - *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError)) - // Record the error and suppress response - return nil, nil + if effectiveMethod != "GET" { + // Only handle GET requests right now + return false } - if resp.StatusCode != http.StatusOK { - // Suppress returning unsuccessful result. Maybe - // another request will find it. - *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", resp.Header.Get(httpserver.HeaderRequestID), s.remoteID, resp.Status)) - if resp.StatusCode != http.StatusNotFound { - // Got a non-404 error response, convert into BadGateway - *s.statusCode = http.StatusBadGateway + if uuid != "" { + // Collection UUID GET request + *clusterId = uuid[0:5] + if *clusterId != "" && *clusterId != h.handler.Cluster.ClusterID { + // request for remote collection by uuid + resp, err := h.handler.remoteClusterRequest(*clusterId, req) + newResponse, err := rewriteSignatures(*clusterId, "", resp, err) + h.handler.proxy.ForwardResponse(w, newResponse, err) + return true } - return nil, nil - } - - s.mtx.Unlock() - - // This reads the response body. We don't want to hold the - // lock while doing this because other remote requests could - // also have made it to this point, and we don't want a - // slow response holding the lock to block a faster response - // that is waiting on the lock. - newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil) - - s.mtx.Lock() - - if *s.sentResponse { - // Another request already returned a response - return nil, nil - } - - if err != nil { - // Suppress returning unsuccessful result. Maybe - // another request will be successful. - *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err)) - return nil, nil } - // We have a successful response. Suppress/cancel all the - // other requests/responses. - *s.sentResponse = true - s.cancelFunc() - - return newResponse, nil + return false } -func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if req.Method != "GET" { +func fetchRemoteCollectionByPDH( + h *genericFederatedRequestHandler, + effectiveMethod string, + clusterId *string, + uuid string, + remainder string, + w http.ResponseWriter, + req *http.Request) bool { + + if effectiveMethod != "GET" { // Only handle GET requests right now - h.next.ServeHTTP(w, req) - return + return false } - m := collectionByPDHRe.FindStringSubmatch(req.URL.Path) + m := collectionsByPDHRe.FindStringSubmatch(req.URL.Path) if len(m) != 2 { - // Not a collection PDH GET request - m = collectionRe.FindStringSubmatch(req.URL.Path) - clusterId := "" - - if len(m) > 0 { - clusterId = m[2] - } - - if clusterId != "" && clusterId != h.handler.Cluster.ClusterID { - // request for remote collection by uuid - resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req) - if cancel != nil { - defer cancel() - } - newResponse, err := rewriteSignatures(clusterId, "", resp, err) - h.handler.proxy.ForwardResponse(w, newResponse, err) - return - } - // not a collection UUID request, or it is a request - // for a local UUID, either way, continue down the - // handler stack. - h.next.ServeHTTP(w, req) - return + return false } // Request for collection by PDH. Search the federation. // First, query the local cluster. - resp, localClusterRequestCancel, err := h.handler.localClusterRequest(req) - if localClusterRequestCancel != nil { - defer localClusterRequestCancel() - } + resp, err := h.handler.localClusterRequest(req) newResp, err := filterLocalClusterResponse(resp, err) if newResp != nil || err != nil { h.handler.proxy.ForwardResponse(w, newResp, err) - return + return true } - sharedContext, cancelFunc := context.WithCancel(req.Context()) - defer cancelFunc() - req = req.WithContext(sharedContext) - // Create a goroutine for each cluster in the // RemoteClusters map. The first valid result gets // returned to the client. When that happens, all - // other outstanding requests are cancelled or - // suppressed. - sentResponse := false - mtx := sync.Mutex{} + // other outstanding requests are cancelled + sharedContext, cancelFunc := context.WithCancel(req.Context()) + defer cancelFunc() + + req = req.WithContext(sharedContext) wg := sync.WaitGroup{} - var errors []string - var errorCode int = http.StatusNotFound + pdh := m[1] + success := make(chan *http.Response) + errorChan := make(chan error, len(h.handler.Cluster.RemoteClusters)) + + acquire, release := semaphore(h.handler.Cluster.API.MaxRequestAmplification) - // use channel as a semaphore to limit the number of concurrent - // requests at a time - sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency()) - defer close(sem) for remoteID := range h.handler.Cluster.RemoteClusters { if remoteID == h.handler.Cluster.ClusterID { // No need to query local cluster again continue } - // blocks until it can put a value into the - // channel (which has a max queue capacity) - sem <- true - if sentResponse { - break - } - search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse, - &sharedContext, cancelFunc, &errors, &errorCode} + wg.Add(1) - go func() { - resp, cancel, err := h.handler.remoteClusterRequest(search.remoteID, req) - if cancel != nil { - defer cancel() + go func(remote string) { + defer wg.Done() + acquire() + defer release() + select { + case <-sharedContext.Done(): + return + default: } - newResp, err := search.filterRemoteClusterResponse(resp, err) - if newResp != nil || err != nil { - h.handler.proxy.ForwardResponse(w, newResp, err) + + resp, err := h.handler.remoteClusterRequest(remote, req) + wasSuccess := false + defer func() { + if resp != nil && !wasSuccess { + resp.Body.Close() + } + }() + if err != nil { + errorChan <- err + return + } + if resp.StatusCode != http.StatusOK { + errorChan <- HTTPError{resp.Status, resp.StatusCode} + return + } + select { + case <-sharedContext.Done(): + return + default: } - wg.Done() - <-sem - }() - } - wg.Wait() - if sentResponse { - return + newResponse, err := rewriteSignatures(remote, pdh, resp, nil) + if err != nil { + errorChan <- err + return + } + select { + case <-sharedContext.Done(): + case success <- newResponse: + wasSuccess = true + } + }(remoteID) + } + go func() { + wg.Wait() + cancelFunc() + }() + + errorCode := http.StatusNotFound + + for { + select { + case newResp = <-success: + h.handler.proxy.ForwardResponse(w, newResp, nil) + return true + case <-sharedContext.Done(): + var errors []string + for len(errorChan) > 0 { + err := <-errorChan + if httperr, ok := err.(HTTPError); ok { + if httperr.Code != http.StatusNotFound { + errorCode = http.StatusBadGateway + } + } + errors = append(errors, err.Error()) + } + httpserver.Errors(w, errors, errorCode) + return true + } } - // No successful responses, so return the error - httpserver.Errors(w, errors, errorCode) + // shouldn't ever get here + return true }