Fix race.
[lightning.git] / import.go
1 package lightning
2
3 import (
4         "bufio"
5         "compress/gzip"
6         "context"
7         "encoding/gob"
8         "encoding/json"
9         "errors"
10         "flag"
11         "fmt"
12         "io"
13         "io/ioutil"
14         "net/http"
15         _ "net/http/pprof"
16         "os"
17         "os/exec"
18         "path/filepath"
19         "regexp"
20         "runtime"
21         "sort"
22         "strings"
23         "sync"
24         "sync/atomic"
25         "time"
26
27         "github.com/klauspost/pgzip"
28         log "github.com/sirupsen/logrus"
29 )
30
31 type importer struct {
32         tagLibraryFile      string
33         refFile             string
34         outputFile          string
35         projectUUID         string
36         loglevel            string
37         priority            int
38         runLocal            bool
39         skipOOO             bool
40         outputTiles         bool
41         saveIncompleteTiles bool
42         outputStats         string
43         matchChromosome     *regexp.Regexp
44         encoder             *gob.Encoder
45         retainAfterEncoding bool // keep imported genomes/refseqs in memory after writing to disk
46         batchArgs
47 }
48
49 func (cmd *importer) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
50         var err error
51         defer func() {
52                 if err != nil {
53                         fmt.Fprintf(stderr, "%s\n", err)
54                 }
55         }()
56         flags := flag.NewFlagSet("", flag.ContinueOnError)
57         flags.SetOutput(stderr)
58         flags.StringVar(&cmd.tagLibraryFile, "tag-library", "", "tag library fasta `file`")
59         flags.StringVar(&cmd.refFile, "ref", "", "reference fasta `file`")
60         flags.StringVar(&cmd.outputFile, "o", "-", "output `file`")
61         flags.StringVar(&cmd.projectUUID, "project", "", "project `UUID` for output data")
62         flags.BoolVar(&cmd.runLocal, "local", false, "run on local host (default: run in an arvados container)")
63         flags.BoolVar(&cmd.skipOOO, "skip-ooo", false, "skip out-of-order tags")
64         flags.BoolVar(&cmd.outputTiles, "output-tiles", false, "include tile variant sequences in output file")
65         flags.BoolVar(&cmd.saveIncompleteTiles, "save-incomplete-tiles", false, "treat tiles with no-calls as regular tiles")
66         flags.StringVar(&cmd.outputStats, "output-stats", "", "output stats to `file` (json)")
67         cmd.batchArgs.Flags(flags)
68         matchChromosome := flags.String("match-chromosome", "^(chr)?([0-9]+|X|Y|MT?)$", "import chromosomes that match the given `regexp`")
69         flags.IntVar(&cmd.priority, "priority", 500, "container request priority")
70         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
71         flags.StringVar(&cmd.loglevel, "loglevel", "info", "logging threshold (trace, debug, info, warn, error, fatal, or panic)")
72         err = flags.Parse(args)
73         if err == flag.ErrHelp {
74                 err = nil
75                 return 0
76         } else if err != nil {
77                 return 2
78         } else if cmd.tagLibraryFile == "" {
79                 fmt.Fprintln(os.Stderr, "cannot import without -tag-library argument")
80                 return 2
81         } else if flags.NArg() == 0 {
82                 flags.Usage()
83                 return 2
84         }
85
86         if *pprof != "" {
87                 go func() {
88                         log.Println(http.ListenAndServe(*pprof, nil))
89                 }()
90         }
91
92         lvl, err := log.ParseLevel(cmd.loglevel)
93         if err != nil {
94                 return 2
95         }
96         log.SetLevel(lvl)
97
98         cmd.matchChromosome, err = regexp.Compile(*matchChromosome)
99         if err != nil {
100                 return 1
101         }
102
103         if !cmd.runLocal {
104                 err = cmd.runBatches(stdout, flags.Args())
105                 if err != nil {
106                         return 1
107                 }
108                 return 0
109         }
110
111         infiles, err := listInputFiles(flags.Args())
112         if err != nil {
113                 return 1
114         }
115         infiles = cmd.batchArgs.Slice(infiles)
116
117         taglib, err := cmd.loadTagLibrary()
118         if err != nil {
119                 return 1
120         }
121
122         var outw, outf io.WriteCloser
123         if cmd.outputFile == "-" {
124                 outw = nopCloser{stdout}
125         } else {
126                 outf, err = os.OpenFile(cmd.outputFile, os.O_CREATE|os.O_WRONLY, 0777)
127                 if err != nil {
128                         return 1
129                 }
130                 defer outf.Close()
131                 if strings.HasSuffix(cmd.outputFile, ".gz") {
132                         outw = pgzip.NewWriter(outf)
133                 } else {
134                         outw = outf
135                 }
136         }
137         bufw := bufio.NewWriterSize(outw, 64*1024*1024)
138         cmd.encoder = gob.NewEncoder(bufw)
139
140         tilelib := &tileLibrary{taglib: taglib, retainNoCalls: cmd.saveIncompleteTiles, skipOOO: cmd.skipOOO}
141         if cmd.outputTiles {
142                 cmd.encoder.Encode(LibraryEntry{TagSet: taglib.Tags()})
143                 tilelib.encoder = cmd.encoder
144         }
145         go func() {
146                 for range time.Tick(10 * time.Minute) {
147                         log.Printf("tilelib.Len() == %d", tilelib.Len())
148                 }
149         }()
150
151         err = cmd.tileInputs(tilelib, infiles)
152         if err != nil {
153                 return 1
154         }
155         err = bufw.Flush()
156         if err != nil {
157                 return 1
158         }
159         err = outw.Close()
160         if err != nil {
161                 return 1
162         }
163         if outf != nil && outf != outw {
164                 err = outf.Close()
165                 if err != nil {
166                         return 1
167                 }
168         }
169         return 0
170 }
171
172 func (cmd *importer) runBatches(stdout io.Writer, inputs []string) error {
173         if cmd.outputFile != "-" {
174                 // Not yet implemented, but this should write
175                 // the collection to an existing collection,
176                 // possibly even an in-place update.
177                 return errors.New("cannot specify output file in container mode: not implemented")
178         }
179         runner := arvadosContainerRunner{
180                 Name:        "lightning import",
181                 Client:      arvadosClientFromEnv,
182                 ProjectUUID: cmd.projectUUID,
183                 APIAccess:   true,
184                 RAM:         700000000000,
185                 VCPUs:       96,
186                 Priority:    cmd.priority,
187                 KeepCache:   1,
188         }
189         err := runner.TranslatePaths(&cmd.tagLibraryFile, &cmd.refFile, &cmd.outputFile)
190         if err != nil {
191                 return err
192         }
193         for i := range inputs {
194                 err = runner.TranslatePaths(&inputs[i])
195                 if err != nil {
196                         return err
197                 }
198         }
199
200         outputs, err := cmd.batchArgs.RunBatches(context.Background(), func(ctx context.Context, batch int) (string, error) {
201                 runner := runner
202                 if cmd.batches > 1 {
203                         runner.Name += fmt.Sprintf(" (batch %d of %d)", batch, cmd.batches)
204                 }
205                 runner.Args = []string{"import",
206                         "-local=true",
207                         "-loglevel=" + cmd.loglevel,
208                         "-pprof=:6061",
209                         fmt.Sprintf("-skip-ooo=%v", cmd.skipOOO),
210                         fmt.Sprintf("-output-tiles=%v", cmd.outputTiles),
211                         fmt.Sprintf("-save-incomplete-tiles=%v", cmd.saveIncompleteTiles),
212                         "-match-chromosome", cmd.matchChromosome.String(),
213                         "-output-stats", "/mnt/output/stats.json",
214                         "-tag-library", cmd.tagLibraryFile,
215                         "-ref", cmd.refFile,
216                         "-o", "/mnt/output/library.gob.gz",
217                 }
218                 runner.Args = append(runner.Args, cmd.batchArgs.Args(batch)...)
219                 runner.Args = append(runner.Args, inputs...)
220                 return runner.RunContext(ctx)
221         })
222         if err != nil {
223                 return err
224         }
225         var outfiles []string
226         for _, o := range outputs {
227                 outfiles = append(outfiles, o+"/library.gob.gz")
228         }
229         fmt.Fprintln(stdout, strings.Join(outfiles, " "))
230         return nil
231 }
232
233 func (cmd *importer) tileFasta(tilelib *tileLibrary, infile string) (tileSeq, []importStats, error) {
234         var input io.ReadCloser
235         input, err := open(infile)
236         if err != nil {
237                 return nil, nil, err
238         }
239         defer input.Close()
240         input = ioutil.NopCloser(bufio.NewReaderSize(input, 8*1024*1024))
241         if strings.HasSuffix(infile, ".gz") {
242                 input, err = pgzip.NewReader(input)
243                 if err != nil {
244                         return nil, nil, err
245                 }
246                 defer input.Close()
247         }
248         return tilelib.TileFasta(infile, input, cmd.matchChromosome)
249 }
250
251 func (cmd *importer) loadTagLibrary() (*tagLibrary, error) {
252         log.Printf("tag library %s load starting", cmd.tagLibraryFile)
253         f, err := open(cmd.tagLibraryFile)
254         if err != nil {
255                 return nil, err
256         }
257         defer f.Close()
258         rdr := ioutil.NopCloser(bufio.NewReaderSize(f, 64*1024*1024))
259         if strings.HasSuffix(cmd.tagLibraryFile, ".gz") {
260                 rdr, err = gzip.NewReader(rdr)
261                 if err != nil {
262                         return nil, fmt.Errorf("%s: gzip: %s", cmd.tagLibraryFile, err)
263                 }
264                 defer rdr.Close()
265         }
266         var taglib tagLibrary
267         err = taglib.Load(rdr)
268         if err != nil {
269                 return nil, err
270         }
271         if taglib.Len() < 1 {
272                 return nil, fmt.Errorf("cannot tile: tag library is empty")
273         }
274         log.Printf("tag library %s load done", cmd.tagLibraryFile)
275         return &taglib, nil
276 }
277
278 var (
279         vcfFilenameRe    = regexp.MustCompile(`\.vcf(\.gz)?$`)
280         fasta1FilenameRe = regexp.MustCompile(`\.1\.fa(sta)?(\.gz)?$`)
281         fasta2FilenameRe = regexp.MustCompile(`\.2\.fa(sta)?(\.gz)?$`)
282         fastaFilenameRe  = regexp.MustCompile(`\.fa(sta)?(\.gz)?$`)
283 )
284
285 func listInputFiles(paths []string) (files []string, err error) {
286         for _, path := range paths {
287                 if fi, err := os.Stat(path); err != nil {
288                         return nil, fmt.Errorf("%s: stat failed: %s", path, err)
289                 } else if !fi.IsDir() {
290                         if !fasta2FilenameRe.MatchString(path) {
291                                 files = append(files, path)
292                         }
293                         continue
294                 }
295                 d, err := os.Open(path)
296                 if err != nil {
297                         return nil, fmt.Errorf("%s: open failed: %s", path, err)
298                 }
299                 defer d.Close()
300                 names, err := d.Readdirnames(0)
301                 if err != nil {
302                         return nil, fmt.Errorf("%s: readdir failed: %s", path, err)
303                 }
304                 sort.Strings(names)
305                 for _, name := range names {
306                         if vcfFilenameRe.MatchString(name) {
307                                 files = append(files, filepath.Join(path, name))
308                         } else if fastaFilenameRe.MatchString(name) && !fasta2FilenameRe.MatchString(name) {
309                                 files = append(files, filepath.Join(path, name))
310                         }
311                 }
312                 d.Close()
313         }
314         for _, file := range files {
315                 if fastaFilenameRe.MatchString(file) {
316                         continue
317                 } else if vcfFilenameRe.MatchString(file) {
318                         if _, err := os.Stat(file + ".csi"); err == nil {
319                                 continue
320                         } else if _, err = os.Stat(file + ".tbi"); err == nil {
321                                 continue
322                         } else {
323                                 return nil, fmt.Errorf("%s: cannot read without .tbi or .csi index file", file)
324                         }
325                 } else {
326                         return nil, fmt.Errorf("don't know how to handle filename %s", file)
327                 }
328         }
329         return
330 }
331
332 func (cmd *importer) tileInputs(tilelib *tileLibrary, infiles []string) error {
333         starttime := time.Now()
334         errs := make(chan error, 1)
335         todo := make(chan func() error, len(infiles)*2)
336         allstats := make([][]importStats, len(infiles)*2)
337         var encodeJobs sync.WaitGroup
338         for idx, infile := range infiles {
339                 idx, infile := idx, infile
340                 var phases sync.WaitGroup
341                 phases.Add(2)
342                 variants := make([][]tileVariantID, 2)
343                 if fasta1FilenameRe.MatchString(infile) {
344                         todo <- func() error {
345                                 defer phases.Done()
346                                 log.Printf("%s starting", infile)
347                                 defer log.Printf("%s done", infile)
348                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile)
349                                 allstats[idx*2] = stats
350                                 var kept, dropped int
351                                 variants[0], kept, dropped = tseqs.Variants()
352                                 log.Printf("%s found %d unique tags plus %d repeats", infile, kept, dropped)
353                                 return err
354                         }
355                         infile2 := fasta1FilenameRe.ReplaceAllString(infile, `.2.fa$1$2`)
356                         todo <- func() error {
357                                 defer phases.Done()
358                                 log.Printf("%s starting", infile2)
359                                 defer log.Printf("%s done", infile2)
360                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile2)
361                                 allstats[idx*2+1] = stats
362                                 var kept, dropped int
363                                 variants[1], kept, dropped = tseqs.Variants()
364                                 log.Printf("%s found %d unique tags plus %d repeats", infile2, kept, dropped)
365                                 return err
366                         }
367                 } else if fastaFilenameRe.MatchString(infile) {
368                         todo <- func() error {
369                                 defer phases.Done()
370                                 defer phases.Done()
371                                 log.Printf("%s starting", infile)
372                                 defer log.Printf("%s done", infile)
373                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile)
374                                 allstats[idx*2] = stats
375                                 if err != nil {
376                                         return err
377                                 }
378                                 totlen := 0
379                                 for _, tseq := range tseqs {
380                                         totlen += len(tseq)
381                                 }
382                                 log.Printf("%s tiled %d seqs, total len %d", infile, len(tseqs), totlen)
383
384                                 if cmd.retainAfterEncoding {
385                                         tilelib.mtx.Lock()
386                                         if tilelib.refseqs == nil {
387                                                 tilelib.refseqs = map[string]map[string][]tileLibRef{}
388                                         }
389                                         tilelib.refseqs[infile] = tseqs
390                                         tilelib.mtx.Unlock()
391                                 }
392
393                                 return cmd.encoder.Encode(LibraryEntry{
394                                         CompactSequences: []CompactSequence{{Name: infile, TileSequences: tseqs}},
395                                 })
396                         }
397                         // Don't write out a CompactGenomes entry
398                         continue
399                 } else if vcfFilenameRe.MatchString(infile) {
400                         for phase := 0; phase < 2; phase++ {
401                                 phase := phase
402                                 todo <- func() error {
403                                         defer phases.Done()
404                                         log.Printf("%s phase %d starting", infile, phase+1)
405                                         defer log.Printf("%s phase %d done", infile, phase+1)
406                                         tseqs, stats, err := cmd.tileGVCF(tilelib, infile, phase)
407                                         allstats[idx*2] = stats
408                                         var kept, dropped int
409                                         variants[phase], kept, dropped = tseqs.Variants()
410                                         log.Printf("%s phase %d found %d unique tags plus %d repeats", infile, phase+1, kept, dropped)
411                                         return err
412                                 }
413                         }
414                 } else {
415                         panic(fmt.Sprintf("bug: unhandled filename %q", infile))
416                 }
417                 encodeJobs.Add(1)
418                 go func() {
419                         defer encodeJobs.Done()
420                         phases.Wait()
421                         if len(errs) > 0 {
422                                 return
423                         }
424                         variants := flatten(variants)
425                         err := cmd.encoder.Encode(LibraryEntry{
426                                 CompactGenomes: []CompactGenome{{Name: infile, Variants: variants}},
427                         })
428                         if err != nil {
429                                 select {
430                                 case errs <- err:
431                                 default:
432                                 }
433                         }
434                         if cmd.retainAfterEncoding {
435                                 tilelib.mtx.Lock()
436                                 if tilelib.compactGenomes == nil {
437                                         tilelib.compactGenomes = make(map[string][]tileVariantID)
438                                 }
439                                 tilelib.compactGenomes[infile] = variants
440                                 tilelib.mtx.Unlock()
441                         }
442                 }()
443         }
444         go close(todo)
445         var tileJobs sync.WaitGroup
446         var running int64
447         for i := 0; i < runtime.GOMAXPROCS(-1)*2; i++ {
448                 tileJobs.Add(1)
449                 atomic.AddInt64(&running, 1)
450                 go func() {
451                         defer tileJobs.Done()
452                         defer atomic.AddInt64(&running, -1)
453                         for fn := range todo {
454                                 if len(errs) > 0 {
455                                         return
456                                 }
457                                 err := fn()
458                                 if err != nil {
459                                         select {
460                                         case errs <- err:
461                                         default:
462                                         }
463                                 }
464                                 remain := len(todo) + int(atomic.LoadInt64(&running)) - 1
465                                 if remain < cap(todo) {
466                                         ttl := time.Now().Sub(starttime) * time.Duration(remain) / time.Duration(cap(todo)-remain)
467                                         eta := time.Now().Add(ttl)
468                                         log.Printf("progress %d/%d, eta %v (%v)", cap(todo)-remain, cap(todo), eta, ttl)
469                                 }
470                         }
471                 }()
472         }
473         tileJobs.Wait()
474         if len(errs) > 0 {
475                 // Must not wait on encodeJobs in this case. If the
476                 // tileJobs goroutines exited early, some funcs in
477                 // todo haven't been called, so the corresponding
478                 // encodeJobs will wait forever.
479                 return <-errs
480         }
481         encodeJobs.Wait()
482
483         go close(errs)
484         err := <-errs
485         if err != nil {
486                 return err
487         }
488
489         if cmd.outputStats != "" {
490                 f, err := os.OpenFile(cmd.outputStats, os.O_CREATE|os.O_WRONLY, 0666)
491                 if err != nil {
492                         return err
493                 }
494                 var flatstats []importStats
495                 for _, stats := range allstats {
496                         flatstats = append(flatstats, stats...)
497                 }
498                 err = json.NewEncoder(f).Encode(flatstats)
499                 if err != nil {
500                         return err
501                 }
502         }
503
504         return nil
505 }
506
507 func (cmd *importer) tileGVCF(tilelib *tileLibrary, infile string, phase int) (tileseq tileSeq, stats []importStats, err error) {
508         if cmd.refFile == "" {
509                 err = errors.New("cannot import vcf: reference data (-ref) not specified")
510                 return
511         }
512         args := []string{"bcftools", "consensus", "--fasta-ref", cmd.refFile, "-H", fmt.Sprint(phase + 1), infile}
513         indexsuffix := ".tbi"
514         if _, err := os.Stat(infile + ".csi"); err == nil {
515                 indexsuffix = ".csi"
516         }
517         if out, err := exec.Command("docker", "image", "ls", "-q", "lightning-runtime").Output(); err == nil && len(out) > 0 {
518                 args = append([]string{
519                         "docker", "run", "--rm",
520                         "--log-driver=none",
521                         "--volume=" + infile + ":" + infile + ":ro",
522                         "--volume=" + infile + indexsuffix + ":" + infile + indexsuffix + ":ro",
523                         "--volume=" + cmd.refFile + ":" + cmd.refFile + ":ro",
524                         "lightning-runtime",
525                 }, args...)
526         }
527         consensus := exec.Command(args[0], args[1:]...)
528         consensus.Stderr = os.Stderr
529         stdout, err := consensus.StdoutPipe()
530         defer stdout.Close()
531         if err != nil {
532                 return
533         }
534         err = consensus.Start()
535         if err != nil {
536                 return
537         }
538         defer consensus.Wait()
539         tileseq, stats, err = tilelib.TileFasta(fmt.Sprintf("%s phase %d", infile, phase+1), stdout, cmd.matchChromosome)
540         if err != nil {
541                 return
542         }
543         err = stdout.Close()
544         if err != nil {
545                 return
546         }
547         err = consensus.Wait()
548         if err != nil {
549                 err = fmt.Errorf("%s phase %d: bcftools: %s", infile, phase, err)
550                 return
551         }
552         return
553 }
554
555 func flatten(variants [][]tileVariantID) []tileVariantID {
556         ntags := 0
557         for _, v := range variants {
558                 if ntags < len(v) {
559                         ntags = len(v)
560                 }
561         }
562         flat := make([]tileVariantID, ntags*2)
563         for i := 0; i < ntags; i++ {
564                 for hap := 0; hap < 2; hap++ {
565                         if i < len(variants[hap]) {
566                                 flat[i*2+hap] = variants[hap][i]
567                         }
568                 }
569         }
570         return flat
571 }