From: Tom Clegg Date: Tue, 11 Oct 2022 14:07:14 +0000 (-0400) Subject: 19524: Output PCA. X-Git-Url: https://git.arvados.org/lightning.git/commitdiff_plain/b015928f71399084e7b691016936839c5b174753?hp=837a48d358fc4f5f0fb04e1878223b7d47443695 19524: Output PCA. Arvados-DCO-1.1-Signed-off-by: Tom Clegg --- diff --git a/gob.go b/gob.go index 0bd1873961..70a5d2811a 100644 --- a/gob.go +++ b/gob.go @@ -59,17 +59,20 @@ func DecodeLibrary(rdr io.Reader, gz bool, cb func(*LibraryEntry) error) error { if err != nil { return err } + defer zrdr.Close() } dec := gob.NewDecoder(zrdr) - for err == nil { + for { var ent LibraryEntry err = dec.Decode(&ent) - if err == nil { - err = cb(&ent) + if err == io.EOF { + return zrdr.Close() + } else if err != nil { + return err + } + err = cb(&ent) + if err != nil { + return err } } - if err != io.EOF { - return err - } - return zrdr.Close() } diff --git a/pca.go b/pca.go index 121925d4c5..22efc904fb 100644 --- a/pca.go +++ b/pca.go @@ -199,7 +199,7 @@ func (cmd *goPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout if *outputFilename == "-" { output = nopCloser{stdout} } else { - output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_WRONLY, 0777) + output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777) if err != nil { return 1 } diff --git a/slicenumpy.go b/slicenumpy.go index 7931cbede9..c5e03d8ed2 100644 --- a/slicenumpy.go +++ b/slicenumpy.go @@ -28,10 +28,12 @@ import ( "git.arvados.org/arvados.git/sdk/go/arvados" "github.com/arvados/lightning/hgvs" + "github.com/james-bowman/nlp" "github.com/kshedden/gonpy" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "golang.org/x/crypto/blake2b" + "gonum.org/v1/gonum/mat" ) const annotationMaxTileSpan = 100 @@ -50,12 +52,14 @@ type sliceNumpy struct { } func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int { - var err error - defer func() { - if err != nil { - fmt.Fprintf(stderr, "%s\n", err) - } - }() + err := cmd.run(prog, args, stdin, stdout, stderr) + if err != nil { + fmt.Fprintf(stderr, "%s\n", err) + return 1 + } + return 0 +} +func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("", flag.ContinueOnError) flags.SetOutput(stderr) pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`") @@ -73,6 +77,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s onehotSingle := flags.Bool("single-onehot", false, "generate one-hot tile-based matrix") onehotChunked := flags.Bool("chunked-onehot", false, "generate one-hot tile-based matrix per input chunk") onlyPCA := flags.Bool("pca", false, "generate pca matrix") + pcaComponents := flags.Int("pca-components", 4, "number of PCA components") debugTag := flags.Int("debug-tag", -1, "log debugging details about specified tag") flags.IntVar(&cmd.threads, "threads", 16, "number of memory-hungry assembly threads") flags.StringVar(&cmd.chi2CaseControlFile, "chi2-case-control-file", "", "tsv file or directory indicating cases and controls for Χ² test (if directory, all .tsv files will be read)") @@ -80,12 +85,11 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s flags.Float64Var(&cmd.chi2PValue, "chi2-p-value", 1, "do Χ² test and omit columns with p-value above this threshold") flags.BoolVar(&cmd.includeVariant1, "include-variant-1", false, "include most common variant when building one-hot matrix") cmd.filter.Flags(flags) - err = flags.Parse(args) + err := flags.Parse(args) if err == flag.ErrHelp { - err = nil - return 0 + return nil } else if err != nil { - return 2 + return err } if *pprof != "" { @@ -95,8 +99,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s } if cmd.chi2PValue != 1 && (cmd.chi2CaseControlFile == "" || cmd.chi2CaseControlColumn == "") { - log.Errorf("cannot use provided -chi2-p-value=%f because -chi2-case-control-file= or -chi2-case-control-column= value is empty", cmd.chi2PValue) - return 2 + return fmt.Errorf("cannot use provided -chi2-p-value=%f because -chi2-case-control-file= or -chi2-case-control-column= value is empty", cmd.chi2PValue) } cmd.debugTag = tagID(*debugTag) @@ -114,7 +117,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s } err = runner.TranslatePaths(inputDir, regionsFilename, &cmd.chi2CaseControlFile) if err != nil { - return 1 + return err } runner.Args = []string{"slice-numpy", "-local=true", "-pprof=:6060", @@ -139,19 +142,19 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s var output string output, err = runner.Run() if err != nil { - return 1 + return err } fmt.Fprintln(stdout, output) - return 0 + return nil } infiles, err := allFiles(*inputDir, matchGobFile) if err != nil { - return 1 + return err } if len(infiles) == 0 { err = fmt.Errorf("no input files found in %s", *inputDir) - return 1 + return err } sort.Strings(infiles) @@ -159,13 +162,13 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s var reftiledata = make(map[tileLibRef][]byte, 11000000) in0, err := open(infiles[0]) if err != nil { - return 1 + return err } matchGenome, err := regexp.Compile(cmd.filter.MatchGenome) if err != nil { err = fmt.Errorf("-match-genome: invalid regexp: %q", cmd.filter.MatchGenome) - return 1 + return err } cmd.cgnames = nil @@ -192,37 +195,37 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s return nil }) if err != nil { - return 1 + return err } in0.Close() if refseq == nil { err = fmt.Errorf("%s: reference sequence not found", infiles[0]) - return 1 + return err } if len(tagset) == 0 { err = fmt.Errorf("tagset not found") - return 1 + return err } taglib := &tagLibrary{} err = taglib.setTags(tagset) if err != nil { - return 1 + return err } taglen := taglib.TagLen() if len(cmd.cgnames) == 0 { err = fmt.Errorf("no genomes found matching regexp %q", cmd.filter.MatchGenome) - return 1 + return err } sort.Strings(cmd.cgnames) err = cmd.useCaseControlFiles() if err != nil { - return 1 + return err } if len(cmd.cgnames) == 0 { err = fmt.Errorf("fatal: 0 cases, 0 controls, nothing to do") - return 1 + return err } if cmd.filter.MinCoverage == 1 { // In the generic formula below, floating point @@ -240,7 +243,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s var f *os.File f, err = os.Create(labelsFilename) if err != nil { - return 1 + return err } defer f.Close() for i, name := range cmd.cgnames { @@ -251,13 +254,13 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s _, err = fmt.Fprintf(f, "%d,%q,%d\n", i, trimFilenameForLabel(name), cc) if err != nil { err = fmt.Errorf("write %s: %w", labelsFilename, err) - return 1 + return err } } err = f.Close() if err != nil { err = fmt.Errorf("close %s: %w", labelsFilename, err) - return 1 + return err } } @@ -282,7 +285,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s tiledata := reftiledata[libref] if len(tiledata) == 0 { err = fmt.Errorf("missing tiledata for tag %d variant %d in %s in ref", libref.Tag, libref.Variant, seqname) - return 1 + return err } foundthistag := false taglib.FindAll(tiledata[:len(tiledata)-1], func(tagid tagID, offset, _ int) { @@ -328,7 +331,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s log.Printf("loading regions from %s", *regionsFilename) mask, err = makeMask(*regionsFilename, *expandRegions) if err != nil { - return 1 + return err } log.Printf("before applying mask, len(reftile) == %d", len(reftile)) log.Printf("deleting reftile entries for regions outside %d intervals", mask.Len()) @@ -349,7 +352,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s var f *os.File f, err = os.Create(*outputDir + "/tmp." + seqname + ".gob") if err != nil { - return 1 + return err } defer os.Remove(f.Name()) bufw := bufio.NewWriterSize(f, 1<<24) @@ -824,7 +827,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s }) } if err = throttleMem.Wait(); err != nil { - return 1 + return err } if *hgvsChunked { @@ -834,14 +837,14 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s } err = encodeHGVS.Wait() if err != nil { - return 1 + return err } for seqname := range refseq { log.Infof("%s: reading hgvsCols from temp file", seqname) f := tmpHGVSCols[seqname] _, err = f.Seek(0, io.SeekStart) if err != nil { - return 1 + return err } var hgvsCols hgvsColSet dec := gob.NewDecoder(bufio.NewReaderSize(f, 1<<24)) @@ -849,7 +852,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s err = dec.Decode(&hgvsCols) } if err != io.EOF { - return 1 + return err } log.Infof("%s: sorting %d hgvs variants", seqname, len(hgvsCols)) variants := make([]hgvs.Variant, 0, len(hgvsCols)) @@ -880,7 +883,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s } err = writeNumpyInt8(fmt.Sprintf("%s/hgvs.%s.npy", *outputDir, seqname), out, rows, cols) if err != nil { - return 1 + return err } out = nil @@ -892,7 +895,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s } err = ioutil.WriteFile(fnm, hgvsLabels.Bytes(), 0666) if err != nil { - return 1 + return err } } } @@ -904,7 +907,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s annoFilename := fmt.Sprintf("%s/matrix.annotations.csv", *outputDir) annof, err = os.Create(annoFilename) if err != nil { - return 1 + return err } annow = bufio.NewWriterSize(annof, 1<<20) } @@ -934,12 +937,12 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s log.Infof("reading %s", annotationsFilename) buf, err := os.ReadFile(annotationsFilename) if err != nil { - return 1 + return err } if *mergeOutput { err = os.Remove(annotationsFilename) if err != nil { - return 1 + return err } } for _, line := range bytes.Split(buf, []byte{'\n'}) { @@ -982,7 +985,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s rt, ok := reftile[tagID(tag)] if !ok { err = fmt.Errorf("bug: seeing annotations for tag %d, but it has no reftile entry", tag) - return 1 + return err } for ph := 0; ph < 2; ph++ { for row := 0; row < rows; row++ { @@ -1022,15 +1025,15 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s if *mergeOutput { err = annow.Flush() if err != nil { - return 1 + return err } err = annof.Close() if err != nil { - return 1 + return err } err = writeNumpyInt16(fmt.Sprintf("%s/matrix.npy", *outputDir), out, rows, cols) if err != nil { - return 1 + return err } } out = nil @@ -1056,18 +1059,18 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s } err = writeNumpyInt16(fmt.Sprintf("%s/hgvs.npy", *outputDir), out, rows, cols) if err != nil { - return 1 + return err } fnm := fmt.Sprintf("%s/hgvs.annotations.csv", *outputDir) log.Printf("writing hgvs labels: %s", fnm) err = ioutil.WriteFile(fnm, hgvsLabels.Bytes(), 0777) if err != nil { - return 1 + return err } } } - if *onehotSingle { + if *onehotSingle || *onlyPCA { nzCount := 0 for _, part := range onehotIndirect { nzCount += len(part[0]) @@ -1092,19 +1095,71 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s onehotXrefs[i] = nil debug.FreeOSMemory() } - fnm := fmt.Sprintf("%s/onehot.npy", *outputDir) - err = writeNumpyUint32(fnm, onehot, 2, nzCount) - if err != nil { - return 1 + if *onehotSingle { + fnm := fmt.Sprintf("%s/onehot.npy", *outputDir) + err = writeNumpyUint32(fnm, onehot, 2, nzCount) + if err != nil { + return err + } + fnm = fmt.Sprintf("%s/onehot-columns.npy", *outputDir) + err = writeNumpyInt32(fnm, onehotXref2int32(xrefs), 5, len(xrefs)) + if err != nil { + return err + } } - fnm = fmt.Sprintf("%s/onehot-columns.npy", *outputDir) - err = writeNumpyInt32(fnm, onehotXref2int32(xrefs), 5, len(xrefs)) - if err != nil { - return 1 + if *onlyPCA { + cols := 0 + for _, c := range onehot[nzCount:] { + if int(c) >= cols { + cols = int(c) + 1 + } + } + if cols == 0 { + return fmt.Errorf("cannot do PCA: one-hot matrix is empty") + } + log.Printf("creating matrix: %d rows, %d cols", len(cmd.cgnames), cols) + mtx := mat.NewDense(len(cmd.cgnames), cols, nil) + for i, c := range onehot[nzCount:] { + mtx.Set(int(onehot[i]), int(c), 1) + } + log.Print("fitting") + transformer := nlp.NewPCA(*pcaComponents) + transformer.Fit(mtx.T()) + log.Printf("transforming") + pca, err := transformer.Transform(mtx.T()) + if err != nil { + return err + } + pca = pca.T() + outrows, outcols := pca.Dims() + log.Printf("copying result to numpy output array: %d rows, %d cols", outrows, outcols) + out := make([]float64, outrows*outcols) + for i := 0; i < outrows; i++ { + for j := 0; j < outcols; j++ { + out[i*outcols+j] = pca.At(i, j) + } + } + fnm := fmt.Sprintf("%s/pca.npy", *outputDir) + log.Printf("writing numpy: %s", fnm) + output, err := os.OpenFile(fnm, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777) + if err != nil { + return err + } + npw, err := gonpy.NewWriter(nopCloser{output}) + if err != nil { + return fmt.Errorf("gonpy.NewWriter: %w", err) + } + npw.Shape = []int{outrows, outcols} + err = npw.WriteFloat64(out) + if err != nil { + return fmt.Errorf("WriteFloat64: %w", err) + } + err = output.Close() + if err != nil { + return err + } + log.Print("done") } - } - if *onlyPCA { - } if !*mergeOutput && !*onehotChunked && !*onehotSingle && !*onlyPCA { tagoffsetFilename := *outputDir + "/chunk-tag-offset.csv" @@ -1112,23 +1167,23 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s var f *os.File f, err = os.Create(tagoffsetFilename) if err != nil { - return 1 + return err } defer f.Close() for idx, offset := range chunkStartTag { _, err = fmt.Fprintf(f, "%q,%d\n", fmt.Sprintf("matrix.%04d.npy", idx), offset) if err != nil { err = fmt.Errorf("write %s: %w", tagoffsetFilename, err) - return 1 + return err } } err = f.Close() if err != nil { err = fmt.Errorf("close %s: %w", tagoffsetFilename, err) - return 1 + return err } } - return 0 + return nil } // Read case/control files, remove non-case/control entries from