19907: Don't cache network/5xx errors when checking UserInfo.
[arvados.git] / sdk / go / arvadostest / oidc_provider.go
index de21302e5a048dfbca340abf24cb6c5359de7305..529c1dca12b15a9550dbffcbb9c37e145fa39cb7 100644 (file)
@@ -9,6 +9,7 @@ import (
        "crypto/rsa"
        "encoding/base64"
        "encoding/json"
+       "fmt"
        "net/http"
        "net/http/httptest"
        "net/url"
@@ -29,10 +30,19 @@ type OIDCProvider struct {
        AuthEmail          string
        AuthEmailVerified  bool
        AuthName           string
+       AuthGivenName      string
+       AuthFamilyName     string
        AccessTokenPayload map[string]interface{}
 
        PeopleAPIResponse map[string]interface{}
 
+       // send incoming /userinfo requests to HoldUserInfo (if not
+       // nil), then receive from ReleaseUserInfo (if not nil),
+       // before responding (these are used to set up races)
+       HoldUserInfo        chan *http.Request
+       ReleaseUserInfo     chan struct{}
+       UserInfoErrorStatus int // if non-zero, return this http status (probably 5xx)
+
        key       *rsa.PrivateKey
        Issuer    *httptest.Server
        PeopleAPI *httptest.Server
@@ -96,6 +106,8 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
                        "email":          p.AuthEmail,
                        "email_verified": p.AuthEmailVerified,
                        "name":           p.AuthName,
+                       "given_name":     p.AuthGivenName,
+                       "family_name":    p.AuthFamilyName,
                        "alt_verified":   true,                    // for custom claim tests
                        "alt_email":      "alt_email@example.com", // for custom claim tests
                        "alt_username":   "desired-username",      // for custom claim tests
@@ -122,6 +134,17 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
        case "/auth":
                w.WriteHeader(http.StatusInternalServerError)
        case "/userinfo":
+               if p.HoldUserInfo != nil {
+                       p.HoldUserInfo <- req
+               }
+               if p.ReleaseUserInfo != nil {
+                       <-p.ReleaseUserInfo
+               }
+               if p.UserInfoErrorStatus > 0 {
+                       w.WriteHeader(p.UserInfoErrorStatus)
+                       fmt.Fprintf(w, "%T error body", p)
+                       return
+               }
                authhdr := req.Header.Get("Authorization")
                if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
                        p.c.Logf("OIDCProvider: bad auth %q", authhdr)
@@ -131,8 +154,8 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
                json.NewEncoder(w).Encode(map[string]interface{}{
                        "sub":            "fake-user-id",
                        "name":           p.AuthName,
-                       "given_name":     p.AuthName,
-                       "family_name":    "",
+                       "given_name":     p.AuthGivenName,
+                       "family_name":    p.AuthFamilyName,
                        "alt_username":   "desired-username",
                        "email":          p.AuthEmail,
                        "email_verified": p.AuthEmailVerified,