return resp, requestError
}
- if resp.StatusCode != 200 {
+ if resp.StatusCode != http.StatusOK {
return resp, nil
}
return resp, requestError
}
- if resp.StatusCode == 404 {
+ if resp.StatusCode == http.StatusNotFound {
// Suppress returning this result, because we want to
// search the federation.
return nil, nil
statusCode *int
}
-func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- s.mtx.Lock()
- defer s.mtx.Unlock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if requestError != nil {
- *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
- // Record the error and suppress response
- return nil, nil
- }
-
- if resp.StatusCode != 200 {
- // Suppress returning unsuccessful result. Maybe
- // another request will find it.
- // TODO collect and return error responses.
- *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", httpserver.GetRequestID(resp.Header), s.remoteID, resp.Status))
- if resp.StatusCode != 404 {
- // Got a non-404 error response, convert into BadGateway
- *s.statusCode = http.StatusBadGateway
- }
- return nil, nil
- }
-
- s.mtx.Unlock()
-
- // This reads the response body. We don't want to hold the
- // lock while doing this because other remote requests could
- // also have made it to this point, and we don't want a
- // slow response holding the lock to block a faster response
- // that is waiting on the lock.
- newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
-
- s.mtx.Lock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if err != nil {
- // Suppress returning unsuccessful result. Maybe
- // another request will be successful.
- *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
- return nil, nil
- }
-
- // We have a successful response. Suppress/cancel all the
- // other requests/responses.
- *s.sentResponse = true
- s.cancelFunc()
-
- return newResponse, nil
-}
-
func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "GET" {
// Only handle GET requests right now
return
}
- sharedContext, cancelFunc := context.WithCancel(req.Context())
- defer cancelFunc()
- req = req.WithContext(sharedContext)
-
// Create a goroutine for each cluster in the
// RemoteClusters map. The first valid result gets
// returned to the client. When that happens, all
- // other outstanding requests are cancelled or
- // suppressed.
- sentResponse := false
- mtx := sync.Mutex{}
+ // other outstanding requests are cancelled
+ sharedContext, cancelFunc := context.WithCancel(req.Context())
+ req = req.WithContext(sharedContext)
wg := sync.WaitGroup{}
- var errors []string
- var errorCode int = 404
+ pdh := m[1]
+ success := make(chan *http.Response)
+ errorChan := make(chan error, len(h.handler.Cluster.RemoteClusters))
// use channel as a semaphore to limit the number of concurrent
// requests at a time
sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
- defer close(sem)
+
+ defer cancelFunc()
+
for remoteID := range h.handler.Cluster.RemoteClusters {
if remoteID == h.handler.Cluster.ClusterID {
// No need to query local cluster again
continue
}
- // blocks until it can put a value into the
- // channel (which has a max queue capacity)
- sem <- true
- if sentResponse {
- break
- }
- search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
- &sharedContext, cancelFunc, &errors, &errorCode}
+
wg.Add(1)
- go func() {
- resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
- newResp, err := search.filterRemoteClusterResponse(resp, err)
- if newResp != nil || err != nil {
- h.handler.proxy.ForwardResponse(w, newResp, err)
+ go func(remote string) {
+ defer wg.Done()
+ // blocks until it can put a value into the
+ // channel (which has a max queue capacity)
+ sem <- true
+ select {
+ case <-sharedContext.Done():
+ return
+ default:
+ }
+
+ resp, err := h.handler.remoteClusterRequest(remote, req)
+ wasSuccess := false
+ defer func() {
+ if resp != nil && !wasSuccess {
+ resp.Body.Close()
+ }
+ }()
+ if err != nil {
+ errorChan <- err
+ return
+ }
+ if resp.StatusCode != http.StatusOK {
+ errorChan <- HTTPError{resp.Status, resp.StatusCode}
+ return
+ }
+ select {
+ case <-sharedContext.Done():
+ return
+ default:
+ }
+
+ newResponse, err := rewriteSignatures(remote, pdh, resp, nil)
+ if err != nil {
+ errorChan <- err
+ return
+ }
+ select {
+ case <-sharedContext.Done():
+ case success <- newResponse:
+ wasSuccess = true
}
- wg.Done()
<-sem
- }()
+ }(remoteID)
}
- wg.Wait()
+ go func() {
+ wg.Wait()
+ cancelFunc()
+ }()
- if sentResponse {
- return
- }
+ errorCode := http.StatusNotFound
- // No successful responses, so return the error
- httpserver.Errors(w, errors, errorCode)
+ for {
+ select {
+ case newResp = <-success:
+ h.handler.proxy.ForwardResponse(w, newResp, nil)
+ return
+ case <-sharedContext.Done():
+ var errors []string
+ for len(errorChan) > 0 {
+ err := <-errorChan
+ if httperr, ok := err.(HTTPError); ok {
+ if httperr.Code != http.StatusNotFound {
+ errorCode = http.StatusBadGateway
+ }
+ }
+ errors = append(errors, err.Error())
+ }
+ httpserver.Errors(w, errors, errorCode)
+ return
+ }
+ }
}