19907: Don't cache network/5xx errors when checking UserInfo.
[arvados.git] / lib / controller / localdb / login_oidc_test.go
index 0fe3bdf7f6b684652cad9c71f3c0a63fba15b925..40cdde76f6bf99f4ce31eb41a536bdbd627126f7 100644 (file)
@@ -256,7 +256,16 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) {
        io.WriteString(mac, accessToken)
        apiToken := fmt.Sprintf("%x", mac.Sum(nil))
 
+       checkTokenInDB := func() time.Time {
+               var exp time.Time
+               err := db.QueryRow(`select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp)
+               c.Check(err, check.IsNil)
+               c.Check(exp.Sub(time.Now()) > -time.Second, check.Equals, true)
+               c.Check(exp.Sub(time.Now()) < time.Second, check.Equals, true)
+               return exp
+       }
        cleanup := func() {
+               oidcAuthorizer.cache.Purge()
                _, err := db.Exec(`delete from api_client_authorizations where api_token=$1`, apiToken)
                c.Check(err, check.IsNil)
        }
@@ -264,8 +273,56 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) {
        defer cleanup()
 
        ctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{accessToken}})
-       var exp1 time.Time
 
+       // Check behavior on 5xx/network errors (don't cache) vs 4xx
+       // (do cache)
+       {
+               call := oidcAuthorizer.WrapCalls(func(ctx context.Context, opts interface{}) (interface{}, error) {
+                       return nil, nil
+               })
+
+               // If fakeProvider UserInfo endpoint returns 502, we
+               // should fail, return an error, and *not* cache the
+               // negative result.
+               tokenCacheNegativeTTL = time.Minute
+               s.fakeProvider.UserInfoErrorStatus = 502
+               _, err := call(ctx, nil)
+               c.Check(err, check.NotNil)
+
+               // The negative result was not cached, so retrying
+               // immediately (with UserInfo working now) should
+               // succeed.
+               s.fakeProvider.UserInfoErrorStatus = 0
+               _, err = call(ctx, nil)
+               c.Check(err, check.IsNil)
+               checkTokenInDB()
+
+               cleanup()
+
+               // UserInfo 401 => cache the negative result, but
+               // don't return an error (just pass the token through
+               // as a v1 token)
+               s.fakeProvider.UserInfoErrorStatus = 401
+               _, err = call(ctx, nil)
+               c.Check(err, check.IsNil)
+               ent, ok := oidcAuthorizer.cache.Get(accessToken)
+               c.Check(ok, check.Equals, true)
+               c.Check(ent, check.FitsTypeOf, time.Time{})
+
+               // UserInfo succeeds now, but we still have a cached
+               // negative result.
+               s.fakeProvider.UserInfoErrorStatus = 0
+               _, err = call(ctx, nil)
+               c.Check(err, check.IsNil)
+               ent, ok = oidcAuthorizer.cache.Get(accessToken)
+               c.Check(ok, check.Equals, true)
+               c.Check(ent, check.FitsTypeOf, time.Time{})
+
+               tokenCacheNegativeTTL = time.Millisecond
+               cleanup()
+       }
+
+       var exp1 time.Time
        concurrent := 4
        s.fakeProvider.HoldUserInfo = make(chan *http.Request)
        s.fakeProvider.ReleaseUserInfo = make(chan struct{})
@@ -285,17 +342,12 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) {
                        defer wg.Done()
                        _, err := oidcAuthorizer.WrapCalls(func(ctx context.Context, opts interface{}) (interface{}, error) {
                                c.Logf("concurrent req %d/%d", i, concurrent)
-                               var exp time.Time
 
                                creds, ok := auth.FromContext(ctx)
                                c.Assert(ok, check.Equals, true)
                                c.Assert(creds.Tokens, check.HasLen, 1)
                                c.Check(creds.Tokens[0], check.Equals, accessToken)
-
-                               err := db.QueryRowContext(ctx, `select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp)
-                               c.Check(err, check.IsNil)
-                               c.Check(exp.Sub(time.Now()) > -time.Second, check.Equals, true)
-                               c.Check(exp.Sub(time.Now()) < time.Second, check.Equals, true)
+                               exp := checkTokenInDB()
                                if i == 0 {
                                        exp1 = exp
                                }
@@ -314,9 +366,7 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) {
        // the expires_at value in the database.
        time.Sleep(3 * time.Millisecond)
        oidcAuthorizer.WrapCalls(func(ctx context.Context, opts interface{}) (interface{}, error) {
-               var exp time.Time
-               err := db.QueryRowContext(ctx, `select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp)
-               c.Check(err, check.IsNil)
+               exp := checkTokenInDB()
                c.Check(exp.Sub(exp1) > 0, check.Equals, true, check.Commentf("expect %v > 0", exp.Sub(exp1)))
                c.Check(exp.Sub(exp1) < time.Second, check.Equals, true, check.Commentf("expect %v < 1s", exp.Sub(exp1)))
                return nil, nil