Output -single-onehot as coordinates of non-zero values.
authorTom Clegg <tom@curii.com>
Thu, 27 Jan 2022 05:31:02 +0000 (00:31 -0500)
committerTom Clegg <tom@curii.com>
Thu, 27 Jan 2022 05:31:02 +0000 (00:31 -0500)
refs #18581

Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

slice_test.go
slicenumpy.go

index be89a764c084964e15fee5bd10d57ca3d9f80744..3bc3d8b82c9286dfda3008a3b01f8662d760fcd6 100644 (file)
@@ -317,17 +317,15 @@ pipeline1dup/input2       0
                defer f.Close()
                npy, err := gonpy.NewReader(f)
                c.Assert(err, check.IsNil)
-               c.Check(npy.Shape, check.DeepEquals, []int{4, 16})
-               onehot, err := npy.GetInt8()
+               c.Check(npy.Shape, check.DeepEquals, []int{2, 16})
+               onehot, err := npy.GetUint32()
                if c.Check(err, check.IsNil) {
                        for r := 0; r < npy.Shape[0]; r++ {
                                c.Logf("%v", onehot[r*npy.Shape[1]:(r+1)*npy.Shape[1]])
                        }
-                       c.Check(onehot, check.DeepEquals, []int8{
-                               0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, // input1
-                               0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, // input2
-                               0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, // dup/input1
-                               0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, // dup/input2
+                       c.Check(onehot, check.DeepEquals, []uint32{
+                               0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 0, 2,
+                               1, 1, 2, 2, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15,
                        })
                }
        }
index 5ef777bf21cd820a61ad944036cd860f0135b456..9e5e315c50a7ec1054bc6ec96499bd2309291558 100644 (file)
@@ -23,6 +23,7 @@ import (
        "strconv"
        "strings"
        "sync/atomic"
+       "unsafe"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "github.com/arvados/lightning/hgvs"
@@ -316,10 +317,10 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        if *mergeOutput || *hgvsSingle {
                toMerge = make([][]int16, len(infiles))
        }
-       var onehotChunks [][][]int8
+       var onehotIndirect [][2][]uint32 // [chunkIndex][axis][index]
        var onehotXrefs [][]onehotXref
        if *onehotSingle {
-               onehotChunks = make([][][]int8, len(infiles))
+               onehotIndirect = make([][2][]uint32, len(infiles))
                onehotXrefs = make([][]onehotXref, len(infiles))
        }
 
@@ -595,7 +596,8 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                throttleNumpyMem.Release()
                        }
                        if *onehotSingle {
-                               onehotChunks[infileIdx] = onehotChunk
+                               log.Infof("%04d: keeping onehot chunk in memory (rows=%d, cols=%d, mem=%d)", infileIdx, len(cmd.cgnames), len(onehotChunk), (len(cmd.cgnames)+int(onehotXrefSize))*len(onehotChunk))
+                               onehotIndirect[infileIdx] = onehotChunk2Indirect(onehotChunk)
                                onehotXrefs[infileIdx] = onehotXref
                        }
                        if !(*onehotSingle || *onehotChunked) || *mergeOutput || *hgvsSingle {
@@ -886,16 +888,29 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                }
        }
        if *onehotSingle {
-               var onehot [][]int8
+               nzCount := 0
+               for _, part := range onehotIndirect {
+                       nzCount += len(part[0])
+               }
+               onehot := make([]uint32, nzCount*2) // [r,r,r,...,c,c,c,...]
                var xrefs []onehotXref
-               for i := range onehotChunks {
-                       onehot = append(onehot, onehotChunks[i]...)
-                       onehotChunks[i] = nil
+               outcol := 0
+               for i, part := range onehotIndirect {
+                       for i := range part[1] {
+                               part[1][i] += uint32(outcol)
+                       }
+                       copy(onehot[outcol:], part[0])
+                       copy(onehot[outcol+nzCount:], part[1])
+                       outcol += len(part[0])
                        xrefs = append(xrefs, onehotXrefs[i]...)
+
+                       part[0] = nil
+                       part[1] = nil
                        onehotXrefs[i] = nil
+                       debug.FreeOSMemory()
                }
                fnm := fmt.Sprintf("%s/onehot.npy", *outputDir)
-               err = writeNumpyInt8(fnm, onehotcols2int8(onehot), len(cmd.cgnames), len(onehot))
+               err = writeNumpyUint32(fnm, onehot, 2, nzCount)
                if err != nil {
                        return 1
                }
@@ -1010,6 +1025,32 @@ func (cmd *sliceNumpy) filterHGVScolpair(colpair [2][]int8) bool {
                (pvalue(col0, cases) <= cmd.chi2PValue || pvalue(col1, cases) <= cmd.chi2PValue)
 }
 
+func writeNumpyUint32(fnm string, out []uint32, rows, cols int) error {
+       output, err := os.Create(fnm)
+       if err != nil {
+               return err
+       }
+       defer output.Close()
+       bufw := bufio.NewWriterSize(output, 1<<26)
+       npw, err := gonpy.NewWriter(nopCloser{bufw})
+       if err != nil {
+               return err
+       }
+       log.WithFields(log.Fields{
+               "filename": fnm,
+               "rows":     rows,
+               "cols":     cols,
+               "bytes":    rows * cols * 4,
+       }).Infof("writing numpy: %s", fnm)
+       npw.Shape = []int{rows, cols}
+       npw.WriteUint32(out)
+       err = bufw.Flush()
+       if err != nil {
+               return err
+       }
+       return output.Close()
+}
+
 func writeNumpyInt32(fnm string, out []int32, rows, cols int) error {
        output, err := os.Create(fnm)
        if err != nil {
@@ -1115,6 +1156,8 @@ type onehotXref struct {
        pvalue  float64
 }
 
+const onehotXrefSize = unsafe.Sizeof(onehotXref{})
+
 // Build onehot matrix (m[variant*2+isHet][genome] == 0 or 1) for all
 // variants of a single tile/tag#.
 //
@@ -1218,3 +1261,18 @@ func onehotcols2int8(in [][]int8) []int8 {
        }
        return out
 }
+
+// Return [2][]uint32{rowIndices, colIndices} indicating which
+// elements of matrixT[c][r] have non-zero values.
+func onehotChunk2Indirect(matrixT [][]int8) [2][]uint32 {
+       var nz [2][]uint32
+       for c, col := range matrixT {
+               for r, val := range col {
+                       if val != 0 {
+                               nz[0] = append(nz[0], uint32(r))
+                               nz[1] = append(nz[1], uint32(c))
+                       }
+               }
+       }
+       return nz
+}