Merge branch '18323-cwl-gpu2' refs #18323
[arvados.git] / lib / ctrlctx / db.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ctrlctx
6
7 import (
8         "context"
9         "errors"
10         "sync"
11
12         "git.arvados.org/arvados.git/lib/controller/api"
13         "git.arvados.org/arvados.git/sdk/go/ctxlog"
14         "github.com/jmoiron/sqlx"
15         // sqlx needs lib/pq to talk to PostgreSQL
16         _ "github.com/lib/pq"
17 )
18
19 var (
20         ErrNoTransaction   = errors.New("bug: there is no transaction in this context")
21         ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
22 )
23
24 // WrapCallsInTransactions returns a call wrapper (suitable for
25 // assigning to router.router.WrapCalls) that starts a new transaction
26 // for each API call, and commits only if the call succeeds.
27 //
28 // The wrapper calls getdb() to get a database handle before each API
29 // call.
30 func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(api.RoutableFunc) api.RoutableFunc {
31         return func(origFunc api.RoutableFunc) api.RoutableFunc {
32                 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
33                         ctx, finishtx := New(ctx, getdb)
34                         defer finishtx(&err)
35                         return origFunc(ctx, opts)
36                 }
37         }
38 }
39
40 // NewWithTransaction returns a child context in which the given
41 // transaction will be used by any localdb API call that needs one.
42 // The caller is responsible for calling Commit or Rollback on tx.
43 func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
44         txn := &transaction{tx: tx}
45         txn.setup.Do(func() {})
46         return context.WithValue(ctx, contextKeyTransaction, txn)
47 }
48
49 type contextKeyT string
50
51 var contextKeyTransaction = contextKeyT("transaction")
52
53 type transaction struct {
54         tx    *sqlx.Tx
55         err   error
56         getdb func(context.Context) (*sqlx.DB, error)
57         setup sync.Once
58 }
59
60 type finishFunc func(*error)
61
62 // New returns a new child context that can be used with
63 // CurrentTx(). It does not open a database transaction until the
64 // first call to CurrentTx().
65 //
66 // The caller must eventually call the returned finishtx() func to
67 // commit or rollback the transaction, if any.
68 //
69 //      func example(ctx context.Context) (err error) {
70 //              ctx, finishtx := New(ctx, dber)
71 //              defer finishtx(&err)
72 //              // ...
73 //              tx, err := CurrentTx(ctx)
74 //              if err != nil {
75 //                      return fmt.Errorf("example: %s", err)
76 //              }
77 //              return tx.ExecContext(...)
78 //      }
79 //
80 // If *err is nil, finishtx() commits the transaction and assigns any
81 // resulting error to *err.
82 //
83 // If *err is non-nil, finishtx() rolls back the transaction, and
84 // does not modify *err.
85 func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
86         txn := &transaction{getdb: getdb}
87         return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
88                 txn.setup.Do(func() {
89                         // Using (*sync.Once)Do() prevents a future
90                         // call to CurrentTx() from opening a
91                         // transaction which would never get committed
92                         // or rolled back. If CurrentTx() hasn't been
93                         // called before now, future calls will return
94                         // this error.
95                         txn.err = ErrContextFinished
96                 })
97                 if txn.tx == nil {
98                         // we never [successfully] started a transaction
99                         return
100                 }
101                 if *err != nil {
102                         ctxlog.FromContext(ctx).Debug("rollback")
103                         txn.tx.Rollback()
104                         return
105                 }
106                 *err = txn.tx.Commit()
107         }
108 }
109
110 func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
111         txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
112         if !ok {
113                 return nil, ErrNoTransaction
114         }
115         txn.setup.Do(func() {
116                 if db, err := txn.getdb(ctx); err != nil {
117                         txn.err = err
118                 } else {
119                         txn.tx, txn.err = db.Beginx()
120                 }
121         })
122         return txn.tx, txn.err
123 }