19566: Merge branch 'main'
[lightning.git] / choosesamples.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bytes"
9         "errors"
10         "flag"
11         "fmt"
12         "io"
13         "math/rand"
14         "net/http"
15         _ "net/http/pprof"
16         "os"
17         "regexp"
18         "sort"
19         "strings"
20
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         log "github.com/sirupsen/logrus"
23 )
24
25 type chooseSamples struct {
26         filter filter
27 }
28
29 func (cmd *chooseSamples) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
30         err := cmd.run(prog, args, stdin, stdout, stderr)
31         if err != nil {
32                 fmt.Fprintf(stderr, "%s\n", err)
33                 return 1
34         }
35         return 0
36 }
37
38 func (cmd *chooseSamples) run(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
39         flags := flag.NewFlagSet("", flag.ContinueOnError)
40         flags.SetOutput(stderr)
41         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
42         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
43         projectUUID := flags.String("project", "", "project `UUID` for output data")
44         priority := flags.Int("priority", 500, "container request priority")
45         inputDir := flags.String("input-dir", "./in", "input `directory`")
46         outputDir := flags.String("output-dir", "./out", "output `directory`")
47         trainingSetSize := flags.Float64("training-set-size", 0.8, "number (or proportion, if <=1) of eligible samples to assign to the training set")
48         caseControlFilename := flags.String("case-control-file", "", "tsv file or directory indicating cases and controls (if directory, all .tsv files will be read)")
49         caseControlColumn := flags.String("case-control-column", "", "name of case/control column in case-control files (value must be 0 for control, 1 for case)")
50         randSeed := flags.Int64("random-seed", 0, "PRNG seed")
51         cmd.filter.Flags(flags)
52         err := flags.Parse(args)
53         if err == flag.ErrHelp {
54                 return nil
55         } else if err != nil {
56                 return err
57         } else if flags.NArg() > 0 {
58                 return fmt.Errorf("errant command line arguments after parsed flags: %v", flags.Args())
59         }
60         if (*caseControlFilename == "") != (*caseControlColumn == "") {
61                 return errors.New("must provide both -case-control-file and -case-control-column, or neither")
62         }
63
64         if *pprof != "" {
65                 go func() {
66                         log.Println(http.ListenAndServe(*pprof, nil))
67                 }()
68         }
69
70         if !*runlocal {
71                 runner := arvadosContainerRunner{
72                         Name:        "lightning choose-samples",
73                         Client:      arvados.NewClientFromEnv(),
74                         ProjectUUID: *projectUUID,
75                         RAM:         16000000000,
76                         VCPUs:       4,
77                         Priority:    *priority,
78                         KeepCache:   2,
79                         APIAccess:   true,
80                 }
81                 err = runner.TranslatePaths(inputDir, caseControlFilename)
82                 if err != nil {
83                         return err
84                 }
85                 runner.Args = []string{"choose-samples", "-local=true",
86                         "-pprof=:6060",
87                         "-input-dir=" + *inputDir,
88                         "-output-dir=/mnt/output",
89                         "-case-control-file=" + *caseControlFilename,
90                         "-case-control-column=" + *caseControlColumn,
91                         "-training-set-size=" + fmt.Sprintf("%f", *trainingSetSize),
92                         "-random-seed=" + fmt.Sprintf("%d", *randSeed),
93                 }
94                 runner.Args = append(runner.Args, cmd.filter.Args()...)
95                 var output string
96                 output, err = runner.Run()
97                 if err != nil {
98                         return err
99                 }
100                 fmt.Fprintln(stdout, output)
101                 return nil
102         }
103
104         infiles, err := allFiles(*inputDir, matchGobFile)
105         if err != nil {
106                 return err
107         }
108         if len(infiles) == 0 {
109                 err = fmt.Errorf("no input files found in %s", *inputDir)
110                 return err
111         }
112         sort.Strings(infiles)
113
114         in0, err := open(infiles[0])
115         if err != nil {
116                 return err
117         }
118
119         matchGenome, err := regexp.Compile(cmd.filter.MatchGenome)
120         if err != nil {
121                 err = fmt.Errorf("-match-genome: invalid regexp: %q", cmd.filter.MatchGenome)
122                 return err
123         }
124
125         var sampleIDs []string
126         err = DecodeLibrary(in0, strings.HasSuffix(infiles[0], ".gz"), func(ent *LibraryEntry) error {
127                 for _, cg := range ent.CompactGenomes {
128                         if matchGenome.MatchString(cg.Name) {
129                                 sampleIDs = append(sampleIDs, cg.Name)
130                         }
131                 }
132                 return nil
133         })
134         if err != nil {
135                 return err
136         }
137         in0.Close()
138
139         if len(sampleIDs) == 0 {
140                 err = fmt.Errorf("no genomes found matching regexp %q", cmd.filter.MatchGenome)
141                 return err
142         }
143         sort.Strings(sampleIDs)
144         caseControl, err := cmd.loadCaseControlFiles(*caseControlFilename, *caseControlColumn, sampleIDs)
145         if err != nil {
146                 return err
147         }
148         if len(caseControl) == 0 {
149                 err = fmt.Errorf("fatal: 0 cases, 0 controls, nothing to do")
150                 return err
151         }
152
153         var trainingSet, validationSet []int
154         for i := range caseControl {
155                 trainingSet = append(trainingSet, i)
156         }
157         sort.Ints(trainingSet)
158         wantlen := int(*trainingSetSize)
159         if *trainingSetSize <= 1 {
160                 wantlen = int(*trainingSetSize * float64(len(trainingSet)))
161         }
162         randsrc := rand.NewSource(*randSeed)
163         for tslen := len(trainingSet); tslen > wantlen; {
164                 i := int(randsrc.Int63()) % tslen
165                 validationSet = append(validationSet, trainingSet[i])
166                 tslen--
167                 trainingSet[i] = trainingSet[tslen]
168                 trainingSet = trainingSet[:tslen]
169         }
170         sort.Ints(trainingSet)
171         sort.Ints(validationSet)
172
173         samplesFilename := *outputDir + "/samples.csv"
174         log.Infof("writing sample metadata to %s", samplesFilename)
175         var f *os.File
176         f, err = os.Create(samplesFilename)
177         if err != nil {
178                 return err
179         }
180         defer f.Close()
181         _, err = fmt.Fprint(f, "Index,SampleID,CaseControl,TrainingValidation\n")
182         if err != nil {
183                 return err
184         }
185         tsi := 0 // next idx in training set
186         vsi := 0 // next idx in validation set
187         for i, name := range sampleIDs {
188                 var cc, tv string
189                 if len(trainingSet) > tsi && trainingSet[tsi] == i {
190                         tv = "1"
191                         tsi++
192                         if caseControl[i] {
193                                 cc = "1"
194                         } else {
195                                 cc = "0"
196                         }
197                 } else if len(validationSet) > vsi && validationSet[vsi] == i {
198                         tv = "0"
199                         vsi++
200                         if caseControl[i] {
201                                 cc = "1"
202                         } else {
203                                 cc = "0"
204                         }
205                 }
206                 _, err = fmt.Fprintf(f, "%d,%s,%s,%s\n", i, trimFilenameForLabel(name), cc, tv)
207                 if err != nil {
208                         err = fmt.Errorf("write %s: %w", samplesFilename, err)
209                         return err
210                 }
211         }
212         err = f.Close()
213         if err != nil {
214                 err = fmt.Errorf("close %s: %w", samplesFilename, err)
215                 return err
216         }
217         return nil
218 }
219
220 // Read case/control file(s). Returned map m has m[i]==true if
221 // sampleIDs[i] is case, m[i]==false if sampleIDs[i] is control.
222 func (cmd *chooseSamples) loadCaseControlFiles(path, colname string, sampleIDs []string) (map[int]bool, error) {
223         if path == "" {
224                 // all samples are control group
225                 cc := make(map[int]bool, len(sampleIDs))
226                 for i := range sampleIDs {
227                         cc[i] = false
228                 }
229                 return cc, nil
230         }
231         infiles, err := allFiles(path, nil)
232         if err != nil {
233                 return nil, err
234         }
235         // index in sampleIDs => case(true) / control(false)
236         cc := map[int]bool{}
237         // index in sampleIDs => true if matched by multiple patterns in case/control files
238         dup := map[int]bool{}
239         for _, infile := range infiles {
240                 f, err := open(infile)
241                 if err != nil {
242                         return nil, err
243                 }
244                 buf, err := io.ReadAll(f)
245                 f.Close()
246                 if err != nil {
247                         return nil, err
248                 }
249                 ccCol := -1
250                 for _, tsv := range bytes.Split(buf, []byte{'\n'}) {
251                         if len(tsv) == 0 {
252                                 continue
253                         }
254                         split := strings.Split(string(tsv), "\t")
255                         if ccCol < 0 {
256                                 // header row
257                                 for col, name := range split {
258                                         if name == colname {
259                                                 ccCol = col
260                                                 break
261                                         }
262                                 }
263                                 if ccCol < 0 {
264                                         return nil, fmt.Errorf("%s: no column named %q in header row %q", infile, colname, tsv)
265                                 }
266                                 continue
267                         }
268                         if len(split) <= ccCol {
269                                 continue
270                         }
271                         pattern := split[0]
272                         found := -1
273                         for i, name := range sampleIDs {
274                                 if strings.Contains(name, pattern) {
275                                         if found >= 0 {
276                                                 log.Warnf("pattern %q in %s matches multiple sample IDs (%q, %q)", pattern, infile, sampleIDs[found], name)
277                                         }
278                                         if dup[i] {
279                                                 continue
280                                         } else if _, ok := cc[i]; ok {
281                                                 log.Warnf("multiple patterns match sample ID %q, omitting from cases/controls", name)
282                                                 dup[i] = true
283                                                 delete(cc, i)
284                                                 continue
285                                         }
286                                         found = i
287                                         if split[ccCol] == "0" {
288                                                 cc[found] = false
289                                         }
290                                         if split[ccCol] == "1" {
291                                                 cc[found] = true
292                                         }
293                                 }
294                         }
295                         if found < 0 {
296                                 log.Warnf("pattern %q in %s does not match any genome IDs", pattern, infile)
297                                 continue
298                         }
299                 }
300         }
301         return cc, nil
302 }