19524: Fit PCA to specified training set.
authorTom Clegg <tom@curii.com>
Wed, 2 Nov 2022 14:49:09 +0000 (10:49 -0400)
committerTom Clegg <tom@curii.com>
Wed, 2 Nov 2022 14:49:09 +0000 (10:49 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

slicenumpy.go

index bef164d72071092b0f07b3ec233e91806d3e59b5..b078bc124d9148055d7c77a3b49394981b197fc9 100644 (file)
@@ -45,6 +45,8 @@ type sliceNumpy struct {
        chi2CaseControlFile   string
        chi2Cases             []bool
        chi2PValue            float64
+       trainingSet           []int // see loadTrainingSet
+       trainingSetSize       int
        minCoverage           int
        cgnames               []string
        includeVariant1       bool
@@ -78,6 +80,7 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout,
        hgvsChunked := flags.Bool("chunked-hgvs-matrix", false, "also generate hgvs-based matrix per chromosome")
        onehotSingle := flags.Bool("single-onehot", false, "generate one-hot tile-based matrix")
        onehotChunked := flags.Bool("chunked-onehot", false, "generate one-hot tile-based matrix per input chunk")
+       trainingSetFilename := flags.String("training-set", "", "`tsv` file with sample IDs to be used for PCA fitting and Χ² test (if not provided, use all samples)")
        onlyPCA := flags.Bool("pca", false, "generate pca matrix")
        pcaComponents := flags.Int("pca-components", 4, "number of PCA components")
        maxPCATiles := flags.Int("max-pca-tiles", 0, "maximum tiles to use as PCA input (filter, then drop every 2nd colum pair until below max)")
@@ -118,7 +121,7 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout,
                        KeepCache:   2,
                        APIAccess:   true,
                }
-               err = runner.TranslatePaths(inputDir, regionsFilename, &cmd.chi2CaseControlFile)
+               err = runner.TranslatePaths(inputDir, regionsFilename, trainingSetFilename, &cmd.chi2CaseControlFile)
                if err != nil {
                        return err
                }
@@ -134,6 +137,7 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout,
                        "-chunked-hgvs-matrix=" + fmt.Sprintf("%v", *hgvsChunked),
                        "-single-onehot=" + fmt.Sprintf("%v", *onehotSingle),
                        "-chunked-onehot=" + fmt.Sprintf("%v", *onehotChunked),
+                       "-training-set=" + *trainingSetFilename,
                        "-pca=" + fmt.Sprintf("%v", *onlyPCA),
                        "-pca-components=" + fmt.Sprintf("%d", *pcaComponents),
                        "-max-pca-tiles=" + fmt.Sprintf("%d", *maxPCATiles),
@@ -232,6 +236,10 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout,
                err = fmt.Errorf("fatal: 0 cases, 0 controls, nothing to do")
                return err
        }
+       err = cmd.loadTrainingSet(*trainingSetFilename)
+       if err != nil {
+               return err
+       }
        if cmd.filter.MinCoverage == 1 {
                // In the generic formula below, floating point
                // arithmetic can effectively push the coverage
@@ -1128,18 +1136,23 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout,
                                cols = (cols + 1) / 2
                                stride = stride * 2
                        }
-                       log.Printf("creating matrix: %d rows, %d cols, stride %d", len(cmd.cgnames), cols, stride)
-                       mtx := mat.NewDense(len(cmd.cgnames), cols, nil)
+                       log.Printf("creating full matrix (%d rows) and training matrix (%d rows) with %d cols, stride %d", len(cmd.cgnames), cmd.trainingSetSize, cols, stride)
+                       mtxFull := mat.NewDense(len(cmd.cgnames), cols, nil)
+                       mtxTrain := mat.NewDense(cmd.trainingSetSize, cols, nil)
                        for i, c := range onehot[nzCount:] {
                                if int(c/2)%stride == 0 {
-                                       mtx.Set(int(onehot[i]), int(c/2)/stride*2+int(c)%2, 1)
+                                       outcol := int(c/2)/stride*2 + int(c)%2
+                                       mtxFull.Set(int(onehot[i]), outcol, 1)
+                                       if trainRow := cmd.trainingSet[int(onehot[i])]; trainRow >= 0 {
+                                               mtxTrain.Set(trainRow, outcol, 1)
+                                       }
                                }
                        }
                        log.Print("fitting")
                        transformer := nlp.NewPCA(*pcaComponents)
-                       transformer.Fit(mtx.T())
+                       transformer.Fit(mtxTrain.T())
                        log.Printf("transforming")
-                       pca, err := transformer.Transform(mtx.T())
+                       pca, err := transformer.Transform(mtxFull.T())
                        if err != nil {
                                return err
                        }
@@ -1199,6 +1212,74 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout,
        return nil
 }
 
+// Read training set file(s) from path (may be dir or file) and set up
+// cmd.trainingSet.
+//
+// cmd.trainingSet[i] == n >= 0 if cmd.cgnames[i] is the nth training
+// set sample.
+//
+// cmd.trainingSet[i] == -1 if cmd.cgnames[i] is not in the training
+// set.
+func (cmd *sliceNumpy) loadTrainingSet(path string) error {
+       cmd.trainingSet = make([]int, len(cmd.cgnames))
+       if path == "" {
+               cmd.trainingSetSize = len(cmd.cgnames)
+               for i := range cmd.trainingSet {
+                       cmd.trainingSet[i] = i
+               }
+               return nil
+       }
+       for i := range cmd.trainingSet {
+               cmd.trainingSet[i] = -1
+       }
+       infiles, err := allFiles(path, nil)
+       if err != nil {
+               return err
+       }
+       for _, infile := range infiles {
+               f, err := open(infile)
+               if err != nil {
+                       return err
+               }
+               buf, err := io.ReadAll(f)
+               f.Close()
+               if err != nil {
+                       return err
+               }
+               for _, tsv := range bytes.Split(buf, []byte{'\n'}) {
+                       if len(tsv) == 0 {
+                               continue
+                       }
+                       split := strings.Split(string(tsv), "\t")
+                       pattern := split[0]
+                       found := -1
+                       for i, name := range cmd.cgnames {
+                               if strings.Contains(name, pattern) {
+                                       if found >= 0 {
+                                               log.Warnf("pattern %q in %s already matched sample ID %q -- not using %q", pattern, infile, cmd.cgnames[found], name)
+                                       } else {
+                                               found = i
+                                               cmd.trainingSet[found] = 1
+                                       }
+                               }
+                       }
+                       if found < 0 {
+                               log.Warnf("pattern %q in %s does not match any genome IDs", pattern, infile)
+                               continue
+                       }
+               }
+       }
+       tsi := 0
+       for i, x := range cmd.trainingSet {
+               if x == 1 {
+                       cmd.trainingSet[i] = tsi
+                       tsi++
+               }
+       }
+       cmd.trainingSetSize = tsi + 1
+       return nil
+}
+
 // Read case/control files, remove non-case/control entries from
 // cmd.cgnames, and build cmd.chi2Cases.
 func (cmd *sliceNumpy) useCaseControlFiles() error {
@@ -1248,21 +1329,21 @@ func (cmd *sliceNumpy) useCaseControlFiles() error {
                        for i, name := range cmd.cgnames {
                                if strings.Contains(name, pattern) {
                                        if found >= 0 {
-                                               log.Warnf("pattern %q in %s matches multiple genome IDs (%qs, %q)", pattern, infile, cmd.cgnames[found], name)
+                                               log.Warnf("pattern %q in %s matches multiple genome IDs (%q, %q)", pattern, infile, cmd.cgnames[found], name)
                                        }
                                        found = i
+                                       if split[ccCol] == "0" {
+                                               cc[found] = false
+                                       }
+                                       if split[ccCol] == "1" {
+                                               cc[found] = true
+                                       }
                                }
                        }
                        if found < 0 {
                                log.Warnf("pattern %q in %s does not match any genome IDs", pattern, infile)
                                continue
                        }
-                       if split[ccCol] == "0" {
-                               cc[found] = false
-                       }
-                       if split[ccCol] == "1" {
-                               cc[found] = true
-                       }
                }
        }
        allnames := cmd.cgnames