1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: Apache-2.0
20 "gopkg.in/go-jose/go-jose.v2"
21 "gopkg.in/go-jose/go-jose.v2/jwt"
24 type OIDCProvider struct {
25 // expected token request
28 ValidClientSecret string
29 // desired response from token endpoint
31 AuthEmailVerified bool
35 AccessTokenPayload map[string]interface{}
36 // end_session_endpoint metadata URL.
37 // If nil or empty, not included in discovery.
38 // If relative, built from Issuer.URL.
39 EndSessionEndpoint *url.URL
41 PeopleAPIResponse map[string]interface{}
43 // send incoming /userinfo requests to HoldUserInfo (if not
44 // nil), then receive from ReleaseUserInfo (if not nil),
45 // before responding (these are used to set up races)
46 HoldUserInfo chan *http.Request
47 ReleaseUserInfo chan struct{}
48 UserInfoErrorStatus int // if non-zero, return this http status (probably 5xx)
51 Issuer *httptest.Server
52 PeopleAPI *httptest.Server
56 func NewOIDCProvider(c *check.C) *OIDCProvider {
57 p := &OIDCProvider{c: c}
59 p.key, err = rsa.GenerateKey(rand.Reader, 2048)
60 c.Assert(err, check.IsNil)
61 p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
62 p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
63 p.AccessTokenPayload = map[string]interface{}{"sub": "example"}
67 func (p *OIDCProvider) ValidAccessToken() string {
68 buf, _ := json.Marshal(p.AccessTokenPayload)
69 return p.fakeToken(buf)
72 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
74 p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
75 w.Header().Set("Content-Type", "application/json")
77 case "/.well-known/openid-configuration":
78 configuration := map[string]interface{}{
79 "issuer": p.Issuer.URL,
80 "authorization_endpoint": p.Issuer.URL + "/auth",
81 "token_endpoint": p.Issuer.URL + "/token",
82 "jwks_uri": p.Issuer.URL + "/jwks",
83 "userinfo_endpoint": p.Issuer.URL + "/userinfo",
85 if p.EndSessionEndpoint == nil {
86 // Not included in configuration
87 } else if p.EndSessionEndpoint.Scheme != "" {
88 configuration["end_session_endpoint"] = p.EndSessionEndpoint.String()
90 u, err := url.Parse(p.Issuer.URL)
91 p.c.Check(err, check.IsNil,
92 check.Commentf("error parsing IssuerURL for EndSessionEndpoint"))
94 u.Path = u.Path + p.EndSessionEndpoint.Path
95 configuration["end_session_endpoint"] = u.String()
97 json.NewEncoder(w).Encode(configuration)
99 var clientID, clientSecret string
100 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
101 authsplit := strings.Split(string(auth), ":")
102 if len(authsplit) == 2 {
103 clientID, _ = url.QueryUnescape(authsplit[0])
104 clientSecret, _ = url.QueryUnescape(authsplit[1])
106 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
107 p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
108 w.WriteHeader(http.StatusUnauthorized)
112 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
113 w.WriteHeader(http.StatusUnauthorized)
116 idToken, _ := json.Marshal(map[string]interface{}{
118 "aud": []string{clientID},
119 "sub": "fake-user-id",
120 "exp": time.Now().UTC().Add(time.Minute).Unix(),
121 "iat": time.Now().UTC().Unix(),
122 "nonce": "fake-nonce",
123 "email": p.AuthEmail,
124 "email_verified": p.AuthEmailVerified,
126 "given_name": p.AuthGivenName,
127 "family_name": p.AuthFamilyName,
128 "alt_verified": true, // for custom claim tests
129 "alt_email": "alt_email@example.com", // for custom claim tests
130 "alt_username": "desired-username", // for custom claim tests
132 json.NewEncoder(w).Encode(struct {
133 AccessToken string `json:"access_token"`
134 TokenType string `json:"token_type"`
135 RefreshToken string `json:"refresh_token"`
136 ExpiresIn int32 `json:"expires_in"`
137 IDToken string `json:"id_token"`
139 AccessToken: p.ValidAccessToken(),
141 RefreshToken: "test-refresh-token",
143 IDToken: p.fakeToken(idToken),
146 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
147 Keys: []jose.JSONWebKey{
148 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
152 w.WriteHeader(http.StatusInternalServerError)
154 if p.HoldUserInfo != nil {
155 p.HoldUserInfo <- req
157 if p.ReleaseUserInfo != nil {
160 if p.UserInfoErrorStatus > 0 {
161 w.WriteHeader(p.UserInfoErrorStatus)
162 fmt.Fprintf(w, "%T error body", p)
165 authhdr := req.Header.Get("Authorization")
166 if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
167 p.c.Logf("OIDCProvider: bad auth %q", authhdr)
168 w.WriteHeader(http.StatusUnauthorized)
171 json.NewEncoder(w).Encode(map[string]interface{}{
172 "sub": "fake-user-id",
174 "given_name": p.AuthGivenName,
175 "family_name": p.AuthFamilyName,
176 "alt_username": "desired-username",
177 "email": p.AuthEmail,
178 "email_verified": p.AuthEmailVerified,
181 w.WriteHeader(http.StatusNotFound)
185 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
187 p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
188 w.Header().Set("Content-Type", "application/json")
189 switch req.URL.Path {
190 case "/v1/people/me":
191 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
192 w.WriteHeader(http.StatusBadRequest)
195 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
197 w.WriteHeader(http.StatusNotFound)
201 func (p *OIDCProvider) fakeToken(payload []byte) string {
202 signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
207 object, err := signer.Sign(payload)
212 t, err := object.CompactSerialize()
217 p.c.Logf("fakeToken(%q) == %q", payload, t)