+++ /dev/null
-// Copyright (C) The Arvados Authors. All rights reserved.
-//
-// SPDX-License-Identifier: AGPL-3.0
-
-package localdb
-
-import (
- "context"
- "sync"
- "sync/atomic"
-
- "git.arvados.org/arvados.git/lib/config"
- "git.arvados.org/arvados.git/sdk/go/arvados"
- "git.arvados.org/arvados.git/sdk/go/ctxlog"
- "github.com/jmoiron/sqlx"
- _ "github.com/lib/pq"
- check "gopkg.in/check.v1"
-)
-
-// testdb returns a DB connection for the given cluster config.
-func testdb(c *check.C, cluster *arvados.Cluster) *sqlx.DB {
- db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
- c.Assert(err, check.IsNil)
- return db
-}
-
-// testctx returns a context suitable for running a test case in a new
-// transaction, and a rollback func which the caller should call after
-// the test.
-func testctx(c *check.C, db *sqlx.DB) (ctx context.Context, rollback func()) {
- tx, err := db.Beginx()
- c.Assert(err, check.IsNil)
- return ContextWithTransaction(context.Background(), tx), func() {
- c.Check(tx.Rollback(), check.IsNil)
- }
-}
-
-var _ = check.Suite(&DatabaseSuite{})
-
-type DatabaseSuite struct{}
-
-func (*DatabaseSuite) TestTransactionContext(c *check.C) {
- cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
- c.Assert(err, check.IsNil)
- cluster, err := cfg.GetCluster("")
- c.Assert(err, check.IsNil)
-
- var getterCalled int64
- getter := func(context.Context) (*sqlx.DB, error) {
- atomic.AddInt64(&getterCalled, 1)
- return testdb(c, cluster), nil
- }
- wrapper := WrapCallsInTransactions(getter)
- wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
- txes := make([]*sqlx.Tx, 20)
- var wg sync.WaitGroup
- for i := range txes {
- i := i
- wg.Add(1)
- go func() {
- // Concurrent calls to currenttx(),
- // with different children of the same
- // parent context, will all return the
- // same transaction.
- defer wg.Done()
- ctx, cancel := context.WithCancel(ctx)
- defer cancel()
- tx, err := currenttx(ctx)
- c.Check(err, check.IsNil)
- txes[i] = tx
- }()
- }
- wg.Wait()
- for i := range txes[1:] {
- c.Check(txes[i], check.Equals, txes[i+1])
- }
- return true, nil
- })
-
- ok, err := wrappedFunc(context.Background(), "blah")
- c.Check(ok, check.Equals, true)
- c.Check(err, check.IsNil)
- c.Check(getterCalled, check.Equals, int64(1))
-
- // When a wrapped func returns without calling currenttx(),
- // calling currenttx() later shouldn't start a new
- // transaction.
- var savedctx context.Context
- ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
- savedctx = ctx
- return true, nil
- })(context.Background(), "blah")
- c.Check(ok, check.Equals, true)
- c.Check(err, check.IsNil)
- tx, err := currenttx(savedctx)
- c.Check(tx, check.IsNil)
- c.Check(err, check.NotNil)
-}