Add concurrent load to exportnumpy.
[lightning.git] / exportnumpy.go
1 package lightning
2
3 import (
4         "bufio"
5         "bytes"
6         "context"
7         "errors"
8         "flag"
9         "fmt"
10         "io"
11         "io/ioutil"
12         "net/http"
13         _ "net/http/pprof"
14         "os"
15         "sort"
16         "strconv"
17         "strings"
18         "sync"
19         "sync/atomic"
20
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/arvados/lightning/hgvs"
23         "github.com/kshedden/gonpy"
24         "github.com/sirupsen/logrus"
25         log "github.com/sirupsen/logrus"
26 )
27
28 type exportNumpy struct {
29         filter filter
30 }
31
32 func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
33         var err error
34         defer func() {
35                 if err != nil {
36                         fmt.Fprintf(stderr, "%s\n", err)
37                 }
38         }()
39         flags := flag.NewFlagSet("", flag.ContinueOnError)
40         flags.SetOutput(stderr)
41         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
42         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
43         projectUUID := flags.String("project", "", "project `UUID` for output data")
44         priority := flags.Int("priority", 500, "container request priority")
45         inputDir := flags.String("input-dir", "./in", "input `directory`")
46         outputDir := flags.String("output-dir", "./out", "output `directory`")
47         annotationsFilename := flags.String("output-annotations", "", "output `file` for tile variant annotations csv")
48         librefsFilename := flags.String("output-onehot2tilevar", "", "when using -one-hot, create csv `file` mapping column# to tag# and variant#")
49         labelsFilename := flags.String("output-labels", "", "output `file` for genome labels csv")
50         regionsFilename := flags.String("regions", "", "only output columns/annotations that intersect regions in specified bed `file`")
51         expandRegions := flags.Int("expand-regions", 0, "expand specified regions by `N` base pairs on each side`")
52         onehot := flags.Bool("one-hot", false, "recode tile variants as one-hot")
53         chunks := flags.Int("chunks", 1, "split output into `N` numpy files")
54         cmd.filter.Flags(flags)
55         err = flags.Parse(args)
56         if err == flag.ErrHelp {
57                 err = nil
58                 return 0
59         } else if err != nil {
60                 return 2
61         }
62
63         if *pprof != "" {
64                 go func() {
65                         log.Println(http.ListenAndServe(*pprof, nil))
66                 }()
67         }
68
69         if !*runlocal {
70                 runner := arvadosContainerRunner{
71                         Name:        "lightning export-numpy",
72                         Client:      arvados.NewClientFromEnv(),
73                         ProjectUUID: *projectUUID,
74                         RAM:         500000000000,
75                         VCPUs:       96,
76                         Priority:    *priority,
77                         KeepCache:   1,
78                         APIAccess:   true,
79                 }
80                 err = runner.TranslatePaths(inputDir, regionsFilename)
81                 if err != nil {
82                         return 1
83                 }
84                 runner.Args = []string{"export-numpy", "-local=true",
85                         "-pprof", ":6060",
86                         fmt.Sprintf("-one-hot=%v", *onehot),
87                         "-input-dir", *inputDir,
88                         "-output-dir", "/mnt/output",
89                         "-output-annotations", "/mnt/output/annotations.csv",
90                         "-output-onehot2tilevar", "/mnt/output/onehot2tilevar.csv",
91                         "-output-labels", "/mnt/output/labels.csv",
92                         "-regions", *regionsFilename,
93                         "-expand-regions", fmt.Sprintf("%d", *expandRegions),
94                         "-max-variants", fmt.Sprintf("%d", cmd.filter.MaxVariants),
95                         "-min-coverage", fmt.Sprintf("%f", cmd.filter.MinCoverage),
96                         "-max-tag", fmt.Sprintf("%d", cmd.filter.MaxTag),
97                         "-chunks", fmt.Sprintf("%d", *chunks),
98                 }
99                 var output string
100                 output, err = runner.Run()
101                 if err != nil {
102                         return 1
103                 }
104                 fmt.Fprintln(stdout, output+"/matrix.npy")
105                 return 0
106         }
107
108         tilelib := &tileLibrary{
109                 retainNoCalls:       true,
110                 retainTileSequences: true,
111                 compactGenomes:      map[string][]tileVariantID{},
112         }
113         err = tilelib.LoadDir(context.Background(), *inputDir, nil)
114         if err != nil {
115                 return 1
116         }
117
118         log.Info("filtering")
119         cmd.filter.Apply(tilelib)
120         log.Info("tidying")
121         tilelib.Tidy()
122
123         log.Info("building lowqual map")
124         lowqual := lowqual(tilelib)
125         names := cgnames(tilelib)
126
127         if *labelsFilename != "" {
128                 log.Infof("writing labels to %s", *labelsFilename)
129                 var f *os.File
130                 f, err = os.OpenFile(*labelsFilename, os.O_CREATE|os.O_WRONLY, 0777)
131                 if err != nil {
132                         return 1
133                 }
134                 defer f.Close()
135                 outBasename := "matrix.npy"
136                 for i, name := range names {
137                         _, err = fmt.Fprintf(f, "%d,%q,%q\n", i, trimFilenameForLabel(name), outBasename)
138                         if err != nil {
139                                 err = fmt.Errorf("write %s: %w", *labelsFilename, err)
140                                 return 1
141                         }
142                 }
143                 err = f.Close()
144                 if err != nil {
145                         err = fmt.Errorf("close %s: %w", *labelsFilename, err)
146                         return 1
147                 }
148         }
149
150         log.Info("determining which tiles intersect given regions")
151         dropTiles, err := chooseTiles(tilelib, *regionsFilename, *expandRegions)
152         if err != nil {
153                 return 1
154         }
155
156         annotation2tvs := map[string]map[hgvs.Variant][]tileLibRef{}
157         if *annotationsFilename != "" {
158                 log.Info("writing annotations")
159                 var annow io.WriteCloser
160                 annow, err = os.OpenFile(*annotationsFilename, os.O_CREATE|os.O_WRONLY, 0666)
161                 if err != nil {
162                         return 1
163                 }
164                 defer annow.Close()
165                 var mtx sync.Mutex
166                 err = (&annotatecmd{
167                         maxTileSize: 5000,
168                         dropTiles:   dropTiles,
169                         reportAnnotation: func(tag tagID, _ int, variant tileVariantID, refname string, seqname string, pdi hgvs.Variant) {
170                                 mtx.Lock()
171                                 defer mtx.Unlock()
172                                 if annotation2tvs[seqname] == nil {
173                                         annotation2tvs[seqname] = map[hgvs.Variant][]tileLibRef{}
174                                 }
175                                 annotation2tvs[seqname][pdi] = append(annotation2tvs[seqname][pdi], tileLibRef{Tag: tag, Variant: variant})
176                         },
177                 }).exportTileDiffs(annow, tilelib)
178                 if err != nil {
179                         return 1
180                 }
181                 err = annow.Close()
182                 if err != nil {
183                         return 1
184                 }
185         }
186
187         var lastErr atomic.Value
188         var wg sync.WaitGroup
189         for seqname, pdivars := range annotation2tvs {
190                 seqname, pdivars := seqname, pdivars
191                 wg.Add(1)
192                 go func() {
193                         defer wg.Done()
194                         log.Infof("choosing hgvs columns for seq %s", seqname)
195                         var pdis []hgvs.Variant
196                         for pdi, librefs := range pdivars {
197                                 // Include this HGVS column if it was
198                                 // seen in a variant of any
199                                 // non-dropped tile.
200                                 for _, libref := range librefs {
201                                         if int(libref.Tag) >= len(dropTiles) || !dropTiles[libref.Tag] {
202                                                 pdis = append(pdis, pdi)
203                                                 break
204                                         }
205                                 }
206                         }
207                         sort.Slice(pdis, func(i, j int) bool {
208                                 if cmp := pdis[i].Position - pdis[j].Position; cmp != 0 {
209                                         return cmp < 0
210                                 } else if pdis[i].Ref != pdis[j].Ref {
211                                         return pdis[i].Ref < pdis[j].Ref
212                                 } else {
213                                         return pdis[i].New < pdis[j].New
214                                 }
215                         })
216                         log.Infof("writing column labels for seq %s", seqname)
217                         var buf bytes.Buffer
218                         for _, pdi := range pdis {
219                                 fmt.Fprintf(&buf, "%s:g.%s\n", seqname, pdi.String())
220                         }
221                         err := ioutil.WriteFile(*outputDir+"/"+seqname+".columns.csv", buf.Bytes(), 0777)
222                         if err != nil {
223                                 lastErr.Store(err)
224                                 return
225                         }
226
227                         log.Infof("building hgvs matrix for seq %s", seqname)
228                         data := make([]int8, len(names)*len(pdis)*2)
229                         for row, name := range names {
230                                 cg := tilelib.compactGenomes[name]
231                                 rowstart := row * len(pdis) * 2
232                                 for col, pdi := range pdis {
233                                         for _, libref := range pdivars[pdi] {
234                                                 if len(cg) <= int(libref.Tag)*2+1 {
235                                                         continue
236                                                 }
237                                                 for phase := 0; phase < 2; phase++ {
238                                                         if cg[int(libref.Tag)*2+phase] == libref.Variant {
239                                                                 data[rowstart+col*2+phase] = 1
240                                                         }
241                                                 }
242                                         }
243                                 }
244                         }
245                         log.Infof("writing hgvs numpy for seq %s", seqname)
246                         f, err := os.OpenFile(*outputDir+"/"+seqname+".npy", os.O_CREATE|os.O_WRONLY, 0777)
247                         if err != nil {
248                                 lastErr.Store(err)
249                                 return
250                         }
251                         defer f.Close()
252                         npw, err := gonpy.NewWriter(f)
253                         if err != nil {
254                                 lastErr.Store(err)
255                                 return
256                         }
257                         npw.Shape = []int{len(names), len(pdis) * 2}
258                         npw.WriteInt8(data)
259                         // gonpy closes f and ignores errors, doh.
260                         // err = f.Close()
261                         // if err != nil {
262                         //      lastErr.Store(err)
263                         //      return
264                         // }
265                 }()
266         }
267         wg.Wait()
268         if e, ok := lastErr.Load().(error); ok {
269                 err = e
270                 return 1
271         }
272
273         chunksize := (len(tilelib.variant) + *chunks - 1) / *chunks
274         for chunk := 0; chunk < *chunks; chunk++ {
275                 log.Infof("preparing chunk %d of %d", chunk+1, *chunks)
276                 tagstart := chunk * chunksize
277                 tagend := tagstart + chunksize
278                 if tagend > len(tilelib.variant) {
279                         tagend = len(tilelib.variant)
280                 }
281                 out, rows, cols := cgs2array(tilelib, names, lowqual, dropTiles, tagstart, tagend)
282
283                 var npw *gonpy.NpyWriter
284                 var output io.WriteCloser
285                 fnm := *outputDir + "/matrix.npy"
286                 if *chunks > 1 {
287                         fnm = fmt.Sprintf("%s/matrix.%d.npy", *outputDir, chunk)
288                 }
289                 output, err = os.OpenFile(fnm, os.O_CREATE|os.O_WRONLY, 0777)
290                 if err != nil {
291                         return 1
292                 }
293                 defer output.Close()
294                 bufw := bufio.NewWriter(output)
295                 npw, err = gonpy.NewWriter(nopCloser{bufw})
296                 if err != nil {
297                         return 1
298                 }
299                 if *onehot {
300                         log.Info("recoding to onehot")
301                         recoded, librefs, recodedcols := recodeOnehot(out, cols)
302                         out, cols = recoded, recodedcols
303                         if *librefsFilename != "" {
304                                 log.Infof("writing onehot column mapping")
305                                 err = cmd.writeLibRefs(*librefsFilename, tilelib, librefs)
306                                 if err != nil {
307                                         return 1
308                                 }
309                         }
310                 }
311                 log.WithFields(logrus.Fields{
312                         "filename": fnm,
313                         "rows":     rows,
314                         "cols":     cols,
315                 }).Info("writing numpy")
316                 npw.Shape = []int{rows, cols}
317                 npw.WriteInt16(out)
318                 err = bufw.Flush()
319                 if err != nil {
320                         return 1
321                 }
322                 err = output.Close()
323                 if err != nil {
324                         return 1
325                 }
326         }
327         return 0
328 }
329
330 func (*exportNumpy) writeLibRefs(fnm string, tilelib *tileLibrary, librefs []tileLibRef) error {
331         f, err := os.OpenFile(fnm, os.O_CREATE|os.O_WRONLY, 0666)
332         if err != nil {
333                 return err
334         }
335         defer f.Close()
336         for i, libref := range librefs {
337                 _, err = fmt.Fprintf(f, "%d,%d,%d\n", i, libref.Tag, libref.Variant)
338                 if err != nil {
339                         return err
340                 }
341         }
342         return f.Close()
343 }
344
345 func cgnames(tilelib *tileLibrary) (cgnames []string) {
346         for name := range tilelib.compactGenomes {
347                 cgnames = append(cgnames, name)
348         }
349         sort.Slice(cgnames, func(i, j int) bool {
350                 return trimFilenameForLabel(cgnames[i]) < trimFilenameForLabel(cgnames[j])
351         })
352         return
353 }
354
355 func lowqual(tilelib *tileLibrary) (lowqual []map[tileVariantID]bool) {
356         lowqual = make([]map[tileVariantID]bool, len(tilelib.variant))
357         for tag, variants := range tilelib.variant {
358                 lq := lowqual[tag]
359                 for varidx, hash := range variants {
360                         if len(tilelib.seq[hash]) == 0 {
361                                 if lq == nil {
362                                         lq = map[tileVariantID]bool{}
363                                         lowqual[tag] = lq
364                                 }
365                                 lq[tileVariantID(varidx+1)] = true
366                         }
367                 }
368         }
369         return
370 }
371
372 func cgs2array(tilelib *tileLibrary, names []string, lowqual []map[tileVariantID]bool, dropTiles []bool, tagstart, tagend int) (data []int16, rows, cols int) {
373         rows = len(tilelib.compactGenomes)
374         for tag := tagstart; tag < tagend; tag++ {
375                 if len(dropTiles) <= tag || !dropTiles[tag] {
376                         cols += 2
377                 }
378         }
379         data = make([]int16, rows*cols)
380         for row, name := range names {
381                 cg := tilelib.compactGenomes[name]
382                 outidx := 0
383                 for tag := tagstart; tag < tagend && tag*2+1 < len(cg); tag++ {
384                         if len(dropTiles) > tag && dropTiles[tag] {
385                                 continue
386                         }
387                         for phase := 0; phase < 2; phase++ {
388                                 v := cg[tag*2+phase]
389                                 if v > 0 && lowqual[tag][v] {
390                                         data[row*cols+outidx] = -1
391                                 } else {
392                                         data[row*cols+outidx] = int16(v)
393                                 }
394                                 outidx++
395                         }
396                 }
397         }
398         return
399 }
400
401 func chooseTiles(tilelib *tileLibrary, regionsFilename string, expandRegions int) (drop []bool, err error) {
402         if regionsFilename == "" {
403                 return
404         }
405         rfile, err := zopen(regionsFilename)
406         if err != nil {
407                 return
408         }
409         defer rfile.Close()
410         regions, err := ioutil.ReadAll(rfile)
411         if err != nil {
412                 return
413         }
414
415         log.Print("chooseTiles: building mask")
416         mask := &mask{}
417         for _, line := range bytes.Split(regions, []byte{'\n'}) {
418                 if bytes.HasPrefix(line, []byte{'#'}) {
419                         continue
420                 }
421                 fields := bytes.Split(line, []byte{'\t'})
422                 if len(fields) < 3 {
423                         continue
424                 }
425                 refseqname := string(fields[0])
426                 if strings.HasPrefix(refseqname, "chr") {
427                         refseqname = refseqname[3:]
428                 }
429                 start, err1 := strconv.Atoi(string(fields[1]))
430                 end, err2 := strconv.Atoi(string(fields[2]))
431                 if err1 == nil && err2 == nil {
432                         // BED
433                 } else {
434                         start, err1 = strconv.Atoi(string(fields[3]))
435                         end, err2 = strconv.Atoi(string(fields[4]))
436                         if err1 == nil && err2 == nil {
437                                 // GFF/GTF
438                                 end++
439                         } else {
440                                 err = fmt.Errorf("cannot parse input line as BED or GFF/GTF: %q", line)
441                                 return
442                         }
443                 }
444                 mask.Add(refseqname, start-expandRegions, end+expandRegions)
445         }
446         log.Print("chooseTiles: mask.Freeze")
447         mask.Freeze()
448
449         tagset := tilelib.taglib.Tags()
450         if len(tagset) == 0 {
451                 err = errors.New("cannot choose tiles by region in a library without tags")
452                 return
453         }
454         taglen := len(tagset[0])
455
456         log.Print("chooseTiles: check ref tiles")
457         // Find position+size of each reference tile, and if it
458         // intersects any of the desired regions, set drop[tag]=false.
459         //
460         // (Note it doesn't quite work to do the more obvious thing --
461         // start with drop=false and change to true when ref tiles
462         // intersect target regions -- because that would give us
463         // drop=false for tiles that don't appear at all in the
464         // reference.)
465         //
466         // TODO: (optionally?) don't drop tags for which some tile
467         // variants are spanning tiles, i.e., where the reference tile
468         // does not intersect the desired regions, but a spanning tile
469         // from a genome does.
470         drop = make([]bool, len(tilelib.variant))
471         for i := range drop {
472                 drop[i] = true
473         }
474         for refname, refseqs := range tilelib.refseqs {
475                 for refseqname, reftiles := range refseqs {
476                         if strings.HasPrefix(refseqname, "chr") {
477                                 refseqname = refseqname[3:]
478                         }
479                         tileend := 0
480                         for _, libref := range reftiles {
481                                 if libref.Variant < 1 {
482                                         err = fmt.Errorf("reference %q seq %q uses variant zero at tag %d", refname, refseqname, libref.Tag)
483                                         return
484                                 }
485                                 seq := tilelib.TileVariantSequence(libref)
486                                 if len(seq) < taglen {
487                                         err = fmt.Errorf("reference %q seq %q uses tile %d variant %d with sequence len %d < taglen %d", refname, refseqname, libref.Tag, libref.Variant, len(seq), taglen)
488                                         return
489                                 }
490                                 tilestart := tileend
491                                 tileend = tilestart + len(seq) - taglen
492                                 if mask.Check(refseqname, tilestart, tileend) {
493                                         drop[libref.Tag] = false
494                                 }
495                         }
496                 }
497         }
498
499         log.Print("chooseTiles: done")
500         return
501 }
502
503 func recodeOnehot(in []int16, incols int) (out []int16, librefs []tileLibRef, outcols int) {
504         rows := len(in) / incols
505         maxvalue := make([]int16, incols)
506         for row := 0; row < rows; row++ {
507                 for col := 0; col < incols; col++ {
508                         if v := in[row*incols+col]; maxvalue[col] < v {
509                                 maxvalue[col] = v
510                         }
511                 }
512         }
513         outcol := make([]int, incols)
514         dropped := 0
515         for incol, maxv := range maxvalue {
516                 outcol[incol] = outcols
517                 if maxv == 0 {
518                         dropped++
519                 }
520                 for v := 1; v <= int(maxv); v++ {
521                         librefs = append(librefs, tileLibRef{Tag: tagID(incol), Variant: tileVariantID(v)})
522                         outcols++
523                 }
524         }
525         log.Printf("recodeOnehot: dropped %d input cols with zero maxvalue", dropped)
526
527         out = make([]int16, rows*outcols)
528         for inidx, row := 0, 0; row < rows; row++ {
529                 outrow := out[row*outcols:]
530                 for col := 0; col < incols; col++ {
531                         if v := in[inidx]; v > 0 {
532                                 outrow[outcol[col]+int(v)-1] = 1
533                         }
534                         inidx++
535                 }
536         }
537         return
538 }
539
540 type nopCloser struct {
541         io.Writer
542 }
543
544 func (nopCloser) Close() error { return nil }
545
546 func trimFilenameForLabel(s string) string {
547         if i := strings.LastIndex(s, "/"); i >= 0 {
548                 s = s[i+1:]
549         }
550         s = strings.TrimSuffix(s, ".gz")
551         s = strings.TrimSuffix(s, ".fa")
552         s = strings.TrimSuffix(s, ".fasta")
553         s = strings.TrimSuffix(s, ".1")
554         s = strings.TrimSuffix(s, ".2")
555         s = strings.TrimSuffix(s, ".gz")
556         s = strings.TrimSuffix(s, ".vcf")
557         return s
558 }