--- /dev/null
+// Copyright (C) The Lightning Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package lightning
+
+import (
+ "bytes"
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "math/rand"
+ "net/http"
+ _ "net/http/pprof"
+ "os"
+ "regexp"
+ "sort"
+ "strings"
+
+ "git.arvados.org/arvados.git/sdk/go/arvados"
+ log "github.com/sirupsen/logrus"
+)
+
+type chooseSamples struct {
+ filter filter
+}
+
+func (cmd *chooseSamples) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
+ err := cmd.run(prog, args, stdin, stdout, stderr)
+ if err != nil {
+ fmt.Fprintf(stderr, "%s\n", err)
+ return 1
+ }
+ return 0
+}
+
+func (cmd *chooseSamples) run(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
+ flags := flag.NewFlagSet("", flag.ContinueOnError)
+ flags.SetOutput(stderr)
+ pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
+ runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
+ projectUUID := flags.String("project", "", "project `UUID` for output data")
+ priority := flags.Int("priority", 500, "container request priority")
+ inputDir := flags.String("input-dir", "./in", "input `directory`")
+ outputDir := flags.String("output-dir", "./out", "output `directory`")
+ trainingSetSize := flags.Float64("training-set-size", 0.8, "number (or proportion, if <=1) of eligible samples to assign to the training set")
+ caseControlFilename := flags.String("case-control-file", "", "tsv file or directory indicating cases and controls (if directory, all .tsv files will be read)")
+ caseControlColumn := flags.String("case-control-column", "", "name of case/control column in case-control files (value must be 0 for control, 1 for case)")
+ randSeed := flags.Int64("random-seed", 0, "PRNG seed")
+ cmd.filter.Flags(flags)
+ err := flags.Parse(args)
+ if err == flag.ErrHelp {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ if *caseControlFilename == "" {
+ return errors.New("must provide -case-control-file")
+ }
+ if *caseControlColumn == "" {
+ return errors.New("must provide -case-control-column")
+ }
+
+ if *pprof != "" {
+ go func() {
+ log.Println(http.ListenAndServe(*pprof, nil))
+ }()
+ }
+
+ if !*runlocal {
+ runner := arvadosContainerRunner{
+ Name: "lightning choose-samples",
+ Client: arvados.NewClientFromEnv(),
+ ProjectUUID: *projectUUID,
+ RAM: 16000000000,
+ VCPUs: 4,
+ Priority: *priority,
+ KeepCache: 2,
+ APIAccess: true,
+ }
+ err = runner.TranslatePaths(inputDir, caseControlFilename)
+ if err != nil {
+ return err
+ }
+ runner.Args = []string{"choose-samples", "-local=true",
+ "-pprof=:6060",
+ "-input-dir=" + *inputDir,
+ "-output-dir=/mnt/output",
+ "-case-control-file=" + *caseControlFilename,
+ "-case-control-column=" + *caseControlColumn,
+ "-training-set-size=" + fmt.Sprintf("%f", *trainingSetSize),
+ "-random-seed=" + fmt.Sprintf("%d", *randSeed),
+ }
+ runner.Args = append(runner.Args, cmd.filter.Args()...)
+ var output string
+ output, err = runner.Run()
+ if err != nil {
+ return err
+ }
+ fmt.Fprintln(stdout, output)
+ return nil
+ }
+
+ infiles, err := allFiles(*inputDir, matchGobFile)
+ if err != nil {
+ return err
+ }
+ if len(infiles) == 0 {
+ err = fmt.Errorf("no input files found in %s", *inputDir)
+ return err
+ }
+ sort.Strings(infiles)
+
+ in0, err := open(infiles[0])
+ if err != nil {
+ return err
+ }
+
+ matchGenome, err := regexp.Compile(cmd.filter.MatchGenome)
+ if err != nil {
+ err = fmt.Errorf("-match-genome: invalid regexp: %q", cmd.filter.MatchGenome)
+ return err
+ }
+
+ var sampleIDs []string
+ err = DecodeLibrary(in0, strings.HasSuffix(infiles[0], ".gz"), func(ent *LibraryEntry) error {
+ for _, cg := range ent.CompactGenomes {
+ if matchGenome.MatchString(cg.Name) {
+ sampleIDs = append(sampleIDs, cg.Name)
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+ in0.Close()
+
+ if len(sampleIDs) == 0 {
+ err = fmt.Errorf("no genomes found matching regexp %q", cmd.filter.MatchGenome)
+ return err
+ }
+ sort.Strings(sampleIDs)
+ caseControl, err := cmd.loadCaseControlFiles(*caseControlFilename, *caseControlColumn, sampleIDs)
+ if err != nil {
+ return err
+ }
+ if len(caseControl) == 0 {
+ err = fmt.Errorf("fatal: 0 cases, 0 controls, nothing to do")
+ return err
+ }
+
+ var trainingSet, validationSet []int
+ for i := range caseControl {
+ trainingSet = append(trainingSet, i)
+ }
+ sort.Ints(trainingSet)
+ wantlen := int(*trainingSetSize)
+ if *trainingSetSize <= 1 {
+ wantlen = int(*trainingSetSize * float64(len(trainingSet)))
+ }
+ randsrc := rand.NewSource(*randSeed)
+ for tslen := len(trainingSet); tslen > wantlen; {
+ i := int(randsrc.Int63()) % tslen
+ validationSet = append(validationSet, trainingSet[i])
+ tslen--
+ trainingSet[i] = trainingSet[tslen]
+ trainingSet = trainingSet[:tslen]
+ }
+ sort.Ints(trainingSet)
+ sort.Ints(validationSet)
+
+ samplesFilename := *outputDir + "/samples.csv"
+ log.Infof("writing sample metadata to %s", samplesFilename)
+ var f *os.File
+ f, err = os.Create(samplesFilename)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ _, err = fmt.Fprint(f, "Index,SampleID,CaseControl,TrainingValidation\n")
+ if err != nil {
+ return err
+ }
+ tsi := 0 // next idx in training set
+ vsi := 0 // next idx in validation set
+ for i, name := range sampleIDs {
+ var cc, tv string
+ if len(trainingSet) > tsi && trainingSet[tsi] == i {
+ tv = "1"
+ tsi++
+ if caseControl[i] {
+ cc = "1"
+ } else {
+ cc = "0"
+ }
+ } else if len(validationSet) > vsi && validationSet[vsi] == i {
+ tv = "0"
+ vsi++
+ if caseControl[i] {
+ cc = "1"
+ } else {
+ cc = "0"
+ }
+ }
+ _, err = fmt.Fprintf(f, "%d,%s,%s,%s\n", i, trimFilenameForLabel(name), cc, tv)
+ if err != nil {
+ err = fmt.Errorf("write %s: %w", samplesFilename, err)
+ return err
+ }
+ }
+ err = f.Close()
+ if err != nil {
+ err = fmt.Errorf("close %s: %w", samplesFilename, err)
+ return err
+ }
+ return nil
+}
+
+// Read case/control file(s). Returned map m has m[i]==true if
+// sampleIDs[i] is case, m[i]==false if sampleIDs[i] is control.
+func (cmd *chooseSamples) loadCaseControlFiles(path, colname string, sampleIDs []string) (map[int]bool, error) {
+ infiles, err := allFiles(path, nil)
+ if err != nil {
+ return nil, err
+ }
+ // index in sampleIDs => case(true) / control(false)
+ cc := map[int]bool{}
+ // index in sampleIDs => true if matched by multiple patterns in case/control files
+ dup := map[int]bool{}
+ for _, infile := range infiles {
+ f, err := open(infile)
+ if err != nil {
+ return nil, err
+ }
+ buf, err := io.ReadAll(f)
+ f.Close()
+ if err != nil {
+ return nil, err
+ }
+ ccCol := -1
+ for _, tsv := range bytes.Split(buf, []byte{'\n'}) {
+ if len(tsv) == 0 {
+ continue
+ }
+ split := strings.Split(string(tsv), "\t")
+ if ccCol < 0 {
+ // header row
+ for col, name := range split {
+ if name == colname {
+ ccCol = col
+ break
+ }
+ }
+ if ccCol < 0 {
+ return nil, fmt.Errorf("%s: no column named %q in header row %q", infile, colname, tsv)
+ }
+ continue
+ }
+ if len(split) <= ccCol {
+ continue
+ }
+ pattern := split[0]
+ found := -1
+ for i, name := range sampleIDs {
+ if strings.Contains(name, pattern) {
+ if found >= 0 {
+ log.Warnf("pattern %q in %s matches multiple sample IDs (%q, %q)", pattern, infile, sampleIDs[found], name)
+ }
+ if dup[i] {
+ continue
+ } else if _, ok := cc[i]; ok {
+ log.Warnf("multiple patterns match sample ID %q, omitting from cases/controls", name)
+ dup[i] = true
+ delete(cc, i)
+ continue
+ }
+ found = i
+ if split[ccCol] == "0" {
+ cc[found] = false
+ }
+ if split[ccCol] == "1" {
+ cc[found] = true
+ }
+ }
+ }
+ if found < 0 {
+ log.Warnf("pattern %q in %s does not match any genome IDs", pattern, infile)
+ continue
+ }
+ }
+ }
+ return cc, nil
+}