Add missing file for batchArgs.
[lightning.git] / batchargs.go
1 package lightning
2
3 import (
4         "context"
5         "flag"
6         "fmt"
7         _ "net/http/pprof"
8         "sync"
9 )
10
11 type batchArgs struct {
12         batch   int
13         batches int
14 }
15
16 func (b *batchArgs) Flags(flags *flag.FlagSet) {
17         flags.IntVar(&b.batches, "batches", 1, "number of batches")
18         flags.IntVar(&b.batch, "batch", -1, "only do `N`th batch (-1 = all)")
19 }
20
21 func (b *batchArgs) Args(batch int) []string {
22         return []string{
23                 fmt.Sprintf("-batches=%d", b.batches),
24                 fmt.Sprintf("-batch=%d", batch),
25         }
26 }
27
28 // RunBatches calls runFunc once per batch, and returns a slice of
29 // return values and the first returned error, if any.
30 func (b *batchArgs) RunBatches(ctx context.Context, runFunc func(context.Context, int) (string, error)) ([]string, error) {
31         ctx, cancel := context.WithCancel(ctx)
32         defer cancel()
33         outputs := make([]string, b.batches)
34         var wg WaitGroup
35         for batch := 0; batch < b.batches; batch++ {
36                 if b.batch >= 0 && b.batch != batch {
37                         continue
38                 }
39                 batch := batch
40                 wg.Add(1)
41                 go func() {
42                         defer wg.Done()
43                         out, err := runFunc(ctx, batch)
44                         outputs[batch] = out
45                         if err != nil {
46                                 wg.Error(err)
47                                 cancel()
48                         }
49                 }()
50         }
51         err := wg.Wait()
52         if b.batch >= 0 {
53                 outputs = outputs[b.batch : b.batch+1]
54         }
55         return outputs, err
56 }
57
58 func (b *batchArgs) Slice(in []string) []string {
59         if b.batches == 0 || b.batch < 0 {
60                 return in
61         }
62         batchsize := (len(in) + b.batches - 1) / b.batches
63         out := in[batchsize*b.batch:]
64         if len(out) > batchsize {
65                 out = out[:batchsize]
66         }
67         return out
68 }
69
70 type WaitGroup struct {
71         sync.WaitGroup
72         err     error
73         errOnce sync.Once
74 }
75
76 func (wg *WaitGroup) Error(err error) {
77         if err != nil {
78                 wg.errOnce.Do(func() { wg.err = err })
79         }
80 }
81
82 func (wg *WaitGroup) Wait() error {
83         wg.WaitGroup.Wait()
84         return wg.err
85 }