Write hgvs-based numpy matrix.
[lightning.git] / exportnumpy.go
1 package lightning
2
3 import (
4         "bufio"
5         "bytes"
6         "context"
7         "errors"
8         "flag"
9         "fmt"
10         "io"
11         "io/ioutil"
12         "net/http"
13         _ "net/http/pprof"
14         "os"
15         "sort"
16         "strconv"
17         "strings"
18         "sync"
19         "sync/atomic"
20
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/arvados/lightning/hgvs"
23         "github.com/kshedden/gonpy"
24         "github.com/sirupsen/logrus"
25         log "github.com/sirupsen/logrus"
26 )
27
28 type exportNumpy struct {
29         filter filter
30 }
31
32 func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
33         var err error
34         defer func() {
35                 if err != nil {
36                         fmt.Fprintf(stderr, "%s\n", err)
37                 }
38         }()
39         flags := flag.NewFlagSet("", flag.ContinueOnError)
40         flags.SetOutput(stderr)
41         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
42         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
43         projectUUID := flags.String("project", "", "project `UUID` for output data")
44         priority := flags.Int("priority", 500, "container request priority")
45         inputFilename := flags.String("i", "-", "input `file`")
46         outputDir := flags.String("output-dir", "/tmp", "output `directory`")
47         annotationsFilename := flags.String("output-annotations", "", "output `file` for tile variant annotations csv")
48         librefsFilename := flags.String("output-onehot2tilevar", "", "when using -one-hot, create csv `file` mapping column# to tag# and variant#")
49         labelsFilename := flags.String("output-labels", "", "output `file` for genome labels csv")
50         regionsFilename := flags.String("regions", "", "only output columns/annotations that intersect regions in specified bed `file`")
51         expandRegions := flags.Int("expand-regions", 0, "expand specified regions by `N` base pairs on each side`")
52         onehot := flags.Bool("one-hot", false, "recode tile variants as one-hot")
53         chunks := flags.Int("chunks", 1, "split output into `N` numpy files")
54         cmd.filter.Flags(flags)
55         err = flags.Parse(args)
56         if err == flag.ErrHelp {
57                 err = nil
58                 return 0
59         } else if err != nil {
60                 return 2
61         }
62
63         if *pprof != "" {
64                 go func() {
65                         log.Println(http.ListenAndServe(*pprof, nil))
66                 }()
67         }
68
69         if !*runlocal {
70                 runner := arvadosContainerRunner{
71                         Name:        "lightning export-numpy",
72                         Client:      arvados.NewClientFromEnv(),
73                         ProjectUUID: *projectUUID,
74                         RAM:         750000000000,
75                         VCPUs:       32,
76                         Priority:    *priority,
77                         KeepCache:   1,
78                         APIAccess:   true,
79                 }
80                 err = runner.TranslatePaths(inputFilename, regionsFilename)
81                 if err != nil {
82                         return 1
83                 }
84                 runner.Args = []string{"export-numpy", "-local=true",
85                         "-pprof", ":6000",
86                         fmt.Sprintf("-one-hot=%v", *onehot),
87                         "-i", *inputFilename,
88                         "-output-dir", "/mnt/output",
89                         "-output-annotations", "/mnt/output/annotations.csv",
90                         "-output-onehot2tilevar", "/mnt/output/onehot2tilevar.csv",
91                         "-output-labels", "/mnt/output/labels.csv",
92                         "-regions", *regionsFilename,
93                         "-expand-regions", fmt.Sprintf("%d", *expandRegions),
94                         "-max-variants", fmt.Sprintf("%d", cmd.filter.MaxVariants),
95                         "-min-coverage", fmt.Sprintf("%f", cmd.filter.MinCoverage),
96                         "-max-tag", fmt.Sprintf("%d", cmd.filter.MaxTag),
97                         "-chunks", fmt.Sprintf("%d", *chunks),
98                 }
99                 var output string
100                 output, err = runner.Run()
101                 if err != nil {
102                         return 1
103                 }
104                 fmt.Fprintln(stdout, output+"/matrix.npy")
105                 return 0
106         }
107
108         var input io.ReadCloser
109         if *inputFilename == "-" {
110                 input = ioutil.NopCloser(stdin)
111         } else {
112                 input, err = open(*inputFilename)
113                 if err != nil {
114                         return 1
115                 }
116                 defer input.Close()
117         }
118         input = ioutil.NopCloser(bufio.NewReaderSize(input, 8*1024*1024))
119         tilelib := &tileLibrary{
120                 retainNoCalls:       true,
121                 retainTileSequences: true,
122                 compactGenomes:      map[string][]tileVariantID{},
123         }
124         err = tilelib.LoadGob(context.Background(), input, strings.HasSuffix(*inputFilename, ".gz"), nil)
125         if err != nil {
126                 return 1
127         }
128         err = input.Close()
129         if err != nil {
130                 return 1
131         }
132
133         log.Info("filtering")
134         cmd.filter.Apply(tilelib)
135         log.Info("tidying")
136         tilelib.Tidy()
137
138         log.Info("building lowqual map")
139         lowqual := lowqual(tilelib)
140         names := cgnames(tilelib)
141
142         if *labelsFilename != "" {
143                 log.Infof("writing labels to %s", *labelsFilename)
144                 var f *os.File
145                 f, err = os.OpenFile(*labelsFilename, os.O_CREATE|os.O_WRONLY, 0777)
146                 if err != nil {
147                         return 1
148                 }
149                 defer f.Close()
150                 outBasename := "matrix.npy"
151                 for i, name := range names {
152                         _, err = fmt.Fprintf(f, "%d,%q,%q\n", i, trimFilenameForLabel(name), outBasename)
153                         if err != nil {
154                                 err = fmt.Errorf("write %s: %w", *labelsFilename, err)
155                                 return 1
156                         }
157                 }
158                 err = f.Close()
159                 if err != nil {
160                         err = fmt.Errorf("close %s: %w", *labelsFilename, err)
161                         return 1
162                 }
163         }
164
165         log.Info("determining which tiles intersect given regions")
166         dropTiles, err := chooseTiles(tilelib, *regionsFilename, *expandRegions)
167         if err != nil {
168                 return 1
169         }
170
171         annotation2tvs := map[string]map[hgvs.Variant][]tileLibRef{}
172         if *annotationsFilename != "" {
173                 log.Info("writing annotations")
174                 var annow io.WriteCloser
175                 annow, err = os.OpenFile(*annotationsFilename, os.O_CREATE|os.O_WRONLY, 0666)
176                 if err != nil {
177                         return 1
178                 }
179                 defer annow.Close()
180                 var mtx sync.Mutex
181                 err = (&annotatecmd{
182                         maxTileSize: 5000,
183                         dropTiles:   dropTiles,
184                         reportAnnotation: func(tag tagID, _ int, variant tileVariantID, refname string, seqname string, pdi hgvs.Variant) {
185                                 mtx.Lock()
186                                 defer mtx.Unlock()
187                                 if annotation2tvs[seqname] == nil {
188                                         annotation2tvs[seqname] = map[hgvs.Variant][]tileLibRef{}
189                                 }
190                                 annotation2tvs[seqname][pdi] = append(annotation2tvs[seqname][pdi], tileLibRef{Tag: tag, Variant: variant})
191                         },
192                 }).exportTileDiffs(annow, tilelib)
193                 if err != nil {
194                         return 1
195                 }
196                 err = annow.Close()
197                 if err != nil {
198                         return 1
199                 }
200         }
201
202         var lastErr atomic.Value
203         var wg sync.WaitGroup
204         for seqname, pdivars := range annotation2tvs {
205                 seqname, pdivars := seqname, pdivars
206                 wg.Add(1)
207                 go func() {
208                         defer wg.Done()
209                         log.Infof("choosing hgvs columns for seq %s", seqname)
210                         var pdis []hgvs.Variant
211                         for pdi, librefs := range pdivars {
212                                 // Include this HGVS column if it was
213                                 // seen in a variant of any
214                                 // non-dropped tile.
215                                 for _, libref := range librefs {
216                                         if int(libref.Tag) >= len(dropTiles) || !dropTiles[libref.Tag] {
217                                                 pdis = append(pdis, pdi)
218                                                 break
219                                         }
220                                 }
221                         }
222                         sort.Slice(pdis, func(i, j int) bool {
223                                 if cmp := pdis[i].Position - pdis[j].Position; cmp != 0 {
224                                         return cmp < 0
225                                 } else if pdis[i].Ref != pdis[j].Ref {
226                                         return pdis[i].Ref < pdis[j].Ref
227                                 } else {
228                                         return pdis[i].New < pdis[j].New
229                                 }
230                         })
231                         log.Infof("writing column labels for seq %s", seqname)
232                         var buf bytes.Buffer
233                         for _, pdi := range pdis {
234                                 fmt.Fprintf(&buf, "%s:g.%s\n", seqname, pdi.String())
235                         }
236                         err := ioutil.WriteFile(*outputDir+"/"+seqname+".columns.csv", buf.Bytes(), 0777)
237                         if err != nil {
238                                 lastErr.Store(err)
239                                 return
240                         }
241
242                         log.Infof("building hgvs matrix for seq %s", seqname)
243                         data := make([]int8, len(names)*len(pdis)*2)
244                         for row, name := range names {
245                                 cg := tilelib.compactGenomes[name]
246                                 rowstart := row * len(pdis) * 2
247                                 for col, pdi := range pdis {
248                                         for _, libref := range pdivars[pdi] {
249                                                 if len(cg) <= int(libref.Tag)*2+1 {
250                                                         continue
251                                                 }
252                                                 for phase := 0; phase < 2; phase++ {
253                                                         if cg[int(libref.Tag)*2+phase] == libref.Variant {
254                                                                 data[rowstart+col*2+phase] = 1
255                                                         }
256                                                 }
257                                         }
258                                 }
259                         }
260                         log.Infof("writing hgvs numpy for seq %s", seqname)
261                         f, err := os.OpenFile(*outputDir+"/"+seqname+".npy", os.O_CREATE|os.O_WRONLY, 0777)
262                         if err != nil {
263                                 lastErr.Store(err)
264                                 return
265                         }
266                         defer f.Close()
267                         npw, err := gonpy.NewWriter(f)
268                         if err != nil {
269                                 lastErr.Store(err)
270                                 return
271                         }
272                         npw.Shape = []int{len(names), len(pdis) * 2}
273                         npw.WriteInt8(data)
274                         // gonpy closes f and ignores errors, doh.
275                         // err = f.Close()
276                         // if err != nil {
277                         //      lastErr.Store(err)
278                         //      return
279                         // }
280                 }()
281         }
282         wg.Wait()
283         if e, ok := lastErr.Load().(error); ok {
284                 err = e
285                 return 1
286         }
287
288         chunksize := (len(tilelib.variant) + *chunks - 1) / *chunks
289         for chunk := 0; chunk < *chunks; chunk++ {
290                 log.Infof("preparing chunk %d of %d", chunk+1, *chunks)
291                 tagstart := chunk * chunksize
292                 tagend := tagstart + chunksize
293                 if tagend > len(tilelib.variant) {
294                         tagend = len(tilelib.variant)
295                 }
296                 out, rows, cols := cgs2array(tilelib, names, lowqual, dropTiles, tagstart, tagend)
297
298                 var npw *gonpy.NpyWriter
299                 var output io.WriteCloser
300                 fnm := *outputDir + "/matrix.npy"
301                 if *chunks > 1 {
302                         fnm = fmt.Sprintf("%s/matrix.%d.npy", *outputDir, chunk)
303                 }
304                 output, err = os.OpenFile(fnm, os.O_CREATE|os.O_WRONLY, 0777)
305                 if err != nil {
306                         return 1
307                 }
308                 defer output.Close()
309                 bufw := bufio.NewWriter(output)
310                 npw, err = gonpy.NewWriter(nopCloser{bufw})
311                 if err != nil {
312                         return 1
313                 }
314                 if *onehot {
315                         log.Info("recoding to onehot")
316                         recoded, librefs, recodedcols := recodeOnehot(out, cols)
317                         out, cols = recoded, recodedcols
318                         if *librefsFilename != "" {
319                                 log.Infof("writing onehot column mapping")
320                                 err = cmd.writeLibRefs(*librefsFilename, tilelib, librefs)
321                                 if err != nil {
322                                         return 1
323                                 }
324                         }
325                 }
326                 log.WithFields(logrus.Fields{
327                         "filename": fnm,
328                         "rows":     rows,
329                         "cols":     cols,
330                 }).Info("writing numpy")
331                 npw.Shape = []int{rows, cols}
332                 npw.WriteInt16(out)
333                 err = bufw.Flush()
334                 if err != nil {
335                         return 1
336                 }
337                 err = output.Close()
338                 if err != nil {
339                         return 1
340                 }
341         }
342         return 0
343 }
344
345 func (*exportNumpy) writeLibRefs(fnm string, tilelib *tileLibrary, librefs []tileLibRef) error {
346         f, err := os.OpenFile(fnm, os.O_CREATE|os.O_WRONLY, 0666)
347         if err != nil {
348                 return err
349         }
350         defer f.Close()
351         for i, libref := range librefs {
352                 _, err = fmt.Fprintf(f, "%d,%d,%d\n", i, libref.Tag, libref.Variant)
353                 if err != nil {
354                         return err
355                 }
356         }
357         return f.Close()
358 }
359
360 func cgnames(tilelib *tileLibrary) (cgnames []string) {
361         for name := range tilelib.compactGenomes {
362                 cgnames = append(cgnames, name)
363         }
364         sort.Slice(cgnames, func(i, j int) bool {
365                 return trimFilenameForLabel(cgnames[i]) < trimFilenameForLabel(cgnames[j])
366         })
367         return
368 }
369
370 func lowqual(tilelib *tileLibrary) (lowqual []map[tileVariantID]bool) {
371         lowqual = make([]map[tileVariantID]bool, len(tilelib.variant))
372         for tag, variants := range tilelib.variant {
373                 lq := lowqual[tag]
374                 for varidx, hash := range variants {
375                         if len(tilelib.seq[hash]) == 0 {
376                                 if lq == nil {
377                                         lq = map[tileVariantID]bool{}
378                                         lowqual[tag] = lq
379                                 }
380                                 lq[tileVariantID(varidx+1)] = true
381                         }
382                 }
383         }
384         return
385 }
386
387 func cgs2array(tilelib *tileLibrary, names []string, lowqual []map[tileVariantID]bool, dropTiles []bool, tagstart, tagend int) (data []int16, rows, cols int) {
388         rows = len(tilelib.compactGenomes)
389         for tag := tagstart; tag < tagend; tag++ {
390                 if len(dropTiles) <= tag || !dropTiles[tag] {
391                         cols += 2
392                 }
393         }
394         data = make([]int16, rows*cols)
395         for row, name := range names {
396                 cg := tilelib.compactGenomes[name]
397                 outidx := 0
398                 for tag := tagstart; tag < tagend && tag*2+1 < len(cg); tag++ {
399                         if len(dropTiles) > tag && dropTiles[tag] {
400                                 continue
401                         }
402                         for phase := 0; phase < 2; phase++ {
403                                 v := cg[tag*2+phase]
404                                 if v > 0 && lowqual[tag][v] {
405                                         data[row*cols+outidx] = -1
406                                 } else {
407                                         data[row*cols+outidx] = int16(v)
408                                 }
409                                 outidx++
410                         }
411                 }
412         }
413         return
414 }
415
416 func chooseTiles(tilelib *tileLibrary, regionsFilename string, expandRegions int) (drop []bool, err error) {
417         if regionsFilename == "" {
418                 return
419         }
420         rfile, err := zopen(regionsFilename)
421         if err != nil {
422                 return
423         }
424         defer rfile.Close()
425         regions, err := ioutil.ReadAll(rfile)
426         if err != nil {
427                 return
428         }
429
430         log.Print("chooseTiles: building mask")
431         mask := &mask{}
432         for _, line := range bytes.Split(regions, []byte{'\n'}) {
433                 if bytes.HasPrefix(line, []byte{'#'}) {
434                         continue
435                 }
436                 fields := bytes.Split(line, []byte{'\t'})
437                 if len(fields) < 3 {
438                         continue
439                 }
440                 refseqname := string(fields[0])
441                 if strings.HasPrefix(refseqname, "chr") {
442                         refseqname = refseqname[3:]
443                 }
444                 start, err1 := strconv.Atoi(string(fields[1]))
445                 end, err2 := strconv.Atoi(string(fields[2]))
446                 if err1 == nil && err2 == nil {
447                         // BED
448                 } else {
449                         start, err1 = strconv.Atoi(string(fields[3]))
450                         end, err2 = strconv.Atoi(string(fields[4]))
451                         if err1 == nil && err2 == nil {
452                                 // GFF/GTF
453                                 end++
454                         } else {
455                                 err = fmt.Errorf("cannot parse input line as BED or GFF/GTF: %q", line)
456                                 return
457                         }
458                 }
459                 mask.Add(refseqname, start-expandRegions, end+expandRegions)
460         }
461         log.Print("chooseTiles: mask.Freeze")
462         mask.Freeze()
463
464         tagset := tilelib.taglib.Tags()
465         if len(tagset) == 0 {
466                 err = errors.New("cannot choose tiles by region in a library without tags")
467                 return
468         }
469         taglen := len(tagset[0])
470
471         log.Print("chooseTiles: check ref tiles")
472         // Find position+size of each reference tile, and if it
473         // intersects any of the desired regions, set drop[tag]=false.
474         //
475         // (Note it doesn't quite work to do the more obvious thing --
476         // start with drop=false and change to true when ref tiles
477         // intersect target regions -- because that would give us
478         // drop=false for tiles that don't appear at all in the
479         // reference.)
480         //
481         // TODO: (optionally?) don't drop tags for which some tile
482         // variants are spanning tiles, i.e., where the reference tile
483         // does not intersect the desired regions, but a spanning tile
484         // from a genome does.
485         drop = make([]bool, len(tilelib.variant))
486         for i := range drop {
487                 drop[i] = true
488         }
489         for refname, refseqs := range tilelib.refseqs {
490                 for refseqname, reftiles := range refseqs {
491                         if strings.HasPrefix(refseqname, "chr") {
492                                 refseqname = refseqname[3:]
493                         }
494                         tileend := 0
495                         for _, libref := range reftiles {
496                                 if libref.Variant < 1 {
497                                         err = fmt.Errorf("reference %q seq %q uses variant zero at tag %d", refname, refseqname, libref.Tag)
498                                         return
499                                 }
500                                 seq := tilelib.TileVariantSequence(libref)
501                                 if len(seq) < taglen {
502                                         err = fmt.Errorf("reference %q seq %q uses tile %d variant %d with sequence len %d < taglen %d", refname, refseqname, libref.Tag, libref.Variant, len(seq), taglen)
503                                         return
504                                 }
505                                 tilestart := tileend
506                                 tileend = tilestart + len(seq) - taglen
507                                 if mask.Check(refseqname, tilestart, tileend) {
508                                         drop[libref.Tag] = false
509                                 }
510                         }
511                 }
512         }
513
514         log.Print("chooseTiles: done")
515         return
516 }
517
518 func recodeOnehot(in []int16, incols int) (out []int16, librefs []tileLibRef, outcols int) {
519         rows := len(in) / incols
520         maxvalue := make([]int16, incols)
521         for row := 0; row < rows; row++ {
522                 for col := 0; col < incols; col++ {
523                         if v := in[row*incols+col]; maxvalue[col] < v {
524                                 maxvalue[col] = v
525                         }
526                 }
527         }
528         outcol := make([]int, incols)
529         dropped := 0
530         for incol, maxv := range maxvalue {
531                 outcol[incol] = outcols
532                 if maxv == 0 {
533                         dropped++
534                 }
535                 for v := 1; v <= int(maxv); v++ {
536                         librefs = append(librefs, tileLibRef{Tag: tagID(incol), Variant: tileVariantID(v)})
537                         outcols++
538                 }
539         }
540         log.Printf("recodeOnehot: dropped %d input cols with zero maxvalue", dropped)
541
542         out = make([]int16, rows*outcols)
543         for inidx, row := 0, 0; row < rows; row++ {
544                 outrow := out[row*outcols:]
545                 for col := 0; col < incols; col++ {
546                         if v := in[inidx]; v > 0 {
547                                 outrow[outcol[col]+int(v)-1] = 1
548                         }
549                         inidx++
550                 }
551         }
552         return
553 }
554
555 type nopCloser struct {
556         io.Writer
557 }
558
559 func (nopCloser) Close() error { return nil }
560
561 func trimFilenameForLabel(s string) string {
562         if i := strings.LastIndex(s, "/"); i >= 0 {
563                 s = s[i+1:]
564         }
565         s = strings.TrimSuffix(s, ".gz")
566         s = strings.TrimSuffix(s, ".fa")
567         s = strings.TrimSuffix(s, ".fasta")
568         s = strings.TrimSuffix(s, ".1")
569         s = strings.TrimSuffix(s, ".2")
570         s = strings.TrimSuffix(s, ".gz")
571         s = strings.TrimSuffix(s, ".vcf")
572         return s
573 }