19388: Add user/auth context to ctrlctx.
authorTom Clegg <tom@curii.com>
Thu, 22 Sep 2022 18:03:01 +0000 (14:03 -0400)
committerTom Clegg <tom@curii.com>
Thu, 22 Sep 2022 18:03:01 +0000 (14:03 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/controller/handler.go
lib/ctrlctx/auth.go [new file with mode: 0644]
lib/ctrlctx/auth_test.go [new file with mode: 0644]

index 665fd5c636372fc4a21bd7de68c5d886aafbcc7c..e9c56db4d4b112b906dbaf36dd21b9a7a1300d98 100644 (file)
@@ -101,7 +101,10 @@ func (h *Handler) setup() {
        h.federation = federation.New(h.Cluster, &healthFuncs)
        rtr := router.New(h.federation, router.Config{
                MaxRequestSize: h.Cluster.API.MaxRequestSize,
-               WrapCalls:      api.ComposeWrappers(ctrlctx.WrapCallsInTransactions(h.db), oidcAuthorizer.WrapCalls),
+               WrapCalls: api.ComposeWrappers(
+                       ctrlctx.WrapCallsInTransactions(h.db),
+                       oidcAuthorizer.WrapCalls,
+                       ctrlctx.WrapCallsWithAuth(h.Cluster)),
        })
 
        healthRoutes := health.Routes{"ping": func() error { _, err := h.db(context.TODO()); return err }}
diff --git a/lib/ctrlctx/auth.go b/lib/ctrlctx/auth.go
new file mode 100644 (file)
index 0000000..5b96463
--- /dev/null
@@ -0,0 +1,125 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package ctrlctx
+
+import (
+       "context"
+       "crypto/hmac"
+       "crypto/sha256"
+       "database/sql"
+       "encoding/json"
+       "errors"
+       "fmt"
+       "io"
+       "strings"
+       "sync"
+
+       "git.arvados.org/arvados.git/lib/controller/api"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/auth"
+)
+
+var (
+       ErrNoAuthContext   = errors.New("bug: there is no authorization in this context")
+       ErrUnauthenticated = errors.New("unauthenticated request")
+)
+
+// WrapCallsWithAuth returns a call wrapper (suitable for assigning to
+// router.router.WrapCalls) that makes CurrentUser(ctx) et al. work
+// from inside the wrapped functions.
+//
+// The incoming context must come from WrapCallsInTransactions or
+// NewWithTransaction.
+func WrapCallsWithAuth(cluster *arvados.Cluster) func(api.RoutableFunc) api.RoutableFunc {
+       return func(origFunc api.RoutableFunc) api.RoutableFunc {
+               return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
+                       var tokens []string
+                       if creds, ok := auth.FromContext(ctx); ok {
+                               tokens = creds.Tokens
+                       }
+                       return origFunc(context.WithValue(ctx, contextKeyAuth, &authcontext{cluster: cluster, tokens: tokens}), opts)
+               }
+       }
+}
+
+// CurrentAuth returns the arvados.User whose privileges should be
+// used in the given context, and the arvados.APIClientAuthorization
+// the caller presented in order to authenticate the current request.
+//
+// Returns ErrUnauthenticated if the current request was not
+// authenticated (no token provided, token is expired, etc).
+func CurrentAuth(ctx context.Context) (*arvados.User, *arvados.APIClientAuthorization, error) {
+       ac, ok := ctx.Value(contextKeyAuth).(*authcontext)
+       if !ok {
+               return nil, nil, ErrNoAuthContext
+       }
+       ac.lookupOnce.Do(func() { ac.user, ac.apiClientAuthorization, ac.err = aclookup(ctx, ac.cluster, ac.tokens) })
+       return ac.user, ac.apiClientAuthorization, ac.err
+}
+
+type contextKeyA string
+
+var contextKeyAuth = contextKeyT("auth")
+
+type authcontext struct {
+       cluster                *arvados.Cluster
+       tokens                 []string
+       user                   *arvados.User
+       apiClientAuthorization *arvados.APIClientAuthorization
+       err                    error
+       lookupOnce             sync.Once
+}
+
+func aclookup(ctx context.Context, cluster *arvados.Cluster, tokens []string) (*arvados.User, *arvados.APIClientAuthorization, error) {
+       if len(tokens) == 0 {
+               return nil, nil, ErrUnauthenticated
+       }
+       tx, err := CurrentTx(ctx)
+       if err != nil {
+               return nil, nil, err
+       }
+       var aca arvados.APIClientAuthorization
+       var user arvados.User
+       for _, token := range tokens {
+               var cond string
+               var args []interface{}
+               if token == "" {
+                       continue
+               } else if len(token) > 30 && strings.HasPrefix(token, "v2/") && token[30] == '/' {
+                       fields := strings.Split(token, "/")
+                       cond = `aca.uuid=$1 and aca.api_token=$2`
+                       args = []interface{}{fields[1], fields[2]}
+               } else {
+                       // Bare token or OIDC access token
+                       mac := hmac.New(sha256.New, []byte(cluster.SystemRootToken))
+                       io.WriteString(mac, token)
+                       hmac := fmt.Sprintf("%x", mac.Sum(nil))
+                       cond = `aca.api_token in ($1, $2)`
+                       args = []interface{}{token, hmac}
+               }
+               var scopesJSON []byte
+               err = tx.QueryRowContext(ctx, `
+select aca.uuid, aca.expires_at, aca.api_token, aca.scopes, users.uuid, users.is_active, users.is_admin
+ from api_client_authorizations aca
+ left join users on aca.user_id = users.id
+ where `+cond+`
+ and (expires_at is null or expires_at > current_timestamp at time zone 'UTC')`, args...).Scan(
+                       &aca.UUID, &aca.ExpiresAt, &aca.APIToken, &scopesJSON,
+                       &user.UUID, &user.IsActive, &user.IsAdmin)
+               if err == sql.ErrNoRows {
+                       continue
+               } else if err != nil {
+                       return nil, nil, err
+               }
+               if len(scopesJSON) > 0 {
+                       err = json.Unmarshal(scopesJSON, &aca.Scopes)
+                       if err != nil {
+                               return nil, nil, err
+                       }
+               }
+               return &user, &aca, nil
+       }
+       return nil, nil, ErrUnauthenticated
+}
diff --git a/lib/ctrlctx/auth_test.go b/lib/ctrlctx/auth_test.go
new file mode 100644 (file)
index 0000000..add7a67
--- /dev/null
@@ -0,0 +1,79 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package ctrlctx
+
+import (
+       "context"
+
+       "git.arvados.org/arvados.git/lib/config"
+       "git.arvados.org/arvados.git/sdk/go/auth"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "github.com/jmoiron/sqlx"
+       _ "github.com/lib/pq"
+       check "gopkg.in/check.v1"
+)
+
+func (*DatabaseSuite) TestAuthContext(c *check.C) {
+       cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
+       c.Assert(err, check.IsNil)
+       cluster, err := cfg.GetCluster("")
+       c.Assert(err, check.IsNil)
+
+       getter := func(context.Context) (*sqlx.DB, error) {
+               return sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
+       }
+       authwrapper := WrapCallsWithAuth(cluster)
+       dbwrapper := WrapCallsInTransactions(getter)
+
+       // valid tokens
+       for _, token := range []string{
+               "3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi",
+               "v2/zzzzz-gj3su-077z32aux8dg2s1/3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi",
+               "v2/zzzzz-gj3su-077z32aux8dg2s1/3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi/asdfasdfasdf",
+       } {
+               ok, err := dbwrapper(authwrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
+                       user, aca, err := CurrentAuth(ctx)
+                       if c.Check(err, check.IsNil) {
+                               c.Check(user.UUID, check.Equals, "zzzzz-tpzed-xurymjxw79nv3jz")
+                               c.Check(aca.UUID, check.Equals, "zzzzz-gj3su-077z32aux8dg2s1")
+                               c.Check(aca.Scopes, check.DeepEquals, []string{"all"})
+                       }
+                       return true, nil
+               }))(auth.NewContext(context.Background(), auth.NewCredentials(token)), "blah")
+               c.Check(ok, check.Equals, true)
+               c.Check(err, check.IsNil)
+       }
+
+       // bad tokens
+       for _, token := range []string{
+               "3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmI", // note last char mangled
+               "v2/zzzzz-gj3su-077z32aux8dg2s1/",
+               "bogus",
+               "",
+       } {
+               ok, err := dbwrapper(authwrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
+                       user, aca, err := CurrentAuth(ctx)
+                       c.Check(err, check.Equals, ErrUnauthenticated)
+                       c.Check(user, check.IsNil)
+                       c.Check(aca, check.IsNil)
+                       return true, err
+               }))(auth.NewContext(context.Background(), auth.NewCredentials(token)), "blah")
+               c.Check(ok, check.Equals, true)
+               c.Check(err, check.Equals, ErrUnauthenticated)
+       }
+
+       // no auth context
+       {
+               ok, err := dbwrapper(authwrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
+                       user, aca, err := CurrentAuth(ctx)
+                       c.Check(err, check.Equals, ErrUnauthenticated)
+                       c.Check(user, check.IsNil)
+                       c.Check(aca, check.IsNil)
+                       return true, err
+               }))(context.Background(), "blah")
+               c.Check(ok, check.Equals, true)
+               c.Check(err, check.Equals, ErrUnauthenticated)
+       }
+}