20183: Fixup auth contexts in tests.
[arvados.git] / lib / ctrlctx / auth.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ctrlctx
6
7 import (
8         "context"
9         "crypto/hmac"
10         "crypto/sha256"
11         "database/sql"
12         "errors"
13         "fmt"
14         "io"
15         "strings"
16         "sync"
17         "time"
18
19         "git.arvados.org/arvados.git/lib/controller/api"
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "git.arvados.org/arvados.git/sdk/go/auth"
22         "github.com/ghodss/yaml"
23 )
24
25 var (
26         ErrNoAuthContext   = errors.New("bug: there is no authorization in this context")
27         ErrUnauthenticated = errors.New("unauthenticated request")
28 )
29
30 // WrapCallsWithAuth returns a call wrapper (suitable for assigning to
31 // router.router.WrapCalls) that makes CurrentUser(ctx) et al. work
32 // from inside the wrapped functions.
33 //
34 // The incoming context must come from WrapCallsInTransactions or
35 // NewWithTransaction.
36 func WrapCallsWithAuth(cluster *arvados.Cluster) func(api.RoutableFunc) api.RoutableFunc {
37         var authcache authcache
38         return func(origFunc api.RoutableFunc) api.RoutableFunc {
39                 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
40                         var tokens []string
41                         if creds, ok := auth.FromContext(ctx); ok {
42                                 tokens = creds.Tokens
43                         }
44                         return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{
45                                 authcache: &authcache,
46                                 cluster:   cluster,
47                                 tokens:    tokens,
48                         }), opts)
49                 }
50         }
51 }
52
53 // NewWithToken returns a context with the provided auth token.
54 //
55 // The incoming context must come from WrapCallsInTransactions or
56 // NewWithTransaction.
57 //
58 // Used for attaching system auth to background threads.
59 //
60 // Also useful for tests, where context doesn't necessarily come from
61 // a router that uses WrapCallsWithAuth.
62 //
63 // The returned context comes with its own token lookup cache, so
64 // NewWithToken is not appropriate to use in a per-request code path.
65 func NewWithToken(ctx context.Context, cluster *arvados.Cluster, token string) context.Context {
66         ctx = auth.NewContext(ctx, &auth.Credentials{Tokens: []string{token}})
67         return context.WithValue(ctx, contextKeyAuth, &authcontext{
68                 authcache: &authcache{},
69                 cluster:   cluster,
70                 tokens:    []string{token},
71         })
72 }
73
74 // CurrentAuth returns the arvados.User whose privileges should be
75 // used in the given context, and the arvados.APIClientAuthorization
76 // the caller presented in order to authenticate the current request.
77 //
78 // Returns ErrUnauthenticated if the current request was not
79 // authenticated (no token provided, token is expired, etc).
80 func CurrentAuth(ctx context.Context) (*arvados.User, *arvados.APIClientAuthorization, error) {
81         ac, ok := ctx.Value(contextKeyAuth).(*authcontext)
82         if !ok {
83                 return nil, nil, ErrNoAuthContext
84         }
85         ac.lookupOnce.Do(func() {
86                 // We only validate/lookup the token once per API
87                 // call, even though authcache should be efficient
88                 // enough to do a lookup each time. This guarantees we
89                 // always return the same result when called multiple
90                 // times in the course of handling a single API call.
91                 for _, token := range ac.tokens {
92                         user, aca, err := ac.authcache.lookup(ctx, ac.cluster, token)
93                         if err != nil {
94                                 ac.err = err
95                                 return
96                         }
97                         if user != nil {
98                                 ac.user, ac.apiClientAuthorization = user, aca
99                                 return
100                         }
101                 }
102                 ac.err = ErrUnauthenticated
103         })
104         return ac.user, ac.apiClientAuthorization, ac.err
105 }
106
107 type contextKeyA string
108
109 var contextKeyAuth = contextKeyT("auth")
110
111 type authcontext struct {
112         authcache              *authcache
113         cluster                *arvados.Cluster
114         tokens                 []string
115         user                   *arvados.User
116         apiClientAuthorization *arvados.APIClientAuthorization
117         err                    error
118         lookupOnce             sync.Once
119 }
120
121 var authcacheTTL = time.Minute
122
123 type authcacheent struct {
124         expireTime             time.Time
125         apiClientAuthorization arvados.APIClientAuthorization
126         user                   arvados.User
127 }
128
129 type authcache struct {
130         mtx         sync.Mutex
131         entries     map[string]*authcacheent
132         nextCleanup time.Time
133 }
134
135 // lookup returns the user and aca info for a given token. Returns nil
136 // if the token is not valid. Returns a non-nil error if there was an
137 // unexpected error from the database, etc.
138 func (ac *authcache) lookup(ctx context.Context, cluster *arvados.Cluster, token string) (*arvados.User, *arvados.APIClientAuthorization, error) {
139         ac.mtx.Lock()
140         ent := ac.entries[token]
141         ac.mtx.Unlock()
142         if ent != nil && ent.expireTime.After(time.Now()) {
143                 return &ent.user, &ent.apiClientAuthorization, nil
144         }
145         if token == "" {
146                 return nil, nil, nil
147         }
148         tx, err := CurrentTx(ctx)
149         if err != nil {
150                 return nil, nil, err
151         }
152         var aca arvados.APIClientAuthorization
153         var user arvados.User
154
155         var cond string
156         var args []interface{}
157         if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' {
158                 fields := strings.Split(token, "/")
159                 cond = `aca.uuid = $1 and aca.api_token = $2`
160                 args = []interface{}{fields[1], fields[2]}
161         } else {
162                 // Bare token or OIDC access token
163                 mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken))
164                 io.WriteString(mac, token)
165                 hmac := fmt.Sprintf("%x", mac.Sum(nil))
166                 cond = `aca.api_token in ($1, $2)`
167                 args = []interface{}{token, hmac}
168         }
169         var expiresAt sql.NullTime
170         var scopesYAML []byte
171         err = tx.QueryRowContext(ctx, `
172 select aca.uuid, aca.expires_at, aca.api_token, aca.scopes, users.uuid, users.is_active, users.is_admin
173  from api_client_authorizations aca
174  left join users on aca.user_id = users.id
175  where `+cond+`
176  and (expires_at is null or expires_at > current_timestamp at time zone 'UTC')`, args...).Scan(
177                 &aca.UUID, &expiresAt, &aca.APIToken, &scopesYAML,
178                 &user.UUID, &user.IsActive, &user.IsAdmin)
179         if err == sql.ErrNoRows {
180                 return nil, nil, nil
181         } else if err != nil {
182                 return nil, nil, err
183         }
184         aca.ExpiresAt = expiresAt.Time
185         if len(scopesYAML) > 0 {
186                 err = yaml.Unmarshal(scopesYAML, &aca.Scopes)
187                 if err != nil {
188                         return nil, nil, fmt.Errorf("loading scopes for %s: %w", aca.UUID, err)
189                 }
190         }
191         ent = &authcacheent{
192                 expireTime:             time.Now().Add(authcacheTTL),
193                 apiClientAuthorization: aca,
194                 user:                   user,
195         }
196         ac.mtx.Lock()
197         defer ac.mtx.Unlock()
198         if ac.entries == nil {
199                 ac.entries = map[string]*authcacheent{}
200         }
201         if ac.nextCleanup.IsZero() || ac.nextCleanup.Before(time.Now()) {
202                 for token, ent := range ac.entries {
203                         if !ent.expireTime.After(time.Now()) {
204                                 delete(ac.entries, token)
205                         }
206                 }
207                 ac.nextCleanup = time.Now().Add(authcacheTTL)
208         }
209         ac.entries[token] = ent
210         return &ent.user, &ent.apiClientAuthorization, nil
211 }