19388: Cache token lookups.
authorTom Clegg <tom@curii.com>
Thu, 22 Sep 2022 19:19:48 +0000 (15:19 -0400)
committerTom Clegg <tom@curii.com>
Thu, 22 Sep 2022 19:19:48 +0000 (15:19 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/ctrlctx/auth.go
lib/ctrlctx/auth_test.go

index 5b96463cce3879c72c3f1174055dab66ebe44a20..61c6253d419472924d63d0e90f26da2c8f9e0fa9 100644 (file)
@@ -15,6 +15,7 @@ import (
        "io"
        "strings"
        "sync"
+       "time"
 
        "git.arvados.org/arvados.git/lib/controller/api"
        "git.arvados.org/arvados.git/sdk/go/arvados"
@@ -33,13 +34,18 @@ var (
 // The incoming context must come from WrapCallsInTransactions or
 // NewWithTransaction.
 func WrapCallsWithAuth(cluster *arvados.Cluster) func(api.RoutableFunc) api.RoutableFunc {
+       var authcache authcache
        return func(origFunc api.RoutableFunc) api.RoutableFunc {
                return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
                        var tokens []string
                        if creds, ok := auth.FromContext(ctx); ok {
                                tokens = creds.Tokens
                        }
-                       return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{cluster: cluster, tokens: tokens}), opts)
+                       return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{
+                               authcache: &authcache,
+                               cluster:   cluster,
+                               tokens:    tokens,
+                       }), opts)
                }
        }
 }
@@ -55,7 +61,25 @@ func CurrentAuth(ctx context.Context) (*arvados.User, *arvados.APIClientAuthoriz
        if !ok {
                return nil, nil, ErrNoAuthContext
        }
-       ac.lookupOnce.Do(func() { ac.user, ac.apiClientAuthorization, ac.err = aclookup(ctx, ac.cluster, ac.tokens) })
+       ac.lookupOnce.Do(func() {
+               // We only validate/lookup the token once per API
+               // call, even though authcache should be efficient
+               // enough to do a lookup each time. This guarantees we
+               // always return the same result when called multiple
+               // times in the course of handling a single API call.
+               for _, token := range ac.tokens {
+                       user, aca, err := ac.authcache.lookup(ctx, ac.cluster, token)
+                       if err != nil {
+                               ac.err = err
+                               return
+                       }
+                       if user != nil {
+                               ac.user, ac.apiClientAuthorization = user, aca
+                               return
+                       }
+               }
+               ac.err = ErrUnauthenticated
+       })
        return ac.user, ac.apiClientAuthorization, ac.err
 }
 
