--- /dev/null
+package lightning
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ _ "net/http/pprof"
+ "sync"
+)
+
+type batchArgs struct {
+ batch int
+ batches int
+}
+
+func (b *batchArgs) Flags(flags *flag.FlagSet) {
+ flags.IntVar(&b.batches, "batches", 1, "number of batches")
+ flags.IntVar(&b.batch, "batch", -1, "only do `N`th batch (-1 = all)")
+}
+
+func (b *batchArgs) Args(batch int) []string {
+ return []string{
+ fmt.Sprintf("-batches=%d", b.batches),
+ fmt.Sprintf("-batch=%d", batch),
+ }
+}
+
+// RunBatches calls runFunc once per batch, and returns a slice of
+// return values and the first returned error, if any.
+func (b *batchArgs) RunBatches(ctx context.Context, runFunc func(context.Context, int) (string, error)) ([]string, error) {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ outputs := make([]string, b.batches)
+ var wg WaitGroup
+ for batch := 0; batch < b.batches; batch++ {
+ if b.batch >= 0 && b.batch != batch {
+ continue
+ }
+ batch := batch
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ out, err := runFunc(ctx, batch)
+ outputs[batch] = out
+ if err != nil {
+ wg.Error(err)
+ cancel()
+ }
+ }()
+ }
+ err := wg.Wait()
+ if b.batch >= 0 {
+ outputs = outputs[b.batch : b.batch+1]
+ }
+ return outputs, err
+}
+
+func (b *batchArgs) Slice(in []string) []string {
+ if b.batches == 0 || b.batch < 0 {
+ return in
+ }
+ batchsize := (len(in) + b.batches - 1) / b.batches
+ out := in[batchsize*b.batch:]
+ if len(out) > batchsize {
+ out = out[:batchsize]
+ }
+ return out
+}
+
+type WaitGroup struct {
+ sync.WaitGroup
+ err error
+ errOnce sync.Once
+}
+
+func (wg *WaitGroup) Error(err error) {
+ if err != nil {
+ wg.errOnce.Do(func() { wg.err = err })
+ }
+}
+
+func (wg *WaitGroup) Wait() error {
+ wg.WaitGroup.Wait()
+ return wg.err
+}