Merge branch '19466-cwl-io' refs #19466
[arvados.git] / sdk / go / arvadostest / oidc_provider.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvadostest
6
7 import (
8         "crypto/rand"
9         "crypto/rsa"
10         "encoding/base64"
11         "encoding/json"
12         "net/http"
13         "net/http/httptest"
14         "net/url"
15         "strings"
16         "time"
17
18         "gopkg.in/check.v1"
19         "gopkg.in/square/go-jose.v2"
20         "gopkg.in/square/go-jose.v2/jwt"
21 )
22
23 type OIDCProvider struct {
24         // expected token request
25         ValidCode         string
26         ValidClientID     string
27         ValidClientSecret string
28         // desired response from token endpoint
29         AuthEmail          string
30         AuthEmailVerified  bool
31         AuthName           string
32         AuthGivenName      string
33         AuthFamilyName     string
34         AccessTokenPayload map[string]interface{}
35
36         PeopleAPIResponse map[string]interface{}
37
38         // send incoming /userinfo requests to HoldUserInfo (if not
39         // nil), then receive from ReleaseUserInfo (if not nil),
40         // before responding (these are used to set up races)
41         HoldUserInfo    chan *http.Request
42         ReleaseUserInfo chan struct{}
43
44         key       *rsa.PrivateKey
45         Issuer    *httptest.Server
46         PeopleAPI *httptest.Server
47         c         *check.C
48 }
49
50 func NewOIDCProvider(c *check.C) *OIDCProvider {
51         p := &OIDCProvider{c: c}
52         var err error
53         p.key, err = rsa.GenerateKey(rand.Reader, 2048)
54         c.Assert(err, check.IsNil)
55         p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
56         p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
57         p.AccessTokenPayload = map[string]interface{}{"sub": "example"}
58         return p
59 }
60
61 func (p *OIDCProvider) ValidAccessToken() string {
62         buf, _ := json.Marshal(p.AccessTokenPayload)
63         return p.fakeToken(buf)
64 }
65
66 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
67         req.ParseForm()
68         p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
69         w.Header().Set("Content-Type", "application/json")
70         switch req.URL.Path {
71         case "/.well-known/openid-configuration":
72                 json.NewEncoder(w).Encode(map[string]interface{}{
73                         "issuer":                 p.Issuer.URL,
74                         "authorization_endpoint": p.Issuer.URL + "/auth",
75                         "token_endpoint":         p.Issuer.URL + "/token",
76                         "jwks_uri":               p.Issuer.URL + "/jwks",
77                         "userinfo_endpoint":      p.Issuer.URL + "/userinfo",
78                 })
79         case "/token":
80                 var clientID, clientSecret string
81                 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
82                 authsplit := strings.Split(string(auth), ":")
83                 if len(authsplit) == 2 {
84                         clientID, _ = url.QueryUnescape(authsplit[0])
85                         clientSecret, _ = url.QueryUnescape(authsplit[1])
86                 }
87                 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
88                         p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
89                         w.WriteHeader(http.StatusUnauthorized)
90                         return
91                 }
92
93                 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
94                         w.WriteHeader(http.StatusUnauthorized)
95                         return
96                 }
97                 idToken, _ := json.Marshal(map[string]interface{}{
98                         "iss":            p.Issuer.URL,
99                         "aud":            []string{clientID},
100                         "sub":            "fake-user-id",
101                         "exp":            time.Now().UTC().Add(time.Minute).Unix(),
102                         "iat":            time.Now().UTC().Unix(),
103                         "nonce":          "fake-nonce",
104                         "email":          p.AuthEmail,
105                         "email_verified": p.AuthEmailVerified,
106                         "name":           p.AuthName,
107                         "given_name":     p.AuthGivenName,
108                         "family_name":    p.AuthFamilyName,
109                         "alt_verified":   true,                    // for custom claim tests
110                         "alt_email":      "alt_email@example.com", // for custom claim tests
111                         "alt_username":   "desired-username",      // for custom claim tests
112                 })
113                 json.NewEncoder(w).Encode(struct {
114                         AccessToken  string `json:"access_token"`
115                         TokenType    string `json:"token_type"`
116                         RefreshToken string `json:"refresh_token"`
117                         ExpiresIn    int32  `json:"expires_in"`
118                         IDToken      string `json:"id_token"`
119                 }{
120                         AccessToken:  p.ValidAccessToken(),
121                         TokenType:    "Bearer",
122                         RefreshToken: "test-refresh-token",
123                         ExpiresIn:    30,
124                         IDToken:      p.fakeToken(idToken),
125                 })
126         case "/jwks":
127                 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
128                         Keys: []jose.JSONWebKey{
129                                 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
130                         },
131                 })
132         case "/auth":
133                 w.WriteHeader(http.StatusInternalServerError)
134         case "/userinfo":
135                 if p.HoldUserInfo != nil {
136                         p.HoldUserInfo <- req
137                 }
138                 if p.ReleaseUserInfo != nil {
139                         <-p.ReleaseUserInfo
140                 }
141                 authhdr := req.Header.Get("Authorization")
142                 if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
143                         p.c.Logf("OIDCProvider: bad auth %q", authhdr)
144                         w.WriteHeader(http.StatusUnauthorized)
145                         return
146                 }
147                 json.NewEncoder(w).Encode(map[string]interface{}{
148                         "sub":            "fake-user-id",
149                         "name":           p.AuthName,
150                         "given_name":     p.AuthGivenName,
151                         "family_name":    p.AuthFamilyName,
152                         "alt_username":   "desired-username",
153                         "email":          p.AuthEmail,
154                         "email_verified": p.AuthEmailVerified,
155                 })
156         default:
157                 w.WriteHeader(http.StatusNotFound)
158         }
159 }
160
161 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
162         req.ParseForm()
163         p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
164         w.Header().Set("Content-Type", "application/json")
165         switch req.URL.Path {
166         case "/v1/people/me":
167                 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
168                         w.WriteHeader(http.StatusBadRequest)
169                         break
170                 }
171                 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
172         default:
173                 w.WriteHeader(http.StatusNotFound)
174         }
175 }
176
177 func (p *OIDCProvider) fakeToken(payload []byte) string {
178         signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
179         if err != nil {
180                 p.c.Error(err)
181                 return ""
182         }
183         object, err := signer.Sign(payload)
184         if err != nil {
185                 p.c.Error(err)
186                 return ""
187         }
188         t, err := object.CompactSerialize()
189         if err != nil {
190                 p.c.Error(err)
191                 return ""
192         }
193         p.c.Logf("fakeToken(%q) == %q", payload, t)
194         return t
195 }