Merge branch '16534-testable'
[arvados.git] / lib / ctrlctx / db.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ctrlctx
6
7 import (
8         "context"
9         "errors"
10         "sync"
11
12         "git.arvados.org/arvados.git/lib/controller/api"
13         "git.arvados.org/arvados.git/sdk/go/ctxlog"
14         "github.com/jmoiron/sqlx"
15         _ "github.com/lib/pq"
16 )
17
18 var (
19         ErrNoTransaction   = errors.New("bug: there is no transaction in this context")
20         ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
21 )
22
23 // WrapCallsInTransactions returns a call wrapper (suitable for
24 // assigning to router.router.WrapCalls) that starts a new transaction
25 // for each API call, and commits only if the call succeeds.
26 //
27 // The wrapper calls getdb() to get a database handle before each API
28 // call.
29 func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(api.RoutableFunc) api.RoutableFunc {
30         return func(origFunc api.RoutableFunc) api.RoutableFunc {
31                 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
32                         ctx, finishtx := New(ctx, getdb)
33                         defer finishtx(&err)
34                         return origFunc(ctx, opts)
35                 }
36         }
37 }
38
39 // NewWithTransaction returns a child context in which the given
40 // transaction will be used by any localdb API call that needs one.
41 // The caller is responsible for calling Commit or Rollback on tx.
42 func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
43         txn := &transaction{tx: tx}
44         txn.setup.Do(func() {})
45         return context.WithValue(ctx, contextKeyTransaction, txn)
46 }
47
48 type contextKeyT string
49
50 var contextKeyTransaction = contextKeyT("transaction")
51
52 type transaction struct {
53         tx    *sqlx.Tx
54         err   error
55         getdb func(context.Context) (*sqlx.DB, error)
56         setup sync.Once
57 }
58
59 type finishFunc func(*error)
60
61 // New returns a new child context that can be used with
62 // CurrentTx(). It does not open a database transaction until the
63 // first call to CurrentTx().
64 //
65 // The caller must eventually call the returned finishtx() func to
66 // commit or rollback the transaction, if any.
67 //
68 //      func example(ctx context.Context) (err error) {
69 //              ctx, finishtx := New(ctx, dber)
70 //              defer finishtx(&err)
71 //              // ...
72 //              tx, err := CurrentTx(ctx)
73 //              if err != nil {
74 //                      return fmt.Errorf("example: %s", err)
75 //              }
76 //              return tx.ExecContext(...)
77 //      }
78 //
79 // If *err is nil, finishtx() commits the transaction and assigns any
80 // resulting error to *err.
81 //
82 // If *err is non-nil, finishtx() rolls back the transaction, and
83 // does not modify *err.
84 func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
85         txn := &transaction{getdb: getdb}
86         return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
87                 txn.setup.Do(func() {
88                         // Using (*sync.Once)Do() prevents a future
89                         // call to CurrentTx() from opening a
90                         // transaction which would never get committed
91                         // or rolled back. If CurrentTx() hasn't been
92                         // called before now, future calls will return
93                         // this error.
94                         txn.err = ErrContextFinished
95                 })
96                 if txn.tx == nil {
97                         // we never [successfully] started a transaction
98                         return
99                 }
100                 if *err != nil {
101                         ctxlog.FromContext(ctx).Debug("rollback")
102                         txn.tx.Rollback()
103                         return
104                 }
105                 *err = txn.tx.Commit()
106         }
107 }
108
109 func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
110         txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
111         if !ok {
112                 return nil, ErrNoTransaction
113         }
114         txn.setup.Do(func() {
115                 if db, err := txn.getdb(ctx); err != nil {
116                         txn.err = err
117                 } else {
118                         txn.tx, txn.err = db.Beginx()
119                 }
120         })
121         return txn.tx, txn.err
122 }