1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: Apache-2.0
19 "gopkg.in/square/go-jose.v2"
20 "gopkg.in/square/go-jose.v2/jwt"
23 type OIDCProvider struct {
24 // expected token request
27 ValidClientSecret string
28 // desired response from token endpoint
30 AuthEmailVerified bool
34 AccessTokenPayload map[string]interface{}
36 PeopleAPIResponse map[string]interface{}
38 // send incoming /userinfo requests to HoldUserInfo (if not
39 // nil), then receive from ReleaseUserInfo (if not nil),
40 // before responding (these are used to set up races)
41 HoldUserInfo chan *http.Request
42 ReleaseUserInfo chan struct{}
45 Issuer *httptest.Server
46 PeopleAPI *httptest.Server
50 func NewOIDCProvider(c *check.C) *OIDCProvider {
51 p := &OIDCProvider{c: c}
53 p.key, err = rsa.GenerateKey(rand.Reader, 2048)
54 c.Assert(err, check.IsNil)
55 p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
56 p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
57 p.AccessTokenPayload = map[string]interface{}{"sub": "example"}
61 func (p *OIDCProvider) ValidAccessToken() string {
62 buf, _ := json.Marshal(p.AccessTokenPayload)
63 return p.fakeToken(buf)
66 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
68 p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
69 w.Header().Set("Content-Type", "application/json")
71 case "/.well-known/openid-configuration":
72 json.NewEncoder(w).Encode(map[string]interface{}{
73 "issuer": p.Issuer.URL,
74 "authorization_endpoint": p.Issuer.URL + "/auth",
75 "token_endpoint": p.Issuer.URL + "/token",
76 "jwks_uri": p.Issuer.URL + "/jwks",
77 "userinfo_endpoint": p.Issuer.URL + "/userinfo",
80 var clientID, clientSecret string
81 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
82 authsplit := strings.Split(string(auth), ":")
83 if len(authsplit) == 2 {
84 clientID, _ = url.QueryUnescape(authsplit[0])
85 clientSecret, _ = url.QueryUnescape(authsplit[1])
87 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
88 p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
89 w.WriteHeader(http.StatusUnauthorized)
93 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
94 w.WriteHeader(http.StatusUnauthorized)
97 idToken, _ := json.Marshal(map[string]interface{}{
99 "aud": []string{clientID},
100 "sub": "fake-user-id",
101 "exp": time.Now().UTC().Add(time.Minute).Unix(),
102 "iat": time.Now().UTC().Unix(),
103 "nonce": "fake-nonce",
104 "email": p.AuthEmail,
105 "email_verified": p.AuthEmailVerified,
107 "given_name": p.AuthGivenName,
108 "family_name": p.AuthFamilyName,
109 "alt_verified": true, // for custom claim tests
110 "alt_email": "alt_email@example.com", // for custom claim tests
111 "alt_username": "desired-username", // for custom claim tests
113 json.NewEncoder(w).Encode(struct {
114 AccessToken string `json:"access_token"`
115 TokenType string `json:"token_type"`
116 RefreshToken string `json:"refresh_token"`
117 ExpiresIn int32 `json:"expires_in"`
118 IDToken string `json:"id_token"`
120 AccessToken: p.ValidAccessToken(),
122 RefreshToken: "test-refresh-token",
124 IDToken: p.fakeToken(idToken),
127 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
128 Keys: []jose.JSONWebKey{
129 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
133 w.WriteHeader(http.StatusInternalServerError)
135 if p.HoldUserInfo != nil {
136 p.HoldUserInfo <- req
138 if p.ReleaseUserInfo != nil {
141 authhdr := req.Header.Get("Authorization")
142 if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
143 p.c.Logf("OIDCProvider: bad auth %q", authhdr)
144 w.WriteHeader(http.StatusUnauthorized)
147 json.NewEncoder(w).Encode(map[string]interface{}{
148 "sub": "fake-user-id",
150 "given_name": p.AuthGivenName,
151 "family_name": p.AuthFamilyName,
152 "alt_username": "desired-username",
153 "email": p.AuthEmail,
154 "email_verified": p.AuthEmailVerified,
157 w.WriteHeader(http.StatusNotFound)
161 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
163 p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
164 w.Header().Set("Content-Type", "application/json")
165 switch req.URL.Path {
166 case "/v1/people/me":
167 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
168 w.WriteHeader(http.StatusBadRequest)
171 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
173 w.WriteHeader(http.StatusNotFound)
177 func (p *OIDCProvider) fakeToken(payload []byte) string {
178 signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
183 object, err := signer.Sign(payload)
188 t, err := object.CompactSerialize()
193 p.c.Logf("fakeToken(%q) == %q", payload, t)