19388: Add user/auth context to ctrlctx.
[arvados.git] / lib / ctrlctx / auth.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ctrlctx
6
7 import (
8         "context"
9         "crypto/hmac"
10         "crypto/sha256"
11         "database/sql"
12         "encoding/json"
13         "errors"
14         "fmt"
15         "io"
16         "strings"
17         "sync"
18
19         "git.arvados.org/arvados.git/lib/controller/api"
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "git.arvados.org/arvados.git/sdk/go/auth"
22 )
23
24 var (
25         ErrNoAuthContext   = errors.New("bug: there is no authorization in this context")
26         ErrUnauthenticated = errors.New("unauthenticated request")
27 )
28
29 // WrapCallsWithAuth returns a call wrapper (suitable for assigning to
30 // router.router.WrapCalls) that makes CurrentUser(ctx) et al. work
31 // from inside the wrapped functions.
32 //
33 // The incoming context must come from WrapCallsInTransactions or
34 // NewWithTransaction.
35 func WrapCallsWithAuth(cluster *arvados.Cluster) func(api.RoutableFunc) api.RoutableFunc {
36         return func(origFunc api.RoutableFunc) api.RoutableFunc {
37                 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
38                         var tokens []string
39                         if creds, ok := auth.FromContext(ctx); ok {
40                                 tokens = creds.Tokens
41                         }
42                         return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{cluster: cluster, tokens: tokens}), opts)
43                 }
44         }
45 }
46
47 // CurrentAuth returns the arvados.User whose privileges should be
48 // used in the given context, and the arvados.APIClientAuthorization
49 // the caller presented in order to authenticate the current request.
50 //
51 // Returns ErrUnauthenticated if the current request was not
52 // authenticated (no token provided, token is expired, etc).
53 func CurrentAuth(ctx context.Context) (*arvados.User, *arvados.APIClientAuthorization, error) {
54         ac, ok := ctx.Value(contextKeyAuth).(*authcontext)
55         if !ok {
56                 return nil, nil, ErrNoAuthContext
57         }
58         ac.lookupOnce.Do(func() { ac.user, ac.apiClientAuthorization, ac.err = aclookup(ctx, ac.cluster, ac.tokens) })
59         return ac.user, ac.apiClientAuthorization, ac.err
60 }
61
62 type contextKeyA string
63
64 var contextKeyAuth = contextKeyT("auth")
65
66 type authcontext struct {
67         cluster                *arvados.Cluster
68         tokens                 []string
69         user                   *arvados.User
70         apiClientAuthorization *arvados.APIClientAuthorization
71         err                    error
72         lookupOnce             sync.Once
73 }
74
75 func aclookup(ctx context.Context, cluster *arvados.Cluster, tokens []string) (*arvados.User, *arvados.APIClientAuthorization, error) {
76         if len(tokens) == 0 {
77                 return nil, nil, ErrUnauthenticated
78         }
79         tx, err := CurrentTx(ctx)
80         if err != nil {
81                 return nil, nil, err
82         }
83         var aca arvados.APIClientAuthorization
84         var user arvados.User
85         for _, token := range tokens {
86                 var cond string
87                 var args []interface{}
88                 if token == "" {
89                         continue
90                 } else if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' {
91                         fields := strings.Split(token, "/")
92                         cond = `aca.uuid=$1 and aca.api_token=$2`
93                         args = []interface{}{fields[1], fields[2]}
94                 } else {
95                         // Bare token or OIDC access token
96                         mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken))
97                         io.WriteString(mac, token)
98                         hmac := fmt.Sprintf("%x", mac.Sum(nil))
99                         cond = `aca.api_token in ($1, $2)`
100                         args = []interface{}{token, hmac}
101                 }
102                 var scopesJSON []byte
103                 err = tx.QueryRowContext(ctx, `
104 select aca.uuid, aca.expires_at, aca.api_token, aca.scopes, users.uuid, users.is_active, users.is_admin
105  from api_client_authorizations aca
106  left join users on aca.user_id = users.id
107  where `+cond+`
108  and (expires_at is null or expires_at > current_timestamp at time zone 'UTC')`, args...).Scan(
109                         &aca.UUID, &aca.ExpiresAt, &aca.APIToken, &scopesJSON,
110                         &user.UUID, &user.IsActive, &user.IsAdmin)
111                 if err == sql.ErrNoRows {
112                         continue
113                 } else if err != nil {
114                         return nil, nil, err
115                 }
116                 if len(scopesJSON) > 0 {
117                         err = json.Unmarshal(scopesJSON, &aca.Scopes)
118                         if err != nil {
119                                 return nil, nil, err
120                         }
121                 }
122                 return &user, &aca, nil
123         }
124         return nil, nil, ErrUnauthenticated
125 }