X-Git-Url: https://git.arvados.org/lightning.git/blobdiff_plain/ac59d25dd8766ac37b889d64ecbf4771d06e80df..4a5362276b3ad274a02853c2181b43cf607a11d4:/exportnumpy.go diff --git a/exportnumpy.go b/exportnumpy.go index 21928b54a0..fd198116b1 100644 --- a/exportnumpy.go +++ b/exportnumpy.go @@ -95,19 +95,8 @@ func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader, return 1 } sort.Slice(cgs, func(i, j int) bool { return cgs[i].Name < cgs[j].Name }) - cols := 0 - for _, cg := range cgs { - if cols < len(cg.Variants) { - cols = len(cg.Variants) - } - } - rows := len(cgs) - out := make([]uint16, rows*cols) - for row, cg := range cgs { - for i, v := range cg.Variants { - out[row*cols+i] = uint16(v) - } - } + + out, rows, cols := cgs2array(cgs) var output io.WriteCloser if *outputFilename == "-" { @@ -125,13 +114,10 @@ func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader, return 1 } if *onehot { - out, cols := recodeOnehot(out, cols) - npw.Shape = []int{rows, cols} - npw.WriteUint8(out) - } else { - npw.Shape = []int{rows, cols} - npw.WriteUint16(out) + out, cols = recodeOnehot(out, cols) } + npw.Shape = []int{rows, cols} + npw.WriteUint16(out) err = bufw.Flush() if err != nil { return 1 @@ -143,7 +129,23 @@ func (cmd *exportNumpy) RunCommand(prog string, args []string, stdin io.Reader, return 0 } -func recodeOnehot(in []uint16, incols int) ([]uint8, int) { +func cgs2array(cgs []CompactGenome) (data []uint16, rows, cols int) { + rows = len(cgs) + for _, cg := range cgs { + if cols < len(cg.Variants) { + cols = len(cg.Variants) + } + } + data = make([]uint16, rows*cols) + for row, cg := range cgs { + for i, v := range cg.Variants { + data[row*cols+i] = uint16(v) + } + } + return +} + +func recodeOnehot(in []uint16, incols int) ([]uint16, int) { rows := len(in) / incols maxvalue := make([]uint16, incols) for row := 0; row < rows; row++ { @@ -159,7 +161,7 @@ func recodeOnehot(in []uint16, incols int) ([]uint8, int) { outcol[incol] = outcols outcols += int(v) } - out := make([]uint8, rows*outcols) + out := make([]uint16, rows*outcols) for row := 0; row < rows; row++ { for col := 0; col < incols; col++ { if v := in[row*incols+col]; v > 0 {