1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
12 "git.arvados.org/arvados.git/lib/controller/api"
13 "git.arvados.org/arvados.git/sdk/go/ctxlog"
14 "github.com/jmoiron/sqlx"
15 // sqlx needs lib/pq to talk to PostgreSQL
20 ErrNoTransaction = errors.New("bug: there is no transaction in this context")
21 ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
24 // WrapCallsInTransactions returns a call wrapper (suitable for
25 // assigning to router.router.WrapCalls) that starts a new transaction
26 // for each API call, and commits only if the call succeeds.
28 // The wrapper calls getdb() to get a database handle before each API
30 func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(api.RoutableFunc) api.RoutableFunc {
31 return func(origFunc api.RoutableFunc) api.RoutableFunc {
32 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
33 ctx, finishtx := New(ctx, getdb)
35 return origFunc(ctx, opts)
40 // NewWithTransaction returns a child context in which the given
41 // transaction will be used by any localdb API call that needs one.
42 // The caller is responsible for calling Commit or Rollback on tx.
43 func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
44 txn := &transaction{tx: tx}
45 txn.setup.Do(func() {})
46 return context.WithValue(ctx, contextKeyTransaction, txn)
49 type contextKeyT string
51 var contextKeyTransaction = contextKeyT("transaction")
53 type transaction struct {
56 getdb func(context.Context) (*sqlx.DB, error)
60 type finishFunc func(*error)
62 // New returns a new child context that can be used with
63 // CurrentTx(). It does not open a database transaction until the
64 // first call to CurrentTx().
66 // The caller must eventually call the returned finishtx() func to
67 // commit or rollback the transaction, if any.
69 // func example(ctx context.Context) (err error) {
70 // ctx, finishtx := New(ctx, dber)
71 // defer finishtx(&err)
73 // tx, err := CurrentTx(ctx)
75 // return fmt.Errorf("example: %s", err)
77 // return tx.ExecContext(...)
80 // If *err is nil, finishtx() commits the transaction and assigns any
81 // resulting error to *err.
83 // If *err is non-nil, finishtx() rolls back the transaction, and
84 // does not modify *err.
85 func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
86 txn := &transaction{getdb: getdb}
87 return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
89 // Using (*sync.Once)Do() prevents a future
90 // call to CurrentTx() from opening a
91 // transaction which would never get committed
92 // or rolled back. If CurrentTx() hasn't been
93 // called before now, future calls will return
95 txn.err = ErrContextFinished
98 // we never [successfully] started a transaction
102 ctxlog.FromContext(ctx).Debug("rollback")
106 *err = txn.tx.Commit()
110 func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
111 txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
113 return nil, ErrNoTransaction
115 txn.setup.Do(func() {
116 if db, err := txn.getdb(ctx); err != nil {
119 txn.tx, txn.err = db.Beginx()
122 return txn.tx, txn.err