Merge branch 'patch-1' of https://github.com/mr-c/arvados into mr-c-patch-1
[arvados.git] / lib / ctrlctx / db_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ctrlctx
6
7 import (
8         "context"
9         "sync"
10         "sync/atomic"
11         "testing"
12
13         "git.arvados.org/arvados.git/lib/config"
14         "git.arvados.org/arvados.git/sdk/go/ctxlog"
15         "github.com/jmoiron/sqlx"
16         _ "github.com/lib/pq"
17         check "gopkg.in/check.v1"
18 )
19
20 // Gocheck boilerplate
21 func Test(t *testing.T) {
22         check.TestingT(t)
23 }
24
25 var _ = check.Suite(&DatabaseSuite{})
26
27 type DatabaseSuite struct{}
28
29 func (*DatabaseSuite) TestTransactionContext(c *check.C) {
30         cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
31         c.Assert(err, check.IsNil)
32         cluster, err := cfg.GetCluster("")
33         c.Assert(err, check.IsNil)
34
35         var getterCalled int64
36         getter := func(context.Context) (*sqlx.DB, error) {
37                 atomic.AddInt64(&getterCalled, 1)
38                 db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
39                 c.Assert(err, check.IsNil)
40                 return db, nil
41         }
42         wrapper := WrapCallsInTransactions(getter)
43         wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
44                 txes := make([]*sqlx.Tx, 20)
45                 var wg sync.WaitGroup
46                 for i := range txes {
47                         i := i
48                         wg.Add(1)
49                         go func() {
50                                 // Concurrent calls to CurrentTx(),
51                                 // with different children of the same
52                                 // parent context, will all return the
53                                 // same transaction.
54                                 defer wg.Done()
55                                 ctx, cancel := context.WithCancel(ctx)
56                                 defer cancel()
57                                 tx, err := CurrentTx(ctx)
58                                 c.Check(err, check.IsNil)
59                                 txes[i] = tx
60                         }()
61                 }
62                 wg.Wait()
63                 for i := range txes[1:] {
64                         c.Check(txes[i], check.Equals, txes[i+1])
65                 }
66                 return true, nil
67         })
68
69         ok, err := wrappedFunc(context.Background(), "blah")
70         c.Check(ok, check.Equals, true)
71         c.Check(err, check.IsNil)
72         c.Check(getterCalled, check.Equals, int64(1))
73
74         // When a wrapped func returns without calling CurrentTx(),
75         // calling CurrentTx() later shouldn't start a new
76         // transaction.
77         var savedctx context.Context
78         ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
79                 savedctx = ctx
80                 return true, nil
81         })(context.Background(), "blah")
82         c.Check(ok, check.Equals, true)
83         c.Check(err, check.IsNil)
84         tx, err := CurrentTx(savedctx)
85         c.Check(tx, check.IsNil)
86         c.Check(err, check.NotNil)
87 }