14262: saltAuthToken returns copy of request object
authorPeter Amstutz <pamstutz@veritasgenetics.com>
Tue, 16 Oct 2018 18:51:12 +0000 (14:51 -0400)
committerPeter Amstutz <pamstutz@veritasgenetics.com>
Tue, 30 Oct 2018 18:12:02 +0000 (14:12 -0400)
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz@veritasgenetics.com>

lib/controller/federation.go

index 5c6f6bf7ab9d503c395701688555359a9e925e6b..e5c56bd837dbf458d84897e09e76daebe471729d 100644 (file)
@@ -14,6 +14,7 @@ import (
        "fmt"
        "io"
        "io/ioutil"
+       "log"
        "net/http"
        "net/url"
        "regexp"
@@ -47,16 +48,27 @@ type collectionFederatedRequestHandler struct {
 func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
        remote, ok := h.Cluster.RemoteClusters[remoteID]
        if !ok {
-               httpserver.Error(w, "no proxy available for cluster "+remoteID, http.StatusNotFound)
+               err := fmt.Errorf("no proxy available for cluster %v", remoteID)
+               if filter != nil {
+                       _, err = filter(nil, err)
+               }
+               if err != nil {
+                       httpserver.Error(w, err.Error(), http.StatusNotFound)
+               }
                return
        }
        scheme := remote.Scheme
        if scheme == "" {
                scheme = "https"
        }
-       err := h.saltAuthToken(req, remoteID)
+       req, err := h.saltAuthToken(req, remoteID)
        if err != nil {
-               httpserver.Error(w, err.Error(), http.StatusBadRequest)
+               if filter != nil {
+                       _, err = filter(nil, err)
+               }
+               if err != nil {
+                       httpserver.Error(w, err.Error(), http.StatusBadRequest)
+               }
                return
        }
        urlOut := &url.URL{
@@ -655,6 +667,10 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
        sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
        defer close(sem)
        for remoteID := range h.handler.Cluster.RemoteClusters {
+               if remoteID == h.handler.Cluster.ClusterID {
+                       // No need to query local cluster again
+                       continue
+               }
                // blocks until it can put a value into the
                // channel (which has a max queue capacity)
                sem <- true
@@ -728,28 +744,40 @@ func (h *Handler) validateAPItoken(req *http.Request, user *CurrentUser) error {
 
 // Extract the auth token supplied in req, and replace it with a
 // salted token for the remote cluster.
-func (h *Handler) saltAuthToken(req *http.Request, remote string) error {
+func (h *Handler) saltAuthToken(req *http.Request, remote string) (updatedReq *http.Request, err error) {
+       updatedReq = (&http.Request{
+               Method:        req.Method,
+               URL:           req.URL,
+               Header:        req.Header,
+               Body:          req.Body,
+               ContentLength: req.ContentLength,
+               Host:          req.Host,
+       }).WithContext(req.Context())
+
        creds := auth.NewCredentials()
-       creds.LoadTokensFromHTTPRequest(req)
-       if len(creds.Tokens) == 0 && req.Header.Get("Content-Type") == "application/x-www-form-encoded" {
+       creds.LoadTokensFromHTTPRequest(updatedReq)
+       if len(creds.Tokens) == 0 && updatedReq.Header.Get("Content-Type") == "application/x-www-form-encoded" {
                // Override ParseForm's 10MiB limit by ensuring
                // req.Body is a *http.maxBytesReader.
-               req.Body = http.MaxBytesReader(nil, req.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
-               if err := creds.LoadTokensFromHTTPRequestBody(req); err != nil {
-                       return err
+               updatedReq.Body = http.MaxBytesReader(nil, updatedReq.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
+               if err := creds.LoadTokensFromHTTPRequestBody(updatedReq); err != nil {
+                       return nil, err
                }
                // Replace req.Body with a buffer that re-encodes the
                // form without api_token, in case we end up
                // forwarding the request.
-               if req.PostForm != nil {
-                       req.PostForm.Del("api_token")
+               if updatedReq.PostForm != nil {
+                       updatedReq.PostForm.Del("api_token")
                }
-               req.Body = ioutil.NopCloser(bytes.NewBufferString(req.PostForm.Encode()))
+               updatedReq.Body = ioutil.NopCloser(bytes.NewBufferString(updatedReq.PostForm.Encode()))
        }
        if len(creds.Tokens) == 0 {
-               return nil
+               return updatedReq, nil
        }
+
        token, err := auth.SaltToken(creds.Tokens[0], remote)
+
+       log.Printf("Salting %q %q to get %q %q", creds.Tokens[0], remote, token, err)
        if err == auth.ErrObsoleteToken {
                // If the token exists in our own database, salt it
                // for the remote. Otherwise, assume it was issued by
@@ -760,26 +788,41 @@ func (h *Handler) saltAuthToken(req *http.Request, remote string) error {
                        // Not ours; pass through unmodified.
                        token = currentUser.Authorization.APIToken
                } else if err != nil {
-                       return err
+                       return nil, err
                } else {
                        // Found; make V2 version and salt it.
                        token, err = auth.SaltToken(currentUser.Authorization.TokenV2(), remote)
                        if err != nil {
-                               return err
+                               return nil, err
                        }
                }
        } else if err != nil {
-               return err
+               return nil, err
+       }
+       updatedReq.Header = http.Header{}
+       for k, v := range req.Header {
+               if k == "Authorization" {
+                       updatedReq.Header[k] = []string{"Bearer " + token}
+               } else {
+                       updatedReq.Header[k] = v
+               }
        }
-       req.Header.Set("Authorization", "Bearer "+token)
+
+       log.Printf("Salted %q %q to get %q", creds.Tokens[0], remote, token)
 
        // Remove api_token=... from the the query string, in case we
        // end up forwarding the request.
-       if values, err := url.ParseQuery(req.URL.RawQuery); err != nil {
-               return err
+       if values, err := url.ParseQuery(updatedReq.URL.RawQuery); err != nil {
+               return nil, err
        } else if _, ok := values["api_token"]; ok {
                delete(values, "api_token")
-               req.URL.RawQuery = values.Encode()
+               updatedReq.URL = &url.URL{
+                       Scheme:   req.URL.Scheme,
+                       Host:     req.URL.Host,
+                       Path:     req.URL.Path,
+                       RawPath:  req.URL.RawPath,
+                       RawQuery: values.Encode(),
+               }
        }
-       return nil
+       return updatedReq, nil
 }