--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package ctrlctx
+
+import (
+ "context"
+ "errors"
+ "sync"
+
+ "git.arvados.org/arvados.git/sdk/go/arvados"
+ "git.arvados.org/arvados.git/sdk/go/ctxlog"
+ "github.com/jmoiron/sqlx"
+ _ "github.com/lib/pq"
+)
+
+var (
+ ErrNoTransaction = errors.New("bug: there is no transaction in this context")
+ ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
+)
+
+// WrapCallsInTransactions returns a call wrapper (suitable for
+// assigning to router.router.WrapCalls) that starts a new transaction
+// for each API call, and commits only if the call succeeds.
+//
+// The wrapper calls getdb() to get a database handle before each API
+// call.
+func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(arvados.RoutableFunc) arvados.RoutableFunc {
+ return func(origFunc arvados.RoutableFunc) arvados.RoutableFunc {
+ return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
+ ctx, finishtx := New(ctx, getdb)
+ defer finishtx(&err)
+ return origFunc(ctx, opts)
+ }
+ }
+}
+
+// NewWithTransaction returns a child context in which the given
+// transaction will be used by any localdb API call that needs one.
+// The caller is responsible for calling Commit or Rollback on tx.
+func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
+ txn := &transaction{tx: tx}
+ txn.setup.Do(func() {})
+ return context.WithValue(ctx, contextKeyTransaction, txn)
+}
+
+type contextKeyT string
+
+var contextKeyTransaction = contextKeyT("transaction")
+
+type transaction struct {
+ tx *sqlx.Tx
+ err error
+ getdb func(context.Context) (*sqlx.DB, error)
+ setup sync.Once
+}
+
+type finishFunc func(*error)
+
+// New returns a new child context that can be used with
+// CurrentTx(). It does not open a database transaction until the
+// first call to CurrentTx().
+//
+// The caller must eventually call the returned finishtx() func to
+// commit or rollback the transaction, if any.
+//
+// func example(ctx context.Context) (err error) {
+// ctx, finishtx := NewContext(ctx, dber)
+// defer finishtx(&err)
+// // ...
+// tx, err := CurrentTx(ctx)
+// if err != nil {
+// return fmt.Errorf("example: %s", err)
+// }
+// return tx.ExecContext(...)
+// }
+//
+// If *err is nil, finishtx() commits the transaction and assigns any
+// resulting error to *err.
+//
+// If *err is non-nil, finishtx() rolls back the transaction, and
+// does not modify *err.
+func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
+ txn := &transaction{getdb: getdb}
+ return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
+ txn.setup.Do(func() {
+ // Using (*sync.Once)Do() prevents a future
+ // call to CurrentTx() from opening a
+ // transaction which would never get committed
+ // or rolled back. If CurrentTx() hasn't been
+ // called before now, future calls will return
+ // this error.
+ txn.err = ErrContextFinished
+ })
+ if txn.tx == nil {
+ // we never [successfully] started a transaction
+ return
+ }
+ if *err != nil {
+ ctxlog.FromContext(ctx).Debug("rollback")
+ txn.tx.Rollback()
+ return
+ }
+ *err = txn.tx.Commit()
+ }
+}
+
+func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
+ txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
+ if !ok {
+ return nil, ErrNoTransaction
+ }
+ txn.setup.Do(func() {
+ if db, err := txn.getdb(ctx); err != nil {
+ txn.err = err
+ } else {
+ txn.tx, txn.err = db.Beginx()
+ }
+ })
+ return txn.tx, txn.err
+}