13619: MultiClusterQuery passes test
authorPeter Amstutz <pamstutz@veritasgenetics.com>
Thu, 27 Sep 2018 16:29:35 +0000 (12:29 -0400)
committerPeter Amstutz <pamstutz@veritasgenetics.com>
Thu, 27 Sep 2018 16:29:35 +0000 (12:29 -0400)
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz@veritasgenetics.com>

lib/controller/federation.go
lib/controller/federation_test.go
lib/controller/handler.go
lib/controller/proxy.go

index dc03e039faea935289ecd4f4ffe334d3bbd24b0b..caa84ca5f29073c78b8b5322222b4c23e06bb70f 100644 (file)
@@ -14,7 +14,6 @@ import (
        "fmt"
        "io"
        "io/ioutil"
-       "log"
        "net/http"
        "net/url"
        "regexp"
@@ -129,7 +128,11 @@ func (c *responseCollector) collectResponse(resp *http.Response, requestError er
        defer c.mtx.Unlock()
 
        if err == nil {
-               c.responses = append(c.responses, loadInto["items"].([]interface{})...)
+               if resp.StatusCode != http.StatusOK {
+                       c.errors = append(c.errors, fmt.Errorf("error %v", loadInto["errors"]))
+               } else {
+                       c.responses = append(c.responses, loadInto["items"].([]interface{})...)
+               }
        } else {
                c.errors = append(c.errors, err)
        }
@@ -186,19 +189,17 @@ func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.Response
                                wg.Done()
                                <-sem
                        }()
-                       remoteReq := *req
+                       var remoteReq http.Request
+                       remoteReq.Header = req.Header
                        remoteReq.Method = "POST"
-                       remoteReq.URL = &url.URL{
-                               Path:    req.URL.Path,
-                               RawPath: req.URL.RawPath,
-                       }
+                       remoteReq.URL = &url.URL{Path: req.URL.Path}
                        remoteParams := make(url.Values)
                        remoteParams["_method"] = []string{"GET"}
                        content, err := json.Marshal(v)
                        if err != nil {
                                rc.mtx.Lock()
+                               defer rc.mtx.Unlock()
                                rc.errors = append(rc.errors, err)
-                               rc.mtx.Unlock()
                                return
                        }
                        remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
@@ -206,8 +207,8 @@ func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.Response
                        remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
 
                        if k == h.handler.Cluster.ClusterID {
-                               h.handler.proxy.Do(w, &remoteReq, remoteReq.URL,
-                                       h.handler.secureClient, rc.collectResponse)
+                               h.handler.localClusterRequest(w, &remoteReq,
+                                       rc.collectResponse)
                        } else {
                                h.handler.remoteClusterRequest(k, w, &remoteReq,
                                        rc.collectResponse)
@@ -222,17 +223,13 @@ func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.Response
                for _, e := range rc.errors {
                        strerr = append(strerr, e.Error())
                }
-               httpserver.Errors(w, strerr, http.StatusBadRequest)
+               httpserver.Errors(w, strerr, http.StatusBadGateway)
        } else {
-               log.Printf("Sending status ok %+v", rc)
                w.Header().Set("Content-Type", "application/json")
                w.WriteHeader(http.StatusOK)
                itemList := make(map[string]interface{})
                itemList["items"] = rc.responses
-               //x, _ := json.Marshal(itemList)
-               //log.Printf("Sending response %v", string(x))
                json.NewEncoder(w).Encode(itemList)
-               log.Printf("Sent?")
        }
 
        return true
@@ -285,7 +282,7 @@ func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *h
                        return
                }
        }
-       log.Printf("Clusterid is %q", clusterId)
+       //log.Printf("Clusterid is %q", clusterId)
 
        if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
                h.next.ServeHTTP(w, req)
@@ -405,11 +402,7 @@ func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requ
        return resp, nil
 }
 
-type searchLocalClusterForPDH struct {
-       sentResponse bool
-}
-
-func (s *searchLocalClusterForPDH) filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
        if requestError != nil {
                return resp, requestError
        }
