4f64e63524469cc9e9fb987a4570772eb445fd8b
[arvados.git] / lib / controller / localdb / db.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package localdb
6
7 import (
8         "context"
9         "database/sql"
10         "errors"
11         "sync"
12
13         "git.arvados.org/arvados.git/lib/controller/router"
14         "git.arvados.org/arvados.git/sdk/go/ctxlog"
15 )
16
17 // WrapCallsInTransactions returns a call wrapper (suitable for
18 // assigning to router.router.WrapCalls) that starts a new transaction
19 // for each API call, and commits only if the call succeeds.
20 //
21 // The wrapper calls getdb() to get a database handle before each API
22 // call.
23 func WrapCallsInTransactions(getdb func(context.Context) (*sql.DB, error)) func(router.RoutableFunc) router.RoutableFunc {
24         return func(origFunc router.RoutableFunc) router.RoutableFunc {
25                 return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
26                         ctx, finishtx := starttx(ctx, getdb)
27                         defer finishtx(&err)
28                         return origFunc(ctx, opts)
29                 }
30         }
31 }
32
33 // ContextWithTransaction returns a child context in which the given
34 // transaction will be used by any localdb API call that needs one.
35 // The caller is responsible for calling Commit or Rollback on tx.
36 func ContextWithTransaction(ctx context.Context, tx *sql.Tx) context.Context {
37         txn := &transaction{tx: tx}
38         txn.setup.Do(func() {})
39         return context.WithValue(ctx, contextKeyTransaction, txn)
40 }
41
42 type contextKeyT string
43
44 var contextKeyTransaction = contextKeyT("transaction")
45
46 type transaction struct {
47         tx    *sql.Tx
48         err   error
49         getdb func(context.Context) (*sql.DB, error)
50         setup sync.Once
51 }
52
53 type transactionFinishFunc func(*error)
54
55 // starttx returns a new child context that can be used with
56 // currenttx(). It does not open a database transaction until the
57 // first call to currenttx().
58 //
59 // The caller must eventually call the returned finishtx() func to
60 // commit or rollback the transaction, if any.
61 //
62 //      func example(ctx context.Context) (err error) {
63 //              ctx, finishtx := starttx(ctx, dber)
64 //              defer finishtx(&err)
65 //              // ...
66 //              tx, err := currenttx(ctx)
67 //              if err != nil {
68 //                      return fmt.Errorf("example: %s", err)
69 //              }
70 //              return tx.ExecContext(...)
71 //      }
72 //
73 // If *err is nil, finishtx() commits the transaction and assigns any
74 // resulting error to *err.
75 //
76 // If *err is non-nil, finishtx() rolls back the transaction, and
77 // does not modify *err.
78 func starttx(ctx context.Context, getdb func(context.Context) (*sql.DB, error)) (context.Context, transactionFinishFunc) {
79         txn := &transaction{getdb: getdb}
80         return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
81                 txn.setup.Do(func() {
82                         // Using (*sync.Once)Do() prevents a future
83                         // call to currenttx() from opening a
84                         // transaction which would never get committed
85                         // or rolled back. If currenttx() hasn't been
86                         // called before now, future calls will return
87                         // this error.
88                         txn.err = errors.New("refusing to start a transaction after wrapped function already returned")
89                 })
90                 if txn.tx == nil {
91                         // we never [successfully] started a transaction
92                         return
93                 }
94                 if *err != nil {
95                         ctxlog.FromContext(ctx).Debug("rollback")
96                         txn.tx.Rollback()
97                         return
98                 }
99                 *err = txn.tx.Commit()
100         }
101 }
102
103 func currenttx(ctx context.Context) (*sql.Tx, error) {
104         txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
105         if !ok {
106                 return nil, errors.New("bug: there is no transaction in this context")
107         }
108         txn.setup.Do(func() {
109                 if db, err := txn.getdb(ctx); err != nil {
110                         txn.err = err
111                 } else {
112                         txn.tx, txn.err = db.Begin()
113                 }
114         })
115         return txn.tx, txn.err
116 }