// Copyright (C) The Lightning Authors. All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0

package lightning

import (
	"context"
	"flag"
	"fmt"
	_ "net/http/pprof"
	"sync"
)

type batchArgs struct {
	batch   int
	batches int
}

func (b *batchArgs) Flags(flags *flag.FlagSet) {
	flags.IntVar(&b.batches, "batches", 1, "number of batches")
	flags.IntVar(&b.batch, "batch", -1, "only do `N`th batch (-1 = all)")
}

func (b *batchArgs) Args(batch int) []string {
	return []string{
		fmt.Sprintf("-batches=%d", b.batches),
		fmt.Sprintf("-batch=%d", batch),
	}
}

// RunBatches calls runFunc once per batch, and returns a slice of
// return values and the first returned error, if any.
func (b *batchArgs) RunBatches(ctx context.Context, runFunc func(context.Context, int) (string, error)) ([]string, error) {
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()
	outputs := make([]string, b.batches)
	var wg WaitGroup
	for batch := 0; batch < b.batches; batch++ {
		if b.batch >= 0 && b.batch != batch {
			continue
		}
		batch := batch
		wg.Add(1)
		go func() {
			defer wg.Done()
			out, err := runFunc(ctx, batch)
			outputs[batch] = out
			if err != nil {
				wg.Error(err)
				cancel()
			}
		}()
	}
	err := wg.Wait()
	if b.batch >= 0 {
		outputs = outputs[b.batch : b.batch+1]
	}
	return outputs, err
}

func (b *batchArgs) Slice(in []string) []string {
	if b.batches == 0 || b.batch < 0 {
		return in
	}
	batchsize := (len(in) + b.batches - 1) / b.batches
	out := in[batchsize*b.batch:]
	if len(out) > batchsize {
		out = out[:batchsize]
	}
	return out
}

type WaitGroup struct {
	sync.WaitGroup
	err     error
	errOnce sync.Once
}

func (wg *WaitGroup) Error(err error) {
	if err != nil {
		wg.errOnce.Do(func() { wg.err = err })
	}
}

func (wg *WaitGroup) Wait() error {
	wg.WaitGroup.Wait()
	return wg.err
}