Merge branch '17106-s3-fed-token'
[arvados.git] / sdk / go / arvadostest / oidc_provider.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvadostest
6
7 import (
8         "crypto/rand"
9         "crypto/rsa"
10         "encoding/base64"
11         "encoding/json"
12         "net/http"
13         "net/http/httptest"
14         "net/url"
15         "strings"
16         "time"
17
18         "gopkg.in/check.v1"
19         "gopkg.in/square/go-jose.v2"
20 )
21
22 type OIDCProvider struct {
23         // expected token request
24         ValidCode         string
25         ValidClientID     string
26         ValidClientSecret string
27         // desired response from token endpoint
28         AuthEmail         string
29         AuthEmailVerified bool
30         AuthName          string
31
32         PeopleAPIResponse map[string]interface{}
33
34         key       *rsa.PrivateKey
35         Issuer    *httptest.Server
36         PeopleAPI *httptest.Server
37         c         *check.C
38 }
39
40 func NewOIDCProvider(c *check.C) *OIDCProvider {
41         p := &OIDCProvider{c: c}
42         var err error
43         p.key, err = rsa.GenerateKey(rand.Reader, 2048)
44         c.Assert(err, check.IsNil)
45         p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
46         p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
47         return p
48 }
49
50 func (p *OIDCProvider) ValidAccessToken() string {
51         return p.fakeToken([]byte("fake access token"))
52 }
53
54 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
55         req.ParseForm()
56         p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
57         w.Header().Set("Content-Type", "application/json")
58         switch req.URL.Path {
59         case "/.well-known/openid-configuration":
60                 json.NewEncoder(w).Encode(map[string]interface{}{
61                         "issuer":                 p.Issuer.URL,
62                         "authorization_endpoint": p.Issuer.URL + "/auth",
63                         "token_endpoint":         p.Issuer.URL + "/token",
64                         "jwks_uri":               p.Issuer.URL + "/jwks",
65                         "userinfo_endpoint":      p.Issuer.URL + "/userinfo",
66                 })
67         case "/token":
68                 var clientID, clientSecret string
69                 auth, _ := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic "))
70                 authsplit := strings.Split(string(auth), ":")
71                 if len(authsplit) == 2 {
72                         clientID, _ = url.QueryUnescape(authsplit[0])
73                         clientSecret, _ = url.QueryUnescape(authsplit[1])
74                 }
75                 if clientID != p.ValidClientID || clientSecret != p.ValidClientSecret {
76                         p.c.Logf("OIDCProvider: expected (%q, %q) got (%q, %q)", p.ValidClientID, p.ValidClientSecret, clientID, clientSecret)
77                         w.WriteHeader(http.StatusUnauthorized)
78                         return
79                 }
80
81                 if req.Form.Get("code") != p.ValidCode || p.ValidCode == "" {
82                         w.WriteHeader(http.StatusUnauthorized)
83                         return
84                 }
85                 idToken, _ := json.Marshal(map[string]interface{}{
86                         "iss":            p.Issuer.URL,
87                         "aud":            []string{clientID},
88                         "sub":            "fake-user-id",
89                         "exp":            time.Now().UTC().Add(time.Minute).Unix(),
90                         "iat":            time.Now().UTC().Unix(),
91                         "nonce":          "fake-nonce",
92                         "email":          p.AuthEmail,
93                         "email_verified": p.AuthEmailVerified,
94                         "name":           p.AuthName,
95                         "alt_verified":   true,                    // for custom claim tests
96                         "alt_email":      "alt_email@example.com", // for custom claim tests
97                         "alt_username":   "desired-username",      // for custom claim tests
98                 })
99                 json.NewEncoder(w).Encode(struct {
100                         AccessToken  string `json:"access_token"`
101                         TokenType    string `json:"token_type"`
102                         RefreshToken string `json:"refresh_token"`
103                         ExpiresIn    int32  `json:"expires_in"`
104                         IDToken      string `json:"id_token"`
105                 }{
106                         AccessToken:  p.ValidAccessToken(),
107                         TokenType:    "Bearer",
108                         RefreshToken: "test-refresh-token",
109                         ExpiresIn:    30,
110                         IDToken:      p.fakeToken(idToken),
111                 })
112         case "/jwks":
113                 json.NewEncoder(w).Encode(jose.JSONWebKeySet{
114                         Keys: []jose.JSONWebKey{
115                                 {Key: p.key.Public(), Algorithm: string(jose.RS256), KeyID: ""},
116                         },
117                 })
118         case "/auth":
119                 w.WriteHeader(http.StatusInternalServerError)
120         case "/userinfo":
121                 if authhdr := req.Header.Get("Authorization"); strings.TrimPrefix(authhdr, "Bearer ") != p.ValidAccessToken() {
122                         p.c.Logf("OIDCProvider: bad auth %q", authhdr)
123                         w.WriteHeader(http.StatusUnauthorized)
124                         return
125                 }
126                 json.NewEncoder(w).Encode(map[string]interface{}{
127                         "sub":            "fake-user-id",
128                         "name":           p.AuthName,
129                         "given_name":     p.AuthName,
130                         "family_name":    "",
131                         "alt_username":   "desired-username",
132                         "email":          p.AuthEmail,
133                         "email_verified": p.AuthEmailVerified,
134                 })
135         default:
136                 w.WriteHeader(http.StatusNotFound)
137         }
138 }
139
140 func (p *OIDCProvider) servePeopleAPI(w http.ResponseWriter, req *http.Request) {
141         req.ParseForm()
142         p.c.Logf("servePeopleAPI: got req: %s %s %s", req.Method, req.URL, req.Form)
143         w.Header().Set("Content-Type", "application/json")
144         switch req.URL.Path {
145         case "/v1/people/me":
146                 if f := req.Form.Get("personFields"); f != "emailAddresses,names" {
147                         w.WriteHeader(http.StatusBadRequest)
148                         break
149                 }
150                 json.NewEncoder(w).Encode(p.PeopleAPIResponse)
151         default:
152                 w.WriteHeader(http.StatusNotFound)
153         }
154 }
155
156 func (p *OIDCProvider) fakeToken(payload []byte) string {
157         signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: p.key}, nil)
158         if err != nil {
159                 p.c.Error(err)
160                 return ""
161         }
162         object, err := signer.Sign(payload)
163         if err != nil {
164                 p.c.Error(err)
165                 return ""
166         }
167         t, err := object.CompactSerialize()
168         if err != nil {
169                 p.c.Error(err)
170                 return ""
171         }
172         p.c.Logf("fakeToken(%q) == %q", payload, t)
173         return t
174 }