16534: Exempt go.mod and go.sum from license header check.
[arvados.git] / lib / controller / localdb / db_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package localdb
6
7 import (
8         "context"
9         "sync"
10         "sync/atomic"
11
12         "git.arvados.org/arvados.git/lib/config"
13         "git.arvados.org/arvados.git/sdk/go/arvados"
14         "git.arvados.org/arvados.git/sdk/go/ctxlog"
15         "github.com/jmoiron/sqlx"
16         _ "github.com/lib/pq"
17         check "gopkg.in/check.v1"
18 )
19
20 // testdb returns a DB connection for the given cluster config.
21 func testdb(c *check.C, cluster *arvados.Cluster) *sqlx.DB {
22         db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
23         c.Assert(err, check.IsNil)
24         return db
25 }
26
27 // testctx returns a context suitable for running a test case in a new
28 // transaction, and a rollback func which the caller should call after
29 // the test.
30 func testctx(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) {
31         tx, err := db.Beginx()
32         c.Assert(err, check.IsNil)
33         return ContextWithTransaction(context.Background(), tx), func() {
34                 c.Check(tx.Rollback(), check.IsNil)
35         }
36 }
37
38 var _ = check.Suite(&DatabaseSuite{})
39
40 type DatabaseSuite struct{}
41
42 func (*DatabaseSuite) TestTransactionContext(c *check.C) {
43         cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
44         c.Assert(err, check.IsNil)
45         cluster, err := cfg.GetCluster("")
46         c.Assert(err, check.IsNil)
47
48         var getterCalled int64
49         getter := func(context.Context) (*sqlx.DB, error) {
50                 atomic.AddInt64(&getterCalled, 1)
51                 return testdb(c, cluster), nil
52         }
53         wrapper := WrapCallsInTransactions(getter)
54         wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
55                 txes := make([]*sqlx.Tx, 20)
56                 var wg sync.WaitGroup
57                 for i := range txes {
58                         i := i
59                         wg.Add(1)
60                         go func() {
61                                 // Concurrent calls to currenttx(),
62                                 // with different children of the same
63                                 // parent context, will all return the
64                                 // same transaction.
65                                 defer wg.Done()
66                                 ctx, cancel := context.WithCancel(ctx)
67                                 defer cancel()
68                                 tx, err := currenttx(ctx)
69                                 c.Check(err, check.IsNil)
70                                 txes[i] = tx
71                         }()
72                 }
73                 wg.Wait()
74                 for i := range txes[1:] {
75                         c.Check(txes[i], check.Equals, txes[i+1])
76                 }
77                 return true, nil
78         })
79
80         ok, err := wrappedFunc(context.Background(), "blah")
81         c.Check(ok, check.Equals, true)
82         c.Check(err, check.IsNil)
83         c.Check(getterCalled, check.Equals, int64(1))
84
85         // When a wrapped func returns without calling currenttx(),
86         // calling currenttx() later shouldn't start a new
87         // transaction.
88         var savedctx context.Context
89         ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
90                 savedctx = ctx
91                 return true, nil
92         })(context.Background(), "blah")
93         c.Check(ok, check.Equals, true)
94         c.Check(err, check.IsNil)
95         tx, err := currenttx(savedctx)
96         c.Check(tx, check.IsNil)
97         c.Check(err, check.NotNil)
98 }