17464: Add tests for paths by /users/ and by PDH
[arvados.git] / sdk / go / arvadostest / oidc_provider.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvadostest
6
7 import (
8         "crypto/rand"
9         "crypto/rsa"
10         "encoding/base64"
11         "encoding/json"
12         "net/http"
13         "net/http/httptest"
14         "net/url"
15         "strings"
16         "time"
17
18         "gopkg.in/check.v1"
19         "gopkg.in/square/go-jose.v2"
20         "gopkg.in/square/go-jose.v2/jwt"
21 )
22
23 type OIDCProvider struct {
24         // expected token request
25         ValidCode         string
26         ValidClientID     string
27         ValidClientSecret string
28         // desired response from token endpoint
29         AuthEmail          string
30         AuthEmailVerified  bool
31         AuthName           string
32         AccessTokenPayload map[string]interface{}
33
34         PeopleAPIResponse map[string]interface{}
35
36         key       *rsa.PrivateKey
37         Issuer    *httptest.Server
38         PeopleAPI *httptest.Server
39         c         *check.C
40 }
41
42 func NewOIDCProvider(c *check.C) *OIDCProvider {
43         p := &OIDCProvider{c: c}
44         var err error
45         p.key, err = rsa.GenerateKey(rand.Reader, 2048)
46         c.Assert(err, check.IsNil)
47         p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
48         p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
49         p.AccessTokenPayload = map[string]interface{}{"sub": "example"}
50         return p
51 }
52
53 func (p *OIDCProvider) ValidAccessToken() string {
54         buf, _ := json.Marshal(p.AccessTokenPayload)
55         return p.fakeToken(buf)
56 }
57
58 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
59         req.ParseForm()
60         p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
61         w.Header().Set("Content-Type", "application/json")
62         switch req.URL.Path {
63         case "/.well-known/openid-configuration":
64                 json.NewEncoder(w).Encode(map[string]interface{}{
65                         "issuer":                 p.Issuer.URL,
66                         "authorization_endpoint": p.Issuer.URL + "/auth",
67                         "token_endpoint":         p.Issuer.URL + "/token",
68                         "jwks_uri":               p.Issuer.URL + "/jwks",
69                         "userinfo_endpoint":      p.Issuer.URL + "/userinfo",
70                 })
71         case "/token":
72                 var clientID, clientSecret string
73                 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
74                 authsplit := strings.Split(string(auth), ":")
75                 if len(authsplit) == 2 {
76                         clientID, _ = url.QueryUnescape(authsplit[0])
77                         clientSecret, _ = url.QueryUnescape(authsplit[1])
78                 }
79                 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
80                         p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
81                         w.WriteHeader(http.StatusUnauthorized)
82                         return
83                 }
84
85                 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
86                         w.WriteHeader(http.StatusUnauthorized)
87                         return
88                 }
89                 idToken, _ := json.Marshal(map[string]interface{}{
90                         "iss":            p.Issuer.URL,
91                         "aud":            []string{clientID},
92                         "sub":            "fake-user-id",
93                         "exp":            time.Now().UTC().Add(time.Minute).Unix(),
94                         "iat":            time.Now().UTC().Unix(),
95                         "nonce":          "fake-nonce",
96                         "email":          p.AuthEmail,
97                         "email_verified": p.AuthEmailVerified,
98                         "name":           p.AuthName,
99                         "alt_verified":   true,                    // for custom claim tests
100                         "alt_email":      "alt_email@example.com", // for custom claim tests
101                         "alt_username":   "desired-username",      // for custom claim tests
102                 })
103                 json.NewEncoder(w).Encode(struct {
104                         AccessToken  string `json:"access_token"`
105                         TokenType    string `json:"token_type"`
106                         RefreshToken string `json:"refresh_token"`
107                         ExpiresIn    int32  `json:"expires_in"`
108                         IDToken      string `json:"id_token"`
109                 }{
110                         AccessToken:  p.ValidAccessToken(),
111                         TokenType:    "Bearer",
112                         RefreshToken: "test-refresh-token",
113                         ExpiresIn:    30,
114                         IDToken:      p.fakeToken(idToken),
115                 })
116         case "/jwks":
117                 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
118                         Keys: []jose.JSONWebKey{
119                                 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
120                         },
121                 })
122         case "/auth":
123                 w.WriteHeader(http.StatusInternalServerError)
124         case "/userinfo":
125                 authhdr := req.Header.Get("Authorization")
126                 if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
127                         p.c.Logf("OIDCProvider: bad auth %q", authhdr)
128                         w.WriteHeader(http.StatusUnauthorized)
129                         return
130                 }
131                 json.NewEncoder(w).Encode(map[string]interface{}{
132                         "sub":            "fake-user-id",
133                         "name":           p.AuthName,
134                         "given_name":     p.AuthName,
135                         "family_name":    "",
136                         "alt_username":   "desired-username",
137                         "email":          p.AuthEmail,
138                         "email_verified": p.AuthEmailVerified,
139                 })
140         default:
141                 w.WriteHeader(http.StatusNotFound)
142         }
143 }
144
145 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
146         req.ParseForm()
147         p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
148         w.Header().Set("Content-Type", "application/json")
149         switch req.URL.Path {
150         case "/v1/people/me":
151                 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
152                         w.WriteHeader(http.StatusBadRequest)
153                         break
154                 }
155                 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
156         default:
157                 w.WriteHeader(http.StatusNotFound)
158         }
159 }
160
161 func (p *OIDCProvider) fakeToken(payload []byte) string {
162         signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
163         if err != nil {
164                 p.c.Error(err)
165                 return ""
166         }
167         object, err := signer.Sign(payload)
168         if err != nil {
169                 p.c.Error(err)
170                 return ""
171         }
172         t, err := object.CompactSerialize()
173         if err != nil {
174                 p.c.Error(err)
175                 return ""
176         }
177         p.c.Logf("fakeToken(%q) == %q", payload, t)
178         return t
179 }