@@ -417,10 +410,8 @@ func (s *searchLocalClusterForPDH) filterLocalClusterResponse(resp *http.Respons
        if resp.StatusCode == 404 {
                // Suppress returning this result, because we want to
                // search the federation.
-               s.sentResponse = false
                return nil, nil
        }
-       s.sentResponse = true
        return resp, nil
 }
 
@@ -526,26 +517,7 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
        // Request for collection by PDH.  Search the federation.
 
        // First, query the local cluster.
-       urlOut, insecure, err := findRailsAPI(h.handler.Cluster, h.handler.NodeProfile)
-       if err != nil {
-               httpserver.Error(w, err.Error(), http.StatusInternalServerError)
-               return
-       }
-
-       urlOut = &url.URL{
-               Scheme:   urlOut.Scheme,
-               Host:     urlOut.Host,
-               Path:     req.URL.Path,
-               RawPath:  req.URL.RawPath,
-               RawQuery: req.URL.RawQuery,
-       }
-       client := h.handler.secureClient
-       if insecure {
-               client = h.handler.insecureClient
-       }
-       sf := &searchLocalClusterForPDH{}
-       h.handler.proxy.Do(w, req, urlOut, client, sf.filterLocalClusterResponse)
-       if sf.sentResponse {
+       if h.handler.localClusterRequest(w, req, filterLocalClusterResponse) {
                return
        }
 
index b94efa6aecc8076e8cabb5c0f899360a3a67917e..113fa9eebdbce4a4870a1ca006f70c5c72d85f09 100644 (file)
@@ -8,7 +8,6 @@ import (
        "encoding/json"
        "fmt"
        "io/ioutil"
-       "log"
        "net/http"
        "net/http/httptest"
        "net/url"
@@ -304,12 +303,10 @@ func (s *FederationSuite) checkJSONErrorMatches(c *check.C, resp *http.Response,
        c.Check(jresp.Errors[0], check.Matches, re)
 }
 
-func (s *FederationSuite) localServiceReturns404(c *check.C) *httpserver.Server {
+func (s *FederationSuite) localServiceHandler(c *check.C, h http.Handler) *httpserver.Server {
        srv := &httpserver.Server{
                Server: http.Server{
-                       Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-                               w.WriteHeader(404)
-                       }),
+                       Handler: h,
                },
        }
 
@@ -325,6 +322,12 @@ func (s *FederationSuite) localServiceReturns404(c *check.C) *httpserver.Server
        return srv
 }
 
+func (s *FederationSuite) localServiceReturns404(c *check.C) *httpserver.Server {
+       return s.localServiceHandler(c, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               w.WriteHeader(404)
+       }))
+}
+
 func (s *FederationSuite) TestGetLocalCollection(c *check.C) {
        np := arvados.NodeProfile{
                Controller: arvados.SystemServiceInstance{Listen: ":"},
@@ -629,14 +632,23 @@ func (s *FederationSuite) TestListRemoteContainer(c *check.C) {
 }
 
 func (s *FederationSuite) TestListMultiRemoteContainers(c *check.C) {
-       defer s.localServiceReturns404(c).Close()
+       defer s.localServiceHandler(c, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               bd, _ := ioutil.ReadAll(req.Body)
+               c.Check(string(bd), check.Equals, `_method=GET&filters=%5B%5B%22uuid%22%2C+%22in%22%2C+%5B%22zhome-xvhdp-cr5queuedcontnr%22%5D%5D%5D`)
+               w.WriteHeader(200)
+               w.Write([]byte(`{"items": [{"uuid": "zhome-xvhdp-cr5queuedcontnr"}]}`))
+       })).Close()
        req := httptest.NewRequest("GET", "/arvados/v1/containers?filters="+
                url.QueryEscape(fmt.Sprintf(`[["uuid", "in", ["%v", "zhome-xvhdp-cr5queuedcontnr"]]]`, arvadostest.QueuedContainerUUID)), nil)
        req.Header.Set("Authorization", "Bearer "+arvadostest.ActiveToken)
        resp := s.testRequest(req)
-       log.Printf("got %+v", resp)
-       c.Assert(resp.StatusCode, check.Equals, http.StatusOK)
+       c.Check(resp.StatusCode, check.Equals, http.StatusOK)
        var cn arvados.ContainerList
        c.Check(json.NewDecoder(resp.Body).Decode(&cn), check.IsNil)
-       c.Check(cn.Items[0].UUID, check.Equals, arvadostest.QueuedContainerUUID)
+       if cn.Items[0].UUID == arvadostest.QueuedContainerUUID {
+               c.Check(cn.Items[1].UUID, check.Equals, "zhome-xvhdp-cr5queuedcontnr")
+       } else {
+               c.Check(cn.Items[1].UUID, check.Equals, arvadostest.QueuedContainerUUID)
+               c.Check(cn.Items[0].UUID, check.Equals, "zhome-xvhdp-cr5queuedcontnr")
+       }
 }
