14262: Rewrite collectionFederatedRequestHandler PDH search to use channels
authorPeter Amstutz <pamstutz@veritasgenetics.com>
Mon, 29 Oct 2018 21:09:08 +0000 (17:09 -0400)
committerPeter Amstutz <pamstutz@veritasgenetics.com>
Tue, 30 Oct 2018 18:12:02 +0000 (14:12 -0400)
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz@veritasgenetics.com>

lib/controller/fed_collections.go

index 8a97c25c94eb35eb57cfa0e751950db03ec92fd1..88b0f95a0267f3979e7b5cfff1f56fbcf3dc32e2 100644 (file)
@@ -159,63 +159,6 @@ type searchRemoteClusterForPDH struct {
        statusCode    *int
 }
 
-func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-       s.mtx.Lock()
-       defer s.mtx.Unlock()
-
-       if *s.sentResponse {
-               // Another request already returned a response
-               return nil, nil
-       }
-
-       if requestError != nil {
-               *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
-               // Record the error and suppress response
-               return nil, nil
-       }
-
-       if resp.StatusCode != http.StatusOK {
-               // Suppress returning unsuccessful result.  Maybe
-               // another request will find it.
-               *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", resp.Header.Get(httpserver.HeaderRequestID), s.remoteID, resp.Status))
-               if resp.StatusCode != http.StatusNotFound {
-                       // Got a non-404 error response, convert into BadGateway
-                       *s.statusCode = http.StatusBadGateway
-               }
-               return nil, nil
-       }
-
-       s.mtx.Unlock()
-
-       // This reads the response body.  We don't want to hold the
-       // lock while doing this because other remote requests could
-       // also have made it to this point, and we don't want a
-       // slow response holding the lock to block a faster response
-       // that is waiting on the lock.
-       newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
-
-       s.mtx.Lock()
-
-       if *s.sentResponse {
-               // Another request already returned a response
-               return nil, nil
-       }
-
-       if err != nil {
-               // Suppress returning unsuccessful result.  Maybe
-               // another request will be successful.
-               *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
-               return nil, nil
-       }
-
-       // We have a successful response.  Suppress/cancel all the
-       // other requests/responses.
-       *s.sentResponse = true
-       s.cancelFunc()
-
-       return newResponse, nil
-}
-
 func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
        if req.Method != "GET" {
                // Only handle GET requests right now
@@ -263,58 +206,107 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
                return
        }
 
-       sharedContext, cancelFunc := context.WithCancel(req.Context())
-       defer cancelFunc()
-       req = req.WithContext(sharedContext)
-
        // Create a goroutine for each cluster in the
        // RemoteClusters map.  The first valid result gets
        // returned to the client.  When that happens, all
-       // other outstanding requests are cancelled or
-       // suppressed.
-       sentResponse := false
-       mtx := sync.Mutex{}
+       // other outstanding requests are cancelled
+       sharedContext, cancelFunc := context.WithCancel(req.Context())
+       req = req.WithContext(sharedContext)
        wg := sync.WaitGroup{}
-       var errors []string
-       var errorCode int = http.StatusNotFound
+       pdh := m[1]
+       success := make(chan *http.Response)
+       errorChan := make(chan error)
 
        // use channel as a semaphore to limit the number of concurrent
        // requests at a time
        sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+
+       defer close(errorChan)
+       defer close(success)
        defer close(sem)
+       defer cancelFunc()
+
        for remoteID := range h.handler.Cluster.RemoteClusters {
                if remoteID == h.handler.Cluster.ClusterID {
                        // No need to query local cluster again
                        continue
                }
-               // blocks until it can put a value into the
-               // channel (which has a max queue capacity)
-               sem <- true
-               if sentResponse {
-                       break
-               }
-               search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
-                       &sharedContext, cancelFunc, &errors, &errorCode}
+
                wg.Add(1)
-               go func() {
-                       resp, cancel, err := h.handler.remoteClusterRequest(search.remoteID, req)
-                       if cancel != nil {
-                               defer cancel()
+               go func(remote string) {
+                       defer wg.Done()
+                       // blocks until it can put a value into the
+                       // channel (which has a max queue capacity)
+                       sem <- true
+                       select {
+                       case <-sharedContext.Done():
+                               return
+                       default:
+                       }
+
+                       resp, _, err := h.handler.remoteClusterRequest(remote, req)
+                       wasSuccess := false
+                       defer func() {
+                               if resp != nil && !wasSuccess {
+                                       resp.Body.Close()
+                               }
+                       }()
+                       // Don't need to do anything with the cancel
+                       // function returned by remoteClusterRequest
+                       // because the context inherits from
+                       // sharedContext, so when sharedContext is
+                       // cancelled it should cancel that one as
+                       // well.
+                       if err != nil {
+                               errorChan <- err
+                               return
                        }
-                       newResp, err := search.filterRemoteClusterResponse(resp, err)
-                       if newResp != nil || err != nil {
-                               h.handler.proxy.ForwardResponse(w, newResp, err)
+                       if resp.StatusCode != http.StatusOK {
+                               errorChan <- HTTPError{resp.Status, resp.StatusCode}
+                               return
+                       }
+                       select {
+                       case <-sharedContext.Done():
+                               return
+                       default:
+                       }
+
+                       newResponse, err := rewriteSignatures(remote, pdh, resp, nil)
+                       if err != nil {
+                               errorChan <- err
+                               return
+                       }
+                       select {
+                       case <-sharedContext.Done():
+                       case success <- newResponse:
+                               wasSuccess = true
                        }
-                       wg.Done()
                        <-sem
-               }()
+               }(remoteID)
        }
-       wg.Wait()
+       go func() {
+               wg.Wait()
+               cancelFunc()
+       }()
 
-       if sentResponse {
-               return
-       }
+       var errors []string
+       errorCode := http.StatusNotFound
 
-       // No successful responses, so return the error
-       httpserver.Errors(w, errors, errorCode)
+       for {
+               select {
+               case newResp = <-success:
+                       h.handler.proxy.ForwardResponse(w, newResp, nil)
+                       return
+               case err := <-errorChan:
+                       if httperr, ok := err.(HTTPError); ok {
+                               if httperr.Code != http.StatusNotFound {
+                                       errorCode = http.StatusBadGateway
+                               }
+                       }
+                       errors = append(errors, err.Error())
+               case <-sharedContext.Done():
+                       httpserver.Errors(w, errors, errorCode)
+                       return
+               }
+       }
 }