19524: Output PCA.
authorTom Clegg <tom@curii.com>
Tue, 11 Oct 2022 14:07:14 +0000 (10:07 -0400)
committerTom Clegg <tom@curii.com>
Tue, 11 Oct 2022 14:07:14 +0000 (10:07 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

gob.go
pca.go
slicenumpy.go

diff --git a/gob.go b/gob.go
index 0bd18739612c3b14dc13ff847731479372166100..70a5d2811a7e94bf9646863ff0283d271f8beb5b 100644 (file)
--- a/gob.go
+++ b/gob.go
@@ -59,17 +59,20 @@ func DecodeLibrary(rdr io.Reader, gz bool, cb func(*LibraryEntry) error) error {
                if err != nil {
                        return err
                }
+               defer zrdr.Close()
        }
        dec := gob.NewDecoder(zrdr)
-       for err == nil {
+       for {
                var ent LibraryEntry
                err = dec.Decode(&ent)
-               if err == nil {
-                       err = cb(&ent)
+               if err == io.EOF {
+                       return zrdr.Close()
+               } else if err != nil {
+                       return err
+               }
+               err = cb(&ent)
+               if err != nil {
+                       return err
                }
        }
-       if err != io.EOF {
-               return err
-       }
-       return zrdr.Close()
 }
diff --git a/pca.go b/pca.go
index 121925d4c59d9833d37aabb397fa3a1bb0cc2243..22efc904fb887a776ea671f6b7f143f06adcd8c7 100644 (file)
--- a/pca.go
+++ b/pca.go
@@ -199,7 +199,7 @@ func (cmd *goPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout
        if *outputFilename == "-" {
                output = nopCloser{stdout}
        } else {
-               output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_WRONLY, 0777)
+               output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
                if err != nil {
                        return 1
                }
index 7931cbede9a77402bd30489634f6ee000477a08c..c5e03d8ed2f412ca2acd912030ef3f6ea4b40513 100644 (file)
@@ -28,10 +28,12 @@ import (
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "github.com/arvados/lightning/hgvs"
+       "github.com/james-bowman/nlp"
        "github.com/kshedden/gonpy"
        "github.com/sirupsen/logrus"
        log "github.com/sirupsen/logrus"
        "golang.org/x/crypto/blake2b"
+       "gonum.org/v1/gonum/mat"
 )
 
 const annotationMaxTileSpan = 100
@@ -50,12 +52,14 @@ type sliceNumpy struct {
 }
 
 func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
-       var err error
-       defer func() {
-               if err != nil {
-                       fmt.Fprintf(stderr, "%s\n", err)
-               }
-       }()
+       err := cmd.run(prog, args, stdin, stdout, stderr)
+       if err != nil {
+               fmt.Fprintf(stderr, "%s\n", err)
+               return 1
+       }
+       return 0
+}
+func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
        flags := flag.NewFlagSet("", flag.ContinueOnError)
        flags.SetOutput(stderr)
        pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
@@ -73,6 +77,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        onehotSingle := flags.Bool("single-onehot", false, "generate one-hot tile-based matrix")
        onehotChunked := flags.Bool("chunked-onehot", false, "generate one-hot tile-based matrix per input chunk")
        onlyPCA := flags.Bool("pca", false, "generate pca matrix")
+       pcaComponents := flags.Int("pca-components", 4, "number of PCA components")
        debugTag := flags.Int("debug-tag", -1, "log debugging details about specified tag")
        flags.IntVar(&cmd.threads, "threads", 16, "number of memory-hungry assembly threads")
        flags.StringVar(&cmd.chi2CaseControlFile, "chi2-case-control-file", "", "tsv file or directory indicating cases and controls for Χ² test (if directory, all .tsv files will be read)")
@@ -80,12 +85,11 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        flags.Float64Var(&cmd.chi2PValue, "chi2-p-value", 1, "do Χ² test and omit columns with p-value above this threshold")
        flags.BoolVar(&cmd.includeVariant1, "include-variant-1", false, "include most common variant when building one-hot matrix")
        cmd.filter.Flags(flags)
-       err = flags.Parse(args)
+       err := flags.Parse(args)
        if err == flag.ErrHelp {
-               err = nil
-               return 0
+               return nil
        } else if err != nil {
-               return 2
+               return err
        }
 
        if *pprof != "" {
@@ -95,8 +99,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        }
 
        if cmd.chi2PValue != 1 && (cmd.chi2CaseControlFile == "" || cmd.chi2CaseControlColumn == "") {
-               log.Errorf("cannot use provided -chi2-p-value=%f because -chi2-case-control-file= or -chi2-case-control-column= value is empty", cmd.chi2PValue)
-               return 2
+               return fmt.Errorf("cannot use provided -chi2-p-value=%f because -chi2-case-control-file= or -chi2-case-control-column= value is empty", cmd.chi2PValue)
        }
 
        cmd.debugTag = tagID(*debugTag)
@@ -114,7 +117,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                }
                err = runner.TranslatePaths(inputDir, regionsFilename, &cmd.chi2CaseControlFile)
                if err != nil {
-                       return 1
+                       return err
                }
                runner.Args = []string{"slice-numpy", "-local=true",
                        "-pprof=:6060",
@@ -139,19 +142,19 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                var output string
                output, err = runner.Run()
                if err != nil {
-                       return 1
+                       return err
                }
                fmt.Fprintln(stdout, output)
-               return 0
+               return nil
        }
 
        infiles, err := allFiles(*inputDir, matchGobFile)
        if err != nil {
-               return 1
+               return err
        }
        if len(infiles) == 0 {
                err = fmt.Errorf("no input files found in %s", *inputDir)
-               return 1
+               return err
        }
        sort.Strings(infiles)
 
@@ -159,13 +162,13 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        var reftiledata = make(map[tileLibRef][]byte, 11000000)
        in0, err := open(infiles[0])
        if err != nil {
-               return 1
+               return err
        }
 
        matchGenome, err := regexp.Compile(cmd.filter.MatchGenome)
        if err != nil {
                err = fmt.Errorf("-match-genome: invalid regexp: %q", cmd.filter.MatchGenome)
-               return 1
+               return err
        }
 
        cmd.cgnames = nil
@@ -192,37 +195,37 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                return nil
        })
        if err != nil {
-               return 1
+               return err
        }
        in0.Close()
        if refseq == nil {
                err = fmt.Errorf("%s: reference sequence not found", infiles[0])
-               return 1
+               return err
        }
        if len(tagset) == 0 {
                err = fmt.Errorf("tagset not found")
-               return 1
+               return err
        }
 
        taglib := &tagLibrary{}
        err = taglib.setTags(tagset)
        if err != nil {
-               return 1
+               return err
        }
        taglen := taglib.TagLen()
 
        if len(cmd.cgnames) == 0 {
                err = fmt.Errorf("no genomes found matching regexp %q", cmd.filter.MatchGenome)
-               return 1
+               return err
        }
        sort.Strings(cmd.cgnames)
        err = cmd.useCaseControlFiles()
        if err != nil {
-               return 1
+               return err
        }
        if len(cmd.cgnames) == 0 {
                err = fmt.Errorf("fatal: 0 cases, 0 controls, nothing to do")
-               return 1
+               return err
        }
        if cmd.filter.MinCoverage == 1 {
                // In the generic formula below, floating point
@@ -240,7 +243,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                var f *os.File
                f, err = os.Create(labelsFilename)
                if err != nil {
-                       return 1
+                       return err
                }
                defer f.Close()
                for i, name := range cmd.cgnames {
@@ -251,13 +254,13 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        _, err = fmt.Fprintf(f, "%d,%q,%d\n", i, trimFilenameForLabel(name), cc)
                        if err != nil {
                                err = fmt.Errorf("write %s: %w", labelsFilename, err)
-                               return 1
+                               return err
                        }
                }
                err = f.Close()
                if err != nil {
                        err = fmt.Errorf("close %s: %w", labelsFilename, err)
-                       return 1
+                       return err
                }
        }
 
