1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
13 "git.arvados.org/arvados.git/lib/config"
14 "git.arvados.org/arvados.git/sdk/go/ctxlog"
15 "github.com/jmoiron/sqlx"
17 check "gopkg.in/check.v1"
20 // Gocheck boilerplate
21 func Test(t *testing.T) {
25 var _ = check.Suite(&DatabaseSuite{})
27 type DatabaseSuite struct{}
29 func (*DatabaseSuite) TestTransactionContext(c *check.C) {
30 cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
31 c.Assert(err, check.IsNil)
32 cluster, err := cfg.GetCluster("")
33 c.Assert(err, check.IsNil)
35 var getterCalled int64
36 getter := func(context.Context) (*sqlx.DB, error) {
37 atomic.AddInt64(&getterCalled, 1)
38 db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
39 c.Assert(err, check.IsNil)
42 wrapper := WrapCallsInTransactions(getter)
43 wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
44 txes := make([]*sqlx.Tx, 20)
50 // Concurrent calls to CurrentTx(),
51 // with different children of the same
52 // parent context, will all return the
55 ctx, cancel := context.WithCancel(ctx)
57 tx, err := CurrentTx(ctx)
58 c.Check(err, check.IsNil)
63 for i := range txes[1:] {
64 c.Check(txes[i], check.Equals, txes[i+1])
69 ok, err := wrappedFunc(context.Background(), "blah")
70 c.Check(ok, check.Equals, true)
71 c.Check(err, check.IsNil)
72 c.Check(getterCalled, check.Equals, int64(1))
74 // When a wrapped func returns without calling CurrentTx(),
75 // calling CurrentTx() later shouldn't start a new
77 var savedctx context.Context
78 ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
81 })(context.Background(), "blah")
82 c.Check(ok, check.Equals, true)
83 c.Check(err, check.IsNil)
84 tx, err := CurrentTx(savedctx)
85 c.Check(tx, check.IsNil)
86 c.Check(err, check.NotNil)