16534: Supply *sqlx.Tx to controller handlers.
[arvados.git] / lib / controller / localdb / db_test.go
index 5bab86c60289e688475efa98e6be9061936a800a..741eabad9aa1a14a0c904a3ae71f549fc8dd7206 100644 (file)
@@ -6,20 +6,20 @@ package localdb
 
 import (
        "context"
-       "database/sql"
        "sync"
        "sync/atomic"
 
        "git.arvados.org/arvados.git/lib/config"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "github.com/jmoiron/sqlx"
        _ "github.com/lib/pq"
        check "gopkg.in/check.v1"
 )
 
 // testdb returns a DB connection for the given cluster config.
-func testdb(c *check.C, cluster *arvados.Cluster) *sql.DB {
-       db, err := sql.Open("postgres", cluster.PostgreSQL.Connection.String())
+func testdb(c *check.C, cluster *arvados.Cluster) *sqlx.DB {
+       db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
        c.Assert(err, check.IsNil)
        return db
 }
@@ -27,8 +27,8 @@ func testdb(c *check.C, cluster *arvados.Cluster) *sql.DB {
 // testctx returns a context suitable for running a test case in a new
 // transaction, and a rollback func which the caller should call after
 // the test.
-func testctx(c *check.C, db *sql.DB) (ctx context.Context, rollback func()) {
-       tx, err := db.Begin()
+func testctx(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) {
+       tx, err := db.Beginx()
        c.Assert(err, check.IsNil)
        return ContextWithTransaction(context.Background(), tx), func() {
                c.Check(tx.Rollback(), check.IsNil)
@@ -46,13 +46,13 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) {
        c.Assert(err, check.IsNil)
 
        var getterCalled int64
-       getter := func(context.Context) (*sql.DB, error) {
+       getter := func(context.Context) (*sqlx.DB, error) {
                atomic.AddInt64(&getterCalled, 1)
                return testdb(c, cluster), nil
        }
        wrapper := WrapCallsInTransactions(getter)
        wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
-               txes := make([]*sql.Tx, 20)
+               txes := make([]*sqlx.Tx, 20)
                var wg sync.WaitGroup
                for i := range txes {
                        i := i