choose-samples: training/validation set.
authorTom Clegg <tom@curii.com>
Mon, 7 Nov 2022 14:29:47 +0000 (09:29 -0500)
committerTom Clegg <tom@curii.com>
Mon, 7 Nov 2022 14:29:47 +0000 (09:29 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

choosesamples.go [new file with mode: 0644]
cmd.go
exportnumpy.go
slicenumpy.go

diff --git a/choosesamples.go b/choosesamples.go
new file mode 100644 (file)
index 0000000..ad56437
--- /dev/null
@@ -0,0 +1,295 @@
+// Copyright (C) The Lightning Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package lightning
+
+import (
+       "bytes"
+       "errors"
+       "flag"
+       "fmt"
+       "io"
+       "math/rand"
+       "net/http"
+       _ "net/http/pprof"
+       "os"
+       "regexp"
+       "sort"
+       "strings"
+
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       log "github.com/sirupsen/logrus"
+)
+
+type chooseSamples struct {
+       filter filter
+}
+
+func (cmd *chooseSamples) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
+       err := cmd.run(prog, args, stdin, stdout, stderr)
+       if err != nil {
+               fmt.Fprintf(stderr, "%s\n", err)
+               return 1
+       }
+       return 0
+}
+
+func (cmd *chooseSamples) run(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
+       flags := flag.NewFlagSet("", flag.ContinueOnError)
+       flags.SetOutput(stderr)
+       pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
+       runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
+       projectUUID := flags.String("project", "", "project `UUID` for output data")
+       priority := flags.Int("priority", 500, "container request priority")
+       inputDir := flags.String("input-dir", "./in", "input `directory`")
+       outputDir := flags.String("output-dir", "./out", "output `directory`")
+       trainingSetSize := flags.Float64("training-set-size", 0.8, "number (or proportion, if <=1) of eligible samples to assign to the training set")
+       caseControlFilename := flags.String("case-control-file", "", "tsv file or directory indicating cases and controls (if directory, all .tsv files will be read)")
+       caseControlColumn := flags.String("case-control-column", "", "name of case/control column in case-control files (value must be 0 for control, 1 for case)")
+       randSeed := flags.Int64("random-seed", 0, "PRNG seed")
+       cmd.filter.Flags(flags)
+       err := flags.Parse(args)
+       if err == flag.ErrHelp {
+               return nil
+       } else if err != nil {
+               return err
+       }
+       if *caseControlFilename == "" {
+               return errors.New("must provide -case-control-file")
+       }
+       if *caseControlColumn == "" {
+               return errors.New("must provide -case-control-column")
+       }
+
+       if *pprof != "" {
+               go func() {
+                       log.Println(http.ListenAndServe(*pprof, nil))
+               }()
+       }
+
+       if !*runlocal {
+               runner := arvadosContainerRunner{
+                       Name:        "lightning choose-samples",
+                       Client:      arvados.NewClientFromEnv(),
+                       ProjectUUID: *projectUUID,
+                       RAM:         16000000000,
+                       VCPUs:       4,
+                       Priority:    *priority,
+                       KeepCache:   2,
+                       APIAccess:   true,
+               }
+               err = runner.TranslatePaths(inputDir, caseControlFilename)
+               if err != nil {
+                       return err
+               }
+               runner.Args = []string{"choose-samples", "-local=true",
+                       "-pprof=:6060",
+                       "-input-dir=" + *inputDir,
+                       "-output-dir=/mnt/output",
+                       "-case-control-file=" + *caseControlFilename,
+                       "-case-control-column=" + *caseControlColumn,
+                       "-training-set-size=" + fmt.Sprintf("%f", *trainingSetSize),
+                       "-random-seed=" + fmt.Sprintf("%d", *randSeed),
+               }
+               runner.Args = append(runner.Args, cmd.filter.Args()...)
+               var output string
+               output, err = runner.Run()
+               if err != nil {
+                       return err
+               }
+               fmt.Fprintln(stdout, output)
+               return nil
+       }
+
+       infiles, err := allFiles(*inputDir, matchGobFile)
+       if err != nil {
+               return err
+       }
+       if len(infiles) == 0 {
+               err = fmt.Errorf("no input files found in %s", *inputDir)
+               return err
+       }
+       sort.Strings(infiles)
+
+       in0, err := open(infiles[0])
+       if err != nil {
+               return err
+       }
+
+       matchGenome, err := regexp.Compile(cmd.filter.MatchGenome)
+       if err != nil {
+               err = fmt.Errorf("-match-genome: invalid regexp: %q", cmd.filter.MatchGenome)
+               return err
+       }
+
+       var sampleIDs []string
+       err = DecodeLibrary(in0, strings.HasSuffix(infiles[0], ".gz"), func(ent *LibraryEntry) error {
+               for _, cg := range ent.CompactGenomes {
+                       if matchGenome.MatchString(cg.Name) {
+                               sampleIDs = append(sampleIDs, cg.Name)
+                       }
+               }
+               return nil
+       })
+       if err != nil {
+               return err
+       }
+       in0.Close()
+
+       if len(sampleIDs) == 0 {
+               err = fmt.Errorf("no genomes found matching regexp %q", cmd.filter.MatchGenome)
+               return err
+       }
+       sort.Strings(sampleIDs)
+       caseControl, err := cmd.loadCaseControlFiles(*caseControlFilename, *caseControlColumn, sampleIDs)
+       if err != nil {
+               return err
+       }
+       if len(caseControl) == 0 {
+               err = fmt.Errorf("fatal: 0 cases, 0 controls, nothing to do")
+               return err
+       }
+
+       var trainingSet, validationSet []int
+       for i := range caseControl {
+               trainingSet = append(trainingSet, i)
+       }
+       sort.Ints(trainingSet)
+       wantlen := int(*trainingSetSize)
+       if *trainingSetSize <= 1 {
+               wantlen = int(*trainingSetSize * float64(len(trainingSet)))
+       }
+       randsrc := rand.NewSource(*randSeed)
+       for tslen := len(trainingSet); tslen > wantlen; {
+               i := int(randsrc.Int63()) % tslen
+               validationSet = append(validationSet, trainingSet[i])
+               tslen--
+               trainingSet[i] = trainingSet[tslen]
+               trainingSet = trainingSet[:tslen]
+       }
+       sort.Ints(trainingSet)
+       sort.Ints(validationSet)
+
+       samplesFilename := *outputDir + "/samples.csv"
+       log.Infof("writing sample metadata to %s", samplesFilename)
+       var f *os.File
+       f, err = os.Create(samplesFilename)
+       if err != nil {
+               return err
+       }
+       defer f.Close()
+       _, err = fmt.Fprint(f, "Index,SampleID,CaseControl,TrainingValidation\n")
+       if err != nil {
+               return err
+       }
+       tsi := 0 // next idx in training set
+       vsi := 0 // next idx in validation set
+       for i, name := range sampleIDs {
+               var cc, tv string
+               if len(trainingSet) > tsi && trainingSet[tsi] == i {
+                       tv = "1"
+                       tsi++
+                       if caseControl[i] {
+                               cc = "1"
+                       } else {
+                               cc = "0"
+                       }
+               } else if len(validationSet) > vsi && validationSet[vsi] == i {
+                       tv = "0"
+                       vsi++
+                       if caseControl[i] {
+                               cc = "1"
+                       } else {
+                               cc = "0"
+                       }
+               }
+               _, err = fmt.Fprintf(f, "%d,%s,%s,%s\n", i, trimFilenameForLabel(name), cc, tv)
+               if err != nil {
+                       err = fmt.Errorf("write %s: %w", samplesFilename, err)
+                       return err
+               }
+       }
+       err = f.Close()
+       if err != nil {
+               err = fmt.Errorf("close %s: %w", samplesFilename, err)
+               return err
+       }
+       return nil
+}
+
+// Read case/control file(s). Returned map m has m[i]==true if
+// sampleIDs[i] is case, m[i]==false if sampleIDs[i] is control.
+func (cmd *chooseSamples) loadCaseControlFiles(path, colname string, sampleIDs []string) (map[int]bool, error) {
+       infiles, err := allFiles(path, nil)
+       if err != nil {
+               return nil, err
+       }
+       // index in sampleIDs => case(true) / control(false)
+       cc := map[int]bool{}
+       // index in sampleIDs => true if matched by multiple patterns in case/control files
+       dup := map[int]bool{}
+       for _, infile := range infiles {
+               f, err := open(infile)
+               if err != nil {
+                       return nil, err
+               }
+               buf, err := io.ReadAll(f)
+               f.Close()
+               if err != nil {
+                       return nil, err
+               }
+               ccCol := -1
+               for _, tsv := range bytes.Split(buf, []byte{'\n'}) {
+                       if len(tsv) == 0 {
+                               continue
+                       }
+                       split := strings.Split(string(tsv), "\t")
+                       if ccCol < 0 {
+                               // header row
+                               for col, name := range split {
+                                       if name == colname {
+                                               ccCol = col
+                                               break
+                                       }
+                               }
+                               if ccCol < 0 {
+                                       return nil, fmt.Errorf("%s: no column named %q in header row %q", infile, colname, tsv)
+                               }
+                               continue
+                       }
+                       if len(split) <= ccCol {
+                               continue
+                       }
+                       pattern := split[0]
+                       found := -1
+                       for i, name := range sampleIDs {
+                               if strings.Contains(name, pattern) {
+                                       if found >= 0 {
+                                               log.Warnf("pattern %q in %s matches multiple sample IDs (%q, %q)", pattern, infile, sampleIDs[found], name)
+                                       }
+                                       if dup[i] {
+                                               continue
+                                       } else if _, ok := cc[i]; ok {
+                                               log.Warnf("multiple patterns match sample ID %q, omitting from cases/controls", name)
+                                               dup[i] = true
+                                               delete(cc, i)
+                                               continue
+                                       }
+                                       found = i
+                                       if split[ccCol] == "0" {
+                                               cc[found] = false
+                                       }
+                                       if split[ccCol] == "1" {
+                                               cc[found] = true
+                                       }
+                               }
+                       }
+                       if found < 0 {
+                               log.Warnf("pattern %q in %s does not match any genome IDs", pattern, infile)
+                               continue
+                       }
+               }
+       }
+       return cc, nil
+}
diff --git a/cmd.go b/cmd.go
index 0ccc6bc690c2f8070a5983ed20dd03f3c1d7f5c7..2b34b0ccaa4e1434b8ebfb333e083ddf89debe2b 100644 (file)
--- a/cmd.go
+++ b/cmd.go
@@ -44,6 +44,7 @@ var (
                "merge":              &merger{},
                "dump":               &dump{},
                "dumpgob":            &dumpGob{},
+               "choose-samples":     &chooseSamples{},
        })
 )
 