@@ -64,6 +88,7 @@ type contextKeyA string
 var contextKeyAuth = contextKeyT("auth")
 
 type authcontext struct {
+       authcache              *authcache
        cluster                *arvados.Cluster
        tokens                 []string
        user                   *arvados.User
@@ -72,9 +97,32 @@ type authcontext struct {
        lookupOnce             sync.Once
 }
 
-func aclookup(ctx context.Context, cluster *arvados.Cluster, tokens []string) (*arvados.User, *arvados.APIClientAuthorization, error) {
-       if len(tokens) == 0 {
-               return nil, nil, ErrUnauthenticated
+var authcacheTTL = time.Minute
+
+type authcacheent struct {
+       expireTime             time.Time
+       apiClientAuthorization arvados.APIClientAuthorization
+       user                   arvados.User
+}
+
+type authcache struct {
+       mtx         sync.Mutex
+       entries     map[string]*authcacheent
+       nextCleanup time.Time
+}
+
+// lookup returns the user and aca info for a given token. Returns nil
+// if the token is not valid. Returns a non-nil error if there was an
+// unexpected error from the database, etc.
+func (ac *authcache) lookup(ctx context.Context, cluster *arvados.Cluster, token string) (*arvados.User, *arvados.APIClientAuthorization, error) {
+       ac.mtx.Lock()
+       ent := ac.entries[token]
+       ac.mtx.Unlock()
+       if ent != nil && ent.expireTime.After(time.Now()) {
+               return &ent.user, &ent.apiClientAuthorization, nil
+       }
+       if token == "" {
+               return nil, nil, nil
        }
        tx, err := CurrentTx(ctx)
        if err != nil {
@@ -82,44 +130,59 @@ func aclookup(ctx context.Context, cluster *arvados.Cluster, tokens []string) (*
        }
        var aca arvados.APIClientAuthorization
        var user arvados.User
-       for _, token := range tokens {
-               var cond string
-               var args []interface{}
-               if token == "" {
-                       continue
-               } else if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' {
-                       fields := strings.Split(token, "/")
-                       cond = `aca.uuid=$1 and aca.api_token=$2`
-                       args = []interface{}{fields[1], fields[2]}
-               } else {
-                       // Bare token or OIDC access token
-                       mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken))
-                       io.WriteString(mac, token)
-                       hmac := fmt.Sprintf("%x", mac.Sum(nil))
-                       cond = `aca.api_token in ($1, $2)`
-                       args = []interface{}{token, hmac}
-               }
-               var scopesJSON []byte
-               err = tx.QueryRowContext(ctx, `
+
+       var cond string
+       var args []interface{}
+       if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' {
+               fields := strings.Split(token, "/")
+               cond = `aca.uuid=$1 and aca.api_token=$2`
+               args = []interface{}{fields[1], fields[2]}
+       } else {
+               // Bare token or OIDC access token
+               mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken))
+               io.WriteString(mac, token)
+               hmac := fmt.Sprintf("%x", mac.Sum(nil))
+               cond = `aca.api_token in ($1, $2)`
+               args = []interface{}{token, hmac}
+       }
+       var scopesJSON []byte
+       err = tx.QueryRowContext(ctx, `
 select aca.uuid, aca.expires_at, aca.api_token, aca.scopes, users.uuid, users.is_active, users.is_admin
  from api_client_authorizations aca
  left join users on aca.user_id = users.id
  where `+cond+`
  and (expires_at is null or expires_at > current_timestamp at time zone 'UTC')`, args...).Scan(
-                       &aca.UUID, &aca.ExpiresAt, &aca.APIToken, &scopesJSON,
-                       &user.UUID, &user.IsActive, &user.IsAdmin)
-               if err == sql.ErrNoRows {
-                       continue
-               } else if err != nil {
+               &aca.UUID, &aca.ExpiresAt, &aca.APIToken, &scopesJSON,
+               &user.UUID, &user.IsActive, &user.IsAdmin)
+       if err == sql.ErrNoRows {
+               return nil, nil, nil
+       } else if err != nil {
+               return nil, nil, err
+       }
+       if len(scopesJSON) > 0 {
+               err = json.Unmarshal(scopesJSON, &aca.Scopes)
+               if err != nil {
                        return nil, nil, err
                }
-               if len(scopesJSON) > 0 {
-                       err = json.Unmarshal(scopesJSON, &aca.Scopes)
-                       if err != nil {
-                               return nil, nil, err
+       }
+       ent = &authcacheent{
+               expireTime:             time.Now().Add(authcacheTTL),
+               apiClientAuthorization: aca,
+               user:                   user,
+       }
+       ac.mtx.Lock()
+       defer ac.mtx.Unlock()
+       if ac.entries == nil {
+               ac.entries = map[string]*authcacheent{}
+       }
+       if ac.nextCleanup.IsZero() || ac.nextCleanup.Before(time.Now()) {
+               for token, ent := range ac.entries {
+                       if !ent.expireTime.After(time.Now()) {
+                               delete(ac.entries, token)
                        }
                }
-               return &user, &aca, nil
+               ac.nextCleanup = time.Now().Add(authcacheTTL)
        }
-       return nil, nil, ErrUnauthenticated
+       ac.entries[token] = ent
+       return &ent.user, &ent.apiClientAuthorization, nil
 }
index add7a67d172bd42ab87313d92c411ab379c62965..5b0b0679821cf6ec589f178b70d253cf87b0376d 100644 (file)
@@ -32,6 +32,7 @@ func (*DatabaseSuite) TestAuthContext(c *check.C) {
                "3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi",
                "v2/zzzzz-gj3su-077z32aux8dg2s1/3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi",
                "v2/zzzzz-gj3su-077z32aux8dg2s1/3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi/asdfasdfasdf",
+               "v2/zzzzz-gj3su-077z32aux8dg2s1/3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi", // cached
        } {
                ok, err := dbwrapper(authwrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
                        user, aca, err := CurrentAuth(ctx)