Change slice-numpy-onehot parameter
[lightning.git] / merge.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bufio"
9         "context"
10         "encoding/gob"
11         "errors"
12         "flag"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "net/http"
17         _ "net/http/pprof"
18         "os"
19         "strings"
20         "sync"
21
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "github.com/klauspost/pgzip"
24         log "github.com/sirupsen/logrus"
25 )
26
27 type merger struct {
28         stdin   io.Reader
29         inputs  []string
30         output  io.Writer
31         tagSet  [][]byte
32         tilelib *tileLibrary
33         mapped  map[string]map[tileLibRef]tileVariantID
34         mtxTags sync.Mutex
35         errs    chan error
36 }
37
38 func (cmd *merger) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
39         var err error
40         defer func() {
41                 if err != nil {
42                         fmt.Fprintf(stderr, "%s\n", err)
43                 }
44         }()
45         flags := flag.NewFlagSet("", flag.ContinueOnError)
46         flags.SetOutput(stderr)
47         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
48         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
49         projectUUID := flags.String("project", "", "project `UUID` for output data")
50         priority := flags.Int("priority", 500, "container request priority")
51         outputFilename := flags.String("o", "-", "output `file`")
52         err = flags.Parse(args)
53         if err == flag.ErrHelp {
54                 err = nil
55                 return 0
56         } else if err != nil {
57                 return 2
58         }
59         cmd.stdin = stdin
60         cmd.inputs = flags.Args()
61
62         if *pprof != "" {
63                 go func() {
64                         log.Println(http.ListenAndServe(*pprof, nil))
65                 }()
66         }
67
68         if !*runlocal {
69                 if *outputFilename != "-" {
70                         err = errors.New("cannot specify output file in container mode: not implemented")
71                         return 1
72                 }
73                 runner := arvadosContainerRunner{
74                         Name:        "lightning merge",
75                         Client:      arvados.NewClientFromEnv(),
76                         ProjectUUID: *projectUUID,
77                         RAM:         700000000000,
78                         VCPUs:       16,
79                         Priority:    *priority,
80                         APIAccess:   true,
81                         KeepCache:   1,
82                 }
83                 for i := range cmd.inputs {
84                         err = runner.TranslatePaths(&cmd.inputs[i])
85                         if err != nil {
86                                 return 1
87                         }
88                 }
89                 runner.Args = append([]string{"merge", "-local=true",
90                         "-o", "/mnt/output/library.gob.gz",
91                 }, cmd.inputs...)
92                 var output string
93                 output, err = runner.Run()
94                 if err != nil {
95                         return 1
96                 }
97                 fmt.Fprintln(stdout, output+"/library.gob.gz")
98                 return 0
99         }
100
101         var outf, outw io.WriteCloser
102         if *outputFilename == "-" {
103                 outw = nopCloser{stdout}
104         } else {
105                 outf, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_WRONLY, 0777)
106                 if err != nil {
107                         return 1
108                 }
109                 defer outf.Close()
110                 if strings.HasSuffix(*outputFilename, ".gz") {
111                         outw = pgzip.NewWriter(outf)
112                 } else {
113                         outw = nopCloser{outf}
114                 }
115         }
116         bufw := bufio.NewWriterSize(outw, 64*1024*1024)
117         cmd.output = bufw
118         err = cmd.doMerge()
119         if err != nil {
120                 return 1
121         }
122         err = bufw.Flush()
123         if err != nil {
124                 return 1
125         }
126         err = outw.Close()
127         if err != nil {
128                 return 1
129         }
130         if outf != nil {
131                 err = outf.Close()
132                 if err != nil {
133                         return 1
134                 }
135         }
136         return 0
137 }
138
139 func (cmd *merger) setError(err error) {
140         select {
141         case cmd.errs <- err:
142         default:
143         }
144 }
145
146 func (cmd *merger) doMerge() error {
147         w := bufio.NewWriter(cmd.output)
148         encoder := gob.NewEncoder(w)
149
150         ctx, cancel := context.WithCancel(context.Background())
151         defer cancel()
152
153         cmd.errs = make(chan error, 1)
154         cmd.tilelib = &tileLibrary{
155                 encoder:       encoder,
156                 retainNoCalls: true,
157         }
158
159         cmd.mapped = map[string]map[tileLibRef]tileVariantID{}
160         for _, input := range cmd.inputs {
161                 cmd.mapped[input] = map[tileLibRef]tileVariantID{}
162         }
163
164         var wg sync.WaitGroup
165         for _, input := range cmd.inputs {
166                 rdr := ioutil.NopCloser(cmd.stdin)
167                 if input != "-" {
168                         var err error
169                         rdr, err = open(input)
170                         if err != nil {
171                                 return err
172                         }
173                         defer rdr.Close()
174                 }
175                 rdr = ioutil.NopCloser(bufio.NewReaderSize(rdr, 8*1024*1024))
176                 wg.Add(1)
177                 go func(input string) {
178                         defer wg.Done()
179                         log.Printf("%s: reading", input)
180                         err := cmd.tilelib.LoadGob(ctx, rdr, strings.HasSuffix(input, ".gz"))
181                         if err != nil {
182                                 cmd.setError(fmt.Errorf("%s: load failed: %w", input, err))
183                                 cancel()
184                                 return
185                         }
186                         log.Printf("%s: done", input)
187                 }(input)
188         }
189         wg.Wait()
190         go close(cmd.errs)
191         if err := <-cmd.errs; err != nil {
192                 return err
193         }
194         log.Print("flushing")
195         err := w.Flush()
196         if err != nil {
197                 return err
198         }
199         return nil
200 }