From: Peter Amstutz Date: Mon, 29 Oct 2018 21:09:08 +0000 (-0400) Subject: 14262: Rewrite collectionFederatedRequestHandler PDH search to use channels X-Git-Tag: 1.3.0~55^2~7 X-Git-Url: https://git.arvados.org/arvados.git/commitdiff_plain/4427f2c5f740d03d5ee38745159f61b6805843e7?hp=703179225b04309485c0a1cefb794df6c919e84f 14262: Rewrite collectionFederatedRequestHandler PDH search to use channels Arvados-DCO-1.1-Signed-off-by: Peter Amstutz --- diff --git a/lib/controller/fed_collections.go b/lib/controller/fed_collections.go index 8a97c25c94..88b0f95a02 100644 --- a/lib/controller/fed_collections.go +++ b/lib/controller/fed_collections.go @@ -159,63 +159,6 @@ type searchRemoteClusterForPDH struct { statusCode *int } -func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) { - s.mtx.Lock() - defer s.mtx.Unlock() - - if *s.sentResponse { - // Another request already returned a response - return nil, nil - } - - if requestError != nil { - *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError)) - // Record the error and suppress response - return nil, nil - } - - if resp.StatusCode != http.StatusOK { - // Suppress returning unsuccessful result. Maybe - // another request will find it. - *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", resp.Header.Get(httpserver.HeaderRequestID), s.remoteID, resp.Status)) - if resp.StatusCode != http.StatusNotFound { - // Got a non-404 error response, convert into BadGateway - *s.statusCode = http.StatusBadGateway - } - return nil, nil - } - - s.mtx.Unlock() - - // This reads the response body. We don't want to hold the - // lock while doing this because other remote requests could - // also have made it to this point, and we don't want a - // slow response holding the lock to block a faster response - // that is waiting on the lock. - newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil) - - s.mtx.Lock() - - if *s.sentResponse { - // Another request already returned a response - return nil, nil - } - - if err != nil { - // Suppress returning unsuccessful result. Maybe - // another request will be successful. - *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err)) - return nil, nil - } - - // We have a successful response. Suppress/cancel all the - // other requests/responses. - *s.sentResponse = true - s.cancelFunc() - - return newResponse, nil -} - func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.Method != "GET" { // Only handle GET requests right now @@ -263,58 +206,107 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req return } - sharedContext, cancelFunc := context.WithCancel(req.Context()) - defer cancelFunc() - req = req.WithContext(sharedContext) - // Create a goroutine for each cluster in the // RemoteClusters map. The first valid result gets // returned to the client. When that happens, all - // other outstanding requests are cancelled or - // suppressed. - sentResponse := false - mtx := sync.Mutex{} + // other outstanding requests are cancelled + sharedContext, cancelFunc := context.WithCancel(req.Context()) + req = req.WithContext(sharedContext) wg := sync.WaitGroup{} - var errors []string - var errorCode int = http.StatusNotFound + pdh := m[1] + success := make(chan *http.Response) + errorChan := make(chan error) // use channel as a semaphore to limit the number of concurrent // requests at a time sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency()) + + defer close(errorChan) + defer close(success) defer close(sem) + defer cancelFunc() + for remoteID := range h.handler.Cluster.RemoteClusters { if remoteID == h.handler.Cluster.ClusterID { // No need to query local cluster again continue } - // blocks until it can put a value into the - // channel (which has a max queue capacity) - sem <- true - if sentResponse { - break - } - search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse, - &sharedContext, cancelFunc, &errors, &errorCode} + wg.Add(1) - go func() { - resp, cancel, err := h.handler.remoteClusterRequest(search.remoteID, req) - if cancel != nil { - defer cancel() + go func(remote string) { + defer wg.Done() + // blocks until it can put a value into the + // channel (which has a max queue capacity) + sem <- true + select { + case <-sharedContext.Done(): + return + default: + } + + resp, _, err := h.handler.remoteClusterRequest(remote, req) + wasSuccess := false + defer func() { + if resp != nil && !wasSuccess { + resp.Body.Close() + } + }() + // Don't need to do anything with the cancel + // function returned by remoteClusterRequest + // because the context inherits from + // sharedContext, so when sharedContext is + // cancelled it should cancel that one as + // well. + if err != nil { + errorChan <- err + return } - newResp, err := search.filterRemoteClusterResponse(resp, err) - if newResp != nil || err != nil { - h.handler.proxy.ForwardResponse(w, newResp, err) + if resp.StatusCode != http.StatusOK { + errorChan <- HTTPError{resp.Status, resp.StatusCode} + return + } + select { + case <-sharedContext.Done(): + return + default: + } + + newResponse, err := rewriteSignatures(remote, pdh, resp, nil) + if err != nil { + errorChan <- err + return + } + select { + case <-sharedContext.Done(): + case success <- newResponse: + wasSuccess = true } - wg.Done() <-sem - }() + }(remoteID) } - wg.Wait() + go func() { + wg.Wait() + cancelFunc() + }() - if sentResponse { - return - } + var errors []string + errorCode := http.StatusNotFound - // No successful responses, so return the error - httpserver.Errors(w, errors, errorCode) + for { + select { + case newResp = <-success: + h.handler.proxy.ForwardResponse(w, newResp, nil) + return + case err := <-errorChan: + if httperr, ok := err.(HTTPError); ok { + if httperr.Code != http.StatusNotFound { + errorCode = http.StatusBadGateway + } + } + errors = append(errors, err.Error()) + case <-sharedContext.Done(): + httpserver.Errors(w, errors, errorCode) + return + } + } }