Refactor chunked-hgvs to use less memory.
authorTom Clegg <tom@curii.com>
Tue, 28 Dec 2021 21:45:39 +0000 (16:45 -0500)
committerTom Clegg <tom@curii.com>
Wed, 29 Dec 2021 14:34:06 +0000 (09:34 -0500)
refs #18438

Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

anno2vcf.go
slicenumpy.go

index 4b1295f5c0aa7102049156dfdf9b4c3b390ee173..23c6f21ece59f6c6b8ad7997472248519db12464 100644 (file)
@@ -128,7 +128,7 @@ func (cmd *anno2vcf) RunCommand(prog string, args []string, stdin io.Reader, std
                                if len(line) == 0 {
                                        continue
                                }
-                               if lineIdx & ^0xfff == 0 && thr.Err() != nil {
+                               if lineIdx&0xff == 0 && thr.Err() != nil {
                                        return nil
                                }
                                fields := bytes.Split(line, []byte{','})
index 193d3ef4d0eeecb3ccac0d85d7ebd91258940e10..d780eb55c7e59ba84eb19b1e5d06790443e72a21 100644 (file)
@@ -7,6 +7,7 @@ package lightning
 import (
        "bufio"
        "bytes"
+       "encoding/gob"
        "flag"
        "fmt"
        "io"
@@ -247,8 +248,27 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                log.Printf("after applying mask, len(reftile) == %d", len(reftile))
        }
 
+       tmpHGVSCols := map[string]*os.File{}
+       bufHGVSCols := map[string]*bufio.Writer{}
+       encodeHGVSCols := map[string]*gob.Encoder{}
+       if *hgvsChunked {
+               for seqname := range refseq {
+                       var f *os.File
+                       f, err = os.Create(*outputDir + "/tmp." + seqname + ".gob")
+                       if err != nil {
+                               return 1
+                       }
+                       defer os.Remove(f.Name())
+                       bufw := bufio.NewWriterSize(f, 1<<24)
+                       enc := gob.NewEncoder(bufw)
+                       tmpHGVSCols[seqname] = f
+                       bufHGVSCols[seqname] = bufw
+                       encodeHGVSCols[seqname] = enc
+               }
+       }
+
        var toMerge [][]int16
-       if *mergeOutput || *hgvsSingle || *hgvsChunked {
+       if *mergeOutput || *hgvsSingle {
                toMerge = make([][]int16, len(infiles))
        }
 
@@ -391,7 +411,14 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                variants := seq[tag]
                                reftilestr := strings.ToUpper(string(rt.tiledata))
                                remap := variantRemap[tag-tagstart]
-                               done := make([]bool, len(variants))
+                               maxv := tileVariantID(0)
+                               for _, v := range remap {
+                                       if maxv < v {
+                                               maxv = v
+                                       }
+                               }
+                               done := make([]bool, maxv+1)
+                               variantDiffs := make([][]hgvs.Variant, maxv+1)
                                for v, tv := range variants {
                                        v := remap[v]
                                        if v == rt.variant || done[v] {
@@ -412,6 +439,54 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                                diff.Position += rt.pos
                                                fmt.Fprintf(annow, "%d,%d,%d,%s:g.%s,%s,%d,%s,%s,%s\n", tag, outcol, v, rt.seqname, diff.String(), rt.seqname, diff.Position, diff.Ref, diff.New, diff.Left)
                                        }
+                                       if *hgvsChunked {
+                                               variantDiffs[v] = diffs
+                                       }
+                               }
+                               if *hgvsChunked {
+                                       // We can now determine, for each HGVS
+                                       // variant (diff) in this reftile
+                                       // region, whether a given genome
+                                       // phase/allele (1) has the variant, (0) has
+                                       // =ref or a different variant in that
+                                       // position, or (-1) is lacking
+                                       // coverage / couldn't be diffed.
+                                       hgvsCol := map[hgvs.Variant][2][]int8{}
+                                       for _, diffs := range variantDiffs {
+                                               for _, diff := range diffs {
+                                                       if _, ok := hgvsCol[diff]; ok {
+                                                               continue
+                                                       }
+                                                       hgvsCol[diff] = [2][]int8{
+                                                               make([]int8, len(cgnames)),
+                                                               make([]int8, len(cgnames)),
+                                                       }
+                                               }
+                                       }
+                                       for row, name := range cgnames {
+                                               variants := cgs[name].Variants[(tag-tagstart)*2:]
+                                               for ph := 0; ph < 2; ph++ {
+                                                       v := variants[ph]
+                                                       if int(v) >= len(remap) {
+                                                               v = 0
+                                                       } else {
+                                                               v = remap[v]
+                                                       }
+                                                       if v == rt.variant {
+                                                               // hgvsCol[*][ph][row] is already 0
+                                                       } else if len(variantDiffs[v]) == 0 {
+                                                               // lacking coverage / couldn't be diffed
+                                                               for _, col := range hgvsCol {
+                                                                       col[ph][row] = -1
+                                                               }
+                                                       } else {
+                                                               for _, diff := range variantDiffs[v] {
+                                                                       hgvsCol[diff][ph][row] = 1
+                                                               }
+                                                       }
+                                               }
+                                       }
+                                       encodeHGVSCols[rt.seqname].Encode(hgvsCol)
                                }
                                outcol++
                        }
@@ -448,7 +523,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        seq = nil
                        throttleNumpyMem.Release()
 
-                       if *mergeOutput || *hgvsSingle || *hgvsChunked {
+                       if *mergeOutput || *hgvsSingle {
                                log.Infof("%04d: matrix fragment %d rows x %d cols", infileIdx, rows, cols)
                                toMerge[infileIdx] = out
                        }
@@ -466,7 +541,78 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        if err = throttleMem.Wait(); err != nil {
                return 1
        }
-       if *mergeOutput || *hgvsSingle || *hgvsChunked {
+
+       if *hgvsChunked {
+               log.Info("flushing hgvsCols temp files")
+               for seqname := range refseq {
+                       err = bufHGVSCols[seqname].Flush()
+                       if err != nil {
+                               return 1
+                       }
+                       bufHGVSCols[seqname] = nil // free buffer memory
+               }
+               for seqname := range refseq {
+                       log.Infof("%s: reading hgvsCols from temp file", seqname)
+                       f := tmpHGVSCols[seqname]
+                       _, err = f.Seek(0, io.SeekStart)
+                       if err != nil {
+                               return 1
+                       }
+                       var hgvsCols map[hgvs.Variant][2][]int8
+                       dec := gob.NewDecoder(bufio.NewReaderSize(f, 1<<24))
+                       for err == nil {
+                               err = dec.Decode(&hgvsCols)
+                       }
+                       if err != io.EOF {
+                               return 1
+                       }
+                       log.Infof("%s: sorting %d hgvs variants", seqname, len(hgvsCols))
+                       variants := make([]hgvs.Variant, 0, len(hgvsCols))
+                       for v := range hgvsCols {
+                               variants = append(variants, v)
+                       }
+                       sort.Slice(variants, func(i, j int) bool {
+                               vi, vj := &variants[i], &variants[j]
+                               if vi.Position != vj.Position {
+                                       return vi.Position < vj.Position
+                               } else if vi.Ref != vj.Ref {
+                                       return vi.Ref < vj.Ref
+                               } else {
+                                       return vi.New < vj.New
+                               }
+                       })
+                       rows := len(cgnames)
+                       cols := len(variants) * 2
+                       log.Infof("%s: building hgvs matrix (rows=%d, cols=%d, mem=%d)", seqname, rows, cols, rows*cols)
+                       out := make([]int8, rows*cols)
+                       for varIdx, variant := range variants {
+                               hgvsCols := hgvsCols[variant]
+                               for row := range cgnames {
+                                       for ph := 0; ph < 2; ph++ {
+                                               out[row*cols+varIdx+ph] = hgvsCols[ph][row]
+                                       }
+                               }
+                       }
+                       err = writeNumpyInt8(fmt.Sprintf("%s/hgvs.%s.npy", *outputDir, seqname), out, rows, cols)
+                       if err != nil {
+                               return 1
+                       }
+                       out = nil
+
+                       fnm := fmt.Sprintf("%s/hgvs.%s.annotations.csv", *outputDir, seqname)
+                       log.Infof("%s: writing hgvs column labels to %s", seqname, fnm)
+                       var hgvsLabels bytes.Buffer
+                       for varIdx, variant := range variants {
+                               fmt.Fprintf(&hgvsLabels, "%d,%s:g.%s\n", varIdx, seqname, variant.String())
+                       }
+                       err = ioutil.WriteFile(fnm, hgvsLabels.Bytes(), 0666)
+                       if err != nil {
+                               return 1
+                       }
+               }
+       }
+
+       if *mergeOutput || *hgvsSingle {
                var annow *bufio.Writer
                var annof *os.File
                if *mergeOutput {
@@ -564,12 +710,12 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                                }
                                        }
                                        hgvsCols[hgvsID] = hgvsColPair
-                                       hgvsref := hgvs.Variant{
-                                               Position: pos,
-                                               Ref:      string(refseq),
-                                               New:      string(refseq),
-                                       }
                                        if annow != nil {
+                                               hgvsref := hgvs.Variant{
+                                                       Position: pos,
+                                                       Ref:      string(refseq),
+                                                       New:      string(refseq),
+                                               }
                                                fmt.Fprintf(annow, "%d,%d,%d,%s:g.%s,%s,%d,%s,%s,%s\n", tag, incol+startcol/2, rt.variant, seqname, hgvsref.String(), seqname, pos, refseq, refseq, fields[8])
                                        }
                                }
@@ -604,37 +750,13 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                }
                out = nil
 
-               var seqnames []string
                if *hgvsSingle {
-                       seqnames = []string{""}
-               }
-               if *hgvsChunked {
-                       for seqname := range refseq {
-                               seqnames = append(seqnames, seqname)
-                       }
-               }
-               for _, seqname := range seqnames {
-                       basename := "hgvs"
-                       wantPrefix := ""
-                       if seqname == "" {
-                               cols = len(hgvsCols) * 2
-                       } else {
-                               basename = "hgvs." + seqname
-                               wantPrefix = seqname + ":"
-                               cols = 0
-                               for hgvsID := range hgvsCols {
-                                       if strings.HasPrefix(hgvsID, wantPrefix) {
-                                               cols += 2
-                                       }
-                               }
-                       }
-                       log.Printf("building hgvs-based matrix [%s]: %d rows x %d cols", seqname, rows, cols)
+                       cols = len(hgvsCols) * 2
+                       log.Printf("building hgvs-based matrix: %d rows x %d cols", rows, cols)
                        out = make([]int16, rows*cols)
                        hgvsIDs := make([]string, 0, cols/2)
                        for hgvsID := range hgvsCols {
-                               if strings.HasPrefix(hgvsID, wantPrefix) {
-                                       hgvsIDs = append(hgvsIDs, hgvsID)
-                               }
+                               hgvsIDs = append(hgvsIDs, hgvsID)
                        }
                        sort.Strings(hgvsIDs)
                        var hgvsLabels bytes.Buffer
@@ -647,12 +769,12 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                        }
                                }
                        }
-                       err = writeNumpyInt16(fmt.Sprintf("%s/%s.npy", *outputDir, basename), out, rows, cols)
+                       err = writeNumpyInt16(fmt.Sprintf("%s/hgvs.npy", *outputDir), out, rows, cols)
                        if err != nil {
                                return 1
                        }
 
-                       fnm := fmt.Sprintf("%s/%s.annotations.csv", *outputDir, basename)
+                       fnm := fmt.Sprintf("%s/hgvs.annotations.csv", *outputDir)
                        log.Printf("writing hgvs labels: %s", fnm)
                        err = ioutil.WriteFile(fnm, hgvsLabels.Bytes(), 0777)
                        if err != nil {
@@ -687,3 +809,28 @@ func writeNumpyInt16(fnm string, out []int16, rows, cols int) error {
        }
        return output.Close()
 }
+
+func writeNumpyInt8(fnm string, out []int8, rows, cols int) error {
+       output, err := os.Create(fnm)
+       if err != nil {
+               return err
+       }
+       defer output.Close()
+       bufw := bufio.NewWriterSize(output, 1<<26)
+       npw, err := gonpy.NewWriter(nopCloser{bufw})
+       if err != nil {
+               return err
+       }
+       log.WithFields(log.Fields{
+               "filename": fnm,
+               "rows":     rows,
+               "cols":     cols,
+       }).Infof("writing numpy: %s", fnm)
+       npw.Shape = []int{rows, cols}
+       npw.WriteInt8(out)
+       err = bufw.Flush()
+       if err != nil {
+               return err
+       }
+       return output.Close()
+}