16552: change default db name to just arvados.
[arvados.git] / sdk / go / arvadostest / oidc_provider.go
index 0632010ba4d3a2f8065e23f126ec592ec7748dd3..087adc4b2441648111c0857b93c84eeb48d58cca 100644 (file)
@@ -17,6 +17,7 @@ import (
 
        "gopkg.in/check.v1"
        "gopkg.in/square/go-jose.v2"
+       "gopkg.in/square/go-jose.v2/jwt"
 )
 
 type OIDCProvider struct {
@@ -25,12 +26,21 @@ type OIDCProvider struct {
        ValidClientID     string
        ValidClientSecret string
        // desired response from token endpoint
-       AuthEmail         string
-       AuthEmailVerified bool
-       AuthName          string
+       AuthEmail          string
+       AuthEmailVerified  bool
+       AuthName           string
+       AuthGivenName      string
+       AuthFamilyName     string
+       AccessTokenPayload map[string]interface{}
 
        PeopleAPIResponse map[string]interface{}
 
+       // send incoming /userinfo requests to HoldUserInfo (if not
+       // nil), then receive from ReleaseUserInfo (if not nil),
+       // before responding (these are used to set up races)
+       HoldUserInfo    chan *http.Request
+       ReleaseUserInfo chan struct{}
+
        key       *rsa.PrivateKey
        Issuer    *httptest.Server
        PeopleAPI *httptest.Server
@@ -44,9 +54,15 @@ func NewOIDCProvider(c *check.C) *OIDCProvider {
        c.Assert(err, check.IsNil)
        p.Issuer = httptest.NewServer(http.HandlerFunc(p.serveOIDC))
        p.PeopleAPI = httptest.NewServer(http.HandlerFunc(p.servePeopleAPI))
+       p.AccessTokenPayload = map[string]interface{}{"sub": "example"}
        return p
 }
 
+func (p *OIDCProvider) ValidAccessToken() string {
+       buf, _ := json.Marshal(p.AccessTokenPayload)
+       return p.fakeToken(buf)
+}
+
 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
        req.ParseForm()
        p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
@@ -88,6 +104,8 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
                        "email":          p.AuthEmail,
                        "email_verified": p.AuthEmailVerified,
                        "name":           p.AuthName,
+                       "given_name":     p.AuthGivenName,
+                       "family_name":    p.AuthFamilyName,
                        "alt_verified":   true,                    // for custom claim tests
                        "alt_email":      "alt_email@example.com", // for custom claim tests
                        "alt_username":   "desired-username",      // for custom claim tests
@@ -99,7 +117,7 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
                        ExpiresIn    int32  `json:"expires_in"`
                        IDToken      string `json:"id_token"`
                }{
-                       AccessToken:  p.fakeToken([]byte("fake access token")),
+                       AccessToken:  p.ValidAccessToken(),
                        TokenType:    "Bearer",
                        RefreshToken: "test-refresh-token",
                        ExpiresIn:    30,
@@ -114,7 +132,27 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
        case "/auth":
                w.WriteHeader(http.StatusInternalServerError)
        case "/userinfo":
-               w.WriteHeader(http.StatusInternalServerError)
+               if p.HoldUserInfo != nil {
+                       p.HoldUserInfo <- req
+               }
+               if p.ReleaseUserInfo != nil {
+                       <-p.ReleaseUserInfo
+               }
+               authhdr := req.Header.Get("Authorization")
+               if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
+                       p.c.Logf("OIDCProvider: bad auth %q", authhdr)
+                       w.WriteHeader(http.StatusUnauthorized)
+                       return
+               }
+               json.NewEncoder(w).Encode(map[string]interface{}{
+                       "sub":            "fake-user-id",
+                       "name":           p.AuthName,
+                       "given_name":     p.AuthGivenName,
+                       "family_name":    p.AuthFamilyName,
+                       "alt_username":   "desired-username",
+                       "email":          p.AuthEmail,
+                       "email_verified": p.AuthEmailVerified,
+               })
        default:
                w.WriteHeader(http.StatusNotFound)
        }