21361: Remove Debian 10 support from installer
[arvados.git] / lib / ctrlctx / auth.go
index 5b96463cce3879c72c3f1174055dab66ebe44a20..31746b64cca5c77a9d8aa695c67e0aa16d8f47ba 100644 (file)
@@ -9,16 +9,17 @@ import (
        "crypto/hmac"
        "crypto/sha256"
        "database/sql"
-       "encoding/json"
        "errors"
        "fmt"
        "io"
        "strings"
        "sync"
+       "time"
 
        "git.arvados.org/arvados.git/lib/controller/api"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/auth"
+       "github.com/ghodss/yaml"
 )
 
 var (
@@ -33,17 +34,43 @@ var (
 // The incoming context must come from WrapCallsInTransactions or
 // NewWithTransaction.
 func WrapCallsWithAuth(cluster *arvados.Cluster) func(api.RoutableFunc) api.RoutableFunc {
+       var authcache authcache
        return func(origFunc api.RoutableFunc) api.RoutableFunc {
                return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
                        var tokens []string
                        if creds, ok := auth.FromContext(ctx); ok {
                                tokens = creds.Tokens
                        }
-                       return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{cluster: cluster, tokens: tokens}), opts)
+                       return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{
+                               authcache: &authcache,
+                               cluster:   cluster,
+                               tokens:    tokens,
+                       }), opts)
                }
        }
 }
 
+// NewWithToken returns a context with the provided auth token.
+//
+// The incoming context must come from WrapCallsInTransactions or
+// NewWithTransaction.
+//
+// Used for attaching system auth to background threads.
+//
+// Also useful for tests, where context doesn't necessarily come from
+// a router that uses WrapCallsWithAuth.
+//
+// The returned context comes with its own token lookup cache, so
+// NewWithToken is not appropriate to use in a per-request code path.
+func NewWithToken(ctx context.Context, cluster *arvados.Cluster, token string) context.Context {
+       ctx = auth.NewContext(ctx, &auth.Credentials{Tokens: []string{token}})
+       return context.WithValue(ctx, contextKeyAuth, &authcontext{
+               authcache: &authcache{},
+               cluster:   cluster,
+               tokens:    []string{token},
+       })
+}
+
 // CurrentAuth returns the arvados.User whose privileges should be
 // used in the given context, and the arvados.APIClientAuthorization
 // the caller presented in order to authenticate the current request.
@@ -55,7 +82,25 @@ func CurrentAuth(ctx context.Context) (*arvados.User, *arvados.APIClientAuthoriz
        if !ok {
                return nil, nil, ErrNoAuthContext
        }
-       ac.lookupOnce.Do(func() { ac.user, ac.apiClientAuthorization, ac.err = aclookup(ctx, ac.cluster, ac.tokens) })
+       ac.lookupOnce.Do(func() {
+               // We only validate/lookup the token once per API
+               // call, even though authcache should be efficient
+               // enough to do a lookup each time. This guarantees we
+               // always return the same result when called multiple
+               // times in the course of handling a single API call.
+               for _, token := range ac.tokens {
+                       user, aca, err := ac.authcache.lookup(ctx, ac.cluster, token)
+                       if err != nil {
+                               ac.err = err
+                               return
+                       }
+                       if user != nil {
+                               ac.user, ac.apiClientAuthorization = user, aca
+                               return
+                       }
+               }
+               ac.err = ErrUnauthenticated
+       })
        return ac.user, ac.apiClientAuthorization, ac.err
 }
 