@@ -58,6 +59,7 @@ func Main() {
                logrus.StandardLogger().Formatter = &logrus.TextFormatter{DisableTimestamp: true}
        }
        if len(os.Args) >= 2 && !strings.HasSuffix(os.Args[1], "version") {
+               // print version (then run subcommand)
                cmd.Version.RunCommand("lightning", nil, nil, os.Stderr, os.Stderr)
        }
        os.Exit(handler.RunCommand(os.Args[0], os.Args[1:], os.Stdin, os.Stdout, os.Stderr))
index 27c86a32107a98faf1952eaffe5fae477faa7213..050f3161bfaac798016802884856f93a6f37903c 100644 (file)
@@ -568,5 +568,6 @@ func trimFilenameForLabel(s string) string {
        s = strings.TrimSuffix(s, ".2")
        s = strings.TrimSuffix(s, ".gz")
        s = strings.TrimSuffix(s, ".vcf")
+       s = strings.Replace(s, ",", "-", -1)
        return s
 }
index b078bc124d9148055d7c77a3b49394981b197fc9..28267502a2442759a4e8253577a093891448c4fa 100644 (file)
@@ -61,6 +61,7 @@ func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, s
        }
        return 0
 }
+
 func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
        flags := flag.NewFlagSet("", flag.ContinueOnError)
        flags.SetOutput(stderr)