handler *Handler
}
-func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
+func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
remote, ok := h.Cluster.RemoteClusters[remoteID]
if !ok {
- err := fmt.Errorf("no proxy available for cluster %v", remoteID)
- if filter != nil {
- _, err = filter(nil, err)
- }
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusNotFound)
- }
- return
+ return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
}
scheme := remote.Scheme
if scheme == "" {
}
saltedReq, err := h.saltAuthToken(req, remoteID)
if err != nil {
- if filter != nil {
- _, err = filter(nil, err)
- }
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- }
- return
+ return nil, err
}
urlOut := &url.URL{
Scheme: scheme,
if remote.Insecure {
client = h.insecureClient
}
- h.proxy.Do(w, saltedReq, urlOut, client, filter)
+ return h.proxy.ForwardRequest(saltedReq, urlOut, client)
}
// Buffer request body, parse form parameters in request, and then
rc := multiClusterQueryResponseCollector{clusterID: clusterID}
+ var resp *http.Response
if clusterID == h.handler.Cluster.ClusterID {
- h.handler.localClusterRequest(w, &remoteReq,
- rc.collectResponse)
+ resp, err = h.handler.localClusterRequest(&remoteReq)
} else {
- h.handler.remoteClusterRequest(clusterID, w, &remoteReq,
- rc.collectResponse)
+ resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
}
+ rc.collectResponse(resp, err)
+
if rc.error != nil {
return nil, "", rc.error
}
if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
h.next.ServeHTTP(w, req)
} else {
- h.handler.remoteClusterRequest(clusterId, w, req, nil)
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ h.handler.proxy.ForwardResponse(w, resp, err)
}
}
-type rewriteSignaturesClusterId struct {
- clusterID string
- expectHash string
-}
+func rewriteSignatures(clusterID string, expectHash string,
+ resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
if requestError != nil {
return resp, requestError
}
m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
if m != nil {
// Rewrite the block signature to be a remote signature
- _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], rw.clusterID, m[5][2:], m[8])
+ _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
if err != nil {
return nil, fmt.Errorf("Error updating manifest: %v", err)
}
// Check that expected hash is consistent with
// portable_data_hash field of the returned record
- if rw.expectHash == "" {
- rw.expectHash = col.PortableDataHash
- } else if rw.expectHash != col.PortableDataHash {
- return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", rw.expectHash, col.PortableDataHash)
+ if expectHash == "" {
+ expectHash = col.PortableDataHash
+ } else if expectHash != col.PortableDataHash {
+ return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
}
// Certify that the computed hash of the manifest_text matches our expectation
sum := hasher.Sum(nil)
computedHash := fmt.Sprintf("%x+%v", sum, sz)
- if computedHash != rw.expectHash {
- return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, rw.expectHash)
+ if computedHash != expectHash {
+ return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
}
col.ManifestText = updatedManifest.String()
// also have made it to this point, and we don't want a
// slow response holding the lock to block a faster response
// that is waiting on the lock.
- newResponse, err = rewriteSignaturesClusterId{s.remoteID, s.pdh}.rewriteSignatures(resp, nil)
+ newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
s.mtx.Lock()
if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
// request for remote collection by uuid
- h.handler.remoteClusterRequest(clusterId, w, req,
- rewriteSignaturesClusterId{clusterId, ""}.rewriteSignatures)
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ newResponse, err := rewriteSignatures(clusterId, "", resp, err)
+ h.handler.proxy.ForwardResponse(w, newResponse, err)
return
}
// not a collection UUID request, or it is a request
// Request for collection by PDH. Search the federation.
// First, query the local cluster.
- if h.handler.localClusterRequest(w, req, filterLocalClusterResponse) {
+ resp, err := h.handler.localClusterRequest(req)
+ newResp, err := filterLocalClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
return
}
&sharedContext, cancelFunc, &errors, &errorCode}
wg.Add(1)
go func() {
- h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse)
+ resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
+ newResp, err := search.filterRemoteClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
+ }
wg.Done()
<-sem
}()
})
}
-// localClusterRequest sets up a request so it can be proxied to the
-// local API server using proxy.Do(). Returns true if a response was
-// written, false if not.
-func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request, filter ResponseFilter) bool {
+func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error) {
urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
if err != nil {
- httpserver.Error(w, err.Error(), http.StatusInternalServerError)
- return true
+ return nil, err
}
urlOut = &url.URL{
Scheme: urlOut.Scheme,
if insecure {
client = h.insecureClient
}
- return h.proxy.Do(w, req, urlOut, client, filter)
+ return h.proxy.ForwardRequest(req, urlOut, client)
}
func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
- if !h.localClusterRequest(w, req, nil) && next != nil {
- next.ServeHTTP(w, req)
+ resp, err := h.localClusterRequest(req)
+ n, err := h.proxy.ForwardResponse(w, resp, err)
+ if err != nil {
+ httpserver.Logger(req).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
}
}
RequestTimeout time.Duration
}
+type HTTPError struct {
+ Message string
+ Code int
+}
+
+func (h HTTPError) Error() string {
+ return h.Message
+}
+
// headers that shouldn't be forwarded when proxying. See
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
var dropHeaders = map[string]bool{
type ResponseFilter func(*http.Response, error) (*http.Response, error)
-// Do sends a request, passes the result to the filter (if provided)
-// and then if the result is not suppressed by the filter, sends the
-// request to the ResponseWriter. Returns true if a response was written,
-// false if not.
-func (p *proxy) Do(w http.ResponseWriter,
+// Forward a request to downstream service, and return response or error.
+func (p *proxy) ForwardRequest(
reqIn *http.Request,
urlOut *url.URL,
- client *http.Client,
- filter ResponseFilter) bool {
+ client *http.Client) (*http.Response, error) {
// Copy headers from incoming request, then add/replace proxy
// headers like Via and X-Forwarded-For.
Body: reqIn.Body,
}).WithContext(ctx)
- resp, err := client.Do(reqOut)
- if filter == nil && err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadGateway)
- return true
- }
-
- // make sure original response body gets closed
- var originalBody io.ReadCloser
- if resp != nil {
- originalBody = resp.Body
- if originalBody != nil {
- defer originalBody.Close()
- }
- }
-
- if filter != nil {
- resp, err = filter(resp, err)
+ return client.Do(reqOut)
+}
- if err != nil {
+// Copy a response (or error) to the upstream client
+func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
+ if err != nil {
+ if he, ok := err.(HTTPError); ok {
+ httpserver.Error(w, he.Message, he.Code)
+ } else {
httpserver.Error(w, err.Error(), http.StatusBadGateway)
- return true
- }
- if resp == nil {
- // filter() returned a nil response, this means suppress
- // writing a response, for the case where there might
- // be multiple response writers.
- return false
- }
-
- // the filter gave us a new response body, make sure that gets closed too.
- if resp.Body != originalBody {
- defer resp.Body.Close()
}
+ return 0, nil
}
+ defer resp.Body.Close()
for k, v := range resp.Header {
for _, v := range v {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
- n, err := io.Copy(w, resp.Body)
- if err != nil {
- httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
- }
- return true
+ return io.Copy(w, resp.Body)
}