1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
12 "git.arvados.org/arvados.git/lib/controller/api"
13 "git.arvados.org/arvados.git/sdk/go/arvados"
14 "git.arvados.org/arvados.git/sdk/go/ctxlog"
15 "github.com/jmoiron/sqlx"
17 // sqlx needs lib/pq to talk to PostgreSQL
22 ErrNoTransaction = errors.New("bug: there is no transaction in this context")
23 ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
26 // WrapCallsInTransactions returns a call wrapper (suitable for
27 // assigning to router.router.WrapCalls) that starts a new transaction
28 // for each API call, and commits only if the call succeeds.
30 // The wrapper calls getdb() to get a database handle before each API
32 func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(api.RoutableFunc) api.RoutableFunc {
33 return func(origFunc api.RoutableFunc) api.RoutableFunc {
34 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
35 ctx, finishtx := New(ctx, getdb)
37 return origFunc(ctx, opts)
42 // NewWithTransaction returns a child context in which the given
43 // transaction will be used by any localdb API call that needs one.
44 // The caller is responsible for calling Commit or Rollback on tx.
45 func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
46 txn := &transaction{tx: tx}
47 txn.setup.Do(func() {})
48 return context.WithValue(ctx, contextKeyTransaction, txn)
51 type contextKeyT string
53 var contextKeyTransaction = contextKeyT("transaction")
55 type transaction struct {
58 getdb func(context.Context) (*sqlx.DB, error)
62 type finishFunc func(*error)
64 // New returns a new child context that can be used with
65 // CurrentTx(). It does not open a database transaction until the
66 // first call to CurrentTx().
68 // The caller must eventually call the returned finishtx() func to
69 // commit or rollback the transaction, if any.
71 // func example(ctx context.Context) (err error) {
72 // ctx, finishtx := New(ctx, getdb)
73 // defer finishtx(&err)
75 // tx, err := CurrentTx(ctx)
77 // return fmt.Errorf("example: %s", err)
79 // return tx.ExecContext(...)
82 // If *err is nil, finishtx() commits the transaction and assigns any
83 // resulting error to *err.
85 // If *err is non-nil, finishtx() rolls back the transaction, and
86 // does not modify *err.
87 func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
88 txn := &transaction{getdb: getdb}
89 return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
91 // Using (*sync.Once)Do() prevents a future
92 // call to CurrentTx() from opening a
93 // transaction which would never get committed
94 // or rolled back. If CurrentTx() hasn't been
95 // called before now, future calls will return
97 txn.err = ErrContextFinished
100 // we never [successfully] started a transaction
104 ctxlog.FromContext(ctx).Debug("rollback")
108 *err = txn.tx.Commit()
112 // NewTx starts a new transaction. The caller is responsible for
113 // calling Commit or Rollback. This is suitable for database queries
114 // that are separate from the API transaction (see CurrentTx), e.g.,
115 // ones that will be committed even if the API call fails, or held
116 // open after the API call finishes.
117 func NewTx(ctx context.Context) (*sqlx.Tx, error) {
118 txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
120 return nil, ErrNoTransaction
122 db, err := txn.getdb(ctx)
129 // CurrentTx returns a transaction that will be committed after the
130 // current API call completes, or rolled back if the current API call
132 func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
133 txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
135 return nil, ErrNoTransaction
137 txn.setup.Do(func() {
138 if db, err := txn.getdb(ctx); err != nil {
141 txn.tx, txn.err = db.Beginx()
144 return txn.tx, txn.err
147 var errDBConnection = errors.New("database connection error")
149 type DBConnector struct {
150 PostgreSQL arvados.PostgreSQL
155 func (dbc *DBConnector) GetDB(ctx context.Context) (*sqlx.DB, error) {
157 defer dbc.mtx.Unlock()
161 db, err := sqlx.Open("postgres", dbc.PostgreSQL.Connection.String())
163 ctxlog.FromContext(ctx).WithError(err).Error("postgresql connect failed")
164 return nil, errDBConnection
166 if p := dbc.PostgreSQL.ConnectionPool; p > 0 {
167 db.SetMaxOpenConns(p)
169 if err := db.Ping(); err != nil {
170 ctxlog.FromContext(ctx).WithError(err).Error("postgresql connect succeeded but ping failed")
172 return nil, errDBConnection
178 func (dbc *DBConnector) Close() error {
180 defer dbc.mtx.Unlock()
183 err = dbc.pgdb.Close()