return resp, requestError
}
- if resp.StatusCode != 200 {
+ if resp.StatusCode != http.StatusOK {
return resp, nil
}
return resp, requestError
}
- if resp.StatusCode == 404 {
+ if resp.StatusCode == http.StatusNotFound {
// Suppress returning this result, because we want to
// search the federation.
return nil, nil
return nil, nil
}
- if resp.StatusCode != 200 {
+ if resp.StatusCode != http.StatusOK {
// Suppress returning unsuccessful result. Maybe
// another request will find it.
- // TODO collect and return error responses.
- *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", httpserver.GetRequestID(resp.Header), s.remoteID, resp.Status))
- if resp.StatusCode != 404 {
+ *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", resp.Header.Get(httpserver.HeaderRequestID), s.remoteID, resp.Status))
+ if resp.StatusCode != http.StatusNotFound {
// Got a non-404 error response, convert into BadGateway
*s.statusCode = http.StatusBadGateway
}
if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
// request for remote collection by uuid
- resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
+ if cancel != nil {
+ defer cancel()
+ }
newResponse, err := rewriteSignatures(clusterId, "", resp, err)
h.handler.proxy.ForwardResponse(w, newResponse, err)
return
// Request for collection by PDH. Search the federation.
// First, query the local cluster.
- resp, err := h.handler.localClusterRequest(req)
+ resp, localClusterRequestCancel, err := h.handler.localClusterRequest(req)
+ if localClusterRequestCancel != nil {
+ defer localClusterRequestCancel()
+ }
newResp, err := filterLocalClusterResponse(resp, err)
if newResp != nil || err != nil {
h.handler.proxy.ForwardResponse(w, newResp, err)
mtx := sync.Mutex{}
wg := sync.WaitGroup{}
var errors []string
- var errorCode int = 404
+ var errorCode int = http.StatusNotFound
// use channel as a semaphore to limit the number of concurrent
// requests at a time
&sharedContext, cancelFunc, &errors, &errorCode}
wg.Add(1)
go func() {
- resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
+ resp, cancel, err := h.handler.remoteClusterRequest(search.remoteID, req)
+ if cancel != nil {
+ defer cancel()
+ }
newResp, err := search.filterRemoteClusterResponse(resp, err)
if newResp != nil || err != nil {
h.handler.proxy.ForwardResponse(w, newResp, err)
req.ContentLength = int64(buf.Len())
req.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
- resp, err := h.handler.remoteClusterRequest(*clusterId, req)
+ resp, cancel, err := h.handler.remoteClusterRequest(*clusterId, req)
+ if cancel != nil {
+ defer cancel()
+ }
h.handler.proxy.ForwardResponse(w, resp, err)
return true
}
import (
"bytes"
+ "context"
"encoding/json"
"fmt"
"io/ioutil"
rc := multiClusterQueryResponseCollector{clusterID: clusterID}
var resp *http.Response
+ var cancel context.CancelFunc
if clusterID == h.handler.Cluster.ClusterID {
- resp, err = h.handler.localClusterRequest(&remoteReq)
+ resp, cancel, err = h.handler.localClusterRequest(&remoteReq)
} else {
- resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
+ resp, cancel, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
}
rc.collectResponse(resp, err)
+ if cancel != nil {
+ cancel()
+ }
if rc.error != nil {
return nil, "", rc.error
if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
h.next.ServeHTTP(w, req)
} else {
- resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
+ if cancel != nil {
+ defer cancel()
+ }
h.handler.proxy.ForwardResponse(w, resp, err)
}
}
import (
"bytes"
+ "context"
"database/sql"
"encoding/json"
"fmt"
var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
-func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
+func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, context.CancelFunc, error) {
remote, ok := h.Cluster.RemoteClusters[remoteID]
if !ok {
- return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
+ return nil, nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
}
scheme := remote.Scheme
if scheme == "" {
}
saltedReq, err := h.saltAuthToken(req, remoteID)
if err != nil {
- return nil, err
+ return nil, nil, err
}
urlOut := &url.URL{
Scheme: scheme,
if remote.Insecure {
client = h.insecureClient
}
- return h.proxy.ForwardRequest(saltedReq, urlOut, client)
+ return h.proxy.Do(saltedReq, urlOut, client)
}
// Buffer request body, parse form parameters in request, and then
func (s *FederationSuite) remoteMockHandler(w http.ResponseWriter, req *http.Request) {
b := &bytes.Buffer{}
io.Copy(b, req.Body)
- req.Body = ioutil.NopCloser(b)
req.Body.Close()
+ req.Body = ioutil.NopCloser(b)
s.remoteMockRequests = append(s.remoteMockRequests, *req)
}
package controller
import (
+ "context"
"database/sql"
"errors"
"net"
})
}
-func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error) {
+func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, context.CancelFunc, error) {
urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
if err != nil {
- return nil, err
+ return nil, nil, err
}
urlOut = &url.URL{
Scheme: urlOut.Scheme,
if insecure {
client = h.insecureClient
}
- return h.proxy.ForwardRequest(req, urlOut, client)
+ return h.proxy.Do(req, urlOut, client)
}
func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
- resp, err := h.localClusterRequest(req)
+ resp, cancel, err := h.localClusterRequest(req)
+ if cancel != nil {
+ defer cancel()
+ }
n, err := h.proxy.ForwardResponse(w, resp, err)
if err != nil {
httpserver.Logger(req).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
type ResponseFilter func(*http.Response, error) (*http.Response, error)
-// Forward a request to downstream service, and return response or error.
-func (p *proxy) ForwardRequest(
+// Forward a request to upstream service, and return response or error.
+func (p *proxy) Do(
reqIn *http.Request,
urlOut *url.URL,
- client *http.Client) (*http.Response, error) {
+ client *http.Client) (*http.Response, context.CancelFunc, error) {
// Copy headers from incoming request, then add/replace proxy
// headers like Via and X-Forwarded-For.
hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
ctx := reqIn.Context()
+ var cancel context.CancelFunc
if p.RequestTimeout > 0 {
- ctx, _ = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
+ ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
}
reqOut := (&http.Request{
Body: reqIn.Body,
}).WithContext(ctx)
- return client.Do(reqOut)
+ resp, err := client.Do(reqOut)
+ return resp, cancel, err
}
-// Copy a response (or error) to the upstream client
+// Copy a response (or error) to the downstream client
func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
if err != nil {
if he, ok := err.(HTTPError); ok {
"time"
)
+const (
+ HeaderRequestID = "X-Request-Id"
+)
+
// IDGenerator generates alphanumeric strings suitable for use as
// unique IDs (a given IDGenerator will never return the same ID
// twice).
func AddRequestIDs(h http.Handler) http.Handler {
gen := &IDGenerator{Prefix: "req-"}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- if req.Header.Get("X-Request-Id") == "" {
+ if req.Header.Get(HeaderRequestID) == "" {
if req.Header == nil {
req.Header = http.Header{}
}
- req.Header.Set("X-Request-Id", gen.Next())
+ req.Header.Set(HeaderRequestID, gen.Next())
}
h.ServeHTTP(w, req)
})
}
-
-func GetRequestID(h http.Header) string {
- return h.Get("X-Request-Id")
-}