// Copyright (C) The Arvados Authors. All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0

package ctrlctx

import (
	"context"
	"sync"
	"sync/atomic"
	"testing"

	"git.arvados.org/arvados.git/lib/config"
	"git.arvados.org/arvados.git/sdk/go/ctxlog"
	"github.com/jmoiron/sqlx"
	_ "github.com/lib/pq"
	check "gopkg.in/check.v1"
)

// Gocheck boilerplate
func Test(t *testing.T) {
	check.TestingT(t)
}

var _ = check.Suite(&DatabaseSuite{})

type DatabaseSuite struct{}

func (*DatabaseSuite) TestTransactionContext(c *check.C) {
	cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
	c.Assert(err, check.IsNil)
	cluster, err := cfg.GetCluster("")
	c.Assert(err, check.IsNil)

	var getterCalled int64
	getter := func(context.Context) (*sqlx.DB, error) {
		atomic.AddInt64(&getterCalled, 1)
		db, err := sqlx.Open("postgres", cluster.PostgreSQL.Connection.String())
		c.Assert(err, check.IsNil)
		return db, nil
	}
	wrapper := WrapCallsInTransactions(getter)
	wrappedFunc := wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
		txes := make([]*sqlx.Tx, 20)
		var wg sync.WaitGroup
		for i := range txes {
			i := i
			wg.Add(1)
			go func() {
				// Concurrent calls to CurrentTx(),
				// with different children of the same
				// parent context, will all return the
				// same transaction.
				defer wg.Done()
				ctx, cancel := context.WithCancel(ctx)
				defer cancel()
				tx, err := CurrentTx(ctx)
				c.Check(err, check.IsNil)
				txes[i] = tx
			}()
		}
		wg.Wait()
		for i := range txes[1:] {
			c.Check(txes[i], check.Equals, txes[i+1])
		}
		return true, nil
	})

	ok, err := wrappedFunc(context.Background(), "blah")
	c.Check(ok, check.Equals, true)
	c.Check(err, check.IsNil)
	c.Check(getterCalled, check.Equals, int64(1))

	// When a wrapped func returns without calling CurrentTx(),
	// calling CurrentTx() later shouldn't start a new
	// transaction.
	var savedctx context.Context
	ok, err = wrapper(func(ctx context.Context, opts interface{}) (interface{}, error) {
		savedctx = ctx
		return true, nil
	})(context.Background(), "blah")
	c.Check(ok, check.Equals, true)
	c.Check(err, check.IsNil)
	tx, err := CurrentTx(savedctx)
	c.Check(tx, check.IsNil)
	c.Check(err, check.NotNil)
}