Arvados-DCO-1.1-Signed-off-by: Radhika Chippada <radhika@curoverse.com>
[arvados.git] / services / keep-web / handler_test.go
index 392de94ffb5af1399bada756ed4a8efc7f33506b..1d03b90a3a60e604ef2023a06dada39b000d6f97 100644 (file)
@@ -1,6 +1,11 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
+       "fmt"
        "html"
        "io/ioutil"
        "net/http"
@@ -18,6 +23,66 @@ var _ = check.Suite(&UnitSuite{})
 
 type UnitSuite struct{}
 
+func (s *UnitSuite) TestCORSPreflight(c *check.C) {
+       h := handler{Config: DefaultConfig()}
+       u, _ := url.Parse("http://keep-web.example/c=" + arvadostest.FooCollection + "/foo")
+       req := &http.Request{
+               Method:     "OPTIONS",
+               Host:       u.Host,
+               URL:        u,
+               RequestURI: u.RequestURI(),
+               Header: http.Header{
+                       "Origin":                        {"https://workbench.example"},
+                       "Access-Control-Request-Method": {"POST"},
+               },
+       }
+
+       // Check preflight for an allowed request
+       resp := httptest.NewRecorder()
+       h.ServeHTTP(resp, req)
+       c.Check(resp.Code, check.Equals, http.StatusOK)
+       c.Check(resp.Body.String(), check.Equals, "")
+       c.Check(resp.Header().Get("Access-Control-Allow-Origin"), check.Equals, "*")
+       c.Check(resp.Header().Get("Access-Control-Allow-Methods"), check.Equals, "GET, POST")
+       c.Check(resp.Header().Get("Access-Control-Allow-Headers"), check.Equals, "Range")
+
+       // Check preflight for a disallowed request
+       resp = httptest.NewRecorder()
+       req.Header.Set("Access-Control-Request-Method", "DELETE")
+       h.ServeHTTP(resp, req)
+       c.Check(resp.Body.String(), check.Equals, "")
+       c.Check(resp.Code, check.Equals, http.StatusMethodNotAllowed)
+}
+
+func (s *UnitSuite) TestInvalidUUID(c *check.C) {
+       bogusID := strings.Replace(arvadostest.FooPdh, "+", "-", 1) + "-"
+       token := arvadostest.ActiveToken
+       for _, trial := range []string{
+               "http://keep-web/c=" + bogusID + "/foo",
+               "http://keep-web/c=" + bogusID + "/t=" + token + "/foo",
+               "http://keep-web/collections/download/" + bogusID + "/" + token + "/foo",
+               "http://keep-web/collections/" + bogusID + "/foo",
+               "http://" + bogusID + ".keep-web/" + bogusID + "/foo",
+               "http://" + bogusID + ".keep-web/t=" + token + "/" + bogusID + "/foo",
+       } {
+               c.Log(trial)
+               u, err := url.Parse(trial)
+               c.Assert(err, check.IsNil)
+               req := &http.Request{
+                       Method:     "GET",
+                       Host:       u.Host,
+                       URL:        u,
+                       RequestURI: u.RequestURI(),
+               }
+               resp := httptest.NewRecorder()
+               cfg := DefaultConfig()
+               cfg.AnonymousTokens = []string{arvadostest.AnonymousToken}
+               h := handler{Config: cfg}
+               h.ServeHTTP(resp, req)
+               c.Check(resp.Code, check.Equals, http.StatusNotFound)
+       }
+}
+
 func mustParseURL(s string) *url.URL {
        r, err := url.Parse(s)
        if err != nil {
@@ -32,11 +97,13 @@ func (s *IntegrationSuite) TestVhost404(c *check.C) {
                arvadostest.NonexistentCollection + ".example.com/t=" + arvadostest.ActiveToken + "/theperthcountyconspiracy",
        } {
                resp := httptest.NewRecorder()
+               u := mustParseURL(testURL)
                req := &http.Request{
-                       Method: "GET",
-                       URL:    mustParseURL(testURL),
+                       Method:     "GET",
+                       URL:        u,
+                       RequestURI: u.RequestURI(),
                }
-               (&handler{}).ServeHTTP(resp, req)
+               s.testServer.Handler.ServeHTTP(resp, req)
                c.Check(resp.Code, check.Equals, http.StatusNotFound)
                c.Check(resp.Body.String(), check.Equals, "")
        }
@@ -49,7 +116,7 @@ func (s *IntegrationSuite) TestVhost404(c *check.C) {
 type authorizer func(*http.Request, string) int
 
 func (s *IntegrationSuite) TestVhostViaAuthzHeader(c *check.C) {
-       doVhostRequests(c, authzViaAuthzHeader)
+       s.doVhostRequests(c, authzViaAuthzHeader)
 }
 func authzViaAuthzHeader(r *http.Request, tok string) int {
        r.Header.Add("Authorization", "OAuth2 "+tok)
@@ -57,7 +124,7 @@ func authzViaAuthzHeader(r *http.Request, tok string) int {
 }
 
 func (s *IntegrationSuite) TestVhostViaCookieValue(c *check.C) {
-       doVhostRequests(c, authzViaCookieValue)
+       s.doVhostRequests(c, authzViaCookieValue)
 }
 func authzViaCookieValue(r *http.Request, tok string) int {
        r.AddCookie(&http.Cookie{
@@ -68,7 +135,7 @@ func authzViaCookieValue(r *http.Request, tok string) int {
 }
 
 func (s *IntegrationSuite) TestVhostViaPath(c *check.C) {
-       doVhostRequests(c, authzViaPath)
+       s.doVhostRequests(c, authzViaPath)
 }
 func authzViaPath(r *http.Request, tok string) int {
        r.URL.Path = "/t=" + tok + r.URL.Path
@@ -76,7 +143,7 @@ func authzViaPath(r *http.Request, tok string) int {
 }
 
 func (s *IntegrationSuite) TestVhostViaQueryString(c *check.C) {
-       doVhostRequests(c, authzViaQueryString)
+       s.doVhostRequests(c, authzViaQueryString)
 }
 func authzViaQueryString(r *http.Request, tok string) int {
        r.URL.RawQuery = "api_token=" + tok
@@ -84,7 +151,7 @@ func authzViaQueryString(r *http.Request, tok string) int {
 }
 
 func (s *IntegrationSuite) TestVhostViaPOST(c *check.C) {
-       doVhostRequests(c, authzViaPOST)
+       s.doVhostRequests(c, authzViaPOST)
 }
 func authzViaPOST(r *http.Request, tok string) int {
        r.Method = "POST"
@@ -94,9 +161,24 @@ func authzViaPOST(r *http.Request, tok string) int {
        return http.StatusUnauthorized
 }
 
+func (s *IntegrationSuite) TestVhostViaXHRPOST(c *check.C) {
+       s.doVhostRequests(c, authzViaPOST)
+}
+func authzViaXHRPOST(r *http.Request, tok string) int {
+       r.Method = "POST"
+       r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+       r.Header.Add("Origin", "https://origin.example")
+       r.Body = ioutil.NopCloser(strings.NewReader(
+               url.Values{
+                       "api_token":   {tok},
+                       "disposition": {"attachment"},
+               }.Encode()))
+       return http.StatusUnauthorized
+}
+
 // Try some combinations of {url, token} using the given authorization
 // mechanism, and verify the result is correct.
-func doVhostRequests(c *check.C, authz authorizer) {
+func (s *IntegrationSuite) doVhostRequests(c *check.C, authz authorizer) {
        for _, hostPath := range []string{
                arvadostest.FooCollection + ".example.com/foo",
                arvadostest.FooCollection + "--collections.example.com/foo",
@@ -106,11 +188,11 @@ func doVhostRequests(c *check.C, authz authorizer) {
                arvadostest.FooBarDirCollection + ".example.com/dir1/foo",
        } {
                c.Log("doRequests: ", hostPath)
-               doVhostRequestsWithHostPath(c, authz, hostPath)
+               s.doVhostRequestsWithHostPath(c, authz, hostPath)
        }
 }
 
-func doVhostRequestsWithHostPath(c *check.C, authz authorizer, hostPath string) {
+func (s *IntegrationSuite) doVhostRequestsWithHostPath(c *check.C, authz authorizer, hostPath string) {
        for _, tok := range []string{
                arvadostest.ActiveToken,
                arvadostest.ActiveToken[:15],
@@ -120,17 +202,25 @@ func doVhostRequestsWithHostPath(c *check.C, authz authorizer, hostPath string)
        } {
                u := mustParseURL("http://" + hostPath)
                req := &http.Request{
-                       Method: "GET",
-                       Host:   u.Host,
-                       URL:    u,
-                       Header: http.Header{},
+                       Method:     "GET",
+                       Host:       u.Host,
+                       URL:        u,
+                       RequestURI: u.RequestURI(),
+                       Header:     http.Header{},
                }
                failCode := authz(req, tok)
-               resp := doReq(req)
+               req, resp := s.doReq(req)
                code, body := resp.Code, resp.Body.String()
+
+               // If the initial request had a (non-empty) token
+               // showing in the query string, we should have been
+               // redirected in order to hide it in a cookie.
+               c.Check(req.URL.String(), check.Not(check.Matches), `.*api_token=.+`)
+
                if tok == arvadostest.ActiveToken {
                        c.Check(code, check.Equals, http.StatusOK)
                        c.Check(body, check.Equals, "foo")
+
                } else {
                        c.Check(code >= 400, check.Equals, true)
                        c.Check(code < 500, check.Equals, true)
@@ -148,24 +238,25 @@ func doVhostRequestsWithHostPath(c *check.C, authz authorizer, hostPath string)
        }
 }
 
-func doReq(req *http.Request) *httptest.ResponseRecorder {
+func (s *IntegrationSuite) doReq(req *http.Request) (*http.Request, *httptest.ResponseRecorder) {
        resp := httptest.NewRecorder()
-       (&handler{}).ServeHTTP(resp, req)
+       s.testServer.Handler.ServeHTTP(resp, req)
        if resp.Code != http.StatusSeeOther {
-               return resp
+               return req, resp
        }
        cookies := (&http.Response{Header: resp.Header()}).Cookies()
        u, _ := req.URL.Parse(resp.Header().Get("Location"))
        req = &http.Request{
-               Method: "GET",
-               Host:   u.Host,
-               URL:    u,
-               Header: http.Header{},
+               Method:     "GET",
+               Host:       u.Host,
+               URL:        u,
+               RequestURI: u.RequestURI(),
+               Header:     http.Header{},
        }
        for _, c := range cookies {
                req.AddCookie(c)
        }
-       return doReq(req)
+       return s.doReq(req)
 }
 
 func (s *IntegrationSuite) TestVhostRedirectQueryTokenToCookie(c *check.C) {
@@ -228,11 +319,23 @@ func (s *IntegrationSuite) TestVhostRedirectQueryTokenSingleOriginError(c *check
        )
 }
 
+// If client requests an attachment by putting ?disposition=attachment
+// in the query string, and gets redirected, the redirect target
+// should respond with an attachment.
+func (s *IntegrationSuite) TestVhostRedirectQueryTokenRequestAttachment(c *check.C) {
+       resp := s.testVhostRedirectTokenToCookie(c, "GET",
+               arvadostest.FooCollection+".example.com/foo",
+               "?disposition=attachment&api_token="+arvadostest.ActiveToken,
+               "",
+               "",
+               http.StatusOK,
+               "foo",
+       )
+       c.Check(strings.Split(resp.Header().Get("Content-Disposition"), ";")[0], check.Equals, "attachment")
+}
+
 func (s *IntegrationSuite) TestVhostRedirectQueryTokenTrustAllContent(c *check.C) {
-       defer func(orig bool) {
-               trustAllContent = orig
-       }(trustAllContent)
-       trustAllContent = true
+       s.testServer.Config.TrustAllContent = true
        s.testVhostRedirectTokenToCookie(c, "GET",
                "example.com/c="+arvadostest.FooCollection+"/foo",
                "?api_token="+arvadostest.ActiveToken,
@@ -244,10 +347,7 @@ func (s *IntegrationSuite) TestVhostRedirectQueryTokenTrustAllContent(c *check.C
 }
 
 func (s *IntegrationSuite) TestVhostRedirectQueryTokenAttachmentOnlyHost(c *check.C) {
-       defer func(orig string) {
-               attachmentOnlyHost = orig
-       }(attachmentOnlyHost)
-       attachmentOnlyHost = "example.com:1234"
+       s.testServer.Config.AttachmentOnlyHost = "example.com:1234"
 
        s.testVhostRedirectTokenToCookie(c, "GET",
                "example.com/c="+arvadostest.FooCollection+"/foo",
@@ -292,7 +392,7 @@ func (s *IntegrationSuite) TestVhostRedirectPOSTFormTokenToCookie404(c *check.C)
 }
 
 func (s *IntegrationSuite) TestAnonymousTokenOK(c *check.C) {
-       anonymousTokens = []string{arvadostest.AnonymousToken}
+       s.testServer.Config.AnonymousTokens = []string{arvadostest.AnonymousToken}
        s.testVhostRedirectTokenToCookie(c, "GET",
                "example.com/c="+arvadostest.HelloWorldCollection+"/Hello%20world.txt",
                "",
@@ -304,7 +404,7 @@ func (s *IntegrationSuite) TestAnonymousTokenOK(c *check.C) {
 }
 
 func (s *IntegrationSuite) TestAnonymousTokenError(c *check.C) {
-       anonymousTokens = []string{"anonymousTokenConfiguredButInvalid"}
+       s.testServer.Config.AnonymousTokens = []string{"anonymousTokenConfiguredButInvalid"}
        s.testVhostRedirectTokenToCookie(c, "GET",
                "example.com/c="+arvadostest.HelloWorldCollection+"/Hello%20world.txt",
                "",
@@ -315,14 +415,43 @@ func (s *IntegrationSuite) TestAnonymousTokenError(c *check.C) {
        )
 }
 
+// XHRs can't follow redirect-with-cookie so they rely on method=POST
+// and disposition=attachment (telling us it's acceptable to respond
+// with content instead of a redirect) and an Origin header that gets
+// added automatically by the browser (telling us it's desirable to do
+// so).
+func (s *IntegrationSuite) TestXHRNoRedirect(c *check.C) {
+       u, _ := url.Parse("http://example.com/c=" + arvadostest.FooCollection + "/foo")
+       req := &http.Request{
+               Method:     "POST",
+               Host:       u.Host,
+               URL:        u,
+               RequestURI: u.RequestURI(),
+               Header: http.Header{
+                       "Origin":       {"https://origin.example"},
+                       "Content-Type": {"application/x-www-form-urlencoded"},
+               },
+               Body: ioutil.NopCloser(strings.NewReader(url.Values{
+                       "api_token":   {arvadostest.ActiveToken},
+                       "disposition": {"attachment"},
+               }.Encode())),
+       }
+       resp := httptest.NewRecorder()
+       s.testServer.Handler.ServeHTTP(resp, req)
+       c.Check(resp.Code, check.Equals, http.StatusOK)
+       c.Check(resp.Body.String(), check.Equals, "foo")
+       c.Check(resp.Header().Get("Access-Control-Allow-Origin"), check.Equals, "*")
+}
+
 func (s *IntegrationSuite) testVhostRedirectTokenToCookie(c *check.C, method, hostPath, queryString, contentType, reqBody string, expectStatus int, expectRespBody string) *httptest.ResponseRecorder {
        u, _ := url.Parse(`http://` + hostPath + queryString)
        req := &http.Request{
-               Method: method,
-               Host:   u.Host,
-               URL:    u,
-               Header: http.Header{"Content-Type": {contentType}},
-               Body:   ioutil.NopCloser(strings.NewReader(reqBody)),
+               Method:     method,
+               Host:       u.Host,
+               URL:        u,
+               RequestURI: u.RequestURI(),
+               Header:     http.Header{"Content-Type": {contentType}},
+               Body:       ioutil.NopCloser(strings.NewReader(reqBody)),
        }
 
        resp := httptest.NewRecorder()
@@ -331,26 +460,132 @@ func (s *IntegrationSuite) testVhostRedirectTokenToCookie(c *check.C, method, ho
                c.Check(resp.Body.String(), check.Equals, expectRespBody)
        }()
 
-       (&handler{}).ServeHTTP(resp, req)
+       s.testServer.Handler.ServeHTTP(resp, req)
        if resp.Code != http.StatusSeeOther {
                return resp
        }
-       c.Check(resp.Body.String(), check.Matches, `.*href="//`+regexp.QuoteMeta(html.EscapeString(hostPath))+`".*`)
+       c.Check(resp.Body.String(), check.Matches, `.*href="//`+regexp.QuoteMeta(html.EscapeString(hostPath))+`(\?[^"]*)?".*`)
        cookies := (&http.Response{Header: resp.Header()}).Cookies()
 
        u, _ = u.Parse(resp.Header().Get("Location"))
        req = &http.Request{
-               Method: "GET",
-               Host:   u.Host,
-               URL:    u,
-               Header: http.Header{},
+               Method:     "GET",
+               Host:       u.Host,
+               URL:        u,
+               RequestURI: u.RequestURI(),
+               Header:     http.Header{},
        }
        for _, c := range cookies {
                req.AddCookie(c)
        }
 
        resp = httptest.NewRecorder()
-       (&handler{}).ServeHTTP(resp, req)
+       s.testServer.Handler.ServeHTTP(resp, req)
        c.Check(resp.Header().Get("Location"), check.Equals, "")
        return resp
 }
+
+func (s *IntegrationSuite) TestDirectoryListing(c *check.C) {
+       s.testServer.Config.AttachmentOnlyHost = "download.example.com"
+       authHeader := http.Header{
+               "Authorization": {"OAuth2 " + arvadostest.ActiveToken},
+       }
+       for _, trial := range []struct {
+               uri     string
+               header  http.Header
+               expect  []string
+               cutDirs int
+       }{
+               {
+                       uri:     strings.Replace(arvadostest.FooAndBarFilesInDirPDH, "+", "-", -1) + ".example.com/",
+                       header:  authHeader,
+                       expect:  []string{"dir1/foo", "dir1/bar"},
+                       cutDirs: 0,
+               },
+               {
+                       uri:     strings.Replace(arvadostest.FooAndBarFilesInDirPDH, "+", "-", -1) + ".example.com/dir1/",
+                       header:  authHeader,
+                       expect:  []string{"foo", "bar"},
+                       cutDirs: 0,
+               },
+               {
+                       uri:     "download.example.com/collections/" + arvadostest.FooAndBarFilesInDirUUID + "/",
+                       header:  authHeader,
+                       expect:  []string{"dir1/foo", "dir1/bar"},
+                       cutDirs: 2,
+               },
+               {
+                       uri:     "collections.example.com/collections/download/" + arvadostest.FooAndBarFilesInDirUUID + "/" + arvadostest.ActiveToken + "/",
+                       header:  nil,
+                       expect:  []string{"dir1/foo", "dir1/bar"},
+                       cutDirs: 4,
+               },
+               {
+                       uri:     "collections.example.com/c=" + arvadostest.FooAndBarFilesInDirUUID + "/t=" + arvadostest.ActiveToken + "/",
+                       header:  nil,
+                       expect:  []string{"dir1/foo", "dir1/bar"},
+                       cutDirs: 2,
+               },
+               {
+                       uri:     "download.example.com/c=" + arvadostest.FooAndBarFilesInDirUUID + "/dir1/",
+                       header:  authHeader,
+                       expect:  []string{"foo", "bar"},
+                       cutDirs: 1,
+               },
+               {
+                       uri:     "download.example.com/c=" + arvadostest.FooAndBarFilesInDirUUID + "/_/dir1/",
+                       header:  authHeader,
+                       expect:  []string{"foo", "bar"},
+                       cutDirs: 2,
+               },
+               {
+                       uri:     arvadostest.FooAndBarFilesInDirUUID + ".example.com/dir1?api_token=" + arvadostest.ActiveToken,
+                       header:  authHeader,
+                       expect:  []string{"foo", "bar"},
+                       cutDirs: 0,
+               },
+               {
+                       uri:    "collections.example.com/c=" + arvadostest.FooAndBarFilesInDirUUID + "/theperthcountyconspiracydoesnotexist/",
+                       header: authHeader,
+                       expect: nil,
+               },
+       } {
+               c.Logf("%q => %q", trial.uri, trial.expect)
+               resp := httptest.NewRecorder()
+               u := mustParseURL("//" + trial.uri)
+               req := &http.Request{
+                       Method:     "GET",
+                       Host:       u.Host,
+                       URL:        u,
+                       RequestURI: u.RequestURI(),
+                       Header:     trial.header,
+               }
+               s.testServer.Handler.ServeHTTP(resp, req)
+               var cookies []*http.Cookie
+               for resp.Code == http.StatusSeeOther {
+                       u, _ := req.URL.Parse(resp.Header().Get("Location"))
+                       req = &http.Request{
+                               Method:     "GET",
+                               Host:       u.Host,
+                               URL:        u,
+                               RequestURI: u.RequestURI(),
+                               Header:     http.Header{},
+                       }
+                       cookies = append(cookies, (&http.Response{Header: resp.Header()}).Cookies()...)
+                       for _, c := range cookies {
+                               req.AddCookie(c)
+                       }
+                       resp = httptest.NewRecorder()
+                       s.testServer.Handler.ServeHTTP(resp, req)
+               }
+               if trial.expect == nil {
+                       c.Check(resp.Code, check.Equals, http.StatusNotFound)
+               } else {
+                       c.Check(resp.Code, check.Equals, http.StatusOK)
+                       for _, e := range trial.expect {
+                               c.Check(resp.Body.String(), check.Matches, `(?ms).*href="`+e+`".*`)
+                       }
+                       c.Check(resp.Body.String(), check.Matches, `(?ms).*--cut-dirs=`+fmt.Sprintf("%d", trial.cutDirs)+` .*`)
+               }
+       }
+}