16534: Test goroutine safety.
authorTom Clegg <tom@tomclegg.ca>
Tue, 30 Jun 2020 21:01:54 +0000 (17:01 -0400)
committerTom Clegg <tom@tomclegg.ca>
Tue, 30 Jun 2020 21:01:54 +0000 (17:01 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@tomclegg.ca>

lib/controller/localdb/db.go
lib/controller/localdb/db_test.go

index a864e32d48af30c48c0bfb47c97726fcb5f65665..4f64e63524469cc9e9fb987a4570772eb445fd8b 100644 (file)
@@ -78,9 +78,15 @@ type transactionFinishFunc func(*error)
 func starttx(ctx context.Context, getdb func(context.Context) (*sql.DB, error)) (context.Context, transactionFinishFunc) {
        txn := &transaction{getdb: getdb}
        return context.WithValue(ctx, contextKeyTransaction, txn), func(err *error) {
-               // Ensure another goroutine can't open a transaction
-               // during/after finishtx().
-               txn.setup.Do(func() {})
+               txn.setup.Do(func() {
+                       // Using (*sync.Once)Do() prevents a future
+                       // call to currenttx() from opening a
+                       // transaction which would never get committed
+                       // or rolled back. If currenttx() hasn't been
+                       // called before now, future calls will return
+                       // this error.
+                       txn.err = errors.New("refusing to start a transaction after wrapped function already returned")
+               })
                if txn.tx == nil {
                        // we never [successfully] started a transaction
                        return
index 39ac524a6a4a918dc86008a49c753d7db1841d87..5bab86c60289e688475efa98e6be9061936a800a 100644 (file)
@@ -7,8 +7,12 @@ package localdb
 import (
        "context"
        "database/sql"
+       "sync"
+       "sync/atomic"
 
+       "git.arvados.org/arvados.git/lib/config"
        "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
        _ "github.com/lib/pq"
        check "gopkg.in/check.v1"
 )
@@ -30,3 +34,65 @@ func testctx(c *check.C, db *sql.DB) (ctx context.Context, rollback func()) {
                c.Check(tx.Rollback(), check.IsNil)
        }
 }
+
+var _ = check.Suite(&DatabaseSuite{})
+
+type DatabaseSuite struct{}
+
+func (*DatabaseSuite) TestTransactionContext(c *check.C) {
+       cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
+       c.Assert(err, check.IsNil)
+       cluster, err := cfg.GetCluster("")
+       c.Assert(err, check.IsNil)
+
+       var getterCalled int64
+       getter := func(context.Context) (*sql.DB, error) {
+               atomic.AddInt64(&getterCalled, 1)
+               return testdb(c, cluster), nil
+       }
+       wrapper := WrapCallsInTransactions(getter)
+       wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
+               txes := make([]*sql.Tx, 20)
+               var wg sync.WaitGroup
+               for i := range txes {
+                       i := i
+                       wg.Add(1)
+                       go func() {
+                               // Concurrent calls to currenttx(),
+                               // with different children of the same
+                               // parent context, will all return the
+                               // same transaction.
+                               defer wg.Done()
+                               ctx, cancel := context.WithCancel(ctx)
+                               defer cancel()
+                               tx, err := currenttx(ctx)
+                               c.Check(err, check.IsNil)
+                               txes[i] = tx
+                       }()
+               }
+               wg.Wait()
+               for i := range txes[1:] {
+                       c.Check(txes[i], check.Equals, txes[i+1])
+               }
+               return true, nil
+       })
+
+       ok, err := wrappedFunc(context.Background(), "blah")
+       c.Check(ok, check.Equals, true)
+       c.Check(err, check.IsNil)
+       c.Check(getterCalled, check.Equals, int64(1))
+
+       // When a wrapped func returns without calling currenttx(),
+       // calling currenttx() later shouldn't start a new
+       // transaction.
+       var savedctx context.Context
+       ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
+               savedctx = ctx
+               return true, nil
+       })(context.Background(), "blah")
+       c.Check(ok, check.Equals, true)
+       c.Check(err, check.IsNil)
+       tx, err := currenttx(savedctx)
+       c.Check(tx, check.IsNil)
+       c.Check(err, check.NotNil)
+}