@@ -64,6 +109,7 @@ type contextKeyA string
 var contextKeyAuth = contextKeyT("auth")
 
 type authcontext struct {
+       authcache              *authcache
        cluster                *arvados.Cluster
        tokens                 []string
        user                   *arvados.User
@@ -72,9 +118,32 @@ type authcontext struct {
        lookupOnce             sync.Once
 }
 
-func aclookup(ctx context.Context, cluster *arvados.Cluster, tokens []string) (*arvados.User, *arvados.APIClientAuthorization, error) {
-       if len(tokens) == 0 {
-               return nil, nil, ErrUnauthenticated
+var authcacheTTL = time.Minute
+
+type authcacheent struct {
+       expireTime             time.Time
+       apiClientAuthorization arvados.APIClientAuthorization
+       user                   arvados.User
+}
+
+type authcache struct {
+       mtx         sync.Mutex
+       entries     map[string]*authcacheent
+       nextCleanup time.Time
+}
+
+// lookup returns the user and aca info for a given token. Returns nil
+// if the token is not valid. Returns a non-nil error if there was an
+// unexpected error from the database, etc.
+func (ac *authcache) lookup(ctx context.Context, cluster *arvados.Cluster, token string) (*arvados.User, *arvados.APIClientAuthorization, error) {
+       ac.mtx.Lock()
+       ent := ac.entries[token]
+       ac.mtx.Unlock()
+       if ent != nil && ent.expireTime.After(time.Now()) {
+               return &ent.user, &ent.apiClientAuthorization, nil
+       }
+       if token == "" {
+               return nil, nil, nil
        }
        tx, err := CurrentTx(ctx)
        if err != nil {
@@ -82,44 +151,61 @@ func aclookup(ctx context.Context, cluster *arvados.Cluster, tokens []string) (*
        }
        var aca arvados.APIClientAuthorization
        var user arvados.User
-       for _, token := range tokens {
-               var cond string
-               var args []interface{}
-               if token == "" {
-                       continue
-               } else if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' {
-                       fields := strings.Split(token, "/")
-                       cond = `aca.uuid=$1 and aca.api_token=$2`
-                       args = []interface{}{fields[1], fields[2]}
-               } else {
-                       // Bare token or OIDC access token
-                       mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken))
-                       io.WriteString(mac, token)
-                       hmac := fmt.Sprintf("%x", mac.Sum(nil))
-                       cond = `aca.api_token in ($1, $2)`
-                       args = []interface{}{token, hmac}
-               }
-               var scopesJSON []byte
-               err = tx.QueryRowContext(ctx, `
+
+       var cond string
+       var args []interface{}
+       if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' {
+               fields := strings.Split(token, "/")
+               cond = `aca.uuid = $1 and aca.api_token = $2`
+               args = []interface{}{fields[1], fields[2]}
+       } else {
+               // Bare token or OIDC access token
+               mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken))
+               io.WriteString(mac, token)
+               hmac := fmt.Sprintf("%x", mac.Sum(nil))
+               cond = `aca.api_token in ($1, $2)`
+               args = []interface{}{token, hmac}
+       }
+       var expiresAt sql.NullTime
+       var scopesYAML []byte
+       err = tx.QueryRowContext(ctx, `
 select aca.uuid, aca.expires_at, aca.api_token, aca.scopes, users.uuid, users.is_active, users.is_admin
  from api_client_authorizations aca
  left join users on aca.user_id = users.id
  where `+cond+`
  and (expires_at is null or expires_at > current_timestamp at time zone 'UTC')`, args...).Scan(
-                       &aca.UUID, &aca.ExpiresAt, &aca.APIToken, &scopesJSON,
-                       &user.UUID, &user.IsActive, &user.IsAdmin)
-               if err == sql.ErrNoRows {
-                       continue
-               } else if err != nil {
-                       return nil, nil, err
+               &aca.UUID, &expiresAt, &aca.APIToken, &scopesYAML,
+               &user.UUID, &user.IsActive, &user.IsAdmin)
+       if err == sql.ErrNoRows {
+               return nil, nil, nil
+       } else if err != nil {
+               return nil, nil, err
+       }
+       aca.ExpiresAt = expiresAt.Time
+       if len(scopesYAML) > 0 {
+               err = yaml.Unmarshal(scopesYAML, &aca.Scopes)
+               if err != nil {
+                       return nil, nil, fmt.Errorf("loading scopes for %s: %w", aca.UUID, err)
                }
-               if len(scopesJSON) > 0 {
-                       err = json.Unmarshal(scopesJSON, &aca.Scopes)
-                       if err != nil {
-                               return nil, nil, err
+       }
+       ent = &authcacheent{
+               expireTime:             time.Now().Add(authcacheTTL),
+               apiClientAuthorization: aca,
+               user:                   user,
+       }
+       ac.mtx.Lock()
+       defer ac.mtx.Unlock()
+       if ac.entries == nil {
+               ac.entries = map[string]*authcacheent{}
+       }
+       if ac.nextCleanup.IsZero() || ac.nextCleanup.Before(time.Now()) {
+               for token, ent := range ac.entries {
+                       if !ent.expireTime.After(time.Now()) {
+                               delete(ac.entries, token)
                        }
                }
-               return &user, &aca, nil
+               ac.nextCleanup = time.Now().Add(authcacheTTL)
        }
-       return nil, nil, ErrUnauthenticated
+       ac.entries[token] = ent
+       return &ent.user, &ent.apiClientAuthorization, nil
 }