X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/e4b9cffde1c932a0ca880d1542c62b4611142352..aa6ca6f50268f29c7cce987f5957796826bfeeed:/lib/ctrlctx/auth.go diff --git a/lib/ctrlctx/auth.go b/lib/ctrlctx/auth.go index 5b96463cce..31746b64cc 100644 --- a/lib/ctrlctx/auth.go +++ b/lib/ctrlctx/auth.go @@ -9,16 +9,17 @@ import ( "crypto/hmac" "crypto/sha256" "database/sql" - "encoding/json" "errors" "fmt" "io" "strings" "sync" + "time" "git.arvados.org/arvados.git/lib/controller/api" "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/auth" + "github.com/ghodss/yaml" ) var ( @@ -33,17 +34,43 @@ var ( // The incoming context must come from WrapCallsInTransactions or // NewWithTransaction. func WrapCallsWithAuth(cluster *arvados.Cluster) func(api.RoutableFunc) api.RoutableFunc { + var authcache authcache return func(origFunc api.RoutableFunc) api.RoutableFunc { return func(ctx context.Context, opts interface{}) (_ interface{}, err error) { var tokens []string if creds, ok := auth.FromContext(ctx); ok { tokens = creds.Tokens } - return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{cluster: cluster, tokens: tokens}), opts) + return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{ + authcache: &authcache, + cluster: cluster, + tokens: tokens, + }), opts) } } } +// NewWithToken returns a context with the provided auth token. +// +// The incoming context must come from WrapCallsInTransactions or +// NewWithTransaction. +// +// Used for attaching system auth to background threads. +// +// Also useful for tests, where context doesn't necessarily come from +// a router that uses WrapCallsWithAuth. +// +// The returned context comes with its own token lookup cache, so +// NewWithToken is not appropriate to use in a per-request code path. +func NewWithToken(ctx context.Context, cluster *arvados.Cluster, token string) context.Context { + ctx = auth.NewContext(ctx, &auth.Credentials{Tokens: []string{token}}) + return context.WithValue(ctx, contextKeyAuth, &authcontext{ + authcache: &authcache{}, + cluster: cluster, + tokens: []string{token}, + }) +} + // CurrentAuth returns the arvados.User whose privileges should be // used in the given context, and the arvados.APIClientAuthorization // the caller presented in order to authenticate the current request. @@ -55,7 +82,25 @@ func CurrentAuth(ctx context.Context) (*arvados.User, *arvados.APIClientAuthoriz if !ok { return nil, nil, ErrNoAuthContext } - ac.lookupOnce.Do(func() { ac.user, ac.apiClientAuthorization, ac.err = aclookup(ctx, ac.cluster, ac.tokens) }) + ac.lookupOnce.Do(func() { + // We only validate/lookup the token once per API + // call, even though authcache should be efficient + // enough to do a lookup each time. This guarantees we + // always return the same result when called multiple + // times in the course of handling a single API call. + for _, token := range ac.tokens { + user, aca, err := ac.authcache.lookup(ctx, ac.cluster, token) + if err != nil { + ac.err = err + return + } + if user != nil { + ac.user, ac.apiClientAuthorization = user, aca + return + } + } + ac.err = ErrUnauthenticated + }) return ac.user, ac.apiClientAuthorization, ac.err } @@ -64,6 +109,7 @@ type contextKeyA string var contextKeyAuth = contextKeyT("auth") type authcontext struct { + authcache *authcache cluster *arvados.Cluster tokens []string user *arvados.User @@ -72,9 +118,32 @@ type authcontext struct { lookupOnce sync.Once } -func aclookup(ctx context.Context, cluster *arvados.Cluster, tokens []string) (*arvados.User, *arvados.APIClientAuthorization, error) { - if len(tokens) == 0 { - return nil, nil, ErrUnauthenticated +var authcacheTTL = time.Minute + +type authcacheent struct { + expireTime time.Time + apiClientAuthorization arvados.APIClientAuthorization + user arvados.User +} + +type authcache struct { + mtx sync.Mutex + entries map[string]*authcacheent + nextCleanup time.Time +} + +// lookup returns the user and aca info for a given token. Returns nil +// if the token is not valid. Returns a non-nil error if there was an +// unexpected error from the database, etc. +func (ac *authcache) lookup(ctx context.Context, cluster *arvados.Cluster, token string) (*arvados.User, *arvados.APIClientAuthorization, error) { + ac.mtx.Lock() + ent := ac.entries[token] + ac.mtx.Unlock() + if ent != nil && ent.expireTime.After(time.Now()) { + return &ent.user, &ent.apiClientAuthorization, nil + } + if token == "" { + return nil, nil, nil } tx, err := CurrentTx(ctx) if err != nil { @@ -82,44 +151,61 @@ func aclookup(ctx context.Context, cluster *arvados.Cluster, tokens []string) (* } var aca arvados.APIClientAuthorization var user arvados.User - for _, token := range tokens { - var cond string - var args []interface{} - if token == "" { - continue - } else if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' { - fields := strings.Split(token, "/") - cond = `aca.uuid=$1 and aca.api_token=$2` - args = []interface{}{fields[1], fields[2]} - } else { - // Bare token or OIDC access token - mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken)) - io.WriteString(mac, token) - hmac := fmt.Sprintf("%x", mac.Sum(nil)) - cond = `aca.api_token in ($1, $2)` - args = []interface{}{token, hmac} - } - var scopesJSON []byte - err = tx.QueryRowContext(ctx, ` + + var cond string + var args []interface{} + if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' { + fields := strings.Split(token, "/") + cond = `aca.uuid = $1 and aca.api_token = $2` + args = []interface{}{fields[1], fields[2]} + } else { + // Bare token or OIDC access token + mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken)) + io.WriteString(mac, token) + hmac := fmt.Sprintf("%x", mac.Sum(nil)) + cond = `aca.api_token in ($1, $2)` + args = []interface{}{token, hmac} + } + var expiresAt sql.NullTime + var scopesYAML []byte + err = tx.QueryRowContext(ctx, ` select aca.uuid, aca.expires_at, aca.api_token, aca.scopes, users.uuid, users.is_active, users.is_admin from api_client_authorizations aca left join users on aca.user_id = users.id where `+cond+` and (expires_at is null or expires_at > current_timestamp at time zone 'UTC')`, args...).Scan( - &aca.UUID, &aca.ExpiresAt, &aca.APIToken, &scopesJSON, - &user.UUID, &user.IsActive, &user.IsAdmin) - if err == sql.ErrNoRows { - continue - } else if err != nil { - return nil, nil, err + &aca.UUID, &expiresAt, &aca.APIToken, &scopesYAML, + &user.UUID, &user.IsActive, &user.IsAdmin) + if err == sql.ErrNoRows { + return nil, nil, nil + } else if err != nil { + return nil, nil, err + } + aca.ExpiresAt = expiresAt.Time + if len(scopesYAML) > 0 { + err = yaml.Unmarshal(scopesYAML, &aca.Scopes) + if err != nil { + return nil, nil, fmt.Errorf("loading scopes for %s: %w", aca.UUID, err) } - if len(scopesJSON) > 0 { - err = json.Unmarshal(scopesJSON, &aca.Scopes) - if err != nil { - return nil, nil, err + } + ent = &authcacheent{ + expireTime: time.Now().Add(authcacheTTL), + apiClientAuthorization: aca, + user: user, + } + ac.mtx.Lock() + defer ac.mtx.Unlock() + if ac.entries == nil { + ac.entries = map[string]*authcacheent{} + } + if ac.nextCleanup.IsZero() || ac.nextCleanup.Before(time.Now()) { + for token, ent := range ac.entries { + if !ent.expireTime.After(time.Now()) { + delete(ac.entries, token) } } - return &user, &aca, nil + ac.nextCleanup = time.Now().Add(authcacheTTL) } - return nil, nil, ErrUnauthenticated + ac.entries[token] = ent + return &ent.user, &ent.apiClientAuthorization, nil }