From a2b994f10fd73bdd882e691854239fc2d3b2e3a0 Mon Sep 17 00:00:00 2001 From: Tom Clegg Date: Mon, 6 Jul 2020 16:00:35 -0400 Subject: [PATCH] 16534: Move controller transaction/context code to its own package. Arvados-DCO-1.1-Signed-off-by: Tom Clegg --- lib/controller/handler.go | 4 +- lib/controller/localdb/login.go | 3 +- lib/controller/localdb/login_ldap_test.go | 14 ++++-- lib/controller/router/router.go | 10 ++--- lib/{controller/localdb => ctrlctx}/db.go | 44 +++++++++++-------- .../localdb => ctrlctx}/db_test.go | 37 ++++++---------- sdk/go/arvados/api.go | 2 + sdk/go/arvadostest/db.go | 33 ++++++++++++++ 8 files changed, 91 insertions(+), 56 deletions(-) rename lib/{controller/localdb => ctrlctx}/db.go (67%) rename lib/{controller/localdb => ctrlctx}/db_test.go (65%) create mode 100644 sdk/go/arvadostest/db.go diff --git a/lib/controller/handler.go b/lib/controller/handler.go index 4d4963413b..e742bbc59b 100644 --- a/lib/controller/handler.go +++ b/lib/controller/handler.go @@ -15,9 +15,9 @@ import ( "time" "git.arvados.org/arvados.git/lib/controller/federation" - "git.arvados.org/arvados.git/lib/controller/localdb" "git.arvados.org/arvados.git/lib/controller/railsproxy" "git.arvados.org/arvados.git/lib/controller/router" + "git.arvados.org/arvados.git/lib/ctrlctx" "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/ctxlog" "git.arvados.org/arvados.git/sdk/go/health" @@ -87,7 +87,7 @@ func (h *Handler) setup() { Routes: health.Routes{"ping": func() error { _, err := h.db(context.TODO()); return err }}, }) - rtr := router.New(federation.New(h.Cluster), localdb.WrapCallsInTransactions(h.db)) + rtr := router.New(federation.New(h.Cluster), ctrlctx.WrapCallsInTransactions(h.db)) mux.Handle("/arvados/v1/config", rtr) mux.Handle("/"+arvados.EndpointUserAuthenticate.Path, rtr) diff --git a/lib/controller/localdb/login.go b/lib/controller/localdb/login.go index dc2c7c875c..ee1ea56924 100644 --- a/lib/controller/localdb/login.go +++ b/lib/controller/localdb/login.go @@ -15,6 +15,7 @@ import ( "strings" "git.arvados.org/arvados.git/lib/controller/rpc" + "git.arvados.org/arvados.git/lib/ctrlctx" "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/auth" "git.arvados.org/arvados.git/sdk/go/httpserver" @@ -117,7 +118,7 @@ func createAPIClientAuthorization(ctx context.Context, conn *rpc.Conn, rootToken return } token := target.Query().Get("api_token") - tx, err := currenttx(ctx) + tx, err := ctrlctx.CurrentTx(ctx) if err != nil { return } diff --git a/lib/controller/localdb/login_ldap_test.go b/lib/controller/localdb/login_ldap_test.go index 15343ab322..0c94fa6c0e 100644 --- a/lib/controller/localdb/login_ldap_test.go +++ b/lib/controller/localdb/login_ldap_test.go @@ -12,6 +12,7 @@ import ( "git.arvados.org/arvados.git/lib/config" "git.arvados.org/arvados.git/lib/controller/railsproxy" + "git.arvados.org/arvados.git/lib/ctrlctx" "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/arvadostest" "git.arvados.org/arvados.git/sdk/go/auth" @@ -31,7 +32,7 @@ type LDAPSuite struct { // transaction context ctx context.Context - rollback func() + rollback func() error } func (s *LDAPSuite) TearDownSuite(c *check.C) { @@ -91,15 +92,20 @@ func (s *LDAPSuite) SetUpSuite(c *check.C) { Cluster: s.cluster, RailsProxy: railsproxy.NewConn(s.cluster), } - s.db = testdb(c, s.cluster) + s.db = arvadostest.DB(c, s.cluster) } func (s *LDAPSuite) SetUpTest(c *check.C) { - s.ctx, s.rollback = testctx(c, s.db) + tx, err := s.db.Beginx() + c.Assert(err, check.IsNil) + s.ctx = ctrlctx.NewWithTransaction(context.Background(), tx) + s.rollback = tx.Rollback } func (s *LDAPSuite) TearDownTest(c *check.C) { - s.rollback() + if s.rollback != nil { + s.rollback() + } } func (s *LDAPSuite) TestLoginSuccess(c *check.C) { diff --git a/lib/controller/router/router.go b/lib/controller/router/router.go index 29c81ac5ca..ed638fe7e8 100644 --- a/lib/controller/router/router.go +++ b/lib/controller/router/router.go @@ -21,7 +21,7 @@ import ( type router struct { mux *mux.Router backend arvados.API - wrapCalls func(RoutableFunc) RoutableFunc + wrapCalls func(arvados.RoutableFunc) arvados.RoutableFunc } // New returns a new router (which implements the http.Handler @@ -32,7 +32,7 @@ type router struct { // the returned method is used in its place. This can be used to // install hooks before and after each API call and alter responses; // see localdb.WrapCallsInTransaction for an example. -func New(backend arvados.API, wrapCalls func(RoutableFunc) RoutableFunc) *router { +func New(backend arvados.API, wrapCalls func(arvados.RoutableFunc) arvados.RoutableFunc) *router { rtr := &router{ mux: mux.NewRouter(), backend: backend, @@ -42,13 +42,11 @@ func New(backend arvados.API, wrapCalls func(RoutableFunc) RoutableFunc) *router return rtr } -type RoutableFunc func(ctx context.Context, opts interface{}) (interface{}, error) - func (rtr *router) addRoutes() { for _, route := range []struct { endpoint arvados.APIEndpoint defaultOpts func() interface{} - exec RoutableFunc + exec arvados.RoutableFunc }{ { arvados.EndpointConfigGet, @@ -340,7 +338,7 @@ var altMethod = map[string]string{ "GET": "HEAD", // Accept HEAD at any GET route } -func (rtr *router) addRoute(endpoint arvados.APIEndpoint, defaultOpts func() interface{}, exec RoutableFunc) { +func (rtr *router) addRoute(endpoint arvados.APIEndpoint, defaultOpts func() interface{}, exec arvados.RoutableFunc) { methods := []string{endpoint.Method} if alt, ok := altMethod[endpoint.Method]; ok { methods = append(methods, alt) diff --git a/lib/controller/localdb/db.go b/lib/ctrlctx/db.go similarity index 67% rename from lib/controller/localdb/db.go rename to lib/ctrlctx/db.go index cad5308853..e8d9248ffc 100644 --- a/lib/controller/localdb/db.go +++ b/lib/ctrlctx/db.go @@ -2,16 +2,22 @@ // // SPDX-License-Identifier: AGPL-3.0 -package localdb +package ctrlctx import ( "context" "errors" "sync" - "git.arvados.org/arvados.git/lib/controller/router" + "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/ctxlog" "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" +) + +var ( + ErrNoTransaction = errors.New("bug: there is no transaction in this context") + ErrContextFinished = errors.New("refusing to start a transaction after wrapped function already returned") ) // WrapCallsInTransactions returns a call wrapper (suitable for @@ -20,20 +26,20 @@ import ( // // The wrapper calls getdb() to get a database handle before each API // call. -func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(router.RoutableFunc) router.RoutableFunc { - return func(origFunc router.RoutableFunc) router.RoutableFunc { +func WrapCallsInTransactions(getdb func(context.Context) (*sqlx.DB, error)) func(arvados.RoutableFunc) arvados.RoutableFunc { + return func(origFunc arvados.RoutableFunc) arvados.RoutableFunc { return func(ctx context.Context, opts interface{}) (_ interface{}, err error) { - ctx, finishtx := starttx(ctx, getdb) + ctx, finishtx := New(ctx, getdb) defer finishtx(&err) return origFunc(ctx, opts) } } } -// ContextWithTransaction returns a child context in which the given +// NewWithTransaction returns a child context in which the given // transaction will be used by any localdb API call that needs one. // The caller is responsible for calling Commit or Rollback on tx. -func ContextWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context { +func NewWithTransaction(ctx context.Context, tx *sqlx.Tx) context.Context { txn := &transaction{tx: tx} txn.setup.Do(func() {}) return context.WithValue(ctx, contextKeyTransaction, txn) @@ -50,20 +56,20 @@ type transaction struct { setup sync.Once } -type transactionFinishFunc func(*error) +type finishFunc func(*error) -// starttx returns a new child context that can be used with -// currenttx(). It does not open a database transaction until the -// first call to currenttx(). +// New returns a new child context that can be used with +// CurrentTx(). It does not open a database transaction until the +// first call to CurrentTx(). // // The caller must eventually call the returned finishtx() func to // commit or rollback the transaction, if any. // // func example(ctx context.Context) (err error) { -// ctx, finishtx := starttx(ctx, dber) +// ctx, finishtx := NewContext(ctx, dber) // defer finishtx(&err) // // ... -// tx, err := currenttx(ctx) +// tx, err := CurrentTx(ctx) // if err != nil { // return fmt.Errorf("example: %s", err) // } @@ -75,17 +81,17 @@ type transactionFinishFunc func(*error) // // If *err is non-nil, finishtx() rolls back the transaction, and // does not modify *err. -func starttx(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, transactionFinishFunc) { +func New(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) (context.Context, finishFunc) { txn := &transaction{getdb: getdb} return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) { txn.setup.Do(func() { // Using (*sync.Once)Do() prevents a future - // call to currenttx() from opening a + // call to CurrentTx() from opening a // transaction which would never get committed - // or rolled back. If currenttx() hasn't been + // or rolled back. If CurrentTx() hasn't been // called before now, future calls will return // this error. - txn.err = errors.New("refusing to start a transaction after wrapped function already returned") + txn.err = ErrContextFinished }) if txn.tx == nil { // we never [successfully] started a transaction @@ -100,10 +106,10 @@ func starttx(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) } } -func currenttx(ctx context.Context) (*sqlx.Tx, error) { +func CurrentTx(ctx context.Context) (*sqlx.Tx, error) { txn, ok := ctx.Value(contextKeyTransaction).(*transaction) if !ok { - return nil, errors.New("bug: there is no transaction in this context") + return nil, ErrNoTransaction } txn.setup.Do(func() { if db, err := txn.getdb(ctx); err != nil { diff --git a/lib/controller/localdb/db_test.go b/lib/ctrlctx/db_test.go similarity index 65% rename from lib/controller/localdb/db_test.go rename to lib/ctrlctx/db_test.go index 741eabad9a..5361f13c68 100644 --- a/lib/controller/localdb/db_test.go +++ b/lib/ctrlctx/db_test.go @@ -2,37 +2,24 @@ // // SPDX-License-Identifier: AGPL-3.0 -package localdb +package ctrlctx import ( "context" "sync" "sync/atomic" + "testing" "git.arvados.org/arvados.git/lib/config" - "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/ctxlog" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" check "gopkg.in/check.v1" ) -// testdb returns a DB connection for the given cluster config. -func testdb(c *check.C, cluster *arvados.Cluster) *sqlx.DB { - db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String()) - c.Assert(err, check.IsNil) - return db -} - -// testctx returns a context suitable for running a test case in a new -// transaction, and a rollback func which the caller should call after -// the test. -func testctx(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) { - tx, err := db.Beginx() - c.Assert(err, check.IsNil) - return ContextWithTransaction(context.Background(), tx), func() { - c.Check(tx.Rollback(), check.IsNil) - } +// Gocheck boilerplate +func Test(t *testing.T) { + check.TestingT(t) } var _ = check.Suite(&DatabaseSuite{}) @@ -48,7 +35,9 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) { var getterCalled int64 getter := func(context.Context) (*sqlx.DB, error) { atomic.AddInt64(&getterCalled, 1) - return testdb(c, cluster), nil + db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String()) + c.Assert(err, check.IsNil) + return db, nil } wrapper := WrapCallsInTransactions(getter) wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) { @@ -58,14 +47,14 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) { i := i wg.Add(1) go func() { - // Concurrent calls to currenttx(), + // Concurrent calls to CurrentTx(), // with different children of the same // parent context, will all return the // same transaction. defer wg.Done() ctx, cancel := context.WithCancel(ctx) defer cancel() - tx, err := currenttx(ctx) + tx, err := CurrentTx(ctx) c.Check(err, check.IsNil) txes[i] = tx }() @@ -82,8 +71,8 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) { c.Check(err, check.IsNil) c.Check(getterCalled, check.Equals, int64(1)) - // When a wrapped func returns without calling currenttx(), - // calling currenttx() later shouldn't start a new + // When a wrapped func returns without calling CurrentTx(), + // calling CurrentTx() later shouldn't start a new // transaction. var savedctx context.Context ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) { @@ -92,7 +81,7 @@ func (*DatabaseSuite) TestTransactionContext(c *check.C) { })(context.Background(), "blah") c.Check(ok, check.Equals, true) c.Check(err, check.IsNil) - tx, err := currenttx(savedctx) + tx, err := CurrentTx(savedctx) c.Check(tx, check.IsNil) c.Check(err, check.NotNil) } diff --git a/sdk/go/arvados/api.go b/sdk/go/arvados/api.go index c32f88864f..e97a365ad9 100644 --- a/sdk/go/arvados/api.go +++ b/sdk/go/arvados/api.go @@ -154,6 +154,8 @@ type LogoutOptions struct { ReturnTo string `json:"return_to"` // Redirect to this URL after logging out } +type RoutableFunc func(ctx context.Context, opts interface{}) (interface{}, error) + type API interface { ConfigGet(ctx context.Context) (json.RawMessage, error) Login(ctx context.Context, options LoginOptions) (LoginResponse, error) diff --git a/sdk/go/arvadostest/db.go b/sdk/go/arvadostest/db.go new file mode 100644 index 0000000000..41ecfacc48 --- /dev/null +++ b/sdk/go/arvadostest/db.go @@ -0,0 +1,33 @@ +// Copyright (C) The Arvados Authors. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + +package arvadostest + +import ( + "context" + + "git.arvados.org/arvados.git/lib/ctrlctx" + "git.arvados.org/arvados.git/sdk/go/arvados" + "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" + "gopkg.in/check.v1" +) + +// DB returns a DB connection for the given cluster config. +func DB(c *check.C, cluster *arvados.Cluster) *sqlx.DB { + db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String()) + c.Assert(err, check.IsNil) + return db +} + +// TransactionContext returns a context suitable for running a test +// case in a new transaction, and a rollback func which the caller +// should call after the test. +func TransactionContext(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) { + tx, err := db.Beginx() + c.Assert(err, check.IsNil) + return ctrlctx.NewWithTransaction(context.Background(), tx), func() { + c.Check(tx.Rollback(), check.IsNil) + } +} -- 2.30.2