@@ -282,7 +285,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        tiledata := reftiledata[libref]
                        if len(tiledata) == 0 {
                                err = fmt.Errorf("missing tiledata for tag %d variant %d in %s in ref", libref.Tag, libref.Variant, seqname)
-                               return 1
+                               return err
                        }
                        foundthistag := false
                        taglib.FindAll(tiledata[:len(tiledata)-1], func(tagid tagID, offset, _ int) {
@@ -328,7 +331,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                log.Printf("loading regions from %s", *regionsFilename)
                mask, err = makeMask(*regionsFilename, *expandRegions)
                if err != nil {
-                       return 1
+                       return err
                }
                log.Printf("before applying mask, len(reftile) == %d", len(reftile))
                log.Printf("deleting reftile entries for regions outside %d intervals", mask.Len())
@@ -349,7 +352,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        var f *os.File
                        f, err = os.Create(*outputDir + "/tmp." + seqname + ".gob")
                        if err != nil {
-                               return 1
+                               return err
                        }
                        defer os.Remove(f.Name())
                        bufw := bufio.NewWriterSize(f, 1<<24)
@@ -824,7 +827,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                })
        }
        if err = throttleMem.Wait(); err != nil {
-               return 1
+               return err
        }
 
        if *hgvsChunked {
@@ -834,14 +837,14 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                }
                err = encodeHGVS.Wait()
                if err != nil {
-                       return 1
+                       return err
                }
                for seqname := range refseq {
                        log.Infof("%s: reading hgvsCols from temp file", seqname)
                        f := tmpHGVSCols[seqname]
                        _, err = f.Seek(0, io.SeekStart)
                        if err != nil {
-                               return 1
+                               return err
                        }
                        var hgvsCols hgvsColSet
                        dec := gob.NewDecoder(bufio.NewReaderSize(f, 1<<24))
@@ -849,7 +852,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                err = dec.Decode(&hgvsCols)
                        }
                        if err != io.EOF {
-                               return 1
+                               return err
                        }
                        log.Infof("%s: sorting %d hgvs variants", seqname, len(hgvsCols))
                        variants := make([]hgvs.Variant, 0, len(hgvsCols))
