X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/78471fbe6370154fe9478a67c29c669a605c22bb..70ecd8ab9a0b82dc6a10ad1e8bf2b35fb8284ab1:/lib/controller/fed_collections.go diff --git a/lib/controller/fed_collections.go b/lib/controller/fed_collections.go index 70dbdc3f51..024af83c2b 100644 --- a/lib/controller/fed_collections.go +++ b/lib/controller/fed_collections.go @@ -34,7 +34,7 @@ func rewriteSignatures(clusterID string, expectHash string, return resp, requestError } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { return resp, nil } @@ -140,7 +140,7 @@ func filterLocalClusterResponse(resp *http.Response, requestError error) (newRes return resp, requestError } - if resp.StatusCode == 404 { + if resp.StatusCode == http.StatusNotFound { // Suppress returning this result, because we want to // search the federation. return nil, nil @@ -159,64 +159,6 @@ type searchRemoteClusterForPDH struct { statusCode *int } -func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) { - s.mtx.Lock() - defer s.mtx.Unlock() - - if *s.sentResponse { - // Another request already returned a response - return nil, nil - } - - if requestError != nil { - *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError)) - // Record the error and suppress response - return nil, nil - } - - if resp.StatusCode != 200 { - // Suppress returning unsuccessful result. Maybe - // another request will find it. - // TODO collect and return error responses. - *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", httpserver.GetRequestID(resp.Header), s.remoteID, resp.Status)) - if resp.StatusCode != 404 { - // Got a non-404 error response, convert into BadGateway - *s.statusCode = http.StatusBadGateway - } - return nil, nil - } - - s.mtx.Unlock() - - // This reads the response body. We don't want to hold the - // lock while doing this because other remote requests could - // also have made it to this point, and we don't want a - // slow response holding the lock to block a faster response - // that is waiting on the lock. - newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil) - - s.mtx.Lock() - - if *s.sentResponse { - // Another request already returned a response - return nil, nil - } - - if err != nil { - // Suppress returning unsuccessful result. Maybe - // another request will be successful. - *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err)) - return nil, nil - } - - // We have a successful response. Suppress/cancel all the - // other requests/responses. - *s.sentResponse = true - s.cancelFunc() - - return newResponse, nil -} - func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.Method != "GET" { // Only handle GET requests right now @@ -258,55 +200,100 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req return } - sharedContext, cancelFunc := context.WithCancel(req.Context()) - defer cancelFunc() - req = req.WithContext(sharedContext) - // Create a goroutine for each cluster in the // RemoteClusters map. The first valid result gets // returned to the client. When that happens, all - // other outstanding requests are cancelled or - // suppressed. - sentResponse := false - mtx := sync.Mutex{} + // other outstanding requests are cancelled + sharedContext, cancelFunc := context.WithCancel(req.Context()) + req = req.WithContext(sharedContext) wg := sync.WaitGroup{} - var errors []string - var errorCode int = 404 + pdh := m[1] + success := make(chan *http.Response) + errorChan := make(chan error, len(h.handler.Cluster.RemoteClusters)) // use channel as a semaphore to limit the number of concurrent // requests at a time sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency()) - defer close(sem) + + defer cancelFunc() + for remoteID := range h.handler.Cluster.RemoteClusters { if remoteID == h.handler.Cluster.ClusterID { // No need to query local cluster again continue } - // blocks until it can put a value into the - // channel (which has a max queue capacity) - sem <- true - if sentResponse { - break - } - search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse, - &sharedContext, cancelFunc, &errors, &errorCode} + wg.Add(1) - go func() { - resp, err := h.handler.remoteClusterRequest(search.remoteID, req) - newResp, err := search.filterRemoteClusterResponse(resp, err) - if newResp != nil || err != nil { - h.handler.proxy.ForwardResponse(w, newResp, err) + go func(remote string) { + defer wg.Done() + // blocks until it can put a value into the + // channel (which has a max queue capacity) + sem <- true + select { + case <-sharedContext.Done(): + return + default: + } + + resp, err := h.handler.remoteClusterRequest(remote, req) + wasSuccess := false + defer func() { + if resp != nil && !wasSuccess { + resp.Body.Close() + } + }() + if err != nil { + errorChan <- err + return + } + if resp.StatusCode != http.StatusOK { + errorChan <- HTTPError{resp.Status, resp.StatusCode} + return + } + select { + case <-sharedContext.Done(): + return + default: + } + + newResponse, err := rewriteSignatures(remote, pdh, resp, nil) + if err != nil { + errorChan <- err + return + } + select { + case <-sharedContext.Done(): + case success <- newResponse: + wasSuccess = true } - wg.Done() <-sem - }() + }(remoteID) } - wg.Wait() + go func() { + wg.Wait() + cancelFunc() + }() - if sentResponse { - return - } + errorCode := http.StatusNotFound - // No successful responses, so return the error - httpserver.Errors(w, errors, errorCode) + for { + select { + case newResp = <-success: + h.handler.proxy.ForwardResponse(w, newResp, nil) + return + case <-sharedContext.Done(): + var errors []string + for len(errorChan) > 0 { + err := <-errorChan + if httperr, ok := err.(HTTPError); ok { + if httperr.Code != http.StatusNotFound { + errorCode = http.StatusBadGateway + } + } + errors = append(errors, err.Error()) + } + httpserver.Errors(w, errors, errorCode) + return + } + } }