16534: Merge branch 'master'
authorTom Clegg <tom@tomclegg.ca>
Thu, 9 Jul 2020 15:09:18 +0000 (11:09 -0400)
committerTom Clegg <tom@tomclegg.ca>
Thu, 9 Jul 2020 15:09:18 +0000 (11:09 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@tomclegg.ca>

12 files changed:
.licenseignore
go.mod
go.sum
lib/controller/api/routable.go [new file with mode: 0644]
lib/controller/handler.go
lib/controller/localdb/conn.go
lib/controller/localdb/login.go
lib/controller/localdb/login_ldap_test.go
lib/controller/router/router.go
lib/ctrlctx/db.go [moved from lib/controller/localdb/db.go with 62% similarity]
lib/ctrlctx/db_test.go [moved from lib/controller/localdb/db_test.go with 62% similarity]
sdk/go/arvadostest/db.go [new file with mode: 0644]

index ad80dc3f4b671cc165db40fe6b215359933a0315..81f6b7181d2083ff2b84b3b5ec0e88168d58ca4b 100644 (file)
@@ -79,4 +79,6 @@ lib/dispatchcloud/test/sshkey_*
 *.asc
 sdk/java-v2/build.gradle
 sdk/java-v2/settings.gradle
-sdk/cwl/tests/wf/feddemo
\ No newline at end of file
+sdk/cwl/tests/wf/feddemo
+go.mod
+go.sum
diff --git a/go.mod b/go.mod
index cc5457975f54da4d6e00702a955451f104fe39d1..45bde2cef48deebdc0004d0ab883810c90c25d97 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -35,6 +35,7 @@ require (
        github.com/imdario/mergo v0.3.8-0.20190415133143-5ef87b449ca7
        github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
        github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff
+       github.com/jmoiron/sqlx v1.2.0
        github.com/julienschmidt/httprouter v1.2.0
        github.com/karalabe/xgo v0.0.0-20191115072854-c5ccff8648a7 // indirect
        github.com/kevinburke/ssh_config v0.0.0-20171013211458-802051befeb5 // indirect
diff --git a/go.sum b/go.sum
index 38153ce3eaa08844dd2abfb944b9318145fbeed0..6dda3992bc279e7a874dc6641aba22a919b7054c 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -72,6 +72,7 @@ github.com/go-ldap/ldap v3.0.3+incompatible h1:HTeSZO8hWMS1Rgb2Ziku6b8a7qRIZZMHj
 github.com/go-ldap/ldap v3.0.3+incompatible/go.mod h1:qfd9rJvER9Q0/D/Sqn1DfHRoBp40uXYvFoEVrNEPqRc=
 github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
+github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
 github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
 github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo=
 github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
@@ -109,6 +110,8 @@ github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff h1:6NvhExg4omUC9
 github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff/go.mod h1:ddfPX8Z28YMjiqoaJhNBzWHapTHXejnB5cDCUWDwriw=
 github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
 github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
+github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
+github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
 github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
 github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
 github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
@@ -121,10 +124,12 @@ github.com/kevinburke/ssh_config v0.0.0-20171013211458-802051befeb5/go.mod h1:CT
 github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
+github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
 github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU=
 github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
 github.com/marstr/guid v1.1.1-0.20170427235115-8bdf7d1a087c h1:ouxemItv3B/Zh008HJkEXDYCN3BIRyNHxtUN7ThJ5Js=
 github.com/marstr/guid v1.1.1-0.20170427235115-8bdf7d1a087c/go.mod h1:74gB1z2wpxxInTG6yaqA7KrtM0NZ+RbrcqDvYHefzho=
+github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
 github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
 github.com/mitchellh/go-homedir v0.0.0-20161203194507-b8bc1bf76747 h1:eQox4Rh4ewJF+mqYPxCkmBAirRnPaHEB26UkNuPyjlk=
diff --git a/lib/controller/api/routable.go b/lib/controller/api/routable.go
new file mode 100644 (file)
index 0000000..6049cba
--- /dev/null
@@ -0,0 +1,17 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+// Package api provides types used by controller/server-component
+// packages.
+package api
+
+import "context"
+
+// A RoutableFunc calls an API method (sometimes via a wrapped
+// RoutableFunc) that has real argument types.
+//
+// (It is used by ctrlctx to manage database transactions, so moving
+// it to the router package would cause a circular dependency
+// router->arvadostest->ctrlctx->router.)
+type RoutableFunc func(ctx context.Context, opts interface{}) (interface{}, error)
index cc06246420559479203e24843164cee281e07633..e742bbc59b08a3a01a8302fcadb2cda6042cded9 100644 (file)
@@ -6,7 +6,6 @@ package controller
 
 import (
        "context"
-       "database/sql"
        "errors"
        "fmt"
        "net/http"
@@ -16,13 +15,14 @@ import (
        "time"
 
        "git.arvados.org/arvados.git/lib/controller/federation"
-       "git.arvados.org/arvados.git/lib/controller/localdb"
        "git.arvados.org/arvados.git/lib/controller/railsproxy"
        "git.arvados.org/arvados.git/lib/controller/router"
+       "git.arvados.org/arvados.git/lib/ctrlctx"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "git.arvados.org/arvados.git/sdk/go/health"
        "git.arvados.org/arvados.git/sdk/go/httpserver"
+       "github.com/jmoiron/sqlx"
        _ "github.com/lib/pq"
 )
 
@@ -34,7 +34,7 @@ type Handler struct {
        proxy          *proxy
        secureClient   *http.Client
        insecureClient *http.Client
-       pgdb           *sql.DB
+       pgdb           *sqlx.DB
        pgdbMtx        sync.Mutex
 }
 
@@ -87,7 +87,7 @@ func (h *Handler) setup() {
                Routes: health.Routes{"ping": func() error { _, err := h.db(context.TODO()); return err }},
        })
 
-       rtr := router.New(federation.New(h.Cluster), localdb.WrapCallsInTransactions(h.db))
+       rtr := router.New(federation.New(h.Cluster), ctrlctx.WrapCallsInTransactions(h.db))
        mux.Handle("/arvados/v1/config", rtr)
        mux.Handle("/"+arvados.EndpointUserAuthenticate.Path, rtr)
 
@@ -121,14 +121,14 @@ func (h *Handler) setup() {
 
 var errDBConnection = errors.New("database connection error")
 
-func (h *Handler) db(ctx context.Context) (*sql.DB, error) {
+func (h *Handler) db(ctx context.Context) (*sqlx.DB, error) {
        h.pgdbMtx.Lock()
        defer h.pgdbMtx.Unlock()
        if h.pgdb != nil {
                return h.pgdb, nil
        }
 
-       db, err := sql.Open("postgres", h.Cluster.PostgreSQL.Connection.String())
+       db, err := sqlx.Open("postgres", h.Cluster.PostgreSQL.Connection.String())
        if err != nil {
                ctxlog.FromContext(ctx).WithError(err).Error("postgresql connect failed")
                return nil, errDBConnection
index 60263455bdb1d02c10a9164c7c235d22a0f90fb7..4f0035edf993ad525c4d82b8d5e880049432c6c2 100644 (file)
@@ -22,11 +22,13 @@ type Conn struct {
 
 func NewConn(cluster *arvados.Cluster) *Conn {
        railsProxy := railsproxy.NewConn(cluster)
-       return &Conn{
+       var conn Conn
+       conn = Conn{
                cluster:         cluster,
                railsProxy:      railsProxy,
                loginController: chooseLoginController(cluster, railsProxy),
        }
+       return &conn
 }
 
 func (conn *Conn) Logout(ctx context.Context, opts arvados.LogoutOptions) (arvados.LogoutResponse, error) {
index 1cd349a10eaa94d987899ac1315f811ffbf186e1..ee1ea56924c5700d25e43262347d1045d534ca5c 100644 (file)
@@ -15,6 +15,7 @@ import (
        "strings"
 
        "git.arvados.org/arvados.git/lib/controller/rpc"
+       "git.arvados.org/arvados.git/lib/ctrlctx"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/httpserver"
@@ -117,7 +118,7 @@ func createAPIClientAuthorization(ctx context.Context, conn *rpc.Conn, rootToken
                return
        }
        token := target.Query().Get("api_token")
-       tx, err := currenttx(ctx)
+       tx, err := ctrlctx.CurrentTx(ctx)
        if err != nil {
                return
        }
@@ -130,7 +131,7 @@ func createAPIClientAuthorization(ctx context.Context, conn *rpc.Conn, rootToken
        }
        var exp sql.NullString
        var scopes []byte
-       err = tx.QueryRowContext(ctx, "select uuid, api_token, expires_at, scopes from api_client_authorizations where api_token=$1", tokensecret).Scan(&resp.UUID, &resp.APIToken, &exp, &scopes)
+       err = tx.QueryRowxContext(ctx, "select uuid, api_token, expires_at, scopes from api_client_authorizations where api_token=$1", tokensecret).Scan(&resp.UUID, &resp.APIToken, &exp, &scopes)
        if err != nil {
                return
        }
index 64ae58bce2681f792020b1855c2465d5ad226ae1..0c94fa6c0e21be72949f6fd5b402ae252d7ce1cc 100644 (file)
@@ -6,18 +6,19 @@ package localdb
 
 import (
        "context"
-       "database/sql"
        "encoding/json"
        "net"
        "net/http"
 
        "git.arvados.org/arvados.git/lib/config"
        "git.arvados.org/arvados.git/lib/controller/railsproxy"
+       "git.arvados.org/arvados.git/lib/ctrlctx"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/arvadostest"
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "github.com/bradleypeabody/godap"
+       "github.com/jmoiron/sqlx"
        check "gopkg.in/check.v1"
 )
 
@@ -27,11 +28,11 @@ type LDAPSuite struct {
        cluster *arvados.Cluster
        ctrl    *ldapLoginController
        ldap    *godap.LDAPServer // fake ldap server that accepts auth goodusername/goodpassword
-       db      *sql.DB
+       db      *sqlx.DB
 
        // transaction context
        ctx      context.Context
-       rollback func()
+       rollback func() error
 }
 
 func (s *LDAPSuite) TearDownSuite(c *check.C) {
@@ -91,15 +92,20 @@ func (s *LDAPSuite) SetUpSuite(c *check.C) {
                Cluster:    s.cluster,
                RailsProxy: railsproxy.NewConn(s.cluster),
        }
-       s.db = testdb(c, s.cluster)
+       s.db = arvadostest.DB(c, s.cluster)
 }
 
 func (s *LDAPSuite) SetUpTest(c *check.C) {
-       s.ctx, s.rollback = testctx(c, s.db)
+       tx, err := s.db.Beginx()
+       c.Assert(err, check.IsNil)
+       s.ctx = ctrlctx.NewWithTransaction(context.Background(), tx)
+       s.rollback = tx.Rollback
 }
 
 func (s *LDAPSuite) TearDownTest(c *check.C) {
-       s.rollback()
+       if s.rollback != nil {
+               s.rollback()
+       }
 }
 
 func (s *LDAPSuite) TestLoginSuccess(c *check.C) {
index 29c81ac5cae9ac63431e691852230a00c2335afe..2944524344e9028fa22cf0c9d18327cb39193733 100644 (file)
@@ -10,6 +10,7 @@ import (
        "net/http"
        "strings"
 
+       "git.arvados.org/arvados.git/lib/controller/api"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
@@ -21,7 +22,7 @@ import (
 type router struct {
        mux       *mux.Router
        backend   arvados.API
-       wrapCalls func(RoutableFunc) RoutableFunc
+       wrapCalls func(api.RoutableFunc) api.RoutableFunc
 }
 
 // New returns a new router (which implements the http.Handler
@@ -32,7 +33,7 @@ type router struct {
 // the returned method is used in its place. This can be used to
 // install hooks before and after each API call and alter responses;
 // see localdb.WrapCallsInTransaction for an example.
-func New(backend arvados.API, wrapCalls func(RoutableFunc) RoutableFunc) *router {
+func New(backend arvados.API, wrapCalls func(api.RoutableFunc) api.RoutableFunc) *router {
        rtr := &router{
                mux:       mux.NewRouter(),
                backend:   backend,
@@ -42,13 +43,11 @@ func New(backend arvados.API, wrapCalls func(RoutableFunc) RoutableFunc) *router
        return rtr
 }
 
-type RoutableFunc func(ctx context.Context, opts interface{}) (interface{}, error)
-
 func (rtr *router) addRoutes() {
        for _, route := range []struct {
                endpoint    arvados.APIEndpoint
                defaultOpts func() interface{}
-               exec        RoutableFunc
+               exec        api.RoutableFunc
        }{
                {
                        arvados.EndpointConfigGet,
@@ -340,7 +339,7 @@ var altMethod = map[string]string{
        "GET":   "HEAD", // Accept HEAD at any GET route
 }
 
-func (rtr *router) addRoute(endpoint arvados.APIEndpoint, defaultOpts func() interface{}, exec RoutableFunc) {
+func (rtr *router) addRoute(endpoint arvados.APIEndpoint, defaultOpts func() interface{}, exec api.RoutableFunc) {
        methods := []string{endpoint.Method}
        if alt, ok := altMethod[endpoint.Method]; ok {
                methods = append(methods, alt)
similarity index 62%
rename from lib/controller/localdb/db.go
rename to lib/ctrlctx/db.go
index 4f64e63524469cc9e9fb987a4570772eb445fd8b..127be489df3a27e553f6aa421a6f1c40cdbdcc55 100644 (file)
@@ -2,16 +2,22 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package localdb
+package ctrlctx
 
 import (
        "context"
-       "database/sql"
        "errors"
        "sync"
 
-       "git.arvados.org/arvados.git/lib/controller/router"
+       "git.arvados.org/arvados.git/lib/controller/api"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "github.com/jmoiron/sqlx"
+       _ "github.com/lib/pq"
+)
+
+var (
+       ErrNoTransaction   = errors.New("bug: there is no transaction in this context")
+       ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
 )
 
 // WrapCallsInTransactions returns a call wrapper (suitable for
@@ -20,20 +26,20 @@ import (
 //
 // The wrapper calls getdb() to get a database handle before each API
 // call.
-func WrapCallsInTransactions(getdb func(context.Context) (*sql.DB, error)) func(router.RoutableFunc) router.RoutableFunc {
-       return func(origFunc router.RoutableFunc) router.RoutableFunc {
+func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(api.RoutableFunc) api.RoutableFunc {
+       return func(origFunc api.RoutableFunc) api.RoutableFunc {
                return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
-                       ctx, finishtx := starttx(ctx, getdb)
+                       ctx, finishtx := New(ctx, getdb)
                        defer finishtx(&err)
                        return origFunc(ctx, opts)
                }
        }
 }
 
-// ContextWithTransaction returns a child context in which the given
+// NewWithTransaction returns a child context in which the given
 // transaction will be used by any localdb API call that needs one.
 // The caller is responsible for calling Commit or Rollback on tx.
-func ContextWithTransaction(ctx context.Context, tx *sql.Tx) context.Context {
+func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
        txn := &transaction{tx: tx}
        txn.setup.Do(func() {})
        return context.WithValue(ctx, contextKeyTransaction, txn)
@@ -44,26 +50,26 @@ type contextKeyT string
 var contextKeyTransaction = contextKeyT("transaction")
 
 type transaction struct {
-       tx    *sql.Tx
+       tx    *sqlx.Tx
        err   error
-       getdb func(context.Context) (*sql.DB, error)
+       getdb func(context.Context) (*sqlx.DB, error)
        setup sync.Once
 }
 
-type transactionFinishFunc func(*error)
+type finishFunc func(*error)
 
-// starttx returns a new child context that can be used with
-// currenttx(). It does not open a database transaction until the
-// first call to currenttx().
+// New returns a new child context that can be used with
+// CurrentTx(). It does not open a database transaction until the
+// first call to CurrentTx().
 //
 // The caller must eventually call the returned finishtx() func to
 // commit or rollback the transaction, if any.
 //
 //     func example(ctx context.Context) (err error) {
-//             ctx, finishtx := starttx(ctx, dber)
+//             ctx, finishtx := New(ctx, dber)
 //             defer finishtx(&err)
 //             // ...
-//             tx, err := currenttx(ctx)
+//             tx, err := CurrentTx(ctx)
 //             if err != nil {
 //                     return fmt.Errorf("example: %s", err)
 //             }
@@ -75,17 +81,17 @@ type transactionFinishFunc func(*error)
 //
 // If *err is non-nil, finishtx() rolls back the transaction, and
 // does not modify *err.
-func starttx(ctx context.Context, getdb func(context.Context) (*sql.DB, error)) (context.Context, transactionFinishFunc) {
+func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
        txn := &transaction{getdb: getdb}
        return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
                txn.setup.Do(func() {
                        // Using (*sync.Once)Do() prevents a future
-                       // call to currenttx() from opening a
+                       // call to CurrentTx() from opening a
                        // transaction which would never get committed
-                       // or rolled back. If currenttx() hasn't been
+                       // or rolled back. If CurrentTx() hasn't been
                        // called before now, future calls will return
                        // this error.
-                       txn.err = errors.New("refusing to start a transaction after wrapped function already returned")
+                       txn.err = ErrContextFinished
                })
                if txn.tx == nil {
                        // we never [successfully] started a transaction
@@ -100,16 +106,16 @@ func starttx(ctx context.Context, getdb func(context.Context) (*sql.DB, error))
        }
 }
 
-func currenttx(ctx context.Context) (*sql.Tx, error) {
+func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
        txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
        if !ok {
-               return nil, errors.New("bug: there is no transaction in this context")
+               return nil, ErrNoTransaction
        }
        txn.setup.Do(func() {
                if db, err := txn.getdb(ctx); err != nil {
                        txn.err = err
                } else {
-                       txn.tx, txn.err = db.Begin()
+                       txn.tx, txn.err = db.Beginx()
                }
        })
        return txn.tx, txn.err
similarity index 62%
rename from lib/controller/localdb/db_test.go
rename to lib/ctrlctx/db_test.go
index 5bab86c60289e688475efa98e6be9061936a800a..5361f13c68a4967168082b28f16ab562fce546ee 100644 (file)
@@ -2,37 +2,24 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package localdb
+package ctrlctx
 
 import (
        "context"
-       "database/sql"
        "sync"
        "sync/atomic"
+       "testing"
 
        "git.arvados.org/arvados.git/lib/config"
-       "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "github.com/jmoiron/sqlx"
        _ "github.com/lib/pq"
        check "gopkg.in/check.v1"
 )
 
-// testdb returns a DB connection for the given cluster config.
-func testdb(c *check.C, cluster *arvados.Cluster) *sql.DB {
-       db, err := sql.Open("postgres", cluster.PostgreSQL.Connection.String())
-       c.Assert(err, check.IsNil)
-       return db
-}
-
-// testctx returns a context suitable for running a test case in a new
-// transaction, and a rollback func which the caller should call after
-// the test.
-func testctx(c *check.C, db *sql.DB) (ctx context.Context, rollback func()) {
-       tx, err := db.Begin()
-       c.Assert(err, check.IsNil)
-       return ContextWithTransaction(context.Background(), tx), func() {
-               c.Check(tx.Rollback(), check.IsNil)
-       }
+// Gocheck boilerplate
+func Test(t *testing.T) {
+       check.TestingT(t)
 }
 
 var _ = check.Suite(&DatabaseSuite{})
@@ -46,26 +33,28 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) {
        c.Assert(err, check.IsNil)
 
        var getterCalled int64
-       getter := func(context.Context) (*sql.DB, error) {
+       getter := func(context.Context) (*sqlx.DB, error) {
                atomic.AddInt64(&getterCalled, 1)
-               return testdb(c, cluster), nil
+               db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
+               c.Assert(err, check.IsNil)
+               return db, nil
        }
        wrapper := WrapCallsInTransactions(getter)
        wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
-               txes := make([]*sql.Tx, 20)
+               txes := make([]*sqlx.Tx, 20)
                var wg sync.WaitGroup
                for i := range txes {
                        i := i
                        wg.Add(1)
                        go func() {
-                               // Concurrent calls to currenttx(),
+                               // Concurrent calls to CurrentTx(),
                                // with different children of the same
                                // parent context, will all return the
                                // same transaction.
                                defer wg.Done()
                                ctx, cancel := context.WithCancel(ctx)
                                defer cancel()
-                               tx, err := currenttx(ctx)
+                               tx, err := CurrentTx(ctx)
                                c.Check(err, check.IsNil)
                                txes[i] = tx
                        }()
@@ -82,8 +71,8 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) {
        c.Check(err, check.IsNil)
        c.Check(getterCalled, check.Equals, int64(1))
 
-       // When a wrapped func returns without calling currenttx(),
-       // calling currenttx() later shouldn't start a new
+       // When a wrapped func returns without calling CurrentTx(),
+       // calling CurrentTx() later shouldn't start a new
        // transaction.
        var savedctx context.Context
        ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
@@ -92,7 +81,7 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) {
        })(context.Background(), "blah")
        c.Check(ok, check.Equals, true)
        c.Check(err, check.IsNil)
-       tx, err := currenttx(savedctx)
+       tx, err := CurrentTx(savedctx)
        c.Check(tx, check.IsNil)
        c.Check(err, check.NotNil)
 }
diff --git a/sdk/go/arvadostest/db.go b/sdk/go/arvadostest/db.go
new file mode 100644 (file)
index 0000000..41ecfac
--- /dev/null
@@ -0,0 +1,33 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package arvadostest
+
+import (
+       "context"
+
+       "git.arvados.org/arvados.git/lib/ctrlctx"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "github.com/jmoiron/sqlx"
+       _ "github.com/lib/pq"
+       "gopkg.in/check.v1"
+)
+
+// DB returns a DB connection for the given cluster config.
+func DB(c *check.C, cluster *arvados.Cluster) *sqlx.DB {
+       db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
+       c.Assert(err, check.IsNil)
+       return db
+}
+
+// TransactionContext returns a context suitable for running a test
+// case in a new transaction, and a rollback func which the caller
+// should call after the test.
+func TransactionContext(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) {
+       tx, err := db.Beginx()
+       c.Assert(err, check.IsNil)
+       return ctrlctx.NewWithTransaction(context.Background(), tx), func() {
+               c.Check(tx.Rollback(), check.IsNil)
+       }
+}