X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/18e79caebbda5a4350462a2b22e5af36f8e4b699..2930c49ba1a06f947d06424a37b0dc93505fad18:/lib/controller/localdb/login_oidc_test.go diff --git a/lib/controller/localdb/login_oidc_test.go b/lib/controller/localdb/login_oidc_test.go index 0fe3bdf7f6..40cdde76f6 100644 --- a/lib/controller/localdb/login_oidc_test.go +++ b/lib/controller/localdb/login_oidc_test.go @@ -256,7 +256,16 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) { io.WriteString(mac, accessToken) apiToken := fmt.Sprintf("%x", mac.Sum(nil)) + checkTokenInDB := func() time.Time { + var exp time.Time + err := db.QueryRow(`select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp) + c.Check(err, check.IsNil) + c.Check(exp.Sub(time.Now()) > -time.Second, check.Equals, true) + c.Check(exp.Sub(time.Now()) < time.Second, check.Equals, true) + return exp + } cleanup := func() { + oidcAuthorizer.cache.Purge() _, err := db.Exec(`delete from api_client_authorizations where api_token=$1`, apiToken) c.Check(err, check.IsNil) } @@ -264,8 +273,56 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) { defer cleanup() ctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{accessToken}}) - var exp1 time.Time + // Check behavior on 5xx/network errors (don't cache) vs 4xx + // (do cache) + { + call := oidcAuthorizer.WrapCalls(func(ctx context.Context, opts interface{}) (interface{}, error) { + return nil, nil + }) + + // If fakeProvider UserInfo endpoint returns 502, we + // should fail, return an error, and *not* cache the + // negative result. + tokenCacheNegativeTTL = time.Minute + s.fakeProvider.UserInfoErrorStatus = 502 + _, err := call(ctx, nil) + c.Check(err, check.NotNil) + + // The negative result was not cached, so retrying + // immediately (with UserInfo working now) should + // succeed. + s.fakeProvider.UserInfoErrorStatus = 0 + _, err = call(ctx, nil) + c.Check(err, check.IsNil) + checkTokenInDB() + + cleanup() + + // UserInfo 401 => cache the negative result, but + // don't return an error (just pass the token through + // as a v1 token) + s.fakeProvider.UserInfoErrorStatus = 401 + _, err = call(ctx, nil) + c.Check(err, check.IsNil) + ent, ok := oidcAuthorizer.cache.Get(accessToken) + c.Check(ok, check.Equals, true) + c.Check(ent, check.FitsTypeOf, time.Time{}) + + // UserInfo succeeds now, but we still have a cached + // negative result. + s.fakeProvider.UserInfoErrorStatus = 0 + _, err = call(ctx, nil) + c.Check(err, check.IsNil) + ent, ok = oidcAuthorizer.cache.Get(accessToken) + c.Check(ok, check.Equals, true) + c.Check(ent, check.FitsTypeOf, time.Time{}) + + tokenCacheNegativeTTL = time.Millisecond + cleanup() + } + + var exp1 time.Time concurrent := 4 s.fakeProvider.HoldUserInfo = make(chan *http.Request) s.fakeProvider.ReleaseUserInfo = make(chan struct{}) @@ -285,17 +342,12 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) { defer wg.Done() _, err := oidcAuthorizer.WrapCalls(func(ctx context.Context, opts interface{}) (interface{}, error) { c.Logf("concurrent req %d/%d", i, concurrent) - var exp time.Time creds, ok := auth.FromContext(ctx) c.Assert(ok, check.Equals, true) c.Assert(creds.Tokens, check.HasLen, 1) c.Check(creds.Tokens[0], check.Equals, accessToken) - - err := db.QueryRowContext(ctx, `select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp) - c.Check(err, check.IsNil) - c.Check(exp.Sub(time.Now()) > -time.Second, check.Equals, true) - c.Check(exp.Sub(time.Now()) < time.Second, check.Equals, true) + exp := checkTokenInDB() if i == 0 { exp1 = exp } @@ -314,9 +366,7 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) { // the expires_at value in the database. time.Sleep(3 * time.Millisecond) oidcAuthorizer.WrapCalls(func(ctx context.Context, opts interface{}) (interface{}, error) { - var exp time.Time - err := db.QueryRowContext(ctx, `select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp) - c.Check(err, check.IsNil) + exp := checkTokenInDB() c.Check(exp.Sub(exp1) > 0, check.Equals, true, check.Commentf("expect %v > 0", exp.Sub(exp1))) c.Check(exp.Sub(exp1) < time.Second, check.Equals, true, check.Commentf("expect %v < 1s", exp.Sub(exp1))) return nil, nil