Concurrent-batches mode for vcf2fasta and import.
[lightning.git] / import.go
1 package main
2
3 import (
4         "bufio"
5         "compress/gzip"
6         "context"
7         "encoding/gob"
8         "encoding/json"
9         "errors"
10         "flag"
11         "fmt"
12         "io"
13         "net/http"
14         _ "net/http/pprof"
15         "os"
16         "os/exec"
17         "path/filepath"
18         "regexp"
19         "runtime"
20         "sort"
21         "strings"
22         "sync"
23         "sync/atomic"
24         "time"
25
26         "git.arvados.org/arvados.git/sdk/go/arvados"
27         "github.com/lucasb-eyer/go-colorful"
28         log "github.com/sirupsen/logrus"
29         "gonum.org/v1/plot"
30         "gonum.org/v1/plot/plotter"
31         "gonum.org/v1/plot/vg"
32         "gonum.org/v1/plot/vg/draw"
33 )
34
35 type importer struct {
36         tagLibraryFile      string
37         refFile             string
38         outputFile          string
39         projectUUID         string
40         loglevel            string
41         priority            int
42         runLocal            bool
43         skipOOO             bool
44         outputTiles         bool
45         saveIncompleteTiles bool
46         outputStats         string
47         matchChromosome     *regexp.Regexp
48         encoder             *gob.Encoder
49         batchArgs
50 }
51
52 func (cmd *importer) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
53         var err error
54         defer func() {
55                 if err != nil {
56                         fmt.Fprintf(stderr, "%s\n", err)
57                 }
58         }()
59         flags := flag.NewFlagSet("", flag.ContinueOnError)
60         flags.SetOutput(stderr)
61         flags.StringVar(&cmd.tagLibraryFile, "tag-library", "", "tag library fasta `file`")
62         flags.StringVar(&cmd.refFile, "ref", "", "reference fasta `file`")
63         flags.StringVar(&cmd.outputFile, "o", "-", "output `file`")
64         flags.StringVar(&cmd.projectUUID, "project", "", "project `UUID` for output data")
65         flags.BoolVar(&cmd.runLocal, "local", false, "run on local host (default: run in an arvados container)")
66         flags.BoolVar(&cmd.skipOOO, "skip-ooo", false, "skip out-of-order tags")
67         flags.BoolVar(&cmd.outputTiles, "output-tiles", false, "include tile variant sequences in output file")
68         flags.BoolVar(&cmd.saveIncompleteTiles, "save-incomplete-tiles", false, "treat tiles with no-calls as regular tiles")
69         flags.StringVar(&cmd.outputStats, "output-stats", "", "output stats to `file` (json)")
70         cmd.batchArgs.Flags(flags)
71         matchChromosome := flags.String("match-chromosome", "^(chr)?([0-9]+|X|Y|MT?)$", "import chromosomes that match the given `regexp`")
72         flags.IntVar(&cmd.priority, "priority", 500, "container request priority")
73         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
74         flags.StringVar(&cmd.loglevel, "loglevel", "info", "logging threshold (trace, debug, info, warn, error, fatal, or panic)")
75         err = flags.Parse(args)
76         if err == flag.ErrHelp {
77                 err = nil
78                 return 0
79         } else if err != nil {
80                 return 2
81         } else if cmd.tagLibraryFile == "" {
82                 fmt.Fprintln(os.Stderr, "cannot import without -tag-library argument")
83                 return 2
84         } else if flags.NArg() == 0 {
85                 flags.Usage()
86                 return 2
87         }
88
89         if *pprof != "" {
90                 go func() {
91                         log.Println(http.ListenAndServe(*pprof, nil))
92                 }()
93         }
94
95         lvl, err := log.ParseLevel(cmd.loglevel)
96         if err != nil {
97                 return 2
98         }
99         log.SetLevel(lvl)
100
101         cmd.matchChromosome, err = regexp.Compile(*matchChromosome)
102         if err != nil {
103                 return 1
104         }
105
106         if !cmd.runLocal {
107                 err = cmd.runBatches(stdout, flags.Args())
108                 if err != nil {
109                         return 1
110                 }
111                 return 0
112         }
113
114         infiles, err := listInputFiles(flags.Args())
115         if err != nil {
116                 return 1
117         }
118         infiles = cmd.batchArgs.Slice(infiles)
119
120         taglib, err := cmd.loadTagLibrary()
121         if err != nil {
122                 return 1
123         }
124
125         var outw, outf io.WriteCloser
126         if cmd.outputFile == "-" {
127                 outw = nopCloser{stdout}
128         } else {
129                 outf, err = os.OpenFile(cmd.outputFile, os.O_CREATE|os.O_WRONLY, 0777)
130                 if err != nil {
131                         return 1
132                 }
133                 defer outf.Close()
134                 if strings.HasSuffix(cmd.outputFile, ".gz") {
135                         outw = gzip.NewWriter(outf)
136                 } else {
137                         outw = outf
138                 }
139         }
140         bufw := bufio.NewWriter(outw)
141         cmd.encoder = gob.NewEncoder(bufw)
142
143         tilelib := &tileLibrary{taglib: taglib, retainNoCalls: cmd.saveIncompleteTiles, skipOOO: cmd.skipOOO}
144         if cmd.outputTiles {
145                 cmd.encoder.Encode(LibraryEntry{TagSet: taglib.Tags()})
146                 tilelib.encoder = cmd.encoder
147         }
148         go func() {
149                 for range time.Tick(10 * time.Minute) {
150                         log.Printf("tilelib.Len() == %d", tilelib.Len())
151                 }
152         }()
153
154         err = cmd.tileInputs(tilelib, infiles)
155         if err != nil {
156                 return 1
157         }
158         err = bufw.Flush()
159         if err != nil {
160                 return 1
161         }
162         err = outw.Close()
163         if err != nil {
164                 return 1
165         }
166         if outf != nil && outf != outw {
167                 err = outf.Close()
168                 if err != nil {
169                         return 1
170                 }
171         }
172         return 0
173 }
174
175 func (cmd *importer) runBatches(stdout io.Writer, inputs []string) error {
176         if cmd.outputFile != "-" {
177                 // Not yet implemented, but this should write
178                 // the collection to an existing collection,
179                 // possibly even an in-place update.
180                 return errors.New("cannot specify output file in container mode: not implemented")
181         }
182         client := arvados.NewClientFromEnv()
183         runner := arvadosContainerRunner{
184                 Name:        "lightning import",
185                 Client:      client,
186                 ProjectUUID: cmd.projectUUID,
187                 RAM:         80000000000,
188                 VCPUs:       32,
189                 Priority:    cmd.priority,
190         }
191         err := runner.TranslatePaths(&cmd.tagLibraryFile, &cmd.refFile, &cmd.outputFile)
192         if err != nil {
193                 return err
194         }
195         for i := range inputs {
196                 err = runner.TranslatePaths(&inputs[i])
197                 if err != nil {
198                         return err
199                 }
200         }
201
202         outputs, err := cmd.batchArgs.RunBatches(context.Background(), func(ctx context.Context, batch int) (string, error) {
203                 runner := runner
204                 if cmd.batches > 1 {
205                         runner.Name += fmt.Sprintf(" (batch %d of %d)", batch, cmd.batches)
206                 }
207                 runner.Args = []string{"import",
208                         "-local=true",
209                         "-loglevel=" + cmd.loglevel,
210                         fmt.Sprintf("-skip-ooo=%v", cmd.skipOOO),
211                         fmt.Sprintf("-output-tiles=%v", cmd.outputTiles),
212                         fmt.Sprintf("-save-incomplete-tiles=%v", cmd.saveIncompleteTiles),
213                         "-match-chromosome", cmd.matchChromosome.String(),
214                         "-output-stats", "/mnt/output/stats.json",
215                         "-tag-library", cmd.tagLibraryFile,
216                         "-ref", cmd.refFile,
217                         "-o", "/mnt/output/library.gob.gz",
218                 }
219                 runner.Args = append(runner.Args, cmd.batchArgs.Args(batch)...)
220                 runner.Args = append(runner.Args, inputs...)
221                 return runner.RunContext(ctx)
222         })
223         if err != nil {
224                 return err
225         }
226         var outfiles []string
227         for _, o := range outputs {
228                 outfiles = append(outfiles, o+"/library.gob.gz")
229         }
230         fmt.Fprintln(stdout, strings.Join(outfiles, " "))
231         return nil
232 }
233
234 func (cmd *importer) tileFasta(tilelib *tileLibrary, infile string) (tileSeq, []importStats, error) {
235         var input io.ReadCloser
236         input, err := os.Open(infile)
237         if err != nil {
238                 return nil, nil, err
239         }
240         defer input.Close()
241         if strings.HasSuffix(infile, ".gz") {
242                 input, err = gzip.NewReader(input)
243                 if err != nil {
244                         return nil, nil, err
245                 }
246                 defer input.Close()
247         }
248         return tilelib.TileFasta(infile, input, cmd.matchChromosome)
249 }
250
251 func (cmd *importer) loadTagLibrary() (*tagLibrary, error) {
252         log.Printf("tag library %s load starting", cmd.tagLibraryFile)
253         f, err := os.Open(cmd.tagLibraryFile)
254         if err != nil {
255                 return nil, err
256         }
257         defer f.Close()
258         var rdr io.ReadCloser = f
259         if strings.HasSuffix(cmd.tagLibraryFile, ".gz") {
260                 rdr, err = gzip.NewReader(f)
261                 if err != nil {
262                         return nil, fmt.Errorf("%s: gzip: %s", cmd.tagLibraryFile, err)
263                 }
264                 defer rdr.Close()
265         }
266         var taglib tagLibrary
267         err = taglib.Load(rdr)
268         if err != nil {
269                 return nil, err
270         }
271         if taglib.Len() < 1 {
272                 return nil, fmt.Errorf("cannot tile: tag library is empty")
273         }
274         log.Printf("tag library %s load done", cmd.tagLibraryFile)
275         return &taglib, nil
276 }
277
278 var (
279         vcfFilenameRe    = regexp.MustCompile(`\.vcf(\.gz)?$`)
280         fasta1FilenameRe = regexp.MustCompile(`\.1\.fa(sta)?(\.gz)?$`)
281         fasta2FilenameRe = regexp.MustCompile(`\.2\.fa(sta)?(\.gz)?$`)
282         fastaFilenameRe  = regexp.MustCompile(`\.fa(sta)?(\.gz)?$`)
283 )
284
285 func listInputFiles(paths []string) (files []string, err error) {
286         for _, path := range paths {
287                 if fi, err := os.Stat(path); err != nil {
288                         return nil, fmt.Errorf("%s: stat failed: %s", path, err)
289                 } else if !fi.IsDir() {
290                         if !fasta2FilenameRe.MatchString(path) {
291                                 files = append(files, path)
292                         }
293                         continue
294                 }
295                 d, err := os.Open(path)
296                 if err != nil {
297                         return nil, fmt.Errorf("%s: open failed: %s", path, err)
298                 }
299                 defer d.Close()
300                 names, err := d.Readdirnames(0)
301                 if err != nil {
302                         return nil, fmt.Errorf("%s: readdir failed: %s", path, err)
303                 }
304                 sort.Strings(names)
305                 for _, name := range names {
306                         if vcfFilenameRe.MatchString(name) {
307                                 files = append(files, filepath.Join(path, name))
308                         } else if fastaFilenameRe.MatchString(name) && !fasta2FilenameRe.MatchString(name) {
309                                 files = append(files, filepath.Join(path, name))
310                         }
311                 }
312                 d.Close()
313         }
314         for _, file := range files {
315                 if fastaFilenameRe.MatchString(file) {
316                         continue
317                 } else if vcfFilenameRe.MatchString(file) {
318                         if _, err := os.Stat(file + ".csi"); err == nil {
319                                 continue
320                         } else if _, err = os.Stat(file + ".tbi"); err == nil {
321                                 continue
322                         } else {
323                                 return nil, fmt.Errorf("%s: cannot read without .tbi or .csi index file", file)
324                         }
325                 } else {
326                         return nil, fmt.Errorf("don't know how to handle filename %s", file)
327                 }
328         }
329         return
330 }
331
332 func (cmd *importer) tileInputs(tilelib *tileLibrary, infiles []string) error {
333         starttime := time.Now()
334         errs := make(chan error, 1)
335         todo := make(chan func() error, len(infiles)*2)
336         allstats := make([][]importStats, len(infiles)*2)
337         var encodeJobs sync.WaitGroup
338         for idx, infile := range infiles {
339                 idx, infile := idx, infile
340                 var phases sync.WaitGroup
341                 phases.Add(2)
342                 variants := make([][]tileVariantID, 2)
343                 if fasta1FilenameRe.MatchString(infile) {
344                         todo <- func() error {
345                                 defer phases.Done()
346                                 log.Printf("%s starting", infile)
347                                 defer log.Printf("%s done", infile)
348                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile)
349                                 allstats[idx*2] = stats
350                                 var kept, dropped int
351                                 variants[0], kept, dropped = tseqs.Variants()
352                                 log.Printf("%s found %d unique tags plus %d repeats", infile, kept, dropped)
353                                 return err
354                         }
355                         infile2 := fasta1FilenameRe.ReplaceAllString(infile, `.2.fa$1$2`)
356                         todo <- func() error {
357                                 defer phases.Done()
358                                 log.Printf("%s starting", infile2)
359                                 defer log.Printf("%s done", infile2)
360                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile2)
361                                 allstats[idx*2+1] = stats
362                                 var kept, dropped int
363                                 variants[1], kept, dropped = tseqs.Variants()
364                                 log.Printf("%s found %d unique tags plus %d repeats", infile2, kept, dropped)
365
366                                 return err
367                         }
368                 } else if fastaFilenameRe.MatchString(infile) {
369                         todo <- func() error {
370                                 defer phases.Done()
371                                 defer phases.Done()
372                                 log.Printf("%s starting", infile)
373                                 defer log.Printf("%s done", infile)
374                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile)
375                                 allstats[idx*2] = stats
376                                 if err != nil {
377                                         return err
378                                 }
379                                 totlen := 0
380                                 for _, tseq := range tseqs {
381                                         totlen += len(tseq)
382                                 }
383                                 log.Printf("%s tiled %d seqs, total len %d", infile, len(tseqs), totlen)
384                                 return cmd.encoder.Encode(LibraryEntry{
385                                         CompactSequences: []CompactSequence{{Name: infile, TileSequences: tseqs}},
386                                 })
387                         }
388                         // Don't write out a CompactGenomes entry
389                         continue
390                 } else if vcfFilenameRe.MatchString(infile) {
391                         for phase := 0; phase < 2; phase++ {
392                                 phase := phase
393                                 todo <- func() error {
394                                         defer phases.Done()
395                                         log.Printf("%s phase %d starting", infile, phase+1)
396                                         defer log.Printf("%s phase %d done", infile, phase+1)
397                                         tseqs, stats, err := cmd.tileGVCF(tilelib, infile, phase)
398                                         allstats[idx*2] = stats
399                                         var kept, dropped int
400                                         variants[phase], kept, dropped = tseqs.Variants()
401                                         log.Printf("%s phase %d found %d unique tags plus %d repeats", infile, phase+1, kept, dropped)
402                                         return err
403                                 }
404                         }
405                 } else {
406                         panic(fmt.Sprintf("bug: unhandled filename %q", infile))
407                 }
408                 encodeJobs.Add(1)
409                 go func() {
410                         defer encodeJobs.Done()
411                         phases.Wait()
412                         if len(errs) > 0 {
413                                 return
414                         }
415                         err := cmd.encoder.Encode(LibraryEntry{
416                                 CompactGenomes: []CompactGenome{{Name: infile, Variants: flatten(variants)}},
417                         })
418                         if err != nil {
419                                 select {
420                                 case errs <- err:
421                                 default:
422                                 }
423                         }
424                 }()
425         }
426         go close(todo)
427         var tileJobs sync.WaitGroup
428         var running int64
429         for i := 0; i < runtime.NumCPU()*9/8+1; i++ {
430                 tileJobs.Add(1)
431                 atomic.AddInt64(&running, 1)
432                 go func() {
433                         defer tileJobs.Done()
434                         defer atomic.AddInt64(&running, -1)
435                         for fn := range todo {
436                                 if len(errs) > 0 {
437                                         return
438                                 }
439                                 err := fn()
440                                 if err != nil {
441                                         select {
442                                         case errs <- err:
443                                         default:
444                                         }
445                                 }
446                                 remain := len(todo) + int(atomic.LoadInt64(&running)) - 1
447                                 if remain < cap(todo) {
448                                         ttl := time.Now().Sub(starttime) * time.Duration(remain) / time.Duration(cap(todo)-remain)
449                                         eta := time.Now().Add(ttl)
450                                         log.Printf("progress %d/%d, eta %v (%v)", cap(todo)-remain, cap(todo), eta, ttl)
451                                 }
452                         }
453                 }()
454         }
455         tileJobs.Wait()
456         encodeJobs.Wait()
457
458         go close(errs)
459         err := <-errs
460         if err != nil {
461                 return err
462         }
463
464         if cmd.outputStats != "" {
465                 f, err := os.OpenFile(cmd.outputStats, os.O_CREATE|os.O_WRONLY, 0666)
466                 if err != nil {
467                         return err
468                 }
469                 var flatstats []importStats
470                 for _, stats := range allstats {
471                         flatstats = append(flatstats, stats...)
472                 }
473                 err = json.NewEncoder(f).Encode(flatstats)
474                 if err != nil {
475                         return err
476                 }
477         }
478
479         return nil
480 }
481
482 func (cmd *importer) tileGVCF(tilelib *tileLibrary, infile string, phase int) (tileseq tileSeq, stats []importStats, err error) {
483         if cmd.refFile == "" {
484                 err = errors.New("cannot import vcf: reference data (-ref) not specified")
485                 return
486         }
487         args := []string{"bcftools", "consensus", "--fasta-ref", cmd.refFile, "-H", fmt.Sprint(phase + 1), infile}
488         indexsuffix := ".tbi"
489         if _, err := os.Stat(infile + ".csi"); err == nil {
490                 indexsuffix = ".csi"
491         }
492         if out, err := exec.Command("docker", "image", "ls", "-q", "lightning-runtime").Output(); err == nil && len(out) > 0 {
493                 args = append([]string{
494                         "docker", "run", "--rm",
495                         "--log-driver=none",
496                         "--volume=" + infile + ":" + infile + ":ro",
497                         "--volume=" + infile + indexsuffix + ":" + infile + indexsuffix + ":ro",
498                         "--volume=" + cmd.refFile + ":" + cmd.refFile + ":ro",
499                         "lightning-runtime",
500                 }, args...)
501         }
502         consensus := exec.Command(args[0], args[1:]...)
503         consensus.Stderr = os.Stderr
504         stdout, err := consensus.StdoutPipe()
505         defer stdout.Close()
506         if err != nil {
507                 return
508         }
509         err = consensus.Start()
510         if err != nil {
511                 return
512         }
513         defer consensus.Wait()
514         tileseq, stats, err = tilelib.TileFasta(fmt.Sprintf("%s phase %d", infile, phase+1), stdout, cmd.matchChromosome)
515         if err != nil {
516                 return
517         }
518         err = stdout.Close()
519         if err != nil {
520                 return
521         }
522         err = consensus.Wait()
523         if err != nil {
524                 err = fmt.Errorf("%s phase %d: bcftools: %s", infile, phase, err)
525                 return
526         }
527         return
528 }
529
530 func flatten(variants [][]tileVariantID) []tileVariantID {
531         ntags := 0
532         for _, v := range variants {
533                 if ntags < len(v) {
534                         ntags = len(v)
535                 }
536         }
537         flat := make([]tileVariantID, ntags*2)
538         for i := 0; i < ntags; i++ {
539                 for hap := 0; hap < 2; hap++ {
540                         if i < len(variants[hap]) {
541                                 flat[i*2+hap] = variants[hap][i]
542                         }
543                 }
544         }
545         return flat
546 }
547
548 type importstatsplot struct{}
549
550 func (cmd *importstatsplot) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
551         err := cmd.Plot(stdin, stdout)
552         if err != nil {
553                 log.Errorf("%s", err)
554                 return 1
555         }
556         return 0
557 }
558
559 func (cmd *importstatsplot) Plot(stdin io.Reader, stdout io.Writer) error {
560         var stats []importStats
561         err := json.NewDecoder(stdin).Decode(&stats)
562         if err != nil {
563                 return err
564         }
565
566         p, err := plot.New()
567         if err != nil {
568                 return err
569         }
570         p.Title.Text = "coverage preserved by import (excl X<0.65)"
571         p.X.Label.Text = "input base calls ÷ sequence length"
572         p.Y.Label.Text = "output base calls ÷ input base calls"
573         p.Add(plotter.NewGrid())
574
575         data := map[string]plotter.XYs{}
576         for _, stat := range stats {
577                 data[stat.InputLabel] = append(data[stat.InputLabel], plotter.XY{
578                         X: float64(stat.InputCoverage) / float64(stat.InputLength),
579                         Y: float64(stat.TileCoverage) / float64(stat.InputCoverage),
580                 })
581         }
582
583         labels := []string{}
584         for label := range data {
585                 labels = append(labels, label)
586         }
587         sort.Strings(labels)
588         palette, err := colorful.SoftPalette(len(labels))
589         if err != nil {
590                 return err
591         }
592         nextInPalette := 0
593         for idx, label := range labels {
594                 s, err := plotter.NewScatter(data[label])
595                 if err != nil {
596                         return err
597                 }
598                 s.GlyphStyle.Color = palette[idx]
599                 s.GlyphStyle.Radius = vg.Millimeter / 2
600                 s.GlyphStyle.Shape = draw.CrossGlyph{}
601                 nextInPalette += 7
602                 p.Add(s)
603                 if false {
604                         p.Legend.Add(label, s)
605                 }
606         }
607         p.X.Min = 0.65
608         p.X.Max = 1
609
610         w, err := p.WriterTo(8*vg.Inch, 6*vg.Inch, "svg")
611         if err != nil {
612                 return err
613         }
614         _, err = w.WriteTo(stdout)
615         return err
616 }