Fix some tests.
[lightning.git] / exportnumpy.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bufio"
9         "bytes"
10         "context"
11         "errors"
12         "flag"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "net/http"
17         _ "net/http/pprof"
18         "os"
19         "sort"
20         "strconv"
21         "strings"
22         "sync"
23         "sync/atomic"
24
25         "git.arvados.org/arvados.git/sdk/go/arvados"
26         "github.com/arvados/lightning/hgvs"
27         "github.com/kshedden/gonpy"
28         "github.com/sirupsen/logrus"
29         log "github.com/sirupsen/logrus"
30 )
31
32 type exportNumpy struct {
33         filter filter
34 }
35
36 func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
37         var err error
38         defer func() {
39                 if err != nil {
40                         fmt.Fprintf(stderr, "%s\n", err)
41                 }
42         }()
43         flags := flag.NewFlagSet("", flag.ContinueOnError)
44         flags.SetOutput(stderr)
45         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
46         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
47         projectUUID := flags.String("project", "", "project `UUID` for output data")
48         priority := flags.Int("priority", 500, "container request priority")
49         inputDir := flags.String("input-dir", "./in", "input `directory`")
50         outputDir := flags.String("output-dir", "./out", "output `directory`")
51         annotationsFilename := flags.String("output-annotations", "", "output `file` for tile variant annotations csv")
52         librefsFilename := flags.String("output-onehot2tilevar", "", "when using -one-hot, create csv `file` mapping column# to tag# and variant#")
53         labelsFilename := flags.String("output-labels", "", "output `file` for genome labels csv")
54         regionsFilename := flags.String("regions", "", "only output columns/annotations that intersect regions in specified bed `file`")
55         expandRegions := flags.Int("expand-regions", 0, "expand specified regions by `N` base pairs on each side`")
56         onehot := flags.Bool("one-hot", false, "recode tile variants as one-hot")
57         chunks := flags.Int("chunks", 1, "split output into `N` numpy files")
58         cmd.filter.Flags(flags)
59         err = flags.Parse(args)
60         if err == flag.ErrHelp {
61                 err = nil
62                 return 0
63         } else if err != nil {
64                 return 2
65         } else if flags.NArg() > 0 {
66                 err = fmt.Errorf("errant command line arguments after parsed flags: %v", flags.Args())
67                 return 2
68         }
69
70         if *pprof != "" {
71                 go func() {
72                         log.Println(http.ListenAndServe(*pprof, nil))
73                 }()
74         }
75
76         if !*runlocal {
77                 runner := arvadosContainerRunner{
78                         Name:        "lightning export-numpy",
79                         Client:      arvados.NewClientFromEnv(),
80                         ProjectUUID: *projectUUID,
81                         RAM:         500000000000,
82                         VCPUs:       96,
83                         Priority:    *priority,
84                         KeepCache:   1,
85                         APIAccess:   true,
86                 }
87                 err = runner.TranslatePaths(inputDir, regionsFilename)
88                 if err != nil {
89                         return 1
90                 }
91                 runner.Args = []string{"export-numpy", "-local=true",
92                         "-pprof", ":6060",
93                         fmt.Sprintf("-one-hot=%v", *onehot),
94                         "-input-dir", *inputDir,
95                         "-output-dir", "/mnt/output",
96                         "-output-annotations", "/mnt/output/annotations.csv",
97                         "-output-onehot2tilevar", "/mnt/output/onehot2tilevar.csv",
98                         "-output-labels", "/mnt/output/labels.csv",
99                         "-regions", *regionsFilename,
100                         "-expand-regions", fmt.Sprintf("%d", *expandRegions),
101                         "-chunks", fmt.Sprintf("%d", *chunks),
102                 }
103                 runner.Args = append(runner.Args, cmd.filter.Args()...)
104                 var output string
105                 output, err = runner.Run()
106                 if err != nil {
107                         return 1
108                 }
109                 fmt.Fprintln(stdout, output+"/matrix.npy")
110                 return 0
111         }
112
113         tilelib := &tileLibrary{
114                 retainNoCalls:       true,
115                 retainTileSequences: true,
116                 compactGenomes:      map[string][]tileVariantID{},
117         }
118         err = tilelib.LoadDir(context.Background(), *inputDir)
119         if err != nil {
120                 return 1
121         }
122
123         log.Info("filtering")
124         cmd.filter.Apply(tilelib)
125         log.Info("tidying")
126         tilelib.Tidy()
127
128         log.Info("building lowqual map")
129         lowqual := lowqual(tilelib)
130         names := cgnames(tilelib)
131
132         if *labelsFilename != "" {
133                 log.Infof("writing labels to %s", *labelsFilename)
134                 var f *os.File
135                 f, err = os.OpenFile(*labelsFilename, os.O_CREATE|os.O_WRONLY, 0777)
136                 if err != nil {
137                         return 1
138                 }
139                 defer f.Close()
140                 outBasename := "matrix.npy"
141                 for i, name := range names {
142                         _, err = fmt.Fprintf(f, "%d,%q,%q\n", i, trimFilenameForLabel(name), outBasename)
143                         if err != nil {
144                                 err = fmt.Errorf("write %s: %w", *labelsFilename, err)
145                                 return 1
146                         }
147                 }
148                 err = f.Close()
149                 if err != nil {
150                         err = fmt.Errorf("close %s: %w", *labelsFilename, err)
151                         return 1
152                 }
153         }
154
155         log.Info("determining which tiles intersect given regions")
156         dropTiles, err := chooseTiles(tilelib, *regionsFilename, *expandRegions)
157         if err != nil {
158                 return 1
159         }
160
161         annotation2tvs := map[string]map[hgvs.Variant][]tileLibRef{}
162         if *annotationsFilename != "" {
163                 log.Info("writing annotations")
164                 var annow io.WriteCloser
165                 annow, err = os.OpenFile(*annotationsFilename, os.O_CREATE|os.O_WRONLY, 0666)
166                 if err != nil {
167                         return 1
168                 }
169                 defer annow.Close()
170                 var mtx sync.Mutex
171                 err = (&annotatecmd{
172                         maxTileSize: 5000,
173                         dropTiles:   dropTiles,
174                         reportAnnotation: func(tag tagID, _ int, variant tileVariantID, refname string, seqname string, pdi hgvs.Variant) {
175                                 mtx.Lock()
176                                 defer mtx.Unlock()
177                                 if annotation2tvs[seqname] == nil {
178                                         annotation2tvs[seqname] = map[hgvs.Variant][]tileLibRef{}
179                                 }
180                                 annotation2tvs[seqname][pdi] = append(annotation2tvs[seqname][pdi], tileLibRef{Tag: tag, Variant: variant})
181                         },
182                 }).exportTileDiffs(annow, tilelib)
183                 if err != nil {
184                         return 1
185                 }
186                 err = annow.Close()
187                 if err != nil {
188                         return 1
189                 }
190         }
191
192         var lastErr atomic.Value
193         var wg sync.WaitGroup
194         for seqname, pdivars := range annotation2tvs {
195                 seqname, pdivars := seqname, pdivars
196                 wg.Add(1)
197                 go func() {
198                         defer wg.Done()
199                         log.Infof("choosing hgvs columns for seq %s", seqname)
200                         var pdis []hgvs.Variant
201                         for pdi, librefs := range pdivars {
202                                 // Include this HGVS column if it was
203                                 // seen in a variant of any
204                                 // non-dropped tile.
205                                 for _, libref := range librefs {
206                                         if int(libref.Tag) >= len(dropTiles) || !dropTiles[libref.Tag] {
207                                                 pdis = append(pdis, pdi)
208                                                 break
209                                         }
210                                 }
211                         }
212                         sort.Slice(pdis, func(i, j int) bool {
213                                 if cmp := pdis[i].Position - pdis[j].Position; cmp != 0 {
214                                         return cmp < 0
215                                 } else if pdis[i].Ref != pdis[j].Ref {
216                                         return pdis[i].Ref < pdis[j].Ref
217                                 } else {
218                                         return pdis[i].New < pdis[j].New
219                                 }
220                         })
221                         log.Infof("writing column labels for seq %s", seqname)
222                         var buf bytes.Buffer
223                         for _, pdi := range pdis {
224                                 fmt.Fprintf(&buf, "%s:g.%s\n", seqname, pdi.String())
225                         }
226                         err := ioutil.WriteFile(*outputDir+"/"+seqname+".columns.csv", buf.Bytes(), 0777)
227                         if err != nil {
228                                 lastErr.Store(err)
229                                 return
230                         }
231
232                         log.Infof("building hgvs matrix for seq %s", seqname)
233                         data := make([]int8, len(names)*len(pdis)*2)
234                         for row, name := range names {
235                                 cg := tilelib.compactGenomes[name]
236                                 rowstart := row * len(pdis) * 2
237                                 for col, pdi := range pdis {
238                                         for _, libref := range pdivars[pdi] {
239                                                 if len(cg) <= int(libref.Tag)*2+1 {
240                                                         continue
241                                                 }
242                                                 for phase := 0; phase < 2; phase++ {
243                                                         if cg[int(libref.Tag)*2+phase] == libref.Variant {
244                                                                 data[rowstart+col*2+phase] = 1
245                                                         }
246                                                 }
247                                         }
248                                 }
249                         }
250                         log.Infof("writing hgvs numpy for seq %s", seqname)
251                         f, err := os.OpenFile(*outputDir+"/"+seqname+".npy", os.O_CREATE|os.O_WRONLY, 0777)
252                         if err != nil {
253                                 lastErr.Store(err)
254                                 return
255                         }
256                         defer f.Close()
257                         // gonpy closes our writer and ignores errors. Give it a nopCloser so we can close f properly.
258                         npw, err := gonpy.NewWriter(nopCloser{f})
259                         if err != nil {
260                                 lastErr.Store(err)
261                                 return
262                         }
263                         npw.Shape = []int{len(names), len(pdis) * 2}
264                         err = npw.WriteInt8(data)
265                         if err != nil {
266                                 lastErr.Store(err)
267                                 return
268                         }
269                         err = f.Close()
270                         if err != nil {
271                                 lastErr.Store(err)
272                                 return
273                         }
274                 }()
275         }
276         wg.Wait()
277         if e, ok := lastErr.Load().(error); ok {
278                 err = e
279                 return 1
280         }
281
282         chunksize := (len(tilelib.variant) + *chunks - 1) / *chunks
283         for chunk := 0; chunk < *chunks; chunk++ {
284                 log.Infof("preparing chunk %d of %d", chunk+1, *chunks)
285                 tagstart := chunk * chunksize
286                 tagend := tagstart + chunksize
287                 if tagend > len(tilelib.variant) {
288                         tagend = len(tilelib.variant)
289                 }
290                 out, rows, cols := cgs2array(tilelib, names, lowqual, dropTiles, tagstart, tagend)
291
292                 var npw *gonpy.NpyWriter
293                 var output io.WriteCloser
294                 fnm := *outputDir + "/matrix.npy"
295                 if *chunks > 1 {
296                         fnm = fmt.Sprintf("%s/matrix.%d.npy", *outputDir, chunk)
297                 }
298                 output, err = os.OpenFile(fnm, os.O_CREATE|os.O_WRONLY, 0777)
299                 if err != nil {
300                         return 1
301                 }
302                 defer output.Close()
303                 bufw := bufio.NewWriter(output)
304                 npw, err = gonpy.NewWriter(nopCloser{bufw})
305                 if err != nil {
306                         return 1
307                 }
308                 if *onehot {
309                         log.Info("recoding to onehot")
310                         recoded, librefs, recodedcols := recodeOnehot(out, cols)
311                         out, cols = recoded, recodedcols
312                         if *librefsFilename != "" {
313                                 log.Infof("writing onehot column mapping")
314                                 err = cmd.writeLibRefs(*librefsFilename, tilelib, librefs)
315                                 if err != nil {
316                                         return 1
317                                 }
318                         }
319                 }
320                 log.WithFields(logrus.Fields{
321                         "filename": fnm,
322                         "rows":     rows,
323                         "cols":     cols,
324                 }).Info("writing numpy")
325                 npw.Shape = []int{rows, cols}
326                 npw.WriteInt16(out)
327                 err = bufw.Flush()
328                 if err != nil {
329                         return 1
330                 }
331                 err = output.Close()
332                 if err != nil {
333                         return 1
334                 }
335         }
336         return 0
337 }
338
339 func (*exportNumpy) writeLibRefs(fnm string, tilelib *tileLibrary, librefs []tileLibRef) error {
340         f, err := os.OpenFile(fnm, os.O_CREATE|os.O_WRONLY, 0666)
341         if err != nil {
342                 return err
343         }
344         defer f.Close()
345         for i, libref := range librefs {
346                 _, err = fmt.Fprintf(f, "%d,%d,%d\n", i, libref.Tag, libref.Variant)
347                 if err != nil {
348                         return err
349                 }
350         }
351         return f.Close()
352 }
353
354 func cgnames(tilelib *tileLibrary) (cgnames []string) {
355         for name := range tilelib.compactGenomes {
356                 cgnames = append(cgnames, name)
357         }
358         sort.Slice(cgnames, func(i, j int) bool {
359                 return trimFilenameForLabel(cgnames[i]) < trimFilenameForLabel(cgnames[j])
360         })
361         return
362 }
363
364 func lowqual(tilelib *tileLibrary) (lowqual []map[tileVariantID]bool) {
365         lowqual = make([]map[tileVariantID]bool, len(tilelib.variant))
366         for tag, variants := range tilelib.variant {
367                 lq := lowqual[tag]
368                 for varidx, hash := range variants {
369                         if len(tilelib.hashSequence(hash)) == 0 {
370                                 if lq == nil {
371                                         lq = map[tileVariantID]bool{}
372                                         lowqual[tag] = lq
373                                 }
374                                 lq[tileVariantID(varidx+1)] = true
375                         }
376                 }
377         }
378         return
379 }
380
381 func cgs2array(tilelib *tileLibrary, names []string, lowqual []map[tileVariantID]bool, dropTiles []bool, tagstart, tagend int) (data []int16, rows, cols int) {
382         rows = len(tilelib.compactGenomes)
383         for tag := tagstart; tag < tagend; tag++ {
384                 if len(dropTiles) <= tag || !dropTiles[tag] {
385                         cols += 2
386                 }
387         }
388         data = make([]int16, rows*cols)
389         for row, name := range names {
390                 cg := tilelib.compactGenomes[name]
391                 outidx := 0
392                 for tag := tagstart; tag < tagend && tag*2+1 < len(cg); tag++ {
393                         if len(dropTiles) > tag && dropTiles[tag] {
394                                 continue
395                         }
396                         for phase := 0; phase < 2; phase++ {
397                                 v := cg[tag*2+phase]
398                                 if v > 0 && lowqual[tag][v] {
399                                         data[row*cols+outidx] = -1
400                                 } else {
401                                         data[row*cols+outidx] = int16(v)
402                                 }
403                                 outidx++
404                         }
405                 }
406         }
407         return
408 }
409
410 func makeMask(regionsFilename string, expandRegions int) (*mask, error) {
411         log.Printf("makeMask: reading %s", regionsFilename)
412         rfile, err := zopen(regionsFilename)
413         if err != nil {
414                 return nil, err
415         }
416         defer rfile.Close()
417         regions, err := io.ReadAll(rfile)
418         if err != nil {
419                 return nil, err
420         }
421
422         log.Print("makeMask: building mask")
423         var mask mask
424         for _, line := range bytes.Split(regions, []byte{'\n'}) {
425                 if bytes.HasPrefix(line, []byte{'#'}) {
426                         continue
427                 }
428                 fields := bytes.Split(line, []byte{'\t'})
429                 if len(fields) < 3 {
430                         continue
431                 }
432                 refseqname := string(fields[0])
433                 if strings.HasPrefix(refseqname, "chr") {
434                         refseqname = refseqname[3:]
435                 }
436                 start, err1 := strconv.Atoi(string(fields[1]))
437                 end, err2 := strconv.Atoi(string(fields[2]))
438                 if err1 == nil && err2 == nil {
439                         // BED
440                 } else {
441                         start, err1 = strconv.Atoi(string(fields[3]))
442                         end, err2 = strconv.Atoi(string(fields[4]))
443                         if err1 == nil && err2 == nil {
444                                 // GFF/GTF
445                                 end++
446                         } else {
447                                 return nil, fmt.Errorf("cannot parse input line as BED or GFF/GTF: %q", line)
448                         }
449                 }
450                 mask.Add(refseqname, start-expandRegions, end+expandRegions)
451         }
452         log.Print("makeMask: mask.Freeze")
453         mask.Freeze()
454         return &mask, nil
455 }
456
457 func chooseTiles(tilelib *tileLibrary, regionsFilename string, expandRegions int) (drop []bool, err error) {
458         if regionsFilename == "" {
459                 return
460         }
461         mask, err := makeMask(regionsFilename, expandRegions)
462         if err != nil {
463                 return
464         }
465
466         tagset := tilelib.taglib.Tags()
467         if len(tagset) == 0 {
468                 err = errors.New("cannot choose tiles by region in a library without tags")
469                 return
470         }
471         taglen := len(tagset[0])
472
473         log.Print("chooseTiles: check ref tiles")
474         // Find position+size of each reference tile, and if it
475         // intersects any of the desired regions, set drop[tag]=false.
476         //
477         // (Note it doesn't quite work to do the more obvious thing --
478         // start with drop=false and change to true when ref tiles
479         // intersect target regions -- because that would give us
480         // drop=false for tiles that don't appear at all in the
481         // reference.)
482         //
483         // TODO: (optionally?) don't drop tags for which some tile
484         // variants are spanning tiles, i.e., where the reference tile
485         // does not intersect the desired regions, but a spanning tile
486         // from a genome does.
487         drop = make([]bool, len(tilelib.variant))
488         for i := range drop {
489                 drop[i] = true
490         }
491         for refname, refseqs := range tilelib.refseqs {
492                 for refseqname, reftiles := range refseqs {
493                         if strings.HasPrefix(refseqname, "chr") {
494                                 refseqname = refseqname[3:]
495                         }
496                         tileend := 0
497                         for _, libref := range reftiles {
498                                 if libref.Variant < 1 {
499                                         err = fmt.Errorf("reference %q seq %q uses variant zero at tag %d", refname, refseqname, libref.Tag)
500                                         return
501                                 }
502                                 seq := tilelib.TileVariantSequence(libref)
503                                 if len(seq) < taglen {
504                                         err = fmt.Errorf("reference %q seq %q uses tile %d variant %d with sequence len %d < taglen %d", refname, refseqname, libref.Tag, libref.Variant, len(seq), taglen)
505                                         return
506                                 }
507                                 tilestart := tileend
508                                 tileend = tilestart + len(seq) - taglen
509                                 if mask.Check(refseqname, tilestart, tileend) {
510                                         drop[libref.Tag] = false
511                                 }
512                         }
513                 }
514         }
515
516         log.Print("chooseTiles: done")
517         return
518 }
519
520 func recodeOnehot(in []int16, incols int) (out []int16, librefs []tileLibRef, outcols int) {
521         rows := len(in) / incols
522         maxvalue := make([]int16, incols)
523         for row := 0; row < rows; row++ {
524                 for col := 0; col < incols; col++ {
525                         if v := in[row*incols+col]; maxvalue[col] < v {
526                                 maxvalue[col] = v
527                         }
528                 }
529         }
530         outcol := make([]int, incols)
531         dropped := 0
532         for incol, maxv := range maxvalue {
533                 outcol[incol] = outcols
534                 if maxv == 0 {
535                         dropped++
536                 }
537                 for v := 1; v <= int(maxv); v++ {
538                         librefs = append(librefs, tileLibRef{Tag: tagID(incol), Variant: tileVariantID(v)})
539                         outcols++
540                 }
541         }
542         log.Printf("recodeOnehot: dropped %d input cols with zero maxvalue", dropped)
543
544         out = make([]int16, rows*outcols)
545         for inidx, row := 0, 0; row < rows; row++ {
546                 outrow := out[row*outcols:]
547                 for col := 0; col < incols; col++ {
548                         if v := in[inidx]; v > 0 {
549                                 outrow[outcol[col]+int(v)-1] = 1
550                         }
551                         inidx++
552                 }
553         }
554         return
555 }
556
557 type nopCloser struct {
558         io.Writer
559 }
560
561 func (nopCloser) Close() error { return nil }
562
563 func trimFilenameForLabel(s string) string {
564         if i := strings.LastIndex(s, "/"); i >= 0 {
565                 s = s[i+1:]
566         }
567         s = strings.TrimSuffix(s, ".gz")
568         s = strings.TrimSuffix(s, ".fa")
569         s = strings.TrimSuffix(s, ".fasta")
570         s = strings.TrimSuffix(s, ".1")
571         s = strings.TrimSuffix(s, ".2")
572         s = strings.TrimSuffix(s, ".gz")
573         s = strings.TrimSuffix(s, ".vcf")
574         s = strings.Replace(s, ",", "-", -1)
575         return s
576 }