16669: Accept OIDC access token in lieu of arvados api token.
authorTom Clegg <tom@tomclegg.ca>
Fri, 4 Sep 2020 15:20:29 +0000 (11:20 -0400)
committerTom Clegg <tom@tomclegg.ca>
Wed, 21 Oct 2020 12:57:55 +0000 (08:57 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@tomclegg.ca>

lib/controller/api/routable.go
lib/controller/auth_test.go [new file with mode: 0644]
lib/controller/handler.go
lib/controller/localdb/login_oidc.go
sdk/go/arvadostest/oidc_provider.go

index 6049cba8e403a576b4eba4774a6bfb192f4fa53c..3003ea2df76035cac71ae8bc8916c3e7eaa0a7d4 100644 (file)
@@ -15,3 +15,16 @@ import "context"
 // it to the router package would cause a circular dependency
 // router->arvadostest->ctrlctx->router.)
 type RoutableFunc func(ctx context.Context, opts interface{}) (interface{}, error)
+
+type RoutableFuncWrapper func(RoutableFunc) RoutableFunc
+
+// ComposeWrappers(w1, w2, w3, ...) returns a RoutableFuncWrapper that
+// composes w1, w2, w3, ... such that w1 is the outermost wrapper.
+func ComposeWrappers(wraps ...RoutableFuncWrapper) RoutableFuncWrapper {
+       return func(f RoutableFunc) RoutableFunc {
+               for i := len(wraps) - 1; i >= 0; i-- {
+                       f = wraps[i](f)
+               }
+               return f
+       }
+}
diff --git a/lib/controller/auth_test.go b/lib/controller/auth_test.go
new file mode 100644 (file)
index 0000000..a188c30
--- /dev/null
@@ -0,0 +1,118 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+       "context"
+       "encoding/json"
+       "fmt"
+       "net/http"
+       "net/http/httptest"
+       "os"
+       "time"
+
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/arvadostest"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "git.arvados.org/arvados.git/sdk/go/httpserver"
+       "github.com/sirupsen/logrus"
+       check "gopkg.in/check.v1"
+)
+
+// Gocheck boilerplate
+var _ = check.Suite(&AuthSuite{})
+
+type AuthSuite struct {
+       log logrus.FieldLogger
+       // testServer and testHandler are the controller being tested,
+       // "zhome".
+       testServer  *httpserver.Server
+       testHandler *Handler
+       // remoteServer ("zzzzz") forwards requests to the Rails API
+       // provided by the integration test environment.
+       remoteServer *httpserver.Server
+       // remoteMock ("zmock") appends each incoming request to
+       // remoteMockRequests, and returns 200 with an empty JSON
+       // object.
+       remoteMock         *httpserver.Server
+       remoteMockRequests []http.Request
+
+       fakeProvider *arvadostest.OIDCProvider
+}
+
+func (s *AuthSuite) SetUpTest(c *check.C) {
+       s.log = ctxlog.TestLogger(c)
+
+       s.remoteServer = newServerFromIntegrationTestEnv(c)
+       c.Assert(s.remoteServer.Start(), check.IsNil)
+
+       s.remoteMock = newServerFromIntegrationTestEnv(c)
+       s.remoteMock.Server.Handler = http.HandlerFunc(http.NotFound)
+       c.Assert(s.remoteMock.Start(), check.IsNil)
+
+       s.fakeProvider = arvadostest.NewOIDCProvider(c)
+       s.fakeProvider.AuthEmail = "active-user@arvados.local"
+       s.fakeProvider.AuthEmailVerified = true
+       s.fakeProvider.AuthName = "Fake User Name"
+       s.fakeProvider.ValidCode = fmt.Sprintf("abcdefgh-%d", time.Now().Unix())
+       s.fakeProvider.PeopleAPIResponse = map[string]interface{}{}
+       s.fakeProvider.ValidClientID = "test%client$id"
+       s.fakeProvider.ValidClientSecret = "test#client/secret"
+
+       cluster := &arvados.Cluster{
+               ClusterID:        "zhome",
+               PostgreSQL:       integrationTestCluster().PostgreSQL,
+               ForceLegacyAPI14: forceLegacyAPI14,
+               SystemRootToken:  arvadostest.SystemRootToken,
+       }
+       cluster.TLS.Insecure = true
+       cluster.API.MaxItemsPerResponse = 1000
+       cluster.API.MaxRequestAmplification = 4
+       cluster.API.RequestTimeout = arvados.Duration(5 * time.Minute)
+       arvadostest.SetServiceURL(&cluster.Services.RailsAPI, "https://"+os.Getenv("ARVADOS_TEST_API_HOST"))
+       arvadostest.SetServiceURL(&cluster.Services.Controller, "http://localhost/")
+
+       cluster.RemoteClusters = map[string]arvados.RemoteCluster{
+               "zzzzz": {
+                       Host:   s.remoteServer.Addr,
+                       Proxy:  true,
+                       Scheme: "http",
+               },
+               "zmock": {
+                       Host:   s.remoteMock.Addr,
+                       Proxy:  true,
+                       Scheme: "http",
+               },
+               "*": {
+                       Scheme: "https",
+               },
+       }
+       cluster.Login.OpenIDConnect.Enable = true
+       cluster.Login.OpenIDConnect.Issuer = s.fakeProvider.Issuer.URL
+       cluster.Login.OpenIDConnect.ClientID = s.fakeProvider.ValidClientID
+       cluster.Login.OpenIDConnect.ClientSecret = s.fakeProvider.ValidClientSecret
+       cluster.Login.OpenIDConnect.EmailClaim = "email"
+       cluster.Login.OpenIDConnect.EmailVerifiedClaim = "email_verified"
+
+       s.testHandler = &Handler{Cluster: cluster}
+       s.testServer = newServerFromIntegrationTestEnv(c)
+       s.testServer.Server.Handler = httpserver.HandlerWithContext(
+               ctxlog.Context(context.Background(), s.log),
+               httpserver.AddRequestIDs(httpserver.LogRequests(s.testHandler)))
+       c.Assert(s.testServer.Start(), check.IsNil)
+}
+
+func (s *AuthSuite) TestLocalOIDCAccessToken(c *check.C) {
+       req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+       req.Header.Set("Authorization", "Bearer "+s.fakeProvider.ValidAccessToken())
+       rr := httptest.NewRecorder()
+       s.testServer.Server.Handler.ServeHTTP(rr, req)
+       resp := rr.Result()
+       c.Check(resp.StatusCode, check.Equals, http.StatusOK)
+       var u arvados.User
+       c.Check(json.NewDecoder(resp.Body).Decode(&u), check.IsNil)
+       c.Check(u.UUID, check.Equals, arvadostest.ActiveUserUUID)
+       c.Check(u.OwnerUUID, check.Equals, "zzzzz-tpzed-000000000000000")
+}
index 2dd1d816e060a752fb8e71d4eeaacc5d0b3cfb9b..25bba558dc68096796143b1d9bd4483d07a6f44f 100644 (file)
@@ -14,7 +14,9 @@ import (
        "sync"
        "time"
 
+       "git.arvados.org/arvados.git/lib/controller/api"
        "git.arvados.org/arvados.git/lib/controller/federation"
+       "git.arvados.org/arvados.git/lib/controller/localdb"
        "git.arvados.org/arvados.git/lib/controller/railsproxy"
        "git.arvados.org/arvados.git/lib/controller/router"
        "git.arvados.org/arvados.git/lib/ctrlctx"
@@ -87,7 +89,8 @@ func (h *Handler) setup() {
                Routes: health.Routes{"ping": func() error { _, err := h.db(context.TODO()); return err }},
        })
 
-       rtr := router.New(federation.New(h.Cluster), ctrlctx.WrapCallsInTransactions(h.db))
+       oidcAuthorizer := localdb.OIDCAccessTokenAuthorizer(h.Cluster, h.db)
+       rtr := router.New(federation.New(h.Cluster), api.ComposeWrappers(ctrlctx.WrapCallsInTransactions(h.db), oidcAuthorizer.WrapCalls))
        mux.Handle("/arvados/v1/config", rtr)
        mux.Handle("/"+arvados.EndpointUserAuthenticate.Path, rtr)
 
@@ -103,6 +106,7 @@ func (h *Handler) setup() {
        hs := http.NotFoundHandler()
        hs = prepend(hs, h.proxyRailsAPI)
        hs = h.setupProxyRemoteCluster(hs)
+       hs = prepend(hs, oidcAuthorizer.Middleware)
        mux.Handle("/", hs)
        h.handlerStack = mux
 
index e0b01f13ebee8c4f01084d0dc4c8dca76804e696..60de70b5d2e51107e48649d615e1caf5e19c3f5e 100644 (file)
@@ -9,9 +9,11 @@ import (
        "context"
        "crypto/hmac"
        "crypto/sha256"
+       "database/sql"
        "encoding/base64"
        "errors"
        "fmt"
+       "io"
        "net/http"
        "net/url"
        "strings"
@@ -19,17 +21,28 @@ import (
        "text/template"
        "time"
 
+       "git.arvados.org/arvados.git/lib/controller/api"
+       "git.arvados.org/arvados.git/lib/controller/railsproxy"
        "git.arvados.org/arvados.git/lib/controller/rpc"
+       "git.arvados.org/arvados.git/lib/ctrlctx"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "git.arvados.org/arvados.git/sdk/go/httpserver"
        "github.com/coreos/go-oidc"
+       lru "github.com/hashicorp/golang-lru"
+       "github.com/jmoiron/sqlx"
        "golang.org/x/oauth2"
        "google.golang.org/api/option"
        "google.golang.org/api/people/v1"
 )
 
+const (
+       tokenCacheSize        = 1000
+       tokenCacheNegativeTTL = time.Minute * 5
+       tokenCacheTTL         = time.Minute * 10
+)
+
 type oidcLoginController struct {
        Cluster            *arvados.Cluster
        RailsProxy         *railsProxy
@@ -139,17 +152,23 @@ func (ctrl *oidcLoginController) UserAuthenticate(ctx context.Context, opts arva
        return arvados.APIClientAuthorization{}, httpserver.ErrorWithStatus(errors.New("username/password authentication is not available"), http.StatusBadRequest)
 }
 
+// claimser can decode arbitrary claims into a map. Implemented by
+// *oauth2.IDToken and *oauth2.UserInfo.
+type claimser interface {
+       Claims(interface{}) error
+}
+
 // Use a person's token to get all of their email addresses, with the
 // primary address at index 0. The provided defaultAddr is always
 // included in the returned slice, and is used as the primary if the
 // Google API does not indicate one.
-func (ctrl *oidcLoginController) getAuthInfo(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*rpc.UserSessionAuthInfo, error) {
+func (ctrl *oidcLoginController) getAuthInfo(ctx context.Context, token *oauth2.Token, claimser claimser) (*rpc.UserSessionAuthInfo, error) {
        var ret rpc.UserSessionAuthInfo
        defer ctxlog.FromContext(ctx).WithField("ret", &ret).Debug("getAuthInfo returned")
 
        var claims map[string]interface{}
-       if err := idToken.Claims(&claims); err != nil {
-               return nil, fmt.Errorf("error extracting claims from ID token: %s", err)
+       if err := claimser.Claims(&claims); err != nil {
+               return nil, fmt.Errorf("error extracting claims from token: %s", err)
        } else if verified, _ := claims[ctrl.EmailVerifiedClaim].(bool); verified || ctrl.EmailVerifiedClaim == "" {
                // Fall back to this info if the People API call
                // (below) doesn't return a primary && verified email.
@@ -297,3 +316,171 @@ func (s oauth2State) computeHMAC(key []byte) []byte {
        fmt.Fprintf(mac, "%x %s %s", s.Time, s.Remote, s.ReturnTo)
        return mac.Sum(nil)
 }
+
+func OIDCAccessTokenAuthorizer(cluster *arvados.Cluster, getdb func(context.Context) (*sqlx.DB, error)) *oidcTokenAuthorizer {
+       // We want ctrl to be nil if the chosen controller is not a
+       // *oidcLoginController, so we can ignore the 2nd return value
+       // of this type cast.
+       ctrl, _ := chooseLoginController(cluster, railsproxy.NewConn(cluster)).(*oidcLoginController)
+       cache, err := lru.New2Q(tokenCacheSize)
+       if err != nil {
+               panic(err)
+       }
+       return &oidcTokenAuthorizer{
+               ctrl:  ctrl,
+               getdb: getdb,
+               cache: cache,
+       }
+}
+
+type oidcTokenAuthorizer struct {
+       ctrl  *oidcLoginController
+       getdb func(context.Context) (*sqlx.DB, error)
+       cache *lru.TwoQueueCache
+}
+
+func (ta *oidcTokenAuthorizer) Middleware(w http.ResponseWriter, r *http.Request, next http.Handler) {
+       if authhdr := strings.Split(r.Header.Get("Authorization"), " "); len(authhdr) > 1 && (authhdr[0] == "OAuth2" || authhdr[0] == "Bearer") {
+               tok, err := ta.exchangeToken(r.Context(), authhdr[1])
+               if err != nil {
+                       http.Error(w, err.Error(), http.StatusInternalServerError)
+                       return
+               }
+               r.Header.Set("Authorization", "Bearer "+tok)
+       }
+       next.ServeHTTP(w, r)
+}
+
+func (ta *oidcTokenAuthorizer) WrapCalls(origFunc api.RoutableFunc) api.RoutableFunc {
+       if ta.ctrl == nil {
+               // Not using a compatible (OIDC) login controller.
+               return origFunc
+       }
+       return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
+               creds, ok := auth.FromContext(ctx)
+               if !ok {
+                       return origFunc(ctx, opts)
+               }
+               // Check each token in the incoming request. If any
+               // are OAuth2 access tokens, swap them out for Arvados
+               // tokens.
+               for tokidx, tok := range creds.Tokens {
+                       tok, err = ta.exchangeToken(ctx, tok)
+                       if err != nil {
+                               return nil, err
+                       }
+                       creds.Tokens[tokidx] = tok
+               }
+               ctxlog.FromContext(ctx).WithField("creds", creds).Debug("(*oidcTokenAuthorizer)WrapCalls: new creds")
+               ctx = auth.NewContext(ctx, creds)
+               return origFunc(ctx, opts)
+       }
+}
+
+func (ta *oidcTokenAuthorizer) exchangeToken(ctx context.Context, tok string) (string, error) {
+       if strings.HasPrefix(tok, "v2/") {
+               return tok, nil
+       }
+       if cached, hit := ta.cache.Get(tok); !hit {
+               // Fall through to database and OIDC provider checks
+               // below
+       } else if exp, ok := cached.(time.Time); ok {
+               // cached negative result (value is expiry time)
+               if time.Now().Before(exp) {
+                       return tok, nil
+               } else {
+                       ta.cache.Remove(tok)
+               }
+       } else {
+               // cached positive result
+               aca := cached.(*arvados.APIClientAuthorization)
+               var expiring bool
+               if aca.ExpiresAt != "" {
+                       t, err := time.Parse(time.RFC3339Nano, aca.ExpiresAt)
+                       if err != nil {
+                               return "", fmt.Errorf("error parsing expires_at value: %w", err)
+                       }
+                       expiring = t.Before(time.Now().Add(time.Minute))
+               }
+               if !expiring {
+                       return aca.TokenV2(), nil
+               }
+       }
+
+       db, err := ta.getdb(ctx)
+       if err != nil {
+               return "", err
+       }
+       tx, err := db.Beginx()
+       if err != nil {
+               return "", err
+       }
+       defer tx.Rollback()
+       ctx = ctrlctx.NewWithTransaction(ctx, tx)
+
+       // We use hmac-sha256(accesstoken,systemroottoken) as the
+       // secret part of our own token, and avoid storing the auth
+       // provider's real secret in our database.
+       mac := hmac.New(sha256.New, []byte(ta.ctrl.Cluster.SystemRootToken))
+       io.WriteString(mac, tok)
+       hmac := fmt.Sprintf("%x", mac.Sum(nil))
+
+       var expiring bool
+       err = tx.QueryRowContext(ctx, `select (expires_at is not null and expires_at - interval '1 minute' <= current_timestamp at time zone 'UTC') from api_client_authorizations where api_token=$1`, hmac).Scan(&expiring)
+       if err != nil && err != sql.ErrNoRows {
+               return "", fmt.Errorf("database error while checking token: %w", err)
+       } else if err == nil && !expiring {
+               // Token is already in the database as an Arvados
+               // token, and isn't about to expire, so we can pass it
+               // through to RailsAPI etc. regardless of whether it's
+               // an OIDC access token.
+               return tok, nil
+       }
+       updating := err == nil
+
+       // Check whether the token is a valid OIDC access token. If
+       // so, swap it out for an Arvados token (creating/updating an
+       // api_client_authorizations row if needed) which downstream
+       // server components will accept.
+       err = ta.ctrl.setup()
+       if err != nil {
+               return "", fmt.Errorf("error setting up OpenID Connect provider: %s", err)
+       }
+       oauth2Token := &oauth2.Token{
+               AccessToken: tok,
+       }
+       userinfo, err := ta.ctrl.provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
+       if err != nil {
+               ta.cache.Add(tok, time.Now().Add(tokenCacheNegativeTTL))
+               return tok, nil
+       }
+       ctxlog.FromContext(ctx).WithField("userinfo", userinfo).Debug("(*oidcTokenAuthorizer)exchangeToken: got userinfo")
+       authinfo, err := ta.ctrl.getAuthInfo(ctx, oauth2Token, userinfo)
+       if err != nil {
+               return "", err
+       }
+
+       var aca arvados.APIClientAuthorization
+       if updating {
+               _, err = tx.ExecContext(ctx, `update api_client_authorizations set expires_at=$1 where api_token=$2`, time.Now().Add(tokenCacheTTL+time.Minute), hmac)
+               if err != nil {
+                       return "", fmt.Errorf("error adding OIDC access token to database: %w", err)
+               }
+       } else {
+               aca, err = createAPIClientAuthorization(ctx, ta.ctrl.RailsProxy, ta.ctrl.Cluster.SystemRootToken, *authinfo)
+               if err != nil {
+                       return "", err
+               }
+               _, err = tx.ExecContext(ctx, `update api_client_authorizations set api_token=$1 where uuid=$2`, hmac, aca.UUID)
+               if err != nil {
+                       return "", fmt.Errorf("error adding OIDC access token to database: %w", err)
+               }
+               aca.APIToken = hmac
+       }
+       err = tx.Commit()
+       if err != nil {
+               return "", err
+       }
+       ta.cache.Add(tok, aca)
+       return aca.TokenV2(), nil
+}
index 0632010ba4d3a2f8065e23f126ec592ec7748dd3..96205f919fa79b813721af4304bdbc27084e4b7f 100644 (file)
@@ -47,6 +47,10 @@ func NewOIDCProvider(c *check.C) *OIDCProvider {
        return p
 }
 
+func (p *OIDCProvider) ValidAccessToken() string {
+       return p.fakeToken([]byte("fake access token"))
+}
+
 func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
        req.ParseForm()
        p.c.Logf("serveOIDC: got req: %s %s %s", req.Method, req.URL, req.Form)
@@ -99,7 +103,7 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
                        ExpiresIn    int32  `json:"expires_in"`
                        IDToken      string `json:"id_token"`
                }{
-                       AccessToken:  p.fakeToken([]byte("fake access token")),
+                       AccessToken:  p.ValidAccessToken(),
                        TokenType:    "Bearer",
                        RefreshToken: "test-refresh-token",
                        ExpiresIn:    30,
@@ -114,7 +118,20 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
        case "/auth":
                w.WriteHeader(http.StatusInternalServerError)
        case "/userinfo":
-               w.WriteHeader(http.StatusInternalServerError)
+               if authhdr := req.Header.Get("Authorization"); strings.TrimPrefix(authhdr, "Bearer ") != p.ValidAccessToken() {
+                       p.c.Logf("OIDCProvider: bad auth %q", authhdr)
+                       w.WriteHeader(http.StatusUnauthorized)
+                       return
+               }
+               json.NewEncoder(w).Encode(map[string]interface{}{
+                       "sub":            "fake-user-id",
+                       "name":           p.AuthName,
+                       "given_name":     p.AuthName,
+                       "family_name":    "",
+                       "alt_username":   "desired-username",
+                       "email":          p.AuthEmail,
+                       "email_verified": p.AuthEmailVerified,
+               })
        default:
                w.WriteHeader(http.StatusNotFound)
        }