21717: Allow cross-origin keepstore requests.
authorTom Clegg <tom@curii.com>
Thu, 25 Apr 2024 13:59:54 +0000 (09:59 -0400)
committerTom Clegg <tom@curii.com>
Thu, 25 Apr 2024 13:59:54 +0000 (09:59 -0400)
Also add some missing Allow-Header and Expose-Header entries to
keepproxy CORS headers.

Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

sdk/go/keepclient/keepclient.go
services/keepproxy/keepproxy.go
services/keepproxy/keepproxy_test.go
services/keepstore/router.go
services/keepstore/router_test.go

index d97a2d1fcd2096a7f44983bbc7349ce11c24d307..1720096aee96fbbf4689cc743f1bb2776a118e14 100644 (file)
@@ -100,6 +100,8 @@ const (
        XKeepReplicasStored          = "X-Keep-Replicas-Stored"
        XKeepStorageClasses          = "X-Keep-Storage-Classes"
        XKeepStorageClassesConfirmed = "X-Keep-Storage-Classes-Confirmed"
+       XKeepSignature               = "X-Keep-Signature"
+       XKeepLocator                 = "X-Keep-Locator"
 )
 
 type HTTPClient interface {
index 39ffd45cbe37b69f663dc6093acd4dfe221c74a1..97a5ad65929094897e4b8fbf9b9c12b3de1633e2 100644 (file)
@@ -23,6 +23,7 @@ import (
        "git.arvados.org/arvados.git/sdk/go/health"
        "git.arvados.org/arvados.git/sdk/go/httpserver"
        "git.arvados.org/arvados.git/sdk/go/keepclient"
+       "git.arvados.org/arvados.git/services/keepstore"
        "github.com/gorilla/mux"
        lru "github.com/hashicorp/golang-lru"
        "github.com/prometheus/client_golang/prometheus"
@@ -271,10 +272,9 @@ func (h *proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) er
 }
 
 func setCORSHeaders(resp http.ResponseWriter) {
-       resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
-       resp.Header().Set("Access-Control-Allow-Origin", "*")
-       resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
-       resp.Header().Set("Access-Control-Max-Age", "86486400")
+       keepstore.SetCORSHeaders(resp)
+       acam := "Access-Control-Allow-Methods"
+       resp.Header().Set(acam, resp.Header().Get(acam)+", POST")
 }
 
 type invalidPathHandler struct{}
@@ -419,9 +419,9 @@ func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
        locatorIn := mux.Vars(req)["locator"]
 
        // Check if the client specified storage classes
-       if req.Header.Get("X-Keep-Storage-Classes") != "" {
+       if req.Header.Get(keepclient.XKeepStorageClasses) != "" {
                var scl []string
-               for _, sc := range strings.Split(req.Header.Get("X-Keep-Storage-Classes"), ",") {
+               for _, sc := range strings.Split(req.Header.Get(keepclient.XKeepStorageClasses), ",") {
                        scl = append(scl, strings.Trim(sc, " "))
                }
                kc.SetStorageClasses(scl)
index 2c73e2d1040d1b37df4e77375b5a859d3187565e..ea8c9ba6ca4cfb6c1ec0b6e5f6734bc0aee7bef4 100644 (file)
@@ -558,14 +558,14 @@ func (s *ServerRequiredSuite) TestCorsHeaders(c *C) {
                body, err := ioutil.ReadAll(resp.Body)
                c.Check(err, IsNil)
                c.Check(string(body), Equals, "")
-               c.Check(resp.Header.Get("Access-Control-Allow-Methods"), Equals, "GET, HEAD, POST, PUT, OPTIONS")
+               c.Check(resp.Header.Get("Access-Control-Allow-Methods"), Equals, "GET, HEAD, PUT, OPTIONS, POST")
                c.Check(resp.Header.Get("Access-Control-Allow-Origin"), Equals, "*")
        }
 
        {
                resp, err := http.Get(fmt.Sprintf("http://%s/%x+3", srv.Addr, md5.Sum([]byte("foo"))))
                c.Check(err, Equals, nil)
-               c.Check(resp.Header.Get("Access-Control-Allow-Headers"), Equals, "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
+               c.Check(resp.Header.Get("Access-Control-Allow-Headers"), Equals, "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas, X-Keep-Signature, X-Keep-Storage-Classes")
                c.Check(resp.Header.Get("Access-Control-Allow-Origin"), Equals, "*")
        }
 }
index 0c8182c6ea31c91a8c20056e8ca886df43d27712..b462487a3d8a41337a73c7e1c9079dd8bcc34770 100644 (file)
@@ -19,6 +19,7 @@ import (
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/httpserver"
+       "git.arvados.org/arvados.git/sdk/go/keepclient"
        "github.com/gorilla/mux"
 )
 
@@ -57,9 +58,11 @@ func newRouter(keepstore *keepstore, puller *puller, trasher *trasher) service.H
        touch.HandleFunc(locatorPath, adminonly(rtr.handleBlockTouch))
        delete := r.Methods(http.MethodDelete).Subrouter()
        delete.HandleFunc(locatorPath, adminonly(rtr.handleBlockTrash))
+       options := r.Methods(http.MethodOptions).Subrouter()
+       options.NewRoute().PathPrefix(`/`).HandlerFunc(rtr.handleOptions)
        r.NotFoundHandler = http.HandlerFunc(rtr.handleBadRequest)
        r.MethodNotAllowedHandler = http.HandlerFunc(rtr.handleBadRequest)
-       rtr.Handler = auth.LoadToken(r)
+       rtr.Handler = corsHandler(auth.LoadToken(r))
        return rtr
 }
 
@@ -75,11 +78,11 @@ func (rtr *router) handleBlockRead(w http.ResponseWriter, req *http.Request) {
        // Intervening proxies must not return a cached GET response
        // to a prior request if a X-Keep-Signature request header has
        // been added or changed.
-       w.Header().Add("Vary", "X-Keep-Signature")
+       w.Header().Add("Vary", keepclient.XKeepSignature)
        var localLocator func(string)
-       if strings.SplitN(req.Header.Get("X-Keep-Signature"), ",", 2)[0] == "local" {
+       if strings.SplitN(req.Header.Get(keepclient.XKeepSignature), ",", 2)[0] == "local" {
                localLocator = func(locator string) {
-                       w.Header().Set("X-Keep-Locator", locator)
+                       w.Header().Set(keepclient.XKeepLocator, locator)
                }
        }
        out := w
@@ -107,20 +110,20 @@ func (rtr *router) handleBlockRead(w http.ResponseWriter, req *http.Request) {
 
 func (rtr *router) handleBlockWrite(w http.ResponseWriter, req *http.Request) {
        dataSize, _ := strconv.Atoi(req.Header.Get("Content-Length"))
-       replicas, _ := strconv.Atoi(req.Header.Get("X-Arvados-Replicas-Desired"))
+       replicas, _ := strconv.Atoi(req.Header.Get("X-Arvados-Desired-Replicas"))
        resp, err := rtr.keepstore.BlockWrite(req.Context(), arvados.BlockWriteOptions{
                Hash:           mux.Vars(req)["locator"],
                Reader:         req.Body,
                DataSize:       dataSize,
                RequestID:      req.Header.Get("X-Request-Id"),
-               StorageClasses: trimSplit(req.Header.Get("X-Keep-Storage-Classes"), ","),
+               StorageClasses: trimSplit(req.Header.Get(keepclient.XKeepStorageClasses), ","),
                Replicas:       replicas,
        })
        if err != nil {
                rtr.handleError(w, req, err)
                return
        }
-       w.Header().Set("X-Keep-Replicas-Stored", fmt.Sprintf("%d", resp.Replicas))
+       w.Header().Set(keepclient.XKeepReplicasStored, fmt.Sprintf("%d", resp.Replicas))
        scc := ""
        for k, n := range resp.StorageClasses {
                if n > 0 {
@@ -130,7 +133,7 @@ func (rtr *router) handleBlockWrite(w http.ResponseWriter, req *http.Request) {
                        scc += fmt.Sprintf("%s=%d", k, n)
                }
        }
-       w.Header().Set("X-Keep-Storage-Classes-Confirmed", scc)
+       w.Header().Set(keepclient.XKeepStorageClassesConfirmed, scc)
        w.WriteHeader(http.StatusOK)
        fmt.Fprintln(w, resp.Locator)
 }
@@ -210,6 +213,9 @@ func (rtr *router) handleBadRequest(w http.ResponseWriter, req *http.Request) {
        http.Error(w, "Bad Request", http.StatusBadRequest)
 }
 
+func (rtr *router) handleOptions(w http.ResponseWriter, req *http.Request) {
+}
+
 func (rtr *router) handleError(w http.ResponseWriter, req *http.Request, err error) {
        if req.Context().Err() != nil {
                w.WriteHeader(499)
@@ -274,3 +280,24 @@ type discardWrite struct {
 func (discardWrite) Write(p []byte) (int, error) {
        return len(p), nil
 }
+
+func corsHandler(h http.Handler) http.Handler {
+       return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               SetCORSHeaders(w)
+               h.ServeHTTP(w, r)
+       })
+}
+
+var corsHeaders = map[string]string{
+       "Access-Control-Allow-Methods":  "GET, HEAD, PUT, OPTIONS",
+       "Access-Control-Allow-Origin":   "*",
+       "Access-Control-Allow-Headers":  "Authorization, Content-Length, Content-Type, " + keepclient.XKeepDesiredReplicas + ", " + keepclient.XKeepSignature + ", " + keepclient.XKeepStorageClasses,
+       "Access-Control-Expose-Headers": keepclient.XKeepLocator + ", " + keepclient.XKeepReplicasStored + ", " + keepclient.XKeepStorageClassesConfirmed,
+       "Access-Control-Max-Age":        "86486400",
+}
+
+func SetCORSHeaders(w http.ResponseWriter) {
+       for k, v := range corsHeaders {
+               w.Header().Set(k, v)
+       }
+}
index 15a055d55ef3bddf23224cb70c8fac86a0b7777f..9fc6e1b7dba963a161b3fe0924cc678a0d8d5980 100644 (file)
@@ -78,22 +78,26 @@ func (s *routerSuite) TestBlockRead_Token(c *C) {
        resp := call(router, "GET", "http://example/"+locSigned, "", nil, nil)
        c.Check(resp.Code, Equals, http.StatusUnauthorized)
        c.Check(resp.Body.String(), Matches, "no token provided in Authorization header\n")
+       checkCORSHeaders(c, resp.Header())
 
        // Different token => invalid signature
        resp = call(router, "GET", "http://example/"+locSigned, "badtoken", nil, nil)
        c.Check(resp.Code, Equals, http.StatusBadRequest)
        c.Check(resp.Body.String(), Equals, "invalid signature\n")
+       checkCORSHeaders(c, resp.Header())
 
        // Correct token
        resp = call(router, "GET", "http://example/"+locSigned, arvadostest.ActiveTokenV2, nil, nil)
        c.Check(resp.Code, Equals, http.StatusOK)
        c.Check(resp.Body.String(), Equals, "foo")
+       checkCORSHeaders(c, resp.Header())
 
        // HEAD
        resp = call(router, "HEAD", "http://example/"+locSigned, arvadostest.ActiveTokenV2, nil, nil)
        c.Check(resp.Code, Equals, http.StatusOK)
        c.Check(resp.Result().ContentLength, Equals, int64(3))
        c.Check(resp.Body.String(), Equals, "")
+       checkCORSHeaders(c, resp.Header())
 }
 
 // As a special case we allow HEAD requests that only provide a hash
@@ -165,13 +169,16 @@ func (s *routerSuite) TestBlockRead_ChecksumMismatch(c *C) {
                }
                c.Check(resp.Body.Len(), Not(Equals), len(gooddata))
                c.Check(resp.Result().ContentLength, Equals, int64(len(gooddata)))
+               checkCORSHeaders(c, resp.Header())
 
                resp = call(router, "HEAD", "http://example/"+locSigned, arvadostest.ActiveTokenV2, nil, nil)
                c.Check(resp.Code, Equals, http.StatusBadGateway)
+               checkCORSHeaders(c, resp.Header())
 
                hashSigned := router.keepstore.signLocator(arvadostest.ActiveTokenV2, hash)
                resp = call(router, "HEAD", "http://example/"+hashSigned, arvadostest.ActiveTokenV2, nil, nil)
                c.Check(resp.Code, Equals, http.StatusBadGateway)
+               checkCORSHeaders(c, resp.Header())
        }
 }
 
@@ -181,6 +188,7 @@ func (s *routerSuite) TestBlockWrite(c *C) {
 
        resp := call(router, "PUT", "http://example/"+fooHash, arvadostest.ActiveTokenV2, []byte("foo"), nil)
        c.Check(resp.Code, Equals, http.StatusOK)
+       checkCORSHeaders(c, resp.Header())
        locator := strings.TrimSpace(resp.Body.String())
 
        resp = call(router, "GET", "http://example/"+locator, arvadostest.ActiveTokenV2, nil, nil)
@@ -192,7 +200,7 @@ func (s *routerSuite) TestBlockWrite_Headers(c *C) {
        router, cancel := testRouter(c, s.cluster, nil)
        defer cancel()
 
-       resp := call(router, "PUT", "http://example/"+fooHash, arvadostest.ActiveTokenV2, []byte("foo"), http.Header{"X-Arvados-Replicas-Desired": []string{"2"}})
+       resp := call(router, "PUT", "http://example/"+fooHash, arvadostest.ActiveTokenV2, []byte("foo"), http.Header{"X-Arvados-Desired-Replicas": []string{"2"}})
        c.Check(resp.Code, Equals, http.StatusOK)
        c.Check(resp.Header().Get("X-Keep-Replicas-Stored"), Equals, "1")
        c.Check(sortCommaSeparated(resp.Header().Get("X-Keep-Storage-Classes-Confirmed")), Equals, "testclass1=1")
@@ -469,7 +477,6 @@ func (s *routerSuite) TestIndex(c *C) {
                c.Check(resp.Code, Equals, http.StatusOK)
                c.Check(strings.Split(resp.Body.String(), "\n"), HasLen, 5)
        }
-
 }
 
 // Check that the context passed to a volume method gets cancelled
@@ -500,6 +507,19 @@ func (s *routerSuite) TestCancelOnDisconnect(c *C) {
        c.Check(resp.Code, Equals, 499)
 }
 
+func (s *routerSuite) TestCORSPreflight(c *C) {
+       router, cancel := testRouter(c, s.cluster, nil)
+       defer cancel()
+
+       for _, path := range []string{"/", "/whatever", "/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa+123"} {
+               c.Logf("=== %s", path)
+               resp := call(router, http.MethodOptions, "http://example"+path, arvadostest.ActiveTokenV2, nil, nil)
+               c.Check(resp.Code, Equals, http.StatusOK)
+               c.Check(resp.Body.String(), Equals, "")
+               checkCORSHeaders(c, resp.Header())
+       }
+}
+
 func call(handler http.Handler, method, path, tok string, body []byte, hdr http.Header) *httptest.ResponseRecorder {
        resp := httptest.NewRecorder()
        req, err := http.NewRequest(method, path, bytes.NewReader(body))
@@ -515,3 +535,10 @@ func call(handler http.Handler, method, path, tok string, body []byte, hdr http.
        handler.ServeHTTP(resp, req)
        return resp
 }
+
+func checkCORSHeaders(c *C, h http.Header) {
+       c.Check(h.Get("Access-Control-Allow-Methods"), Equals, "GET, HEAD, PUT, OPTIONS")
+       c.Check(h.Get("Access-Control-Allow-Origin"), Equals, "*")
+       c.Check(h.Get("Access-Control-Allow-Headers"), Equals, "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas, X-Keep-Signature, X-Keep-Storage-Classes")
+       c.Check(h.Get("Access-Control-Expose-Headers"), Equals, "X-Keep-Locator, X-Keep-Replicas-Stored, X-Keep-Storage-Classes-Confirmed")
+}