14285: Add tests for LoadToken and RequireLiteralToken.
authorTom Clegg <tclegg@veritasgenetics.com>
Mon, 15 Oct 2018 17:50:10 +0000 (13:50 -0400)
committerTom Clegg <tclegg@veritasgenetics.com>
Mon, 15 Oct 2018 17:50:10 +0000 (13:50 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tclegg@veritasgenetics.com>

build/run-tests.sh
sdk/go/auth/handlers.go
sdk/go/auth/handlers_test.go [new file with mode: 0644]

index 4ddbf89c1d7ccb286fcfe887fb941734bffbb519..26a907fc2f99f36b7cc7d0aa7fa8d2bf097af5f0 100755 (executable)
@@ -101,6 +101,7 @@ sdk/python:py3
 sdk/ruby
 sdk/go/arvados
 sdk/go/arvadosclient
+sdk/go/auth
 sdk/go/dispatch
 sdk/go/keepclient
 sdk/go/health
@@ -925,6 +926,7 @@ gostuff=(
     lib/dispatchcloud
     sdk/go/arvados
     sdk/go/arvadosclient
+    sdk/go/auth
     sdk/go/blockdigest
     sdk/go/dispatch
     sdk/go/health
index 7b1760f4b8192dd28867e2f96a7dcccf4cffbc1a..ad1fa5141a1cf268729e33bdef3cacc3fb14d76c 100644 (file)
@@ -18,7 +18,10 @@ var contextKeyCredentials contextKey = "credentials"
 // CredentialsFromRequest.
 func LoadToken(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKeyCredentials, CredentialsFromRequest(r))))
+               if _, ok := r.Context().Value(contextKeyCredentials).(*Credentials); !ok {
+                       r = r.WithContext(context.WithValue(r.Context(), contextKeyCredentials, CredentialsFromRequest(r)))
+               }
+               next.ServeHTTP(w, r)
        })
 }
 
diff --git a/sdk/go/auth/handlers_test.go b/sdk/go/auth/handlers_test.go
new file mode 100644 (file)
index 0000000..362aeb7
--- /dev/null
@@ -0,0 +1,79 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package auth
+
+import (
+       "net/http"
+       "net/http/httptest"
+       "testing"
+
+       check "gopkg.in/check.v1"
+)
+
+// Gocheck boilerplate
+func Test(t *testing.T) {
+       check.TestingT(t)
+}
+
+var _ = check.Suite(&HandlersSuite{})
+
+type HandlersSuite struct {
+       served         int
+       gotCredentials *Credentials
+}
+
+func (s *HandlersSuite) SetUpTest(c *check.C) {
+       s.served = 0
+       s.gotCredentials = nil
+}
+
+func (s *HandlersSuite) TestLoadToken(c *check.C) {
+       handler := LoadToken(s)
+       handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/foo/bar?api_token=xyzzy", nil))
+       c.Assert(s.gotCredentials, check.NotNil)
+       c.Assert(s.gotCredentials.Tokens, check.HasLen, 1)
+       c.Check(s.gotCredentials.Tokens[0], check.Equals, "xyzzy")
+}
+
+func (s *HandlersSuite) TestRequireLiteralTokenEmpty(c *check.C) {
+       handler := RequireLiteralToken("", s)
+
+       w := httptest.NewRecorder()
+       handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar?api_token=abcdef", nil))
+       c.Check(s.served, check.Equals, 1)
+       c.Check(w.Code, check.Equals, http.StatusOK)
+
+       w = httptest.NewRecorder()
+       handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar", nil))
+       c.Check(s.served, check.Equals, 2)
+       c.Check(w.Code, check.Equals, http.StatusOK)
+}
+
+func (s *HandlersSuite) TestRequireLiteralToken(c *check.C) {
+       handler := RequireLiteralToken("xyzzy", s)
+
+       w := httptest.NewRecorder()
+       handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar?api_token=abcdef", nil))
+       c.Check(s.served, check.Equals, 0)
+       c.Check(w.Code, check.Equals, http.StatusForbidden)
+
+       w = httptest.NewRecorder()
+       handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar", nil))
+       c.Check(s.served, check.Equals, 0)
+       c.Check(w.Code, check.Equals, http.StatusUnauthorized)
+
+       w = httptest.NewRecorder()
+       handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar?api_token=xyzzy", nil))
+       c.Check(s.served, check.Equals, 1)
+       c.Check(w.Code, check.Equals, http.StatusOK)
+       c.Assert(s.gotCredentials, check.NotNil)
+       c.Assert(s.gotCredentials.Tokens, check.HasLen, 1)
+       c.Check(s.gotCredentials.Tokens[0], check.Equals, "xyzzy")
+}
+
+func (s *HandlersSuite) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+       s.served++
+       s.gotCredentials = CredentialsFromRequest(r)
+}