1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: Apache-2.0
19 "gopkg.in/square/go-jose.v2"
22 type OIDCProvider struct {
23 // expected token request
26 ValidClientSecret string
27 // desired response from token endpoint
29 AuthEmailVerified bool
32 PeopleAPIResponse map[string]interface{}
35 Issuer *httptest.Server
36 PeopleAPI *httptest.Server
40 func NewOIDCProvider(c *check.C) *OIDCProvider {
41 p := &OIDCProvider{c: c}
43 p.key, err = rsa.GenerateKey(rand.Reader, 2048)
44 c.Assert(err, check.IsNil)
45 p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
46 p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
50 func (p *OIDCProvider) ValidAccessToken() string {
51 return p.fakeToken([]byte("fake access token"))
54 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
56 p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
57 w.Header().Set("Content-Type", "application/json")
59 case "/.well-known/openid-configuration":
60 json.NewEncoder(w).Encode(map[string]interface{}{
61 "issuer": p.Issuer.URL,
62 "authorization_endpoint": p.Issuer.URL + "/auth",
63 "token_endpoint": p.Issuer.URL + "/token",
64 "jwks_uri": p.Issuer.URL + "/jwks",
65 "userinfo_endpoint": p.Issuer.URL + "/userinfo",
68 var clientID, clientSecret string
69 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
70 authsplit := strings.Split(string(auth), ":")
71 if len(authsplit) == 2 {
72 clientID, _ = url.QueryUnescape(authsplit[0])
73 clientSecret, _ = url.QueryUnescape(authsplit[1])
75 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
76 p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
77 w.WriteHeader(http.StatusUnauthorized)
81 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
82 w.WriteHeader(http.StatusUnauthorized)
85 idToken, _ := json.Marshal(map[string]interface{}{
87 "aud": []string{clientID},
88 "sub": "fake-user-id",
89 "exp": time.Now().UTC().Add(time.Minute).Unix(),
90 "iat": time.Now().UTC().Unix(),
91 "nonce": "fake-nonce",
93 "email_verified": p.AuthEmailVerified,
95 "alt_verified": true, // for custom claim tests
96 "alt_email": "alt_email@example.com", // for custom claim tests
97 "alt_username": "desired-username", // for custom claim tests
99 json.NewEncoder(w).Encode(struct {
100 AccessToken string `json:"access_token"`
101 TokenType string `json:"token_type"`
102 RefreshToken string `json:"refresh_token"`
103 ExpiresIn int32 `json:"expires_in"`
104 IDToken string `json:"id_token"`
106 AccessToken: p.ValidAccessToken(),
108 RefreshToken: "test-refresh-token",
110 IDToken: p.fakeToken(idToken),
113 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
114 Keys: []jose.JSONWebKey{
115 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
119 w.WriteHeader(http.StatusInternalServerError)
121 if authhdr := req.Header.Get("Authorization"); strings.TrimPrefix(authhdr, "Bearer ") != p.ValidAccessToken() {
122 p.c.Logf("OIDCProvider: bad auth %q", authhdr)
123 w.WriteHeader(http.StatusUnauthorized)
126 json.NewEncoder(w).Encode(map[string]interface{}{
127 "sub": "fake-user-id",
129 "given_name": p.AuthName,
131 "alt_username": "desired-username",
132 "email": p.AuthEmail,
133 "email_verified": p.AuthEmailVerified,
136 w.WriteHeader(http.StatusNotFound)
140 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
142 p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
143 w.Header().Set("Content-Type", "application/json")
144 switch req.URL.Path {
145 case "/v1/people/me":
146 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
147 w.WriteHeader(http.StatusBadRequest)
150 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
152 w.WriteHeader(http.StatusNotFound)
156 func (p *OIDCProvider) fakeToken(payload []byte) string {
157 signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
162 object, err := signer.Sign(payload)
167 t, err := object.CompactSerialize()
172 p.c.Logf("fakeToken(%q) == %q", payload, t)