1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
12 "git.arvados.org/arvados.git/lib/config"
13 "git.arvados.org/arvados.git/sdk/go/arvados"
14 "git.arvados.org/arvados.git/sdk/go/ctxlog"
15 "github.com/jmoiron/sqlx"
17 check "gopkg.in/check.v1"
20 // testdb returns a DB connection for the given cluster config.
21 func testdb(c *check.C, cluster *arvados.Cluster) *sqlx.DB {
22 db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
23 c.Assert(err, check.IsNil)
27 // testctx returns a context suitable for running a test case in a new
28 // transaction, and a rollback func which the caller should call after
30 func testctx(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) {
31 tx, err := db.Beginx()
32 c.Assert(err, check.IsNil)
33 return ContextWithTransaction(context.Background(), tx), func() {
34 c.Check(tx.Rollback(), check.IsNil)
38 var _ = check.Suite(&DatabaseSuite{})
40 type DatabaseSuite struct{}
42 func (*DatabaseSuite) TestTransactionContext(c *check.C) {
43 cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
44 c.Assert(err, check.IsNil)
45 cluster, err := cfg.GetCluster("")
46 c.Assert(err, check.IsNil)
48 var getterCalled int64
49 getter := func(context.Context) (*sqlx.DB, error) {
50 atomic.AddInt64(&getterCalled, 1)
51 return testdb(c, cluster), nil
53 wrapper := WrapCallsInTransactions(getter)
54 wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
55 txes := make([]*sqlx.Tx, 20)
61 // Concurrent calls to currenttx(),
62 // with different children of the same
63 // parent context, will all return the
66 ctx, cancel := context.WithCancel(ctx)
68 tx, err := currenttx(ctx)
69 c.Check(err, check.IsNil)
74 for i := range txes[1:] {
75 c.Check(txes[i], check.Equals, txes[i+1])
80 ok, err := wrappedFunc(context.Background(), "blah")
81 c.Check(ok, check.Equals, true)
82 c.Check(err, check.IsNil)
83 c.Check(getterCalled, check.Equals, int64(1))
85 // When a wrapped func returns without calling currenttx(),
86 // calling currenttx() later shouldn't start a new
88 var savedctx context.Context
89 ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
92 })(context.Background(), "blah")
93 c.Check(ok, check.Equals, true)
94 c.Check(err, check.IsNil)
95 tx, err := currenttx(savedctx)
96 c.Check(tx, check.IsNil)
97 c.Check(err, check.NotNil)