index 2b41aba6bfabf2bf9a76d5b8c483d146eef5cc6a..0c31815cba21f2869e7ae4ddf73c880bf4d0a5c8 100644 (file)
@@ -121,11 +121,14 @@ func prepend(next http.Handler, middleware middlewareFunc) http.Handler {
        })
 }
 
-func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
+// localClusterRequest sets up a request so it can be proxied to the
+// local API server using proxy.Do().  Returns true if a response was
+// written, false if not.
+func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request, filter ResponseFilter) bool {
        urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
        if err != nil {
                httpserver.Error(w, err.Error(), http.StatusInternalServerError)
-               return
+               return true
        }
        urlOut = &url.URL{
                Scheme:   urlOut.Scheme,
@@ -138,7 +141,13 @@ func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next h
        if insecure {
                client = h.insecureClient
        }
-       h.proxy.Do(w, req, urlOut, client, nil)
+       return h.proxy.Do(w, req, urlOut, client, filter)
+}
+
+func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
+       if !h.localClusterRequest(w, req, nil) && next != nil {
+               next.ServeHTTP(w, req)
+       }
 }
 
 // For now, findRailsAPI always uses the rails API running on this
index 373b42e8f4fe285f5c89f114d7d2997220618943..951cb9d25fe24ba74a5697d54187847cfc84ae1a 100644 (file)
@@ -36,11 +36,15 @@ var dropHeaders = map[string]bool{
 
 type ResponseFilter func(*http.Response, error) (*http.Response, error)
 
+// Do sends a request, passes the result to the filter (if provided)
+// and then if the result is not suppressed by the filter, sends the
+// request to the ResponseWriter.  Returns true if a response was written,
+// false if not.
 func (p *proxy) Do(w http.ResponseWriter,
        reqIn *http.Request,
        urlOut *url.URL,
        client *http.Client,
-       filter ResponseFilter) {
+       filter ResponseFilter) bool {
 
        // Copy headers from incoming request, then add/replace proxy
        // headers like Via and X-Forwarded-For.
@@ -78,7 +82,7 @@ func (p *proxy) Do(w http.ResponseWriter,
        resp, err := client.Do(reqOut)
        if filter == nil && err != nil {
                httpserver.Error(w, err.Error(), http.StatusBadGateway)
-               return
+               return true
        }
 
        // make sure original response body gets closed
@@ -95,13 +99,13 @@ func (p *proxy) Do(w http.ResponseWriter,
 
                if err != nil {
                        httpserver.Error(w, err.Error(), http.StatusBadGateway)
-                       return
+                       return true
                }
                if resp == nil {
                        // filter() returned a nil response, this means suppress
                        // writing a response, for the case where there might
                        // be multiple response writers.
-                       return
+                       return false
                }
 
                // the filter gave us a new response body, make sure that gets closed too.
@@ -120,4 +124,5 @@ func (p *proxy) Do(w http.ResponseWriter,
        if err != nil {
                httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
        }
+       return true
 }