@@ -880,7 +883,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        }
                        err = writeNumpyInt8(fmt.Sprintf("%s/hgvs.%s.npy", *outputDir, seqname), out, rows, cols)
                        if err != nil {
-                               return 1
+                               return err
                        }
                        out = nil
 
@@ -892,7 +895,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        }
                        err = ioutil.WriteFile(fnm, hgvsLabels.Bytes(), 0666)
                        if err != nil {
-                               return 1
+                               return err
                        }
                }
        }
@@ -904,7 +907,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        annoFilename := fmt.Sprintf("%s/matrix.annotations.csv", *outputDir)
                        annof, err = os.Create(annoFilename)
                        if err != nil {
-                               return 1
+                               return err
                        }
                        annow = bufio.NewWriterSize(annof, 1<<20)
                }
@@ -934,12 +937,12 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        log.Infof("reading %s", annotationsFilename)
                        buf, err := os.ReadFile(annotationsFilename)
                        if err != nil {
-                               return 1
+                               return err
                        }
                        if *mergeOutput {
                                err = os.Remove(annotationsFilename)
                                if err != nil {
-                                       return 1
+                                       return err
                                }
                        }
                        for _, line := range bytes.Split(buf, []byte{'\n'}) {
@@ -982,7 +985,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                                        rt, ok := reftile[tagID(tag)]
                                        if !ok {
                                                err = fmt.Errorf("bug: seeing annotations for tag %d, but it has no reftile entry", tag)
-                                               return 1
+                                               return err
                                        }
                                        for ph := 0; ph < 2; ph++ {
                                                for row := 0; row < rows; row++ {
@@ -1022,15 +1025,15 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                if *mergeOutput {
                        err = annow.Flush()
                        if err != nil {
-                               return 1
+                               return err
                        }
                        err = annof.Close()
                        if err != nil {
-                               return 1
+                               return err
                        }
                        err = writeNumpyInt16(fmt.Sprintf("%s/matrix.npy", *outputDir), out, rows, cols)
                        if err != nil {
-                               return 1
+                               return err
                        }
                }
                out = nil
@@ -1056,18 +1059,18 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        }
                        err = writeNumpyInt16(fmt.Sprintf("%s/hgvs.npy", *outputDir), out, rows, cols)
                        if err != nil {
-                               return 1
+                               return err
                        }
 
                        fnm := fmt.Sprintf("%s/hgvs.annotations.csv", *outputDir)
                        log.Printf("writing hgvs labels: %s", fnm)
                        err = ioutil.WriteFile(fnm, hgvsLabels.Bytes(), 0777)
                        if err != nil {
-                               return 1
+                               return err
                        }
                }
        }
