a76420860604b9a6fb9823bdc6b3775c70f85ff4
[arvados.git] / lib / ctrlctx / db.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ctrlctx
6
7 import (
8         "context"
9         "errors"
10         "sync"
11
12         "git.arvados.org/arvados.git/lib/controller/api"
13         "git.arvados.org/arvados.git/sdk/go/ctxlog"
14         "github.com/jmoiron/sqlx"
15
16         // sqlx needs lib/pq to talk to PostgreSQL
17         _ "github.com/lib/pq"
18 )
19
20 var (
21         ErrNoTransaction   = errors.New("bug: there is no transaction in this context")
22         ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
23 )
24
25 // WrapCallsInTransactions returns a call wrapper (suitable for
26 // assigning to router.router.WrapCalls) that starts a new transaction
27 // for each API call, and commits only if the call succeeds.
28 //
29 // The wrapper calls getdb() to get a database handle before each API
30 // call.
31 func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(api.RoutableFunc) api.RoutableFunc {
32         return func(origFunc api.RoutableFunc) api.RoutableFunc {
33                 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
34                         ctx, finishtx := New(ctx, getdb)
35                         defer finishtx(&err)
36                         return origFunc(ctx, opts)
37                 }
38         }
39 }
40
41 // NewWithTransaction returns a child context in which the given
42 // transaction will be used by any localdb API call that needs one.
43 // The caller is responsible for calling Commit or Rollback on tx.
44 func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
45         txn := &transaction{tx: tx}
46         txn.setup.Do(func() {})
47         return context.WithValue(ctx, contextKeyTransaction, txn)
48 }
49
50 type contextKeyT string
51
52 var contextKeyTransaction = contextKeyT("transaction")
53
54 type transaction struct {
55         tx    *sqlx.Tx
56         err   error
57         getdb func(context.Context) (*sqlx.DB, error)
58         setup sync.Once
59 }
60
61 type finishFunc func(*error)
62
63 // New returns a new child context that can be used with
64 // CurrentTx(). It does not open a database transaction until the
65 // first call to CurrentTx().
66 //
67 // The caller must eventually call the returned finishtx() func to
68 // commit or rollback the transaction, if any.
69 //
70 //      func example(ctx context.Context) (err error) {
71 //              ctx, finishtx := New(ctx, dber)
72 //              defer finishtx(&err)
73 //              // ...
74 //              tx, err := CurrentTx(ctx)
75 //              if err != nil {
76 //                      return fmt.Errorf("example: %s", err)
77 //              }
78 //              return tx.ExecContext(...)
79 //      }
80 //
81 // If *err is nil, finishtx() commits the transaction and assigns any
82 // resulting error to *err.
83 //
84 // If *err is non-nil, finishtx() rolls back the transaction, and
85 // does not modify *err.
86 func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
87         txn := &transaction{getdb: getdb}
88         return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
89                 txn.setup.Do(func() {
90                         // Using (*sync.Once)Do() prevents a future
91                         // call to CurrentTx() from opening a
92                         // transaction which would never get committed
93                         // or rolled back. If CurrentTx() hasn't been
94                         // called before now, future calls will return
95                         // this error.
96                         txn.err = ErrContextFinished
97                 })
98                 if txn.tx == nil {
99                         // we never [successfully] started a transaction
100                         return
101                 }
102                 if *err != nil {
103                         ctxlog.FromContext(ctx).Debug("rollback")
104                         txn.tx.Rollback()
105                         return
106                 }
107                 *err = txn.tx.Commit()
108         }
109 }
110
111 // NewTx starts a new transaction. The caller is responsible for
112 // calling Commit or Rollback. This is suitable for database queries
113 // that are separate from the API transaction (see CurrentTx), e.g.,
114 // ones that will be committed even if the API call fails, or held
115 // open after the API call finishes.
116 func NewTx(ctx context.Context) (*sqlx.Tx, error) {
117         txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
118         if !ok {
119                 return nil, ErrNoTransaction
120         }
121         db, err := txn.getdb(ctx)
122         if err != nil {
123                 return nil, err
124         }
125         return db.Beginx()
126 }
127
128 // CurrentTx returns a transaction that will be committed after the
129 // current API call completes, or rolled back if the current API call
130 // returns an error.
131 func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
132         txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
133         if !ok {
134                 return nil, ErrNoTransaction
135         }
136         txn.setup.Do(func() {
137                 if db, err := txn.getdb(ctx); err != nil {
138                         txn.err = err
139                 } else {
140                         txn.tx, txn.err = db.Beginx()
141                 }
142         })
143         return txn.tx, txn.err
144 }