16534: Supply *sqlx.Tx to controller handlers.
[arvados.git] / lib / controller / localdb / db.go
index 4f64e63524469cc9e9fb987a4570772eb445fd8b..cad530885315af894df98915fa79df30056b13c0 100644 (file)
@@ -6,12 +6,12 @@ package localdb
 
 import (
        "context"
-       "database/sql"
        "errors"
        "sync"
 
        "git.arvados.org/arvados.git/lib/controller/router"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "github.com/jmoiron/sqlx"
 )
 
 // WrapCallsInTransactions returns a call wrapper (suitable for
@@ -20,7 +20,7 @@ import (
 //
 // The wrapper calls getdb() to get a database handle before each API
 // call.
-func WrapCallsInTransactions(getdb func(context.Context) (*sql.DB, error)) func(router.RoutableFunc) router.RoutableFunc {
+func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(router.RoutableFunc) router.RoutableFunc {
        return func(origFunc router.RoutableFunc) router.RoutableFunc {
                return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
                        ctx, finishtx := starttx(ctx, getdb)
@@ -33,7 +33,7 @@ func WrapCallsInTransactions(getdb func(context.Context) (*sql.DB, error)) func(
 // ContextWithTransaction returns a child context in which the given
 // transaction will be used by any localdb API call that needs one.
 // The caller is responsible for calling Commit or Rollback on tx.
-func ContextWithTransaction(ctx context.Context, tx *sql.Tx) context.Context {
+func ContextWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
        txn := &transaction{tx: tx}
        txn.setup.Do(func() {})
        return context.WithValue(ctx, contextKeyTransaction, txn)
@@ -44,9 +44,9 @@ type contextKeyT string
 var contextKeyTransaction = contextKeyT("transaction")
 
 type transaction struct {
-       tx    *sql.Tx
+       tx    *sqlx.Tx
        err   error
-       getdb func(context.Context) (*sql.DB, error)
+       getdb func(context.Context) (*sqlx.DB, error)
        setup sync.Once
 }
 
@@ -75,7 +75,7 @@ type transactionFinishFunc func(*error)
 //
 // If *err is non-nil, finishtx() rolls back the transaction, and
 // does not modify *err.
-func starttx(ctx context.Context, getdb func(context.Context) (*sql.DB, error)) (context.Context, transactionFinishFunc) {
+func starttx(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, transactionFinishFunc) {
        txn := &transaction{getdb: getdb}
        return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
                txn.setup.Do(func() {
@@ -100,7 +100,7 @@ func starttx(ctx context.Context, getdb func(context.Context) (*sql.DB, error))
        }
 }
 
-func currenttx(ctx context.Context) (*sql.Tx, error) {
+func currenttx(ctx context.Context) (*sqlx.Tx, error) {
        txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
        if !ok {
                return nil, errors.New("bug: there is no transaction in this context")
@@ -109,7 +109,7 @@ func currenttx(ctx context.Context) (*sql.Tx, error) {
                if db, err := txn.getdb(ctx); err != nil {
                        txn.err = err
                } else {
-                       txn.tx, txn.err = db.Begin()
+                       txn.tx, txn.err = db.Beginx()
                }
        })
        return txn.tx, txn.err