-       if *onehotSingle {
+       if *onehotSingle || *onlyPCA {
                nzCount := 0
                for _, part := range onehotIndirect {
                        nzCount += len(part[0])
@@ -1092,19 +1095,71 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                        onehotXrefs[i] = nil
                        debug.FreeOSMemory()
                }
-               fnm := fmt.Sprintf("%s/onehot.npy", *outputDir)
-               err = writeNumpyUint32(fnm, onehot, 2, nzCount)
-               if err != nil {
-                       return 1
+               if *onehotSingle {
+                       fnm := fmt.Sprintf("%s/onehot.npy", *outputDir)
+                       err = writeNumpyUint32(fnm, onehot, 2, nzCount)
+                       if err != nil {
+                               return err
+                       }
+                       fnm = fmt.Sprintf("%s/onehot-columns.npy", *outputDir)
+                       err = writeNumpyInt32(fnm, onehotXref2int32(xrefs), 5, len(xrefs))
+                       if err != nil {
+                               return err
+                       }
                }
-               fnm = fmt.Sprintf("%s/onehot-columns.npy", *outputDir)
-               err = writeNumpyInt32(fnm, onehotXref2int32(xrefs), 5, len(xrefs))
-               if err != nil {
-                       return 1
+               if *onlyPCA {
+                       cols := 0
+                       for _, c := range onehot[nzCount:] {
+                               if int(c) >= cols {
+                                       cols = int(c) + 1
+                               }
+                       }
+                       if cols == 0 {
+                               return fmt.Errorf("cannot do PCA: one-hot matrix is empty")
+                       }
+                       log.Printf("creating matrix: %d rows, %d cols", len(cmd.cgnames), cols)
+                       mtx := mat.NewDense(len(cmd.cgnames), cols, nil)
+                       for i, c := range onehot[nzCount:] {
+                               mtx.Set(int(onehot[i]), int(c), 1)
+                       }
+                       log.Print("fitting")
+                       transformer := nlp.NewPCA(*pcaComponents)
+                       transformer.Fit(mtx.T())
+                       log.Printf("transforming")
+                       pca, err := transformer.Transform(mtx.T())
+                       if err != nil {
+                               return err
+                       }
+                       pca = pca.T()
+                       outrows, outcols := pca.Dims()
+                       log.Printf("copying result to numpy output array: %d rows, %d cols", outrows, outcols)
+                       out := make([]float64, outrows*outcols)
+                       for i := 0; i < outrows; i++ {
+                               for j := 0; j < outcols; j++ {
+                                       out[i*outcols+j] = pca.At(i, j)
+                               }
+                       }
+                       fnm := fmt.Sprintf("%s/pca.npy", *outputDir)
+                       log.Printf("writing numpy: %s", fnm)
+                       output, err := os.OpenFile(fnm, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
+                       if err != nil {
+                               return err
+                       }
+                       npw, err := gonpy.NewWriter(nopCloser{output})
+                       if err != nil {
+                               return fmt.Errorf("gonpy.NewWriter: %w", err)
+                       }
+                       npw.Shape = []int{outrows, outcols}
+                       err = npw.WriteFloat64(out)
+                       if err != nil {
+                               return fmt.Errorf("WriteFloat64: %w", err)
+                       }
+                       err = output.Close()
+                       if err != nil {
+                               return err
+                       }
+                       log.Print("done")
                }
-       }
-       if *onlyPCA {
-
        }
        if !*mergeOutput && !*onehotChunked && !*onehotSingle && !*onlyPCA {
                tagoffsetFilename := *outputDir + "/chunk-tag-offset.csv"
@@ -1112,23 +1167,23 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
                var f *os.File
                f, err = os.Create(tagoffsetFilename)
                if err != nil {
-                       return 1
+                       return err
                }
                defer f.Close()
                for idx, offset := range chunkStartTag {
                        _, err = fmt.Fprintf(f, "%q,%d\n", fmt.Sprintf("matrix.%04d.npy", idx), offset)
                        if err != nil {
                                err = fmt.Errorf("write %s: %w", tagoffsetFilename, err)
-                               return 1
+                               return err
                        }
                }
                err = f.Close()
                if err != nil {
                        err = fmt.Errorf("close %s: %w", tagoffsetFilename, err)
-                       return 1
+                       return err
                }
        }
-       return 0
+       return nil
 }
 
 // Read case/control files, remove non-case/control entries from