17417: Merge branch 'main' into 17417-add-arm64
[arvados.git] / sdk / go / arvadostest / oidc_provider.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvadostest
6
7 import (
8         "crypto/rand"
9         "crypto/rsa"
10         "encoding/base64"
11         "encoding/json"
12         "net/http"
13         "net/http/httptest"
14         "net/url"
15         "strings"
16         "time"
17
18         "gopkg.in/check.v1"
19         "gopkg.in/square/go-jose.v2"
20         "gopkg.in/square/go-jose.v2/jwt"
21 )
22
23 type OIDCProvider struct {
24         // expected token request
25         ValidCode         string
26         ValidClientID     string
27         ValidClientSecret string
28         // desired response from token endpoint
29         AuthEmail          string
30         AuthEmailVerified  bool
31         AuthName           string
32         AuthGivenName      string
33         AuthFamilyName     string
34         AccessTokenPayload map[string]interface{}
35
36         PeopleAPIResponse map[string]interface{}
37
38         key       *rsa.PrivateKey
39         Issuer    *httptest.Server
40         PeopleAPI *httptest.Server
41         c         *check.C
42 }
43
44 func NewOIDCProvider(c *check.C) *OIDCProvider {
45         p := &OIDCProvider{c: c}
46         var err error
47         p.key, err = rsa.GenerateKey(rand.Reader, 2048)
48         c.Assert(err, check.IsNil)
49         p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
50         p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
51         p.AccessTokenPayload = map[string]interface{}{"sub": "example"}
52         return p
53 }
54
55 func (p *OIDCProvider) ValidAccessToken() string {
56         buf, _ := json.Marshal(p.AccessTokenPayload)
57         return p.fakeToken(buf)
58 }
59
60 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
61         req.ParseForm()
62         p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
63         w.Header().Set("Content-Type", "application/json")
64         switch req.URL.Path {
65         case "/.well-known/openid-configuration":
66                 json.NewEncoder(w).Encode(map[string]interface{}{
67                         "issuer":                 p.Issuer.URL,
68                         "authorization_endpoint": p.Issuer.URL + "/auth",
69                         "token_endpoint":         p.Issuer.URL + "/token",
70                         "jwks_uri":               p.Issuer.URL + "/jwks",
71                         "userinfo_endpoint":      p.Issuer.URL + "/userinfo",
72                 })
73         case "/token":
74                 var clientID, clientSecret string
75                 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
76                 authsplit := strings.Split(string(auth), ":")
77                 if len(authsplit) == 2 {
78                         clientID, _ = url.QueryUnescape(authsplit[0])
79                         clientSecret, _ = url.QueryUnescape(authsplit[1])
80                 }
81                 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
82                         p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
83                         w.WriteHeader(http.StatusUnauthorized)
84                         return
85                 }
86
87                 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
88                         w.WriteHeader(http.StatusUnauthorized)
89                         return
90                 }
91                 idToken, _ := json.Marshal(map[string]interface{}{
92                         "iss":            p.Issuer.URL,
93                         "aud":            []string{clientID},
94                         "sub":            "fake-user-id",
95                         "exp":            time.Now().UTC().Add(time.Minute).Unix(),
96                         "iat":            time.Now().UTC().Unix(),
97                         "nonce":          "fake-nonce",
98                         "email":          p.AuthEmail,
99                         "email_verified": p.AuthEmailVerified,
100                         "name":           p.AuthName,
101                         "given_name":     p.AuthGivenName,
102                         "family_name":    p.AuthFamilyName,
103                         "alt_verified":   true,                    // for custom claim tests
104                         "alt_email":      "alt_email@example.com", // for custom claim tests
105                         "alt_username":   "desired-username",      // for custom claim tests
106                 })
107                 json.NewEncoder(w).Encode(struct {
108                         AccessToken  string `json:"access_token"`
109                         TokenType    string `json:"token_type"`
110                         RefreshToken string `json:"refresh_token"`
111                         ExpiresIn    int32  `json:"expires_in"`
112                         IDToken      string `json:"id_token"`
113                 }{
114                         AccessToken:  p.ValidAccessToken(),
115                         TokenType:    "Bearer",
116                         RefreshToken: "test-refresh-token",
117                         ExpiresIn:    30,
118                         IDToken:      p.fakeToken(idToken),
119                 })
120         case "/jwks":
121                 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
122                         Keys: []jose.JSONWebKey{
123                                 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
124                         },
125                 })
126         case "/auth":
127                 w.WriteHeader(http.StatusInternalServerError)
128         case "/userinfo":
129                 authhdr := req.Header.Get("Authorization")
130                 if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
131                         p.c.Logf("OIDCProvider: bad auth %q", authhdr)
132                         w.WriteHeader(http.StatusUnauthorized)
133                         return
134                 }
135                 json.NewEncoder(w).Encode(map[string]interface{}{
136                         "sub":            "fake-user-id",
137                         "name":           p.AuthName,
138                         "given_name":     p.AuthGivenName,
139                         "family_name":    p.AuthFamilyName,
140                         "alt_username":   "desired-username",
141                         "email":          p.AuthEmail,
142                         "email_verified": p.AuthEmailVerified,
143                 })
144         default:
145                 w.WriteHeader(http.StatusNotFound)
146         }
147 }
148
149 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
150         req.ParseForm()
151         p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
152         w.Header().Set("Content-Type", "application/json")
153         switch req.URL.Path {
154         case "/v1/people/me":
155                 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
156                         w.WriteHeader(http.StatusBadRequest)
157                         break
158                 }
159                 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
160         default:
161                 w.WriteHeader(http.StatusNotFound)
162         }
163 }
164
165 func (p *OIDCProvider) fakeToken(payload []byte) string {
166         signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
167         if err != nil {
168                 p.c.Error(err)
169                 return ""
170         }
171         object, err := signer.Sign(payload)
172         if err != nil {
173                 p.c.Error(err)
174                 return ""
175         }
176         t, err := object.CompactSerialize()
177         if err != nil {
178                 p.c.Error(err)
179                 return ""
180         }
181         p.c.Logf("fakeToken(%q) == %q", payload, t)
182         return t
183 }