+
+func OIDCAccessTokenAuthorizer(cluster *arvados.Cluster, getdb func(context.Context) (*sqlx.DB, error)) *oidcTokenAuthorizer {
+ // We want ctrl to be nil if the chosen controller is not a
+ // *oidcLoginController, so we can ignore the 2nd return value
+ // of this type cast.
+ ctrl, _ := NewConn(cluster).loginController.(*oidcLoginController)
+ cache, err := lru.New2Q(tokenCacheSize)
+ if err != nil {
+ panic(err)
+ }
+ return &oidcTokenAuthorizer{
+ ctrl: ctrl,
+ getdb: getdb,
+ cache: cache,
+ }
+}
+
+type oidcTokenAuthorizer struct {
+ ctrl *oidcLoginController
+ getdb func(context.Context) (*sqlx.DB, error)
+ cache *lru.TwoQueueCache
+}
+
+func (ta *oidcTokenAuthorizer) Middleware(w http.ResponseWriter, r *http.Request, next http.Handler) {
+ if ta.ctrl == nil {
+ // Not using a compatible (OIDC) login controller.
+ } else if authhdr := strings.Split(r.Header.Get("Authorization"), " "); len(authhdr) > 1 && (authhdr[0] == "OAuth2" || authhdr[0] == "Bearer") {
+ err := ta.registerToken(r.Context(), authhdr[1])
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ }
+ next.ServeHTTP(w, r)
+}
+
+func (ta *oidcTokenAuthorizer) WrapCalls(origFunc api.RoutableFunc) api.RoutableFunc {
+ if ta.ctrl == nil {
+ // Not using a compatible (OIDC) login controller.
+ return origFunc
+ }
+ return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
+ creds, ok := auth.FromContext(ctx)
+ if !ok {
+ return origFunc(ctx, opts)
+ }
+ // Check each token in the incoming request. If any
+ // are valid OAuth2 access tokens, insert/update them
+ // in the database so RailsAPI's auth code accepts
+ // them.
+ for _, tok := range creds.Tokens {
+ err = ta.registerToken(ctx, tok)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return origFunc(ctx, opts)
+ }
+}
+
+// Matches error from oidc UserInfo() when receiving HTTP status 5xx
+var re5xxError = regexp.MustCompile(`^5\d\d `)
+
+// registerToken checks whether tok is a valid OIDC Access Token and,
+// if so, ensures that an api_client_authorizations row exists so that
+// RailsAPI will accept it as an Arvados token.
+func (ta *oidcTokenAuthorizer) registerToken(ctx context.Context, tok string) error {
+ if tok == ta.ctrl.Cluster.SystemRootToken || strings.HasPrefix(tok, "v2/") {
+ return nil
+ }
+ if cached, hit := ta.cache.Get(tok); !hit {
+ // Fall through to database and OIDC provider checks
+ // below
+ } else if exp, ok := cached.(time.Time); ok {
+ // cached negative result (value is expiry time)
+ if time.Now().Before(exp) {
+ return nil
+ }
+ ta.cache.Remove(tok)
+ } else {
+ // cached positive result
+ aca := cached.(arvados.APIClientAuthorization)
+ var expiring bool
+ if !aca.ExpiresAt.IsZero() {
+ t := aca.ExpiresAt
+ expiring = t.Before(time.Now().Add(time.Minute))
+ }
+ if !expiring {
+ return nil
+ }
+ }
+
+ db, err := ta.getdb(ctx)
+ if err != nil {
+ return err
+ }
+ tx, err := db.Beginx()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+ ctx = ctrlctx.NewWithTransaction(ctx, tx)
+
+ // We use hmac-sha256(accesstoken,systemroottoken) as the
+ // secret part of our own token, and avoid storing the auth
+ // provider's real secret in our database.
+ mac := hmac.New(sha256.New, []byte(ta.ctrl.Cluster.SystemRootToken))
+ io.WriteString(mac, tok)
+ hmac := fmt.Sprintf("%x", mac.Sum(nil))
+
+ var expiring bool
+ err = tx.QueryRowContext(ctx, `select (expires_at is not null and expires_at - interval '1 minute' <= current_timestamp at time zone 'UTC') from api_client_authorizations where api_token=$1`, hmac).Scan(&expiring)
+ if err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("database error while checking token: %w", err)
+ } else if err == nil && !expiring {
+ // Token is already in the database as an Arvados
+ // token, and isn't about to expire, so we can pass it
+ // through to RailsAPI etc. regardless of whether it's
+ // an OIDC access token.
+ return nil
+ }
+ updating := err == nil
+
+ // Check whether the token is a valid OIDC access token. If
+ // so, swap it out for an Arvados token (creating/updating an
+ // api_client_authorizations row if needed) which downstream
+ // server components will accept.
+ err = ta.ctrl.setup()
+ if err != nil {
+ return fmt.Errorf("error setting up OpenID Connect provider: %s", err)
+ }
+ if ok, err := ta.checkAccessTokenScope(ctx, tok); err != nil || !ok {
+ // Note checkAccessTokenScope logs any interesting errors
+ ta.cache.Add(tok, time.Now().Add(tokenCacheNegativeTTL))
+ return err
+ }
+ oauth2Token := &oauth2.Token{
+ AccessToken: tok,
+ }
+ userinfo, err := ta.ctrl.provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
+ if err != nil {
+ if neterr := net.Error(nil); errors.As(err, &neterr) || re5xxError.MatchString(err.Error()) {
+ // If this token is in fact a valid OIDC
+ // token, but we failed to validate it here
+ // because of a network problem or internal
+ // server error, we error out now with a 5xx
+ // error, indicating to the client that they
+ // can try again. If we didn't error out now,
+ // the unrecognized token would eventually
+ // cause a 401 error further down the stack,
+ // which the caller would interpret as an
+ // unrecoverable failure.
+ ctxlog.FromContext(ctx).WithError(err).Debugf("treating OIDC UserInfo lookup error type %T as transient; failing request instead of forwarding token blindly", err)
+ return err
+ }
+ ctxlog.FromContext(ctx).WithError(err).WithField("HMAC", hmac).Debug("UserInfo failed (not an OIDC token?), caching negative result")
+ ta.cache.Add(tok, time.Now().Add(tokenCacheNegativeTTL))
+ return nil
+ }
+ ctxlog.FromContext(ctx).WithField("userinfo", userinfo).Debug("(*oidcTokenAuthorizer)registerToken: got userinfo")
+ authinfo, err := ta.ctrl.getAuthInfo(ctx, oauth2Token, userinfo)
+ if err != nil {
+ return err
+ }
+
+ // Expiry time for our token is one minute longer than our
+ // cache TTL, so we don't pass it through to RailsAPI just as
+ // it's expiring.
+ exp := time.Now().UTC().Add(tokenCacheTTL + tokenCacheRaceWindow)
+
+ if updating {
+ _, err = tx.ExecContext(ctx, `update api_client_authorizations set expires_at=$1 where api_token=$2`, exp, hmac)
+ if err != nil {
+ return fmt.Errorf("error updating token expiry time: %w", err)
+ }
+ ctxlog.FromContext(ctx).WithField("HMAC", hmac).Debug("(*oidcTokenAuthorizer)registerToken: updated api_client_authorizations row")
+ } else {
+ aca, err := ta.ctrl.Parent.CreateAPIClientAuthorization(ctx, ta.ctrl.Cluster.SystemRootToken, *authinfo)
+ if err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(ctx, `savepoint upd`)
+ if err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(ctx, `update api_client_authorizations set api_token=$1, expires_at=$2 where uuid=$3`, hmac, exp, aca.UUID)
+ if e, ok := err.(*pq.Error); ok && e.Code == pqCodeUniqueViolation {
+ // unique_violation, given that the above
+ // query did not find a row with matching
+ // api_token, means another thread/process
+ // also received this same token and won the
+ // race to insert it -- in which case this
+ // thread doesn't need to update the database.
+ // Discard the redundant row.
+ _, err = tx.ExecContext(ctx, `rollback to savepoint upd`)
+ if err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(ctx, `delete from api_client_authorizations where uuid=$1`, aca.UUID)
+ if err != nil {
+ return err
+ }
+ ctxlog.FromContext(ctx).WithField("HMAC", hmac).Debug("(*oidcTokenAuthorizer)registerToken: api_client_authorizations row inserted by another thread")
+ } else if err != nil {
+ ctxlog.FromContext(ctx).Errorf("%#v", err)
+ return fmt.Errorf("error adding OIDC access token to database: %w", err)
+ } else {
+ ctxlog.FromContext(ctx).WithFields(logrus.Fields{"UUID": aca.UUID, "HMAC": hmac}).Debug("(*oidcTokenAuthorizer)registerToken: inserted api_client_authorizations row")
+ }
+ }
+ err = tx.Commit()
+ if err != nil {
+ return err
+ }
+ ta.cache.Add(tok, arvados.APIClientAuthorization{ExpiresAt: exp})
+ return nil
+}
+
+// Check that the provided access token is a JWT with the required
+// scope. If it is a valid JWT but missing the required scope, we
+// return a 403 error, otherwise true (acceptable as an API token) or
+// false (pass through unmodified).
+//
+// Return false if configured not to accept access tokens at all.
+//
+// Note we don't check signature or expiry here. We are relying on the
+// caller to verify those separately (e.g., by calling the UserInfo
+// endpoint).
+func (ta *oidcTokenAuthorizer) checkAccessTokenScope(ctx context.Context, tok string) (bool, error) {
+ if !ta.ctrl.AcceptAccessToken {
+ return false, nil
+ } else if ta.ctrl.AcceptAccessTokenScope == "" {
+ return true, nil
+ }
+ var claims struct {
+ Scope string `json:"scope"`
+ }
+ if t, err := jwt.ParseSigned(tok); err != nil {
+ ctxlog.FromContext(ctx).WithError(err).Debug("error parsing jwt")
+ return false, nil
+ } else if err = t.UnsafeClaimsWithoutVerification(&claims); err != nil {
+ ctxlog.FromContext(ctx).WithError(err).Debug("error extracting jwt claims")
+ return false, nil
+ }
+ for _, s := range strings.Split(claims.Scope, " ") {
+ if s == ta.ctrl.AcceptAccessTokenScope {
+ return true, nil
+ }
+ }
+ ctxlog.FromContext(ctx).WithFields(logrus.Fields{"have": claims.Scope, "need": ta.ctrl.AcceptAccessTokenScope}).Info("unacceptable access token scope")
+ return false, httpserver.ErrorWithStatus(errors.New("unacceptable access token scope"), http.StatusUnauthorized)
+}