16534: Move controller transaction/context code to its own package.
authorTom Clegg <tom@tomclegg.ca>
Mon, 6 Jul 2020 20:00:35 +0000 (16:00 -0400)
committerTom Clegg <tom@tomclegg.ca>
Mon, 6 Jul 2020 20:00:35 +0000 (16:00 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@tomclegg.ca>

lib/controller/handler.go
lib/controller/localdb/login.go
lib/controller/localdb/login_ldap_test.go
lib/controller/router/router.go
lib/ctrlctx/db.go [moved from lib/controller/localdb/db.go with 67% similarity]
lib/ctrlctx/db_test.go [moved from lib/controller/localdb/db_test.go with 65% similarity]
sdk/go/arvados/api.go
sdk/go/arvadostest/db.go [new file with mode: 0644]

index 4d4963413b033aacea971585e714b033e3128c99..e742bbc59b08a3a01a8302fcadb2cda6042cded9 100644 (file)
@@ -15,9 +15,9 @@ import (
        "time"
 
        "git.arvados.org/arvados.git/lib/controller/federation"
-       "git.arvados.org/arvados.git/lib/controller/localdb"
        "git.arvados.org/arvados.git/lib/controller/railsproxy"
        "git.arvados.org/arvados.git/lib/controller/router"
+       "git.arvados.org/arvados.git/lib/ctrlctx"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "git.arvados.org/arvados.git/sdk/go/health"
@@ -87,7 +87,7 @@ func (h *Handler) setup() {
                Routes: health.Routes{"ping": func() error { _, err := h.db(context.TODO()); return err }},
        })
 
-       rtr := router.New(federation.New(h.Cluster), localdb.WrapCallsInTransactions(h.db))
+       rtr := router.New(federation.New(h.Cluster), ctrlctx.WrapCallsInTransactions(h.db))
        mux.Handle("/arvados/v1/config", rtr)
        mux.Handle("/"+arvados.EndpointUserAuthenticate.Path, rtr)
 
index dc2c7c875cda54823f852bccd0facd7c160ce2f7..ee1ea56924c5700d25e43262347d1045d534ca5c 100644 (file)
@@ -15,6 +15,7 @@ import (
        "strings"
 
        "git.arvados.org/arvados.git/lib/controller/rpc"
+       "git.arvados.org/arvados.git/lib/ctrlctx"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/httpserver"
@@ -117,7 +118,7 @@ func createAPIClientAuthorization(ctx context.Context, conn *rpc.Conn, rootToken
                return
        }
        token := target.Query().Get("api_token")
-       tx, err := currenttx(ctx)
+       tx, err := ctrlctx.CurrentTx(ctx)
        if err != nil {
                return
        }
index 15343ab32216e552504379dffb433b839c346e2c..0c94fa6c0e21be72949f6fd5b402ae252d7ce1cc 100644 (file)
@@ -12,6 +12,7 @@ import (
 
        "git.arvados.org/arvados.git/lib/config"
        "git.arvados.org/arvados.git/lib/controller/railsproxy"
+       "git.arvados.org/arvados.git/lib/ctrlctx"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/arvadostest"
        "git.arvados.org/arvados.git/sdk/go/auth"
@@ -31,7 +32,7 @@ type LDAPSuite struct {
 
        // transaction context
        ctx      context.Context
-       rollback func()
+       rollback func() error
 }
 
 func (s *LDAPSuite) TearDownSuite(c *check.C) {
@@ -91,15 +92,20 @@ func (s *LDAPSuite) SetUpSuite(c *check.C) {
                Cluster:    s.cluster,
                RailsProxy: railsproxy.NewConn(s.cluster),
        }
-       s.db = testdb(c, s.cluster)
+       s.db = arvadostest.DB(c, s.cluster)
 }
 
 func (s *LDAPSuite) SetUpTest(c *check.C) {
-       s.ctx, s.rollback = testctx(c, s.db)
+       tx, err := s.db.Beginx()
+       c.Assert(err, check.IsNil)
+       s.ctx = ctrlctx.NewWithTransaction(context.Background(), tx)
+       s.rollback = tx.Rollback
 }
 
 func (s *LDAPSuite) TearDownTest(c *check.C) {
-       s.rollback()
+       if s.rollback != nil {
+               s.rollback()
+       }
 }
 
 func (s *LDAPSuite) TestLoginSuccess(c *check.C) {
index 29c81ac5cae9ac63431e691852230a00c2335afe..ed638fe7e83c4ce0bb6d9a21b5fafd74cfc5e85e 100644 (file)
@@ -21,7 +21,7 @@ import (
 type router struct {
        mux       *mux.Router
        backend   arvados.API
-       wrapCalls func(RoutableFunc) RoutableFunc
+       wrapCalls func(arvados.RoutableFunc) arvados.RoutableFunc
 }
 
 // New returns a new router (which implements the http.Handler
@@ -32,7 +32,7 @@ type router struct {
 // the returned method is used in its place. This can be used to
 // install hooks before and after each API call and alter responses;
 // see localdb.WrapCallsInTransaction for an example.
-func New(backend arvados.API, wrapCalls func(RoutableFunc) RoutableFunc) *router {
+func New(backend arvados.API, wrapCalls func(arvados.RoutableFunc) arvados.RoutableFunc) *router {
        rtr := &router{
                mux:       mux.NewRouter(),
                backend:   backend,
@@ -42,13 +42,11 @@ func New(backend arvados.API, wrapCalls func(RoutableFunc) RoutableFunc) *router
        return rtr
 }
 
-type RoutableFunc func(ctx context.Context, opts interface{}) (interface{}, error)
-
 func (rtr *router) addRoutes() {
        for _, route := range []struct {
                endpoint    arvados.APIEndpoint
                defaultOpts func() interface{}
-               exec        RoutableFunc
+               exec        arvados.RoutableFunc
        }{
                {
                        arvados.EndpointConfigGet,
@@ -340,7 +338,7 @@ var altMethod = map[string]string{
        "GET":   "HEAD", // Accept HEAD at any GET route
 }
 
-func (rtr *router) addRoute(endpoint arvados.APIEndpoint, defaultOpts func() interface{}, exec RoutableFunc) {
+func (rtr *router) addRoute(endpoint arvados.APIEndpoint, defaultOpts func() interface{}, exec arvados.RoutableFunc) {
        methods := []string{endpoint.Method}
        if alt, ok := altMethod[endpoint.Method]; ok {
                methods = append(methods, alt)
similarity index 67%
rename from lib/controller/localdb/db.go
rename to lib/ctrlctx/db.go
index cad530885315af894df98915fa79df30056b13c0..e8d9248ffcc58f89daf3302debde266b093862ed 100644 (file)
@@ -2,16 +2,22 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package localdb
+package ctrlctx
 
 import (
        "context"
        "errors"
        "sync"
 
-       "git.arvados.org/arvados.git/lib/controller/router"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "github.com/jmoiron/sqlx"
+       _ "github.com/lib/pq"
+)
+
+var (
+       ErrNoTransaction   = errors.New("bug: there is no transaction in this context")
+       ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned")
 )
 
 // WrapCallsInTransactions returns a call wrapper (suitable for
@@ -20,20 +26,20 @@ import (
 //
 // The wrapper calls getdb() to get a database handle before each API
 // call.
-func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(router.RoutableFunc) router.RoutableFunc {
-       return func(origFunc router.RoutableFunc) router.RoutableFunc {
+func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(arvados.RoutableFunc) arvados.RoutableFunc {
+       return func(origFunc arvados.RoutableFunc) arvados.RoutableFunc {
                return func(ctx context.Context, opts interface{}) (_ interface{}, err error) {
-                       ctx, finishtx := starttx(ctx, getdb)
+                       ctx, finishtx := New(ctx, getdb)
                        defer finishtx(&err)
                        return origFunc(ctx, opts)
                }
        }
 }
 
-// ContextWithTransaction returns a child context in which the given
+// NewWithTransaction returns a child context in which the given
 // transaction will be used by any localdb API call that needs one.
 // The caller is responsible for calling Commit or Rollback on tx.
-func ContextWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
+func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context {
        txn := &transaction{tx: tx}
        txn.setup.Do(func() {})
        return context.WithValue(ctx, contextKeyTransaction, txn)
@@ -50,20 +56,20 @@ type transaction struct {
        setup sync.Once
 }
 
-type transactionFinishFunc func(*error)
+type finishFunc func(*error)
 
-// starttx returns a new child context that can be used with
-// currenttx(). It does not open a database transaction until the
-// first call to currenttx().
+// New returns a new child context that can be used with
+// CurrentTx(). It does not open a database transaction until the
+// first call to CurrentTx().
 //
 // The caller must eventually call the returned finishtx() func to
 // commit or rollback the transaction, if any.
 //
 //     func example(ctx context.Context) (err error) {
-//             ctx, finishtx := starttx(ctx, dber)
+//             ctx, finishtx := NewContext(ctx, dber)
 //             defer finishtx(&err)
 //             // ...
-//             tx, err := currenttx(ctx)
+//             tx, err := CurrentTx(ctx)
 //             if err != nil {
 //                     return fmt.Errorf("example: %s", err)
 //             }
@@ -75,17 +81,17 @@ type transactionFinishFunc func(*error)
 //
 // If *err is non-nil, finishtx() rolls back the transaction, and
 // does not modify *err.
-func starttx(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, transactionFinishFunc) {
+func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) {
        txn := &transaction{getdb: getdb}
        return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
                txn.setup.Do(func() {
                        // Using (*sync.Once)Do() prevents a future
-                       // call to currenttx() from opening a
+                       // call to CurrentTx() from opening a
                        // transaction which would never get committed
-                       // or rolled back. If currenttx() hasn't been
+                       // or rolled back. If CurrentTx() hasn't been
                        // called before now, future calls will return
                        // this error.
-                       txn.err = errors.New("refusing to start a transaction after wrapped function already returned")
+                       txn.err = ErrContextFinished
                })
                if txn.tx == nil {
                        // we never [successfully] started a transaction
@@ -100,10 +106,10 @@ func starttx(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error))
        }
 }
 
-func currenttx(ctx context.Context) (*sqlx.Tx, error) {
+func CurrentTx(ctx context.Context) (*sqlx.Tx, error) {
        txn, ok := ctx.Value(contextKeyTransaction).(*transaction)
        if !ok {
-               return nil, errors.New("bug: there is no transaction in this context")
+               return nil, ErrNoTransaction
        }
        txn.setup.Do(func() {
                if db, err := txn.getdb(ctx); err != nil {
similarity index 65%
rename from lib/controller/localdb/db_test.go
rename to lib/ctrlctx/db_test.go
index 741eabad9aa1a14a0c904a3ae71f549fc8dd7206..5361f13c68a4967168082b28f16ab562fce546ee 100644 (file)
@@ -2,37 +2,24 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package localdb
+package ctrlctx
 
 import (
        "context"
        "sync"
        "sync/atomic"
+       "testing"
 
        "git.arvados.org/arvados.git/lib/config"
-       "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "github.com/jmoiron/sqlx"
        _ "github.com/lib/pq"
        check "gopkg.in/check.v1"
 )
 
-// testdb returns a DB connection for the given cluster config.
-func testdb(c *check.C, cluster *arvados.Cluster) *sqlx.DB {
-       db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
-       c.Assert(err, check.IsNil)
-       return db
-}
-
-// testctx returns a context suitable for running a test case in a new
-// transaction, and a rollback func which the caller should call after
-// the test.
-func testctx(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) {
-       tx, err := db.Beginx()
-       c.Assert(err, check.IsNil)
-       return ContextWithTransaction(context.Background(), tx), func() {
-               c.Check(tx.Rollback(), check.IsNil)
-       }
+// Gocheck boilerplate
+func Test(t *testing.T) {
+       check.TestingT(t)
 }
 
 var _ = check.Suite(&DatabaseSuite{})
@@ -48,7 +35,9 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) {
        var getterCalled int64
        getter := func(context.Context) (*sqlx.DB, error) {
                atomic.AddInt64(&getterCalled, 1)
-               return testdb(c, cluster), nil
+               db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
+               c.Assert(err, check.IsNil)
+               return db, nil
        }
        wrapper := WrapCallsInTransactions(getter)
        wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
@@ -58,14 +47,14 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) {
                        i := i
                        wg.Add(1)
                        go func() {
-                               // Concurrent calls to currenttx(),
+                               // Concurrent calls to CurrentTx(),
                                // with different children of the same
                                // parent context, will all return the
                                // same transaction.
                                defer wg.Done()
                                ctx, cancel := context.WithCancel(ctx)
                                defer cancel()
-                               tx, err := currenttx(ctx)
+                               tx, err := CurrentTx(ctx)
                                c.Check(err, check.IsNil)
                                txes[i] = tx
                        }()
@@ -82,8 +71,8 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) {
        c.Check(err, check.IsNil)
        c.Check(getterCalled, check.Equals, int64(1))
 
-       // When a wrapped func returns without calling currenttx(),
-       // calling currenttx() later shouldn't start a new
+       // When a wrapped func returns without calling CurrentTx(),
+       // calling CurrentTx() later shouldn't start a new
        // transaction.
        var savedctx context.Context
        ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
@@ -92,7 +81,7 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) {
        })(context.Background(), "blah")
        c.Check(ok, check.Equals, true)
        c.Check(err, check.IsNil)
-       tx, err := currenttx(savedctx)
+       tx, err := CurrentTx(savedctx)
        c.Check(tx, check.IsNil)
        c.Check(err, check.NotNil)
 }
index c32f88864f88750c00fe896286e147ccd9d061ce..e97a365ad9147f1c69d5faee5dd73cd8d072ceff 100644 (file)
@@ -154,6 +154,8 @@ type LogoutOptions struct {
        ReturnTo string `json:"return_to"` // Redirect to this URL after logging out
 }
 
+type RoutableFunc func(ctx context.Context, opts interface{}) (interface{}, error)
+
 type API interface {
        ConfigGet(ctx context.Context) (json.RawMessage, error)
        Login(ctx context.Context, options LoginOptions) (LoginResponse, error)
diff --git a/sdk/go/arvadostest/db.go b/sdk/go/arvadostest/db.go
new file mode 100644 (file)
index 0000000..41ecfac
--- /dev/null
@@ -0,0 +1,33 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package arvadostest
+
+import (
+       "context"
+
+       "git.arvados.org/arvados.git/lib/ctrlctx"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "github.com/jmoiron/sqlx"
+       _ "github.com/lib/pq"
+       "gopkg.in/check.v1"
+)
+
+// DB returns a DB connection for the given cluster config.
+func DB(c *check.C, cluster *arvados.Cluster) *sqlx.DB {
+       db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
+       c.Assert(err, check.IsNil)
+       return db
+}
+
+// TransactionContext returns a context suitable for running a test
+// case in a new transaction, and a rollback func which the caller
+// should call after the test.
+func TransactionContext(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) {
+       tx, err := db.Beginx()
+       c.Assert(err, check.IsNil)
+       return ctrlctx.NewWithTransaction(context.Background(), tx), func() {
+               c.Check(tx.Rollback(), check.IsNil)
+       }
+}