20259: Add documentation for banner and tooltip features
[arvados.git] / sdk / go / arvadostest / oidc_provider.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvadostest
6
7 import (
8         "crypto/rand"
9         "crypto/rsa"
10         "encoding/base64"
11         "encoding/json"
12         "fmt"
13         "net/http"
14         "net/http/httptest"
15         "net/url"
16         "strings"
17         "time"
18
19         "gopkg.in/check.v1"
20         "gopkg.in/square/go-jose.v2"
21         "gopkg.in/square/go-jose.v2/jwt"
22 )
23
24 type OIDCProvider struct {
25         // expected token request
26         ValidCode         string
27         ValidClientID     string
28         ValidClientSecret string
29         // desired response from token endpoint
30         AuthEmail          string
31         AuthEmailVerified  bool
32         AuthName           string
33         AuthGivenName      string
34         AuthFamilyName     string
35         AccessTokenPayload map[string]interface{}
36
37         PeopleAPIResponse map[string]interface{}
38
39         // send incoming /userinfo requests to HoldUserInfo (if not
40         // nil), then receive from ReleaseUserInfo (if not nil),
41         // before responding (these are used to set up races)
42         HoldUserInfo        chan *http.Request
43         ReleaseUserInfo     chan struct{}
44         UserInfoErrorStatus int // if non-zero, return this http status (probably 5xx)
45
46         key       *rsa.PrivateKey
47         Issuer    *httptest.Server
48         PeopleAPI *httptest.Server
49         c         *check.C
50 }
51
52 func NewOIDCProvider(c *check.C) *OIDCProvider {
53         p := &OIDCProvider{c: c}
54         var err error
55         p.key, err = rsa.GenerateKey(rand.Reader, 2048)
56         c.Assert(err, check.IsNil)
57         p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
58         p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
59         p.AccessTokenPayload = map[string]interface{}{"sub": "example"}
60         return p
61 }
62
63 func (p *OIDCProvider) ValidAccessToken() string {
64         buf, _ := json.Marshal(p.AccessTokenPayload)
65         return p.fakeToken(buf)
66 }
67
68 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
69         req.ParseForm()
70         p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
71         w.Header().Set("Content-Type", "application/json")
72         switch req.URL.Path {
73         case "/.well-known/openid-configuration":
74                 json.NewEncoder(w).Encode(map[string]interface{}{
75                         "issuer":                 p.Issuer.URL,
76                         "authorization_endpoint": p.Issuer.URL + "/auth",
77                         "token_endpoint":         p.Issuer.URL + "/token",
78                         "jwks_uri":               p.Issuer.URL + "/jwks",
79                         "userinfo_endpoint":      p.Issuer.URL + "/userinfo",
80                 })
81         case "/token":
82                 var clientID, clientSecret string
83                 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
84                 authsplit := strings.Split(string(auth), ":")
85                 if len(authsplit) == 2 {
86                         clientID, _ = url.QueryUnescape(authsplit[0])
87                         clientSecret, _ = url.QueryUnescape(authsplit[1])
88                 }
89                 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
90                         p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
91                         w.WriteHeader(http.StatusUnauthorized)
92                         return
93                 }
94
95                 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
96                         w.WriteHeader(http.StatusUnauthorized)
97                         return
98                 }
99                 idToken, _ := json.Marshal(map[string]interface{}{
100                         "iss":            p.Issuer.URL,
101                         "aud":            []string{clientID},
102                         "sub":            "fake-user-id",
103                         "exp":            time.Now().UTC().Add(time.Minute).Unix(),
104                         "iat":            time.Now().UTC().Unix(),
105                         "nonce":          "fake-nonce",
106                         "email":          p.AuthEmail,
107                         "email_verified": p.AuthEmailVerified,
108                         "name":           p.AuthName,
109                         "given_name":     p.AuthGivenName,
110                         "family_name":    p.AuthFamilyName,
111                         "alt_verified":   true,                    // for custom claim tests
112                         "alt_email":      "alt_email@example.com", // for custom claim tests
113                         "alt_username":   "desired-username",      // for custom claim tests
114                 })
115                 json.NewEncoder(w).Encode(struct {
116                         AccessToken  string `json:"access_token"`
117                         TokenType    string `json:"token_type"`
118                         RefreshToken string `json:"refresh_token"`
119                         ExpiresIn    int32  `json:"expires_in"`
120                         IDToken      string `json:"id_token"`
121                 }{
122                         AccessToken:  p.ValidAccessToken(),
123                         TokenType:    "Bearer",
124                         RefreshToken: "test-refresh-token",
125                         ExpiresIn:    30,
126                         IDToken:      p.fakeToken(idToken),
127                 })
128         case "/jwks":
129                 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
130                         Keys: []jose.JSONWebKey{
131                                 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
132                         },
133                 })
134         case "/auth":
135                 w.WriteHeader(http.StatusInternalServerError)
136         case "/userinfo":
137                 if p.HoldUserInfo != nil {
138                         p.HoldUserInfo <- req
139                 }
140                 if p.ReleaseUserInfo != nil {
141                         <-p.ReleaseUserInfo
142                 }
143                 if p.UserInfoErrorStatus > 0 {
144                         w.WriteHeader(p.UserInfoErrorStatus)
145                         fmt.Fprintf(w, "%T error body", p)
146                         return
147                 }
148                 authhdr := req.Header.Get("Authorization")
149                 if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
150                         p.c.Logf("OIDCProvider: bad auth %q", authhdr)
151                         w.WriteHeader(http.StatusUnauthorized)
152                         return
153                 }
154                 json.NewEncoder(w).Encode(map[string]interface{}{
155                         "sub":            "fake-user-id",
156                         "name":           p.AuthName,
157                         "given_name":     p.AuthGivenName,
158                         "family_name":    p.AuthFamilyName,
159                         "alt_username":   "desired-username",
160                         "email":          p.AuthEmail,
161                         "email_verified": p.AuthEmailVerified,
162                 })
163         default:
164                 w.WriteHeader(http.StatusNotFound)
165         }
166 }
167
168 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
169         req.ParseForm()
170         p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
171         w.Header().Set("Content-Type", "application/json")
172         switch req.URL.Path {
173         case "/v1/people/me":
174                 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
175                         w.WriteHeader(http.StatusBadRequest)
176                         break
177                 }
178                 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
179         default:
180                 w.WriteHeader(http.StatusNotFound)
181         }
182 }
183
184 func (p *OIDCProvider) fakeToken(payload []byte) string {
185         signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
186         if err != nil {
187                 p.c.Error(err)
188                 return ""
189         }
190         object, err := signer.Sign(payload)
191         if err != nil {
192                 p.c.Error(err)
193                 return ""
194         }
195         t, err := object.CompactSerialize()
196         if err != nil {
197                 p.c.Error(err)
198                 return ""
199         }
200         p.c.Logf("fakeToken(%q) == %q", payload, t)
201         return t
202 }