Merge branch '20640-computed-permissions-api'
[arvados.git] / sdk / go / arvadostest / oidc_provider.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvadostest
6
7 import (
8         "crypto/rand"
9         "crypto/rsa"
10         "encoding/base64"
11         "encoding/json"
12         "fmt"
13         "net/http"
14         "net/http/httptest"
15         "net/url"
16         "strings"
17         "time"
18
19         "gopkg.in/check.v1"
20         "gopkg.in/go-jose/go-jose.v2"
21         "gopkg.in/go-jose/go-jose.v2/jwt"
22 )
23
24 type OIDCProvider struct {
25         // expected token request
26         ValidCode         string
27         ValidClientID     string
28         ValidClientSecret string
29         // desired response from token endpoint
30         AuthEmail          string
31         AuthEmailVerified  bool
32         AuthName           string
33         AuthGivenName      string
34         AuthFamilyName     string
35         AccessTokenPayload map[string]interface{}
36         // end_session_endpoint metadata URL.
37         // If nil or empty, not included in discovery.
38         // If relative, built from Issuer.URL.
39         EndSessionEndpoint *url.URL
40
41         PeopleAPIResponse map[string]interface{}
42
43         // send incoming /userinfo requests to HoldUserInfo (if not
44         // nil), then receive from ReleaseUserInfo (if not nil),
45         // before responding (these are used to set up races)
46         HoldUserInfo        chan *http.Request
47         ReleaseUserInfo     chan struct{}
48         UserInfoErrorStatus int // if non-zero, return this http status (probably 5xx)
49
50         key       *rsa.PrivateKey
51         Issuer    *httptest.Server
52         PeopleAPI *httptest.Server
53         c         *check.C
54 }
55
56 func NewOIDCProvider(c *check.C) *OIDCProvider {
57         p := &OIDCProvider{c: c}
58         var err error
59         p.key, err = rsa.GenerateKey(rand.Reader, 2048)
60         c.Assert(err, check.IsNil)
61         p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
62         p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
63         p.AccessTokenPayload = map[string]interface{}{"sub": "example"}
64         return p
65 }
66
67 func (p *OIDCProvider) ValidAccessToken() string {
68         buf, _ := json.Marshal(p.AccessTokenPayload)
69         return p.fakeToken(buf)
70 }
71
72 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
73         req.ParseForm()
74         p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
75         w.Header().Set("Content-Type", "application/json")
76         switch req.URL.Path {
77         case "/.well-known/openid-configuration":
78                 configuration := map[string]interface{}{
79                         "issuer":                 p.Issuer.URL,
80                         "authorization_endpoint": p.Issuer.URL + "/auth",
81                         "token_endpoint":         p.Issuer.URL + "/token",
82                         "jwks_uri":               p.Issuer.URL + "/jwks",
83                         "userinfo_endpoint":      p.Issuer.URL + "/userinfo",
84                 }
85                 if p.EndSessionEndpoint == nil {
86                         // Not included in configuration
87                 } else if p.EndSessionEndpoint.Scheme != "" {
88                         configuration["end_session_endpoint"] = p.EndSessionEndpoint.String()
89                 } else {
90                         u, err := url.Parse(p.Issuer.URL)
91                         p.c.Check(err, check.IsNil,
92                                 check.Commentf("error parsing IssuerURL for EndSessionEndpoint"))
93                         u.Scheme = "https"
94                         u.Path = u.Path + p.EndSessionEndpoint.Path
95                         configuration["end_session_endpoint"] = u.String()
96                 }
97                 json.NewEncoder(w).Encode(configuration)
98         case "/token":
99                 var clientID, clientSecret string
100                 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
101                 authsplit := strings.Split(string(auth), ":")
102                 if len(authsplit) == 2 {
103                         clientID, _ = url.QueryUnescape(authsplit[0])
104                         clientSecret, _ = url.QueryUnescape(authsplit[1])
105                 }
106                 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
107                         p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
108                         w.WriteHeader(http.StatusUnauthorized)
109                         return
110                 }
111
112                 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
113                         w.WriteHeader(http.StatusUnauthorized)
114                         return
115                 }
116                 idToken, _ := json.Marshal(map[string]interface{}{
117                         "iss":            p.Issuer.URL,
118                         "aud":            []string{clientID},
119                         "sub":            "fake-user-id",
120                         "exp":            time.Now().UTC().Add(time.Minute).Unix(),
121                         "iat":            time.Now().UTC().Unix(),
122                         "nonce":          "fake-nonce",
123                         "email":          p.AuthEmail,
124                         "email_verified": p.AuthEmailVerified,
125                         "name":           p.AuthName,
126                         "given_name":     p.AuthGivenName,
127                         "family_name":    p.AuthFamilyName,
128                         "alt_verified":   true,                    // for custom claim tests
129                         "alt_email":      "alt_email@example.com", // for custom claim tests
130                         "alt_username":   "desired-username",      // for custom claim tests
131                 })
132                 json.NewEncoder(w).Encode(struct {
133                         AccessToken  string `json:"access_token"`
134                         TokenType    string `json:"token_type"`
135                         RefreshToken string `json:"refresh_token"`
136                         ExpiresIn    int32  `json:"expires_in"`
137                         IDToken      string `json:"id_token"`
138                 }{
139                         AccessToken:  p.ValidAccessToken(),
140                         TokenType:    "Bearer",
141                         RefreshToken: "test-refresh-token",
142                         ExpiresIn:    30,
143                         IDToken:      p.fakeToken(idToken),
144                 })
145         case "/jwks":
146                 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
147                         Keys: []jose.JSONWebKey{
148                                 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
149                         },
150                 })
151         case "/auth":
152                 w.WriteHeader(http.StatusInternalServerError)
153         case "/userinfo":
154                 if p.HoldUserInfo != nil {
155                         p.HoldUserInfo <- req
156                 }
157                 if p.ReleaseUserInfo != nil {
158                         <-p.ReleaseUserInfo
159                 }
160                 if p.UserInfoErrorStatus > 0 {
161                         w.WriteHeader(p.UserInfoErrorStatus)
162                         fmt.Fprintf(w, "%T error body", p)
163                         return
164                 }
165                 authhdr := req.Header.Get("Authorization")
166                 if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
167                         p.c.Logf("OIDCProvider: bad auth %q", authhdr)
168                         w.WriteHeader(http.StatusUnauthorized)
169                         return
170                 }
171                 json.NewEncoder(w).Encode(map[string]interface{}{
172                         "sub":            "fake-user-id",
173                         "name":           p.AuthName,
174                         "given_name":     p.AuthGivenName,
175                         "family_name":    p.AuthFamilyName,
176                         "alt_username":   "desired-username",
177                         "email":          p.AuthEmail,
178                         "email_verified": p.AuthEmailVerified,
179                 })
180         default:
181                 w.WriteHeader(http.StatusNotFound)
182         }
183 }
184
185 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
186         req.ParseForm()
187         p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
188         w.Header().Set("Content-Type", "application/json")
189         switch req.URL.Path {
190         case "/v1/people/me":
191                 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
192                         w.WriteHeader(http.StatusBadRequest)
193                         break
194                 }
195                 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
196         default:
197                 w.WriteHeader(http.StatusNotFound)
198         }
199 }
200
201 func (p *OIDCProvider) fakeToken(payload []byte) string {
202         signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
203         if err != nil {
204                 p.c.Error(err)
205                 return ""
206         }
207         object, err := signer.Sign(payload)
208         if err != nil {
209                 p.c.Error(err)
210                 return ""
211         }
212         t, err := object.CompactSerialize()
213         if err != nil {
214                 p.c.Error(err)
215                 return ""
216         }
217         p.c.Logf("fakeToken(%q) == %q", payload, t)
218         return t
219 }