Log dimensions.
[lightning.git] / exportnumpy.go
index 21928b54a0cac722e7f15949e8adfd23b2b7d6ae..fd198116b1512c01f6fb51db67bc82354edd40d5 100644 (file)
@@ -95,19 +95,8 @@ func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader,
                return 1
        }
        sort.Slice(cgs, func(i, j int) bool { return cgs[i].Name < cgs[j].Name })
-       cols := 0
-       for _, cg := range cgs {
-               if cols < len(cg.Variants) {
-                       cols = len(cg.Variants)
-               }
-       }
-       rows := len(cgs)
-       out := make([]uint16, rows*cols)
-       for row, cg := range cgs {
-               for i, v := range cg.Variants {
-                       out[row*cols+i] = uint16(v)
-               }
-       }
+
+       out, rows, cols := cgs2array(cgs)
 
        var output io.WriteCloser
        if *outputFilename == "-" {
@@ -125,13 +114,10 @@ func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader,
                return 1
        }
        if *onehot {
-               out, cols := recodeOnehot(out, cols)
-               npw.Shape = []int{rows, cols}
-               npw.WriteUint8(out)
-       } else {
-               npw.Shape = []int{rows, cols}
-               npw.WriteUint16(out)
+               out, cols = recodeOnehot(out, cols)
        }
+       npw.Shape = []int{rows, cols}
+       npw.WriteUint16(out)
        err = bufw.Flush()
        if err != nil {
                return 1
@@ -143,7 +129,23 @@ func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader,
        return 0
 }
 
-func recodeOnehot(in []uint16, incols int) ([]uint8, int) {
+func cgs2array(cgs []CompactGenome) (data []uint16, rows, cols int) {
+       rows = len(cgs)
+       for _, cg := range cgs {
+               if cols < len(cg.Variants) {
+                       cols = len(cg.Variants)
+               }
+       }
+       data = make([]uint16, rows*cols)
+       for row, cg := range cgs {
+               for i, v := range cg.Variants {
+                       data[row*cols+i] = uint16(v)
+               }
+       }
+       return
+}
+
+func recodeOnehot(in []uint16, incols int) ([]uint16, int) {
        rows := len(in) / incols
        maxvalue := make([]uint16, incols)
        for row := 0; row < rows; row++ {
@@ -159,7 +161,7 @@ func recodeOnehot(in []uint16, incols int) ([]uint8, int) {
                outcol[incol] = outcols
                outcols += int(v)
        }
-       out := make([]uint8, rows*outcols)
+       out := make([]uint16, rows*outcols)
        for row := 0; row < rows; row++ {
                for col := 0; col < incols; col++ {
                        if v := in[row*incols+col]; v > 0 {