// Copyright (C) The Lightning Authors. All rights reserved. // // SPDX-License-Identifier: AGPL-3.0 package lightning import ( "bytes" "errors" "flag" "fmt" "io" "math/rand" "net/http" _ "net/http/pprof" "os" "regexp" "sort" "strings" "git.arvados.org/arvados.git/sdk/go/arvados" log "github.com/sirupsen/logrus" ) type chooseSamples struct { filter filter } func (cmd *chooseSamples) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int { err := cmd.run(prog, args, stdin, stdout, stderr) if err != nil { fmt.Fprintf(stderr, "%s\n", err) return 1 } return 0 } func (cmd *chooseSamples) run(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("", flag.ContinueOnError) flags.SetOutput(stderr) pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`") runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)") projectUUID := flags.String("project", "", "project `UUID` for output data") priority := flags.Int("priority", 500, "container request priority") inputDir := flags.String("input-dir", "./in", "input `directory`") outputDir := flags.String("output-dir", "./out", "output `directory`") trainingSetSize := flags.Float64("training-set-size", 0.8, "number (or proportion, if <=1) of eligible samples to assign to the training set") caseControlFilename := flags.String("case-control-file", "", "tsv file or directory indicating cases and controls (if directory, all .tsv files will be read)") caseControlColumn := flags.String("case-control-column", "", "name of case/control column in case-control files (value must be 0 for control, 1 for case)") randSeed := flags.Int64("random-seed", 0, "PRNG seed") cmd.filter.Flags(flags) err := flags.Parse(args) if err == flag.ErrHelp { return nil } else if err != nil { return err } if *caseControlFilename == "" { return errors.New("must provide -case-control-file") } if *caseControlColumn == "" { return errors.New("must provide -case-control-column") } if *pprof != "" { go func() { log.Println(http.ListenAndServe(*pprof, nil)) }() } if !*runlocal { runner := arvadosContainerRunner{ Name: "lightning choose-samples", Client: arvados.NewClientFromEnv(), ProjectUUID: *projectUUID, RAM: 16000000000, VCPUs: 4, Priority: *priority, KeepCache: 2, APIAccess: true, } err = runner.TranslatePaths(inputDir, caseControlFilename) if err != nil { return err } runner.Args = []string{"choose-samples", "-local=true", "-pprof=:6060", "-input-dir=" + *inputDir, "-output-dir=/mnt/output", "-case-control-file=" + *caseControlFilename, "-case-control-column=" + *caseControlColumn, "-training-set-size=" + fmt.Sprintf("%f", *trainingSetSize), "-random-seed=" + fmt.Sprintf("%d", *randSeed), } runner.Args = append(runner.Args, cmd.filter.Args()...) var output string output, err = runner.Run() if err != nil { return err } fmt.Fprintln(stdout, output) return nil } infiles, err := allFiles(*inputDir, matchGobFile) if err != nil { return err } if len(infiles) == 0 { err = fmt.Errorf("no input files found in %s", *inputDir) return err } sort.Strings(infiles) in0, err := open(infiles[0]) if err != nil { return err } matchGenome, err := regexp.Compile(cmd.filter.MatchGenome) if err != nil { err = fmt.Errorf("-match-genome: invalid regexp: %q", cmd.filter.MatchGenome) return err } var sampleIDs []string err = DecodeLibrary(in0, strings.HasSuffix(infiles[0], ".gz"), func(ent *LibraryEntry) error { for _, cg := range ent.CompactGenomes { if matchGenome.MatchString(cg.Name) { sampleIDs = append(sampleIDs, cg.Name) } } return nil }) if err != nil { return err } in0.Close() if len(sampleIDs) == 0 { err = fmt.Errorf("no genomes found matching regexp %q", cmd.filter.MatchGenome) return err } sort.Strings(sampleIDs) caseControl, err := cmd.loadCaseControlFiles(*caseControlFilename, *caseControlColumn, sampleIDs) if err != nil { return err } if len(caseControl) == 0 { err = fmt.Errorf("fatal: 0 cases, 0 controls, nothing to do") return err } var trainingSet, validationSet []int for i := range caseControl { trainingSet = append(trainingSet, i) } sort.Ints(trainingSet) wantlen := int(*trainingSetSize) if *trainingSetSize <= 1 { wantlen = int(*trainingSetSize * float64(len(trainingSet))) } randsrc := rand.NewSource(*randSeed) for tslen := len(trainingSet); tslen > wantlen; { i := int(randsrc.Int63()) % tslen validationSet = append(validationSet, trainingSet[i]) tslen-- trainingSet[i] = trainingSet[tslen] trainingSet = trainingSet[:tslen] } sort.Ints(trainingSet) sort.Ints(validationSet) samplesFilename := *outputDir + "/samples.csv" log.Infof("writing sample metadata to %s", samplesFilename) var f *os.File f, err = os.Create(samplesFilename) if err != nil { return err } defer f.Close() _, err = fmt.Fprint(f, "Index,SampleID,CaseControl,TrainingValidation\n") if err != nil { return err } tsi := 0 // next idx in training set vsi := 0 // next idx in validation set for i, name := range sampleIDs { var cc, tv string if len(trainingSet) > tsi && trainingSet[tsi] == i { tv = "1" tsi++ if caseControl[i] { cc = "1" } else { cc = "0" } } else if len(validationSet) > vsi && validationSet[vsi] == i { tv = "0" vsi++ if caseControl[i] { cc = "1" } else { cc = "0" } } _, err = fmt.Fprintf(f, "%d,%s,%s,%s\n", i, trimFilenameForLabel(name), cc, tv) if err != nil { err = fmt.Errorf("write %s: %w", samplesFilename, err) return err } } err = f.Close() if err != nil { err = fmt.Errorf("close %s: %w", samplesFilename, err) return err } return nil } // Read case/control file(s). Returned map m has m[i]==true if // sampleIDs[i] is case, m[i]==false if sampleIDs[i] is control. func (cmd *chooseSamples) loadCaseControlFiles(path, colname string, sampleIDs []string) (map[int]bool, error) { infiles, err := allFiles(path, nil) if err != nil { return nil, err } // index in sampleIDs => case(true) / control(false) cc := map[int]bool{} // index in sampleIDs => true if matched by multiple patterns in case/control files dup := map[int]bool{} for _, infile := range infiles { f, err := open(infile) if err != nil { return nil, err } buf, err := io.ReadAll(f) f.Close() if err != nil { return nil, err } ccCol := -1 for _, tsv := range bytes.Split(buf, []byte{'\n'}) { if len(tsv) == 0 { continue } split := strings.Split(string(tsv), "\t") if ccCol < 0 { // header row for col, name := range split { if name == colname { ccCol = col break } } if ccCol < 0 { return nil, fmt.Errorf("%s: no column named %q in header row %q", infile, colname, tsv) } continue } if len(split) <= ccCol { continue } pattern := split[0] found := -1 for i, name := range sampleIDs { if strings.Contains(name, pattern) { if found >= 0 { log.Warnf("pattern %q in %s matches multiple sample IDs (%q, %q)", pattern, infile, sampleIDs[found], name) } if dup[i] { continue } else if _, ok := cc[i]; ok { log.Warnf("multiple patterns match sample ID %q, omitting from cases/controls", name) dup[i] = true delete(cc, i) continue } found = i if split[ccCol] == "0" { cc[found] = false } if split[ccCol] == "1" { cc[found] = true } } } if found < 0 { log.Warnf("pattern %q in %s does not match any genome IDs", pattern, infile) continue } } } return cc, nil }