1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
12 "git.arvados.org/arvados.git/lib/controller/api"
13 "git.arvados.org/arvados.git/sdk/go/ctxlog"
14 "github.com/jmoiron/sqlx"
19 ErrNoTransaction = errors.New("bug: there is no transaction in this context")
20 ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
23 // WrapCallsInTransactions returns a call wrapper (suitable for
24 // assigning to router.router.WrapCalls) that starts a new transaction
25 // for each API call, and commits only if the call succeeds.
27 // The wrapper calls getdb() to get a database handle before each API
29 func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(api.RoutableFunc) api.RoutableFunc {
30 return func(origFunc api.RoutableFunc) api.RoutableFunc {
31 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
32 ctx, finishtx := New(ctx, getdb)
34 return origFunc(ctx, opts)
39 // NewWithTransaction returns a child context in which the given
40 // transaction will be used by any localdb API call that needs one.
41 // The caller is responsible for calling Commit or Rollback on tx.
42 func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
43 txn := &transaction{tx: tx}
44 txn.setup.Do(func() {})
45 return context.WithValue(ctx, contextKeyTransaction, txn)
48 type contextKeyT string
50 var contextKeyTransaction = contextKeyT("transaction")
52 type transaction struct {
55 getdb func(context.Context) (*sqlx.DB, error)
59 type finishFunc func(*error)
61 // New returns a new child context that can be used with
62 // CurrentTx(). It does not open a database transaction until the
63 // first call to CurrentTx().
65 // The caller must eventually call the returned finishtx() func to
66 // commit or rollback the transaction, if any.
68 // func example(ctx context.Context) (err error) {
69 // ctx, finishtx := New(ctx, dber)
70 // defer finishtx(&err)
72 // tx, err := CurrentTx(ctx)
74 // return fmt.Errorf("example: %s", err)
76 // return tx.ExecContext(...)
79 // If *err is nil, finishtx() commits the transaction and assigns any
80 // resulting error to *err.
82 // If *err is non-nil, finishtx() rolls back the transaction, and
83 // does not modify *err.
84 func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
85 txn := &transaction{getdb: getdb}
86 return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
88 // Using (*sync.Once)Do() prevents a future
89 // call to CurrentTx() from opening a
90 // transaction which would never get committed
91 // or rolled back. If CurrentTx() hasn't been
92 // called before now, future calls will return
94 txn.err = ErrContextFinished
97 // we never [successfully] started a transaction
101 ctxlog.FromContext(ctx).Debug("rollback")
105 *err = txn.tx.Commit()
109 func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
110 txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
112 return nil, ErrNoTransaction
114 txn.setup.Do(func() {
115 if db, err := txn.getdb(ctx); err != nil {
118 txn.tx, txn.err = db.Beginx()
121 return txn.tx, txn.err