21934: Use assertGreater
[arvados.git] / lib / ctrlctx / db.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ctrlctx
6
7 import (
8         "context"
9         "errors"
10         "sync"
11
12         "git.arvados.org/arvados.git/lib/controller/api"
13         "git.arvados.org/arvados.git/sdk/go/arvados"
14         "git.arvados.org/arvados.git/sdk/go/ctxlog"
15         "github.com/jmoiron/sqlx"
16
17         // sqlx needs lib/pq to talk to PostgreSQL
18         _ "github.com/lib/pq"
19 )
20
21 var (
22         ErrNoTransaction   = errors.New("bug: there is no transaction in this context")
23         ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
24 )
25
26 // WrapCallsInTransactions returns a call wrapper (suitable for
27 // assigning to router.router.WrapCalls) that starts a new transaction
28 // for each API call, and commits only if the call succeeds.
29 //
30 // The wrapper calls getdb() to get a database handle before each API
31 // call.
32 func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(api.RoutableFunc) api.RoutableFunc {
33         return func(origFunc api.RoutableFunc) api.RoutableFunc {
34                 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
35                         ctx, finishtx := New(ctx, getdb)
36                         defer finishtx(&err)
37                         return origFunc(ctx, opts)
38                 }
39         }
40 }
41
42 // NewWithTransaction returns a child context in which the given
43 // transaction will be used by any localdb API call that needs one.
44 // The caller is responsible for calling Commit or Rollback on tx.
45 func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
46         txn := &transaction{tx: tx}
47         txn.setup.Do(func() {})
48         return context.WithValue(ctx, contextKeyTransaction, txn)
49 }
50
51 type contextKeyT string
52
53 var contextKeyTransaction = contextKeyT("transaction")
54
55 type transaction struct {
56         tx    *sqlx.Tx
57         err   error
58         getdb func(context.Context) (*sqlx.DB, error)
59         setup sync.Once
60 }
61
62 type finishFunc func(*error)
63
64 // New returns a new child context that can be used with
65 // CurrentTx(). It does not open a database transaction until the
66 // first call to CurrentTx().
67 //
68 // The caller must eventually call the returned finishtx() func to
69 // commit or rollback the transaction, if any.
70 //
71 //      func example(ctx context.Context) (err error) {
72 //              ctx, finishtx := New(ctx, getdb)
73 //              defer finishtx(&err)
74 //              // ...
75 //              tx, err := CurrentTx(ctx)
76 //              if err != nil {
77 //                      return fmt.Errorf("example: %s", err)
78 //              }
79 //              return tx.ExecContext(...)
80 //      }
81 //
82 // If *err is nil, finishtx() commits the transaction and assigns any
83 // resulting error to *err.
84 //
85 // If *err is non-nil, finishtx() rolls back the transaction, and
86 // does not modify *err.
87 func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
88         txn := &transaction{getdb: getdb}
89         return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
90                 txn.setup.Do(func() {
91                         // Using (*sync.Once)Do() prevents a future
92                         // call to CurrentTx() from opening a
93                         // transaction which would never get committed
94                         // or rolled back. If CurrentTx() hasn't been
95                         // called before now, future calls will return
96                         // this error.
97                         txn.err = ErrContextFinished
98                 })
99                 if txn.tx == nil {
100                         // we never [successfully] started a transaction
101                         return
102                 }
103                 if *err != nil {
104                         ctxlog.FromContext(ctx).Debug("rollback")
105                         txn.tx.Rollback()
106                         return
107                 }
108                 *err = txn.tx.Commit()
109         }
110 }
111
112 // NewTx starts a new transaction. The caller is responsible for
113 // calling Commit or Rollback. This is suitable for database queries
114 // that are separate from the API transaction (see CurrentTx), e.g.,
115 // ones that will be committed even if the API call fails, or held
116 // open after the API call finishes.
117 func NewTx(ctx context.Context) (*sqlx.Tx, error) {
118         txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
119         if !ok {
120                 return nil, ErrNoTransaction
121         }
122         db, err := txn.getdb(ctx)
123         if err != nil {
124                 return nil, err
125         }
126         return db.Beginx()
127 }
128
129 // CurrentTx returns a transaction that will be committed after the
130 // current API call completes, or rolled back if the current API call
131 // returns an error.
132 func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
133         txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
134         if !ok {
135                 return nil, ErrNoTransaction
136         }
137         txn.setup.Do(func() {
138                 if db, err := txn.getdb(ctx); err != nil {
139                         txn.err = err
140                 } else {
141                         txn.tx, txn.err = db.Beginx()
142                 }
143         })
144         return txn.tx, txn.err
145 }
146
147 var errDBConnection = errors.New("database connection error")
148
149 type DBConnector struct {
150         PostgreSQL arvados.PostgreSQL
151         pgdb       *sqlx.DB
152         mtx        sync.Mutex
153 }
154
155 func (dbc *DBConnector) GetDB(ctx context.Context) (*sqlx.DB, error) {
156         dbc.mtx.Lock()
157         defer dbc.mtx.Unlock()
158         if dbc.pgdb != nil {
159                 return dbc.pgdb, nil
160         }
161         db, err := sqlx.Open("postgres", dbc.PostgreSQL.Connection.String())
162         if err != nil {
163                 ctxlog.FromContext(ctx).WithError(err).Error("postgresql connect failed")
164                 return nil, errDBConnection
165         }
166         if p := dbc.PostgreSQL.ConnectionPool; p > 0 {
167                 db.SetMaxOpenConns(p)
168         }
169         if err := db.Ping(); err != nil {
170                 ctxlog.FromContext(ctx).WithError(err).Error("postgresql connect succeeded but ping failed")
171                 db.Close()
172                 return nil, errDBConnection
173         }
174         dbc.pgdb = db
175         return db, nil
176 }
177
178 func (dbc *DBConnector) Close() error {
179         dbc.mtx.Lock()
180         defer dbc.mtx.Unlock()
181         var err error
182         if dbc.pgdb != nil {
183                 err = dbc.pgdb.Close()
184                 dbc.pgdb = nil
185         }
186         return err
187 }