Revert "Add license file to all the git.arvados.org/arvados.git/sdk/go
[arvados.git] / sdk / go / auth / handlers_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package auth
6
7 import (
8         "net/http"
9         "net/http/httptest"
10         "testing"
11
12         check "gopkg.in/check.v1"
13 )
14
15 // Gocheck boilerplate
16 func Test(t *testing.T) {
17         check.TestingT(t)
18 }
19
20 var _ = check.Suite(&HandlersSuite{})
21
22 type HandlersSuite struct {
23         served         int
24         gotCredentials *Credentials
25 }
26
27 func (s *HandlersSuite) SetUpTest(c *check.C) {
28         s.served = 0
29         s.gotCredentials = nil
30 }
31
32 func (s *HandlersSuite) TestLoadToken(c *check.C) {
33         handler := LoadToken(s)
34         handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/foo/bar?api_token=xyzzy", nil))
35         c.Assert(s.gotCredentials, check.NotNil)
36         c.Assert(s.gotCredentials.Tokens, check.HasLen, 1)
37         c.Check(s.gotCredentials.Tokens[0], check.Equals, "xyzzy")
38 }
39
40 func (s *HandlersSuite) TestRequireLiteralTokenEmpty(c *check.C) {
41         handler := RequireLiteralToken("", s)
42
43         w := httptest.NewRecorder()
44         handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar?api_token=abcdef", nil))
45         c.Check(s.served, check.Equals, 1)
46         c.Check(w.Code, check.Equals, http.StatusOK)
47
48         w = httptest.NewRecorder()
49         handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar", nil))
50         c.Check(s.served, check.Equals, 2)
51         c.Check(w.Code, check.Equals, http.StatusOK)
52 }
53
54 func (s *HandlersSuite) TestRequireLiteralToken(c *check.C) {
55         handler := RequireLiteralToken("xyzzy", s)
56
57         w := httptest.NewRecorder()
58         handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar?api_token=abcdef", nil))
59         c.Check(s.served, check.Equals, 0)
60         c.Check(w.Code, check.Equals, http.StatusForbidden)
61
62         w = httptest.NewRecorder()
63         handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar", nil))
64         c.Check(s.served, check.Equals, 0)
65         c.Check(w.Code, check.Equals, http.StatusUnauthorized)
66
67         w = httptest.NewRecorder()
68         handler.ServeHTTP(w, httptest.NewRequest("GET", "/foo/bar?api_token=xyzzy", nil))
69         c.Check(s.served, check.Equals, 1)
70         c.Check(w.Code, check.Equals, http.StatusOK)
71         c.Assert(s.gotCredentials, check.NotNil)
72         c.Assert(s.gotCredentials.Tokens, check.HasLen, 1)
73         c.Check(s.gotCredentials.Tokens[0], check.Equals, "xyzzy")
74 }
75
76 func (s *HandlersSuite) ServeHTTP(w http.ResponseWriter, r *http.Request) {
77         s.served++
78         s.gotCredentials = CredentialsFromRequest(r)
79 }