statusCode *int
}
-func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- s.mtx.Lock()
- defer s.mtx.Unlock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if requestError != nil {
- *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
- // Record the error and suppress response
- return nil, nil
- }
-
- if resp.StatusCode != http.StatusOK {
- // Suppress returning unsuccessful result. Maybe
- // another request will find it.
- *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", resp.Header.Get(httpserver.HeaderRequestID), s.remoteID, resp.Status))
- if resp.StatusCode != http.StatusNotFound {
- // Got a non-404 error response, convert into BadGateway
- *s.statusCode = http.StatusBadGateway
- }
- return nil, nil
- }
-
- s.mtx.Unlock()
-
- // This reads the response body. We don't want to hold the
- // lock while doing this because other remote requests could
- // also have made it to this point, and we don't want a
- // slow response holding the lock to block a faster response
- // that is waiting on the lock.
- newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
-
- s.mtx.Lock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if err != nil {
- // Suppress returning unsuccessful result. Maybe
- // another request will be successful.
- *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
- return nil, nil
- }
-
- // We have a successful response. Suppress/cancel all the
- // other requests/responses.
- *s.sentResponse = true
- s.cancelFunc()
-
- return newResponse, nil
-}
-
func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "GET" {
// Only handle GET requests right now
return
}
- sharedContext, cancelFunc := context.WithCancel(req.Context())
- defer cancelFunc()
- req = req.WithContext(sharedContext)
-
// Create a goroutine for each cluster in the
// RemoteClusters map. The first valid result gets
// returned to the client. When that happens, all
- // other outstanding requests are cancelled or
- // suppressed.
- sentResponse := false
- mtx := sync.Mutex{}
+ // other outstanding requests are cancelled
+ sharedContext, cancelFunc := context.WithCancel(req.Context())
+ req = req.WithContext(sharedContext)
wg := sync.WaitGroup{}
- var errors []string
- var errorCode int = http.StatusNotFound
+ pdh := m[1]
+ success := make(chan *http.Response)
+ errorChan := make(chan error)
// use channel as a semaphore to limit the number of concurrent
// requests at a time
sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+
+ defer close(errorChan)
+ defer close(success)
defer close(sem)
+ defer cancelFunc()
+
for remoteID := range h.handler.Cluster.RemoteClusters {
if remoteID == h.handler.Cluster.ClusterID {
// No need to query local cluster again
continue
}
- // blocks until it can put a value into the
- // channel (which has a max queue capacity)
- sem <- true
- if sentResponse {
- break
- }
- search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
- &sharedContext, cancelFunc, &errors, &errorCode}
+
wg.Add(1)
- go func() {
- resp, cancel, err := h.handler.remoteClusterRequest(search.remoteID, req)
- if cancel != nil {
- defer cancel()
+ go func(remote string) {
+ defer wg.Done()
+ // blocks until it can put a value into the
+ // channel (which has a max queue capacity)
+ sem <- true
+ select {
+ case <-sharedContext.Done():
+ return
+ default:
+ }
+
+ resp, _, err := h.handler.remoteClusterRequest(remote, req)
+ wasSuccess := false
+ defer func() {
+ if resp != nil && !wasSuccess {
+ resp.Body.Close()
+ }
+ }()
+ // Don't need to do anything with the cancel
+ // function returned by remoteClusterRequest
+ // because the context inherits from
+ // sharedContext, so when sharedContext is
+ // cancelled it should cancel that one as
+ // well.
+ if err != nil {
+ errorChan <- err
+ return
}
- newResp, err := search.filterRemoteClusterResponse(resp, err)
- if newResp != nil || err != nil {
- h.handler.proxy.ForwardResponse(w, newResp, err)
+ if resp.StatusCode != http.StatusOK {
+ errorChan <- HTTPError{resp.Status, resp.StatusCode}
+ return
+ }
+ select {
+ case <-sharedContext.Done():
+ return
+ default:
+ }
+
+ newResponse, err := rewriteSignatures(remote, pdh, resp, nil)
+ if err != nil {
+ errorChan <- err
+ return
+ }
+ select {
+ case <-sharedContext.Done():
+ case success <- newResponse:
+ wasSuccess = true
}
- wg.Done()
<-sem
- }()
+ }(remoteID)
}
- wg.Wait()
+ go func() {
+ wg.Wait()
+ cancelFunc()
+ }()
- if sentResponse {
- return
- }
+ var errors []string
+ errorCode := http.StatusNotFound
- // No successful responses, so return the error
- httpserver.Errors(w, errors, errorCode)
+ for {
+ select {
+ case newResp = <-success:
+ h.handler.proxy.ForwardResponse(w, newResp, nil)
+ return
+ case err := <-errorChan:
+ if httperr, ok := err.(HTTPError); ok {
+ if httperr.Code != http.StatusNotFound {
+ errorCode = http.StatusBadGateway
+ }
+ }
+ errors = append(errors, err.Error())
+ case <-sharedContext.Done():
+ httpserver.Errors(w, errors, errorCode)
+ return
+ }
+ }
}