Add -chunked-onehot option.
authorTom Clegg <tom@curii.com>
Thu, 13 Jan 2022 19:47:40 +0000 (14:47 -0500)
committerTom Clegg <tom@curii.com>
Thu, 13 Jan 2022 19:47:40 +0000 (14:47 -0500)
refs #18581

Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

slice_test.go
slicenumpy.go

index 0dd93d1f729bbac6233a2a598a3cad97f1b2de1a..62c0c9181f761ba6685bc30aeea4b37f4946e8d7 100644 (file)
@@ -196,7 +196,9 @@ func (s *sliceSuite) TestImportAndSlice(c *check.C) {
                c.Assert(err, check.IsNil)
                c.Check(npy.Shape, check.DeepEquals, []int{4, 4})
                variants, err := npy.GetInt16()
-               c.Check(variants, check.DeepEquals, []int16{2, 1, 3, 1, -1, -1, 4, 2, 2, 1, 3, 1, -1, -1, 4, 2})
+               if c.Check(err, check.IsNil) {
+                       c.Check(variants, check.DeepEquals, []int16{2, 1, 3, 1, -1, -1, 4, 2, 2, 1, 3, 1, -1, -1, 4, 2})
+               }
 
                annotations, err := ioutil.ReadFile(npydir + "/matrix.annotations.csv")
                c.Assert(err, check.IsNil)
@@ -235,4 +237,42 @@ func (s *sliceSuite) TestImportAndSlice(c *check.C) {
 2,chr2:g.472G>A
 `)
        }
+
+       c.Log("=== slice-numpy + onehot ===")
+       {
+               err = ioutil.WriteFile(tmpdir+"/cases.txt", []byte("pipeline1/input1\npipeline1dup/input1\n"), 0600)
+               c.Assert(err, check.IsNil)
+               npydir := c.MkDir()
+               exited := (&sliceNumpy{}).RunCommand("slice-numpy", []string{
+                       "-local=true",
+                       "-chunked-onehot=true",
+                       "-chi2-cases-file=" + tmpdir + "/cases.txt",
+                       "-chi2-p-value=0.05",
+                       "-min-coverage=0.75",
+                       "-input-dir=" + slicedir,
+                       "-output-dir=" + npydir,
+               }, nil, os.Stderr, os.Stderr)
+               c.Check(exited, check.Equals, 0)
+               out, _ := exec.Command("find", npydir, "-ls").CombinedOutput()
+               c.Logf("%s", out)
+
+               f, err := os.Open(npydir + "/onehot.0002.npy")
+               c.Assert(err, check.IsNil)
+               defer f.Close()
+               npy, err := gonpy.NewReader(f)
+               c.Assert(err, check.IsNil)
+               c.Check(npy.Shape, check.DeepEquals, []int{4, 6})
+               onehot, err := npy.GetInt8()
+               if c.Check(err, check.IsNil) {
+                       for r := 0; r < npy.Shape[0]; r++ {
+                               c.Logf("%v", onehot[r*npy.Shape[1]:(r+1)*npy.Shape[1]])
+                       }
+                       c.Check(onehot, check.DeepEquals, []int8{
+                               0, 0, 0, 1, 0, 0, // input1
+                               0, 1, 0, 0, 0, 1, // input2
+                               0, 0, 0, 1, 0, 0, // dup/input1
+                               0, 1, 0, 0, 0, 1, // dup/input2
+                       })
+               }
+       }
 }
index 37dac81d80fc05af541962eb7aa0739efd488379..c05f1509c05e37a40fb215b90a96950b36ad932e 100644 (file)
@@ -38,6 +38,7 @@ type sliceNumpy struct {
        chi2Cases     []bool
        chi2PValue    float64
        minCoverage   int
+       cgnames       []string
 }
 
 func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
@@ -61,6 +62,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        mergeOutput := flags.Bool("merge-output", false, "merge output into one matrix.npy and one matrix.annotations.csv")
        hgvsSingle := flags.Bool("single-hgvs-matrix", false, "also generate hgvs-based matrix")
        hgvsChunked := flags.Bool("chunked-hgvs-matrix", false, "also generate hgvs-based matrix per chromosome")
+       onehotChunked := flags.Bool("chunked-onehot", false, "generate one-hot tile-based matrix")
        flags.IntVar(&cmd.threads, "threads", 16, "number of memory-hungry assembly threads")
        flags.StringVar(&cmd.chi2CasesFile, "chi2-cases-file", "", "text file indicating positive cases (for Χ² test)")
        flags.Float64Var(&cmd.chi2PValue, "chi2-p-value", 1, "do Χ² test and omit columns with p-value above this threshold")
@@ -109,6 +111,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        "-merge-output=" + fmt.Sprintf("%v", *mergeOutput),
                        "-single-hgvs-matrix=" + fmt.Sprintf("%v", *hgvsSingle),
                        "-chunked-hgvs-matrix=" + fmt.Sprintf("%v", *hgvsChunked),
+                       "-chunked-onehot=" + fmt.Sprintf("%v", *onehotChunked),
                        "-chi2-cases-file=" + cmd.chi2CasesFile,
                        "-chi2-p-value=" + fmt.Sprintf("%f", cmd.chi2PValue),
                }
@@ -132,7 +135,6 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        }
        sort.Strings(infiles)
 
-       var cgnames []string
        var refseq map[string][]tileLibRef
        var reftiledata = make(map[tileLibRef][]byte, 11000000)
        in0, err := open(infiles[0])
@@ -146,6 +148,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                return 1
        }
 
+       cmd.cgnames = nil
        taglen := -1
        DecodeLibrary(in0, strings.HasSuffix(infiles[0], ".gz"), func(ent *LibraryEntry) error {
                if len(ent.TagSet) > 0 {
@@ -158,7 +161,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                }
                for _, cg := range ent.CompactGenomes {
                        if matchGenome.MatchString(cg.Name) {
-                               cgnames = append(cgnames, cg.Name)
+                               cmd.cgnames = append(cmd.cgnames, cg.Name)
                        }
                }
                for _, tv := range ent.TileVariants {
@@ -180,13 +183,13 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                err = fmt.Errorf("tagset not found")
                return 1
        }
-       if len(cgnames) == 0 {
+       if len(cmd.cgnames) == 0 {
                err = fmt.Errorf("no genomes found matching regexp %q", cmd.filter.MatchGenome)
                return 1
        }
-       sort.Strings(cgnames)
+       sort.Strings(cmd.cgnames)
 
-       cmd.minCoverage = int(math.Ceil(cmd.filter.MinCoverage * float64(len(cgnames))))
+       cmd.minCoverage = int(math.Ceil(cmd.filter.MinCoverage * float64(len(cmd.cgnames))))
 
        if cmd.chi2CasesFile != "" {
                f, err2 := open(cmd.chi2CasesFile)
@@ -200,7 +203,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        err = err2
                        return 1
                }
-               cmd.chi2Cases = make([]bool, len(cgnames))
+               cmd.chi2Cases = make([]bool, len(cmd.cgnames))
                ncases := 0
                for _, pattern := range bytes.Split(buf, []byte{'\n'}) {
                        if len(pattern) == 0 {
@@ -208,14 +211,14 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        }
                        pattern := string(pattern)
                        idx := -1
-                       for i, name := range cgnames {
+                       for i, name := range cmd.cgnames {
                                if !strings.Contains(name, pattern) {
                                        continue
                                }
                                cmd.chi2Cases[i] = true
                                ncases++
                                if idx >= 0 {
-                                       log.Warnf("pattern %q in cases file matches multiple genome IDs: %q, %q", pattern, cgnames[idx], name)
+                                       log.Warnf("pattern %q in cases file matches multiple genome IDs: %q, %q", pattern, cmd.cgnames[idx], name)
                                } else {
                                        idx = i
                                }
@@ -225,7 +228,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                continue
                        }
                }
-               log.Printf("%d cases, %d controls", ncases, len(cgnames)-ncases)
+               log.Printf("%d cases, %d controls", ncases, len(cmd.cgnames)-ncases)
        }
 
        {
@@ -237,7 +240,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        return 1
                }
                defer f.Close()
-               for i, name := range cgnames {
+               for i, name := range cmd.cgnames {
                        _, err = fmt.Fprintf(f, "%d,%q\n", i, trimFilenameForLabel(name))
                        if err != nil {
                                err = fmt.Errorf("write %s: %w", labelsFilename, err)
@@ -350,7 +353,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                infileIdx, infile := infileIdx, infile
                throttleMem.Go(func() error {
                        seq := make(map[tagID][]TileVariant, 50000)
-                       cgs := make(map[string]CompactGenome, len(cgnames))
+                       cgs := make(map[string]CompactGenome, len(cmd.cgnames))
                        f, err := open(infile)
                        if err != nil {
                                return err
@@ -395,8 +398,8 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        if err != nil {
                                return err
                        }
-                       tagstart := cgs[cgnames[0]].StartTag
-                       tagend := cgs[cgnames[0]].EndTag
+                       tagstart := cgs[cmd.cgnames[0]].StartTag
+                       tagend := cgs[cmd.cgnames[0]].EndTag
 
                        // TODO: filters
 
@@ -461,6 +464,9 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        }
                        throttleCPU.Wait()
 
+                       var onehotChunk [][]int8
+                       var onehotXrefs []onehotXref
+
                        annotationsFilename := fmt.Sprintf("%s/matrix.%04d.annotations.csv", *outputDir, infileIdx)
                        log.Infof("%04d: writing %s", infileIdx, annotationsFilename)
                        annof, err := os.Create(annotationsFilename)
@@ -482,9 +488,6 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                        // mention it in annotations?)
                                        continue
                                }
-                               fmt.Fprintf(annow, "%d,%d,%d,=,%s,%d,,,\n", tag, outcol, rt.variant, rt.seqname, rt.pos)
-                               variants := seq[tag]
-                               reftilestr := strings.ToUpper(string(rt.tiledata))
                                remap := variantRemap[tag-tagstart]
                                maxv := tileVariantID(0)
                                for _, v := range remap {
@@ -492,6 +495,15 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                                maxv = v
                                        }
                                }
+                               if *onehotChunked {
+                                       onehot, xrefs := cmd.tv2homhet(cgs, maxv, remap, tag, tagstart)
+                                       onehotChunk = append(onehotChunk, onehot...)
+                                       onehotXrefs = append(onehotXrefs, xrefs...)
+                               }
+                               fmt.Fprintf(annow, "%d,%d,%d,=,%s,%d,,,\n", tag, outcol, rt.variant, rt.seqname, rt.pos)
+                               variants := seq[tag]
+                               reftilestr := strings.ToUpper(string(rt.tiledata))
+
                                done := make([]bool, maxv+1)
                                variantDiffs := make([][]hgvs.Variant, maxv+1)
                                for v, tv := range variants {
@@ -535,12 +547,12 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                                                continue
                                                        }
                                                        hgvsCol[diff] = [2][]int8{
-                                                               make([]int8, len(cgnames)),
-                                                               make([]int8, len(cgnames)),
+                                                               make([]int8, len(cmd.cgnames)),
+                                                               make([]int8, len(cmd.cgnames)),
                                                        }
                                                }
                                        }
-                                       for row, name := range cgnames {
+                                       for row, name := range cmd.cgnames {
                                                variants := cgs[name].Variants[(tag-tagstart)*2:]
                                                for ph := 0; ph < 2; ph++ {
                                                        v := variants[ph]
@@ -584,44 +596,85 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                return err
                        }
 
-                       log.Infof("%04d: preparing numpy", infileIdx)
-                       throttleNumpyMem.Acquire()
-                       rows := len(cgnames)
-                       cols := 2 * outcol
-                       out := make([]int16, rows*cols)
-                       for row, name := range cgnames {
-                               out := out[row*cols:]
-                               outcol := 0
-                               for col, v := range cgs[name].Variants {
-                                       tag := tagstart + tagID(col/2)
-                                       if mask != nil && reftile[tag] == nil {
-                                               continue
+                       if *onehotChunked {
+                               // transpose onehotChunk[col][row] to numpy[row*ncols+col]
+                               rows := len(cmd.cgnames)
+                               cols := len(onehotChunk)
+                               log.Infof("%04d: preparing onehot numpy (rows=%d, cols=%d, mem=%d)", infileIdx, rows, cols, rows*cols)
+                               throttleNumpyMem.Acquire()
+                               out := make([]int8, rows*cols)
+                               for row := range cmd.cgnames {
+                                       out := out[row*cols:]
+                                       for colnum, values := range onehotChunk {
+                                               out[colnum] = values[row]
                                        }
-                                       if variants, ok := seq[tag]; ok && len(variants) > int(v) && len(variants[v].Sequence) > 0 {
-                                               out[outcol] = int16(variantRemap[tag-tagstart][v])
-                                       } else {
-                                               out[outcol] = -1
-                                       }
-                                       outcol++
                                }
-                       }
-                       seq = nil
-                       cgs = nil
-                       debug.FreeOSMemory()
-                       throttleNumpyMem.Release()
+                               seq = nil
+                               cgs = nil
+                               debug.FreeOSMemory()
+                               throttleNumpyMem.Release()
 
-                       if *mergeOutput || *hgvsSingle {
-                               log.Infof("%04d: matrix fragment %d rows x %d cols", infileIdx, rows, cols)
-                               toMerge[infileIdx] = out
-                       }
-                       if !*mergeOutput {
-                               fnm := fmt.Sprintf("%s/matrix.%04d.npy", *outputDir, infileIdx)
-                               err = writeNumpyInt16(fnm, out, rows, cols)
+                               fnm := fmt.Sprintf("%s/onehot.%04d.npy", *outputDir, infileIdx)
+                               err = writeNumpyInt8(fnm, out, rows, cols)
                                if err != nil {
                                        return err
                                }
+
+                               fnm = fmt.Sprintf("%s/onehot-columns.%04d.npy", *outputDir, infileIdx)
+                               xcols := len(onehotXrefs)
+                               xdata := make([]int32, 4*xcols)
+                               for i, xref := range onehotXrefs {
+                                       xdata[i] = int32(xref.tag)
+                                       xdata[xcols+i] = int32(xref.variant)
+                                       if xref.het {
+                                               xdata[xcols*2+i] = 1
+                                       }
+                                       xdata[xcols*3+i] = int32(xref.pvalue * 1000000)
+                               }
+                               err = writeNumpyInt32(fnm, xdata, 4, xcols)
+                               if err != nil {
+                                       return err
+                               }
+                       }
+                       if !*onehotChunked || *mergeOutput || *hgvsSingle {
+                               log.Infof("%04d: preparing numpy", infileIdx)
+                               throttleNumpyMem.Acquire()
+                               rows := len(cmd.cgnames)
+                               cols := 2 * outcol
+                               out := make([]int16, rows*cols)
+                               for row, name := range cmd.cgnames {
+                                       out := out[row*cols:]
+                                       outcol := 0
+                                       for col, v := range cgs[name].Variants {
+                                               tag := tagstart + tagID(col/2)
+                                               if mask != nil && reftile[tag] == nil {
+                                                       continue
+                                               }
+                                               if variants, ok := seq[tag]; ok && len(variants) > int(v) && len(variants[v].Sequence) > 0 {
+                                                       out[outcol] = int16(variantRemap[tag-tagstart][v])
+                                               } else {
+                                                       out[outcol] = -1
+                                               }
+                                               outcol++
+                                       }
+                               }
+                               seq = nil
+                               cgs = nil
                                debug.FreeOSMemory()
+                               throttleNumpyMem.Release()
+                               if *mergeOutput || *hgvsSingle {
+                                       log.Infof("%04d: matrix fragment %d rows x %d cols", infileIdx, rows, cols)
+                                       toMerge[infileIdx] = out
+                               }
+                               if !*mergeOutput && !*onehotChunked {
+                                       fnm := fmt.Sprintf("%s/matrix.%04d.npy", *outputDir, infileIdx)
+                                       err = writeNumpyInt16(fnm, out, rows, cols)
+                                       if err != nil {
+                                               return err
+                                       }
+                               }
                        }
+                       debug.FreeOSMemory()
                        log.Infof("%s: done (%d/%d)", infile, int(atomic.AddInt64(&done, 1)), len(infiles))
                        return nil
                })
@@ -669,13 +722,13 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                        return vi.New < vj.New
                                }
                        })
-                       rows := len(cgnames)
+                       rows := len(cmd.cgnames)
                        cols := len(variants) * 2
                        log.Infof("%s: building hgvs matrix (rows=%d, cols=%d, mem=%d)", seqname, rows, cols, rows*cols)
                        out := make([]int8, rows*cols)
                        for varIdx, variant := range variants {
                                hgvsCols := hgvsCols[variant]
-                               for row := range cgnames {
+                               for row := range cmd.cgnames {
                                        for ph := 0; ph < 2; ph++ {
                                                out[row*cols+varIdx+ph] = hgvsCols[ph][row]
                                        }
@@ -712,7 +765,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        annow = bufio.NewWriterSize(annof, 1<<20)
                }
 
-               rows := len(cgnames)
+               rows := len(cmd.cgnames)
                cols := 0
                for _, chunk := range toMerge {
                        cols += len(chunk) / rows
@@ -781,7 +834,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                        // change to 1 ("hgvs variant
                                        // present") below, either on
                                        // this line or a future line.
-                                       hgvsColPair = [2][]int16{make([]int16, len(cgnames)), make([]int16, len(cgnames))}
+                                       hgvsColPair = [2][]int16{make([]int16, len(cmd.cgnames)), make([]int16, len(cmd.cgnames))}
                                        rt, ok := reftile[tagID(tag)]
                                        if !ok {
                                                err = fmt.Errorf("bug: seeing annotations for tag %d, but it has no reftile entry", tag)
@@ -892,6 +945,32 @@ func (cmd *sliceNumpy) filterHGVScolpair(colpair [2][]int8) bool {
                (pvalue(cases, col0) <= cmd.chi2PValue || pvalue(cases, col1) <= cmd.chi2PValue)
 }
 
+func writeNumpyInt32(fnm string, out []int32, rows, cols int) error {
+       output, err := os.Create(fnm)
+       if err != nil {
+               return err
+       }
+       defer output.Close()
+       bufw := bufio.NewWriterSize(output, 1<<26)
+       npw, err := gonpy.NewWriter(nopCloser{bufw})
+       if err != nil {
+               return err
+       }
+       log.WithFields(log.Fields{
+               "filename": fnm,
+               "rows":     rows,
+               "cols":     cols,
+               "bytes":    rows * cols * 4,
+       }).Infof("writing numpy: %s", fnm)
+       npw.Shape = []int{rows, cols}
+       npw.WriteInt32(out)
+       err = bufw.Flush()
+       if err != nil {
+               return err
+       }
+       return output.Close()
+}
+
 func writeNumpyInt16(fnm string, out []int16, rows, cols int) error {
        output, err := os.Create(fnm)
        if err != nil {
@@ -907,6 +986,7 @@ func writeNumpyInt16(fnm string, out []int16, rows, cols int) error {
                "filename": fnm,
                "rows":     rows,
                "cols":     cols,
+               "bytes":    rows * cols * 2,
        }).Infof("writing numpy: %s", fnm)
        npw.Shape = []int{rows, cols}
        npw.WriteInt16(out)
@@ -932,6 +1012,7 @@ func writeNumpyInt8(fnm string, out []int8, rows, cols int) error {
                "filename": fnm,
                "rows":     rows,
                "cols":     cols,
+               "bytes":    rows * cols,
        }).Infof("writing numpy: %s", fnm)
        npw.Shape = []int{rows, cols}
        npw.WriteInt8(out)
@@ -961,3 +1042,76 @@ func allele2homhet(colpair [2][]int8) {
                }
        }
 }
+
+type onehotXref struct {
+       tag     tagID
+       variant tileVariantID
+       het     bool
+       pvalue  float64
+}
+
+// Build onehot matrix (m[variant*2+isHet][genome] == 0 or 1) for all
+// variants of a single tile/tag#.
+//
+// Return nil if no tile variant passes Χ² filter.
+func (cmd *sliceNumpy) tv2homhet(cgs map[string]CompactGenome, maxv tileVariantID, remap []tileVariantID, tag, chunkstarttag tagID) ([][]int8, []onehotXref) {
+       if maxv < 2 {
+               // everyone has the most common variant
+               return nil, nil
+       }
+       tagoffset := tag - chunkstarttag
+       coverage := 0
+       for _, cg := range cgs {
+               if cg.Variants[tagoffset*2] > 0 && cg.Variants[tagoffset*2+1] > 0 {
+                       coverage++
+               }
+       }
+       if coverage < cmd.minCoverage {
+               return nil, nil
+       }
+       obs := make([][]bool, (maxv+1)*2) // 2 slices (hom + het) for each variant#
+       for i := range obs {
+               obs[i] = make([]bool, len(cmd.cgnames))
+       }
+       for cgid, name := range cmd.cgnames {
+               cgvars := cgs[name].Variants
+               for v := tileVariantID(2); v <= maxv; v++ {
+                       if remap[cgvars[tagoffset*2]] == v && remap[cgvars[tagoffset*2+1]] == v {
+                               obs[v*2][cgid] = true
+                       } else if remap[cgvars[tagoffset*2]] == v || remap[cgvars[tagoffset*2+1]] == v {
+                               obs[v*2+1][cgid] = true
+                       }
+               }
+       }
+       var onehot [][]int8
+       var xref []onehotXref
+       for homcol := 4; homcol < len(obs); homcol += 2 {
+               p := [2]float64{
+                       pvalue(cmd.chi2Cases, obs[homcol]),
+                       pvalue(cmd.chi2Cases, obs[homcol+1]),
+               }
+               if cmd.chi2PValue < 1 && !(p[0] < cmd.chi2PValue || p[1] < cmd.chi2PValue) {
+                       continue
+               }
+               for het := 0; het < 2; het++ {
+                       onehot = append(onehot, bool2int8(obs[homcol+het]))
+                       xref = append(xref, onehotXref{
+                               tag:     tag,
+                               variant: tileVariantID(homcol / 2),
+                               het:     het == 1,
+                               pvalue:  p[het],
+                       })
+               }
+       }
+       return onehot, xref
+}
+
+func bool2int8(in []bool) []int8 {
+       out := make([]int8, len(in))
+       for i, v := range in {
+               if v {
+                       out[i] = 1
+               }
+       }
+       return out
+}