19527: Load training-set flag from samples.csv.
[lightning.git] / batchargs.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "context"
9         "flag"
10         "fmt"
11         _ "net/http/pprof"
12         "sync"
13 )
14
15 type batchArgs struct {
16         batch   int
17         batches int
18 }
19
20 func (b *batchArgs) Flags(flags *flag.FlagSet) {
21         flags.IntVar(&b.batches, "batches", 1, "number of batches")
22         flags.IntVar(&b.batch, "batch", -1, "only do `N`th batch (-1 = all)")
23 }
24
25 func (b *batchArgs) Args(batch int) []string {
26         return []string{
27                 fmt.Sprintf("-batches=%d", b.batches),
28                 fmt.Sprintf("-batch=%d", batch),
29         }
30 }
31
32 // RunBatches calls runFunc once per batch, and returns a slice of
33 // return values and the first returned error, if any.
34 func (b *batchArgs) RunBatches(ctx context.Context, runFunc func(context.Context, int) (string, error)) ([]string, error) {
35         ctx, cancel := context.WithCancel(ctx)
36         defer cancel()
37         outputs := make([]string, b.batches)
38         var wg WaitGroup
39         for batch := 0; batch < b.batches; batch++ {
40                 if b.batch >= 0 && b.batch != batch {
41                         continue
42                 }
43                 batch := batch
44                 wg.Add(1)
45                 go func() {
46                         defer wg.Done()
47                         out, err := runFunc(ctx, batch)
48                         outputs[batch] = out
49                         if err != nil {
50                                 wg.Error(err)
51                                 cancel()
52                         }
53                 }()
54         }
55         err := wg.Wait()
56         if b.batch >= 0 {
57                 outputs = outputs[b.batch : b.batch+1]
58         }
59         return outputs, err
60 }
61
62 func (b *batchArgs) Slice(in []string) []string {
63         if b.batches == 0 || b.batch < 0 {
64                 return in
65         }
66         batchsize := (len(in) + b.batches - 1) / b.batches
67         out := in[batchsize*b.batch:]
68         if len(out) > batchsize {
69                 out = out[:batchsize]
70         }
71         return out
72 }
73
74 type WaitGroup struct {
75         sync.WaitGroup
76         err     error
77         errOnce sync.Once
78 }
79
80 func (wg *WaitGroup) Error(err error) {
81         if err != nil {
82                 wg.errOnce.Do(func() { wg.err = err })
83         }
84 }
85
86 func (wg *WaitGroup) Wait() error {
87         wg.WaitGroup.Wait()
88         return wg.err
89 }