Use buffered writer to avoid overwhelming arv-mount.
[lightning.git] / exportnumpy.go
1 package main
2
3 import (
4         "bufio"
5         "context"
6         "errors"
7         "flag"
8         "fmt"
9         "io"
10         "io/ioutil"
11         "net/http"
12         _ "net/http/pprof"
13         "os"
14         "sort"
15         "strings"
16
17         "git.arvados.org/arvados.git/sdk/go/arvados"
18         "github.com/kshedden/gonpy"
19         log "github.com/sirupsen/logrus"
20 )
21
22 type exportNumpy struct {
23         filter filter
24 }
25
26 func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
27         var err error
28         defer func() {
29                 if err != nil {
30                         fmt.Fprintf(stderr, "%s\n", err)
31                 }
32         }()
33         flags := flag.NewFlagSet("", flag.ContinueOnError)
34         flags.SetOutput(stderr)
35         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
36         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
37         projectUUID := flags.String("project", "", "project `UUID` for output data")
38         priority := flags.Int("priority", 500, "container request priority")
39         inputFilename := flags.String("i", "-", "input `file`")
40         outputFilename := flags.String("o", "-", "output `file`")
41         annotationsFilename := flags.String("output-annotations", "", "output `file` for tile variant annotations csv")
42         librefsFilename := flags.String("output-onehot2tilevar", "", "when using -one-hot, create csv `file` mapping column# to tag# and variant#")
43         labelsFilename := flags.String("output-labels", "", "output `file` for genome labels csv")
44         onehot := flags.Bool("one-hot", false, "recode tile variants as one-hot")
45         cmd.filter.Flags(flags)
46         err = flags.Parse(args)
47         if err == flag.ErrHelp {
48                 err = nil
49                 return 0
50         } else if err != nil {
51                 return 2
52         }
53
54         if *pprof != "" {
55                 go func() {
56                         log.Println(http.ListenAndServe(*pprof, nil))
57                 }()
58         }
59
60         if !*runlocal {
61                 if *outputFilename != "-" {
62                         err = errors.New("cannot specify output file in container mode: not implemented")
63                         return 1
64                 }
65                 runner := arvadosContainerRunner{
66                         Name:        "lightning export-numpy",
67                         Client:      arvados.NewClientFromEnv(),
68                         ProjectUUID: *projectUUID,
69                         RAM:         128000000000,
70                         VCPUs:       32,
71                         Priority:    *priority,
72                 }
73                 err = runner.TranslatePaths(inputFilename)
74                 if err != nil {
75                         return 1
76                 }
77                 runner.Args = []string{"export-numpy", "-local=true",
78                         fmt.Sprintf("-one-hot=%v", *onehot),
79                         "-i", *inputFilename,
80                         "-o", "/mnt/output/matrix.npy",
81                         "-output-annotations", "/mnt/output/annotations.csv",
82                         "-output-onehot2tilevar", "/mnt/output/onehot2tilevar.csv",
83                         "-output-labels", "/mnt/output/labels.csv",
84                         "-max-variants", fmt.Sprintf("%d", cmd.filter.MaxVariants),
85                         "-min-coverage", fmt.Sprintf("%f", cmd.filter.MinCoverage),
86                         "-max-tag", fmt.Sprintf("%d", cmd.filter.MaxTag),
87                 }
88                 var output string
89                 output, err = runner.Run()
90                 if err != nil {
91                         return 1
92                 }
93                 fmt.Fprintln(stdout, output+"/matrix.npy")
94                 return 0
95         }
96
97         var input io.ReadCloser
98         if *inputFilename == "-" {
99                 input = ioutil.NopCloser(stdin)
100         } else {
101                 input, err = os.Open(*inputFilename)
102                 if err != nil {
103                         return 1
104                 }
105                 defer input.Close()
106         }
107         tilelib := &tileLibrary{
108                 retainNoCalls:       true,
109                 retainTileSequences: true,
110                 compactGenomes:      map[string][]tileVariantID{},
111         }
112         err = tilelib.LoadGob(context.Background(), input, strings.HasSuffix(*inputFilename, ".gz"), nil)
113         if err != nil {
114                 return 1
115         }
116         err = input.Close()
117         if err != nil {
118                 return 1
119         }
120
121         log.Info("filtering")
122         cmd.filter.Apply(tilelib)
123         log.Info("tidying")
124         tilelib.Tidy()
125
126         if *annotationsFilename != "" {
127                 log.Infof("writing annotations")
128                 var annow io.WriteCloser
129                 annow, err = os.OpenFile(*annotationsFilename, os.O_CREATE|os.O_WRONLY, 0666)
130                 if err != nil {
131                         return 1
132                 }
133                 defer annow.Close()
134                 err = (&annotatecmd{maxTileSize: 5000}).exportTileDiffs(annow, tilelib)
135                 if err != nil {
136                         return 1
137                 }
138                 err = annow.Close()
139                 if err != nil {
140                         return 1
141                 }
142         }
143
144         log.Info("building numpy array")
145         out, rows, cols, names := cgs2array(tilelib)
146
147         if *labelsFilename != "" {
148                 log.Infof("writing labels to %s", *labelsFilename)
149                 var f *os.File
150                 f, err = os.OpenFile(*labelsFilename, os.O_CREATE|os.O_WRONLY, 0777)
151                 if err != nil {
152                         return 1
153                 }
154                 defer f.Close()
155                 for i, name := range names {
156                         _, err = fmt.Fprintf(f, "%d,%q\n", i, trimFilenameForLabel(name))
157                         if err != nil {
158                                 err = fmt.Errorf("write %s: %w", *labelsFilename, err)
159                                 return 1
160                         }
161                 }
162                 err = f.Close()
163                 if err != nil {
164                         err = fmt.Errorf("close %s: %w", *labelsFilename, err)
165                         return 1
166                 }
167         }
168
169         log.Info("writing numpy file")
170         var output io.WriteCloser
171         if *outputFilename == "-" {
172                 output = nopCloser{stdout}
173         } else {
174                 output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_WRONLY, 0777)
175                 if err != nil {
176                         return 1
177                 }
178                 defer output.Close()
179         }
180         bufw := bufio.NewWriter(output)
181         npw, err := gonpy.NewWriter(nopCloser{bufw})
182         if err != nil {
183                 return 1
184         }
185         if *onehot {
186                 log.Info("recoding to onehot")
187                 recoded, librefs, recodedcols := recodeOnehot(out, cols)
188                 out, cols = recoded, recodedcols
189                 if *librefsFilename != "" {
190                         log.Infof("writing onehot column mapping")
191                         err = cmd.writeLibRefs(*librefsFilename, tilelib, librefs)
192                         if err != nil {
193                                 return 1
194                         }
195                 }
196         }
197         log.Info("writing numpy")
198         npw.Shape = []int{rows, cols}
199         npw.WriteInt16(out)
200         err = bufw.Flush()
201         if err != nil {
202                 return 1
203         }
204         err = output.Close()
205         if err != nil {
206                 return 1
207         }
208         return 0
209 }
210
211 func (*exportNumpy) writeLibRefs(fnm string, tilelib *tileLibrary, librefs []tileLibRef) error {
212         f, err := os.OpenFile(fnm, os.O_CREATE|os.O_WRONLY, 0666)
213         if err != nil {
214                 return err
215         }
216         defer f.Close()
217         for i, libref := range librefs {
218                 _, err = fmt.Fprintf(f, "%d,%d,%d\n", i, libref.Tag, libref.Variant)
219                 if err != nil {
220                         return err
221                 }
222         }
223         return f.Close()
224 }
225
226 func cgs2array(tilelib *tileLibrary) (data []int16, rows, cols int, cgnames []string) {
227         for name := range tilelib.compactGenomes {
228                 cgnames = append(cgnames, name)
229         }
230         sort.Strings(cgnames)
231
232         rows = len(tilelib.compactGenomes)
233         for _, cg := range tilelib.compactGenomes {
234                 if cols < len(cg) {
235                         cols = len(cg)
236                 }
237         }
238
239         // flag low-quality tile variants so we can change to -1 below
240         lowqual := make([]map[tileVariantID]bool, cols/2)
241         for tag, variants := range tilelib.variant {
242                 lq := lowqual[tag]
243                 for varidx, hash := range variants {
244                         if len(tilelib.seq[hash]) == 0 {
245                                 if lq == nil {
246                                         lq = map[tileVariantID]bool{}
247                                         lowqual[tag] = lq
248                                 }
249                                 lq[tileVariantID(varidx+1)] = true
250                         }
251                 }
252         }
253
254         data = make([]int16, rows*cols)
255         for row, name := range cgnames {
256                 for i, v := range tilelib.compactGenomes[name] {
257                         if v > 0 && lowqual[i/2][v] {
258                                 data[row*cols+i] = -1
259                         } else {
260                                 data[row*cols+i] = int16(v)
261                         }
262                 }
263         }
264
265         return
266 }
267
268 func recodeOnehot(in []int16, incols int) (out []int16, librefs []tileLibRef, outcols int) {
269         rows := len(in) / incols
270         maxvalue := make([]int16, incols)
271         for row := 0; row < rows; row++ {
272                 for col := 0; col < incols; col++ {
273                         if v := in[row*incols+col]; maxvalue[col] < v {
274                                 maxvalue[col] = v
275                         }
276                 }
277         }
278         outcol := make([]int, incols)
279         dropped := 0
280         for incol, maxv := range maxvalue {
281                 outcol[incol] = outcols
282                 if maxv == 0 {
283                         dropped++
284                 }
285                 for v := 1; v <= int(maxv); v++ {
286                         librefs = append(librefs, tileLibRef{Tag: tagID(incol), Variant: tileVariantID(v)})
287                         outcols++
288                 }
289         }
290         log.Printf("recodeOnehot: dropped %d input cols with zero maxvalue", dropped)
291
292         out = make([]int16, rows*outcols)
293         for inidx, row := 0, 0; row < rows; row++ {
294                 outrow := out[row*outcols:]
295                 for col := 0; col < incols; col++ {
296                         if v := in[inidx]; v > 0 {
297                                 outrow[outcol[col]+int(v)-1] = 1
298                         }
299                         inidx++
300                 }
301         }
302         return
303 }
304
305 type nopCloser struct {
306         io.Writer
307 }
308
309 func (nopCloser) Close() error { return nil }
310
311 func trimFilenameForLabel(s string) string {
312         if i := strings.LastIndex(s, "/"); i >= 0 {
313                 s = s[i+1:]
314         }
315         s = strings.TrimSuffix(s, ".gz")
316         s = strings.TrimSuffix(s, ".fa")
317         s = strings.TrimSuffix(s, ".fasta")
318         s = strings.TrimSuffix(s, ".1")
319         s = strings.TrimSuffix(s, ".2")
320         s = strings.TrimSuffix(s, ".gz")
321         s = strings.TrimSuffix(s, ".vcf")
322         return s
323 }