Fix divide by zero.
[lightning.git] / import.go
1 package main
2
3 import (
4         "bufio"
5         "compress/gzip"
6         "encoding/gob"
7         "encoding/json"
8         "errors"
9         "flag"
10         "fmt"
11         "io"
12         "net/http"
13         _ "net/http/pprof"
14         "os"
15         "os/exec"
16         "path/filepath"
17         "regexp"
18         "runtime"
19         "sort"
20         "strings"
21         "sync"
22         "sync/atomic"
23         "time"
24
25         "git.arvados.org/arvados.git/sdk/go/arvados"
26         "github.com/lucasb-eyer/go-colorful"
27         log "github.com/sirupsen/logrus"
28         "gonum.org/v1/plot"
29         "gonum.org/v1/plot/plotter"
30         "gonum.org/v1/plot/vg"
31         "gonum.org/v1/plot/vg/draw"
32 )
33
34 type importer struct {
35         tagLibraryFile      string
36         refFile             string
37         outputFile          string
38         projectUUID         string
39         runLocal            bool
40         skipOOO             bool
41         outputTiles         bool
42         saveIncompleteTiles bool
43         outputStats         string
44         matchChromosome     *regexp.Regexp
45         encoder             *gob.Encoder
46 }
47
48 func (cmd *importer) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
49         var err error
50         defer func() {
51                 if err != nil {
52                         fmt.Fprintf(stderr, "%s\n", err)
53                 }
54         }()
55         flags := flag.NewFlagSet("", flag.ContinueOnError)
56         flags.SetOutput(stderr)
57         flags.StringVar(&cmd.tagLibraryFile, "tag-library", "", "tag library fasta `file`")
58         flags.StringVar(&cmd.refFile, "ref", "", "reference fasta `file`")
59         flags.StringVar(&cmd.outputFile, "o", "-", "output `file`")
60         flags.StringVar(&cmd.projectUUID, "project", "", "project `UUID` for output data")
61         flags.BoolVar(&cmd.runLocal, "local", false, "run on local host (default: run in an arvados container)")
62         flags.BoolVar(&cmd.skipOOO, "skip-ooo", false, "skip out-of-order tags")
63         flags.BoolVar(&cmd.outputTiles, "output-tiles", false, "include tile variant sequences in output file")
64         flags.BoolVar(&cmd.saveIncompleteTiles, "save-incomplete-tiles", false, "treat tiles with no-calls as regular tiles")
65         flags.StringVar(&cmd.outputStats, "output-stats", "", "output stats to `file` (json)")
66         matchChromosome := flags.String("match-chromosome", "^(chr)?([0-9]+|X|Y|MT?)$", "import chromosomes that match the given `regexp`")
67         priority := flags.Int("priority", 500, "container request priority")
68         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
69         loglevel := flags.String("loglevel", "info", "logging threshold (trace, debug, info, warn, error, fatal, or panic)")
70         err = flags.Parse(args)
71         if err == flag.ErrHelp {
72                 err = nil
73                 return 0
74         } else if err != nil {
75                 return 2
76         } else if cmd.tagLibraryFile == "" {
77                 fmt.Fprintln(os.Stderr, "cannot import without -tag-library argument")
78                 return 2
79         } else if flags.NArg() == 0 {
80                 flags.Usage()
81                 return 2
82         }
83
84         if *pprof != "" {
85                 go func() {
86                         log.Println(http.ListenAndServe(*pprof, nil))
87                 }()
88         }
89
90         lvl, err := log.ParseLevel(*loglevel)
91         if err != nil {
92                 return 2
93         }
94         log.SetLevel(lvl)
95
96         cmd.matchChromosome, err = regexp.Compile(*matchChromosome)
97         if err != nil {
98                 return 1
99         }
100
101         if !cmd.runLocal {
102                 runner := arvadosContainerRunner{
103                         Name:        "lightning import",
104                         Client:      arvados.NewClientFromEnv(),
105                         ProjectUUID: cmd.projectUUID,
106                         RAM:         80000000000,
107                         VCPUs:       32,
108                         Priority:    *priority,
109                 }
110                 err = runner.TranslatePaths(&cmd.tagLibraryFile, &cmd.refFile, &cmd.outputFile)
111                 if err != nil {
112                         return 1
113                 }
114                 inputs := flags.Args()
115                 for i := range inputs {
116                         err = runner.TranslatePaths(&inputs[i])
117                         if err != nil {
118                                 return 1
119                         }
120                 }
121                 if cmd.outputFile == "-" {
122                         cmd.outputFile = "/mnt/output/library.gob.gz"
123                 } else {
124                         // Not yet implemented, but this should write
125                         // the collection to an existing collection,
126                         // possibly even an in-place update.
127                         err = errors.New("cannot specify output file in container mode: not implemented")
128                         return 1
129                 }
130                 runner.Args = append([]string{"import",
131                         "-local=true",
132                         "-loglevel=" + *loglevel,
133                         fmt.Sprintf("-skip-ooo=%v", cmd.skipOOO),
134                         fmt.Sprintf("-output-tiles=%v", cmd.outputTiles),
135                         fmt.Sprintf("-save-incomplete-tiles=%v", cmd.saveIncompleteTiles),
136                         "-output-stats", "/mnt/output/stats.json",
137                         "-tag-library", cmd.tagLibraryFile,
138                         "-ref", cmd.refFile,
139                         "-o", cmd.outputFile,
140                 }, inputs...)
141                 var output string
142                 output, err = runner.Run()
143                 if err != nil {
144                         return 1
145                 }
146                 fmt.Fprintln(stdout, output+"/library.gob.gz")
147                 return 0
148         }
149
150         infiles, err := listInputFiles(flags.Args())
151         if err != nil {
152                 return 1
153         }
154
155         taglib, err := cmd.loadTagLibrary()
156         if err != nil {
157                 return 1
158         }
159
160         var outw, outf io.WriteCloser
161         if cmd.outputFile == "-" {
162                 outw = nopCloser{stdout}
163         } else {
164                 outf, err = os.OpenFile(cmd.outputFile, os.O_CREATE|os.O_WRONLY, 0777)
165                 if err != nil {
166                         return 1
167                 }
168                 defer outf.Close()
169                 if strings.HasSuffix(cmd.outputFile, ".gz") {
170                         outw = gzip.NewWriter(outf)
171                 } else {
172                         outw = outf
173                 }
174         }
175         bufw := bufio.NewWriter(outw)
176         cmd.encoder = gob.NewEncoder(bufw)
177
178         tilelib := &tileLibrary{taglib: taglib, retainNoCalls: cmd.saveIncompleteTiles, skipOOO: cmd.skipOOO}
179         if cmd.outputTiles {
180                 cmd.encoder.Encode(LibraryEntry{TagSet: taglib.Tags()})
181                 tilelib.encoder = cmd.encoder
182         }
183         go func() {
184                 for range time.Tick(10 * time.Minute) {
185                         log.Printf("tilelib.Len() == %d", tilelib.Len())
186                 }
187         }()
188
189         err = cmd.tileInputs(tilelib, infiles)
190         if err != nil {
191                 return 1
192         }
193         err = bufw.Flush()
194         if err != nil {
195                 return 1
196         }
197         err = outw.Close()
198         if err != nil {
199                 return 1
200         }
201         if outf != nil && outf != outw {
202                 err = outf.Close()
203                 if err != nil {
204                         return 1
205                 }
206         }
207         return 0
208 }
209
210 func (cmd *importer) tileFasta(tilelib *tileLibrary, infile string) (tileSeq, []importStats, error) {
211         var input io.ReadCloser
212         input, err := os.Open(infile)
213         if err != nil {
214                 return nil, nil, err
215         }
216         defer input.Close()
217         if strings.HasSuffix(infile, ".gz") {
218                 input, err = gzip.NewReader(input)
219                 if err != nil {
220                         return nil, nil, err
221                 }
222                 defer input.Close()
223         }
224         return tilelib.TileFasta(infile, input, cmd.matchChromosome)
225 }
226
227 func (cmd *importer) loadTagLibrary() (*tagLibrary, error) {
228         log.Printf("tag library %s load starting", cmd.tagLibraryFile)
229         f, err := os.Open(cmd.tagLibraryFile)
230         if err != nil {
231                 return nil, err
232         }
233         defer f.Close()
234         var rdr io.ReadCloser = f
235         if strings.HasSuffix(cmd.tagLibraryFile, ".gz") {
236                 rdr, err = gzip.NewReader(f)
237                 if err != nil {
238                         return nil, fmt.Errorf("%s: gzip: %s", cmd.tagLibraryFile, err)
239                 }
240                 defer rdr.Close()
241         }
242         var taglib tagLibrary
243         err = taglib.Load(rdr)
244         if err != nil {
245                 return nil, err
246         }
247         if taglib.Len() < 1 {
248                 return nil, fmt.Errorf("cannot tile: tag library is empty")
249         }
250         log.Printf("tag library %s load done", cmd.tagLibraryFile)
251         return &taglib, nil
252 }
253
254 var (
255         vcfFilenameRe    = regexp.MustCompile(`\.vcf(\.gz)?$`)
256         fasta1FilenameRe = regexp.MustCompile(`\.1\.fa(sta)?(\.gz)?$`)
257         fasta2FilenameRe = regexp.MustCompile(`\.2\.fa(sta)?(\.gz)?$`)
258         fastaFilenameRe  = regexp.MustCompile(`\.fa(sta)?(\.gz)?$`)
259 )
260
261 func listInputFiles(paths []string) (files []string, err error) {
262         for _, path := range paths {
263                 if fi, err := os.Stat(path); err != nil {
264                         return nil, fmt.Errorf("%s: stat failed: %s", path, err)
265                 } else if !fi.IsDir() {
266                         if !fasta2FilenameRe.MatchString(path) {
267                                 files = append(files, path)
268                         }
269                         continue
270                 }
271                 d, err := os.Open(path)
272                 if err != nil {
273                         return nil, fmt.Errorf("%s: open failed: %s", path, err)
274                 }
275                 defer d.Close()
276                 names, err := d.Readdirnames(0)
277                 if err != nil {
278                         return nil, fmt.Errorf("%s: readdir failed: %s", path, err)
279                 }
280                 sort.Strings(names)
281                 for _, name := range names {
282                         if vcfFilenameRe.MatchString(name) {
283                                 files = append(files, filepath.Join(path, name))
284                         } else if fastaFilenameRe.MatchString(name) && !fasta2FilenameRe.MatchString(name) {
285                                 files = append(files, filepath.Join(path, name))
286                         }
287                 }
288                 d.Close()
289         }
290         for _, file := range files {
291                 if fastaFilenameRe.MatchString(file) {
292                         continue
293                 } else if vcfFilenameRe.MatchString(file) {
294                         if _, err := os.Stat(file + ".csi"); err == nil {
295                                 continue
296                         } else if _, err = os.Stat(file + ".tbi"); err == nil {
297                                 continue
298                         } else {
299                                 return nil, fmt.Errorf("%s: cannot read without .tbi or .csi index file", file)
300                         }
301                 } else {
302                         return nil, fmt.Errorf("don't know how to handle filename %s", file)
303                 }
304         }
305         return
306 }
307
308 func (cmd *importer) tileInputs(tilelib *tileLibrary, infiles []string) error {
309         starttime := time.Now()
310         errs := make(chan error, 1)
311         todo := make(chan func() error, len(infiles)*2)
312         allstats := make([][]importStats, len(infiles)*2)
313         var encodeJobs sync.WaitGroup
314         for idx, infile := range infiles {
315                 idx, infile := idx, infile
316                 var phases sync.WaitGroup
317                 phases.Add(2)
318                 variants := make([][]tileVariantID, 2)
319                 if fasta1FilenameRe.MatchString(infile) {
320                         todo <- func() error {
321                                 defer phases.Done()
322                                 log.Printf("%s starting", infile)
323                                 defer log.Printf("%s done", infile)
324                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile)
325                                 allstats[idx*2] = stats
326                                 var kept, dropped int
327                                 variants[0], kept, dropped = tseqs.Variants()
328                                 log.Printf("%s found %d unique tags plus %d repeats", infile, kept, dropped)
329                                 return err
330                         }
331                         infile2 := fasta1FilenameRe.ReplaceAllString(infile, `.2.fa$1$2`)
332                         todo <- func() error {
333                                 defer phases.Done()
334                                 log.Printf("%s starting", infile2)
335                                 defer log.Printf("%s done", infile2)
336                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile2)
337                                 allstats[idx*2+1] = stats
338                                 var kept, dropped int
339                                 variants[1], kept, dropped = tseqs.Variants()
340                                 log.Printf("%s found %d unique tags plus %d repeats", infile2, kept, dropped)
341
342                                 return err
343                         }
344                 } else if fastaFilenameRe.MatchString(infile) {
345                         todo <- func() error {
346                                 defer phases.Done()
347                                 defer phases.Done()
348                                 log.Printf("%s starting", infile)
349                                 defer log.Printf("%s done", infile)
350                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile)
351                                 allstats[idx*2] = stats
352                                 if err != nil {
353                                         return err
354                                 }
355                                 totlen := 0
356                                 for _, tseq := range tseqs {
357                                         totlen += len(tseq)
358                                 }
359                                 log.Printf("%s tiled %d seqs, total len %d", infile, len(tseqs), totlen)
360                                 return cmd.encoder.Encode(LibraryEntry{
361                                         CompactSequences: []CompactSequence{{Name: infile, TileSequences: tseqs}},
362                                 })
363                         }
364                         // Don't write out a CompactGenomes entry
365                         continue
366                 } else if vcfFilenameRe.MatchString(infile) {
367                         for phase := 0; phase < 2; phase++ {
368                                 phase := phase
369                                 todo <- func() error {
370                                         defer phases.Done()
371                                         log.Printf("%s phase %d starting", infile, phase+1)
372                                         defer log.Printf("%s phase %d done", infile, phase+1)
373                                         tseqs, stats, err := cmd.tileGVCF(tilelib, infile, phase)
374                                         allstats[idx*2] = stats
375                                         var kept, dropped int
376                                         variants[phase], kept, dropped = tseqs.Variants()
377                                         log.Printf("%s phase %d found %d unique tags plus %d repeats", infile, phase+1, kept, dropped)
378                                         return err
379                                 }
380                         }
381                 } else {
382                         panic(fmt.Sprintf("bug: unhandled filename %q", infile))
383                 }
384                 encodeJobs.Add(1)
385                 go func() {
386                         defer encodeJobs.Done()
387                         phases.Wait()
388                         if len(errs) > 0 {
389                                 return
390                         }
391                         err := cmd.encoder.Encode(LibraryEntry{
392                                 CompactGenomes: []CompactGenome{{Name: infile, Variants: flatten(variants)}},
393                         })
394                         if err != nil {
395                                 select {
396                                 case errs <- err:
397                                 default:
398                                 }
399                         }
400                 }()
401         }
402         go close(todo)
403         var tileJobs sync.WaitGroup
404         var running int64
405         for i := 0; i < runtime.NumCPU()*9/8+1; i++ {
406                 tileJobs.Add(1)
407                 atomic.AddInt64(&running, 1)
408                 go func() {
409                         defer tileJobs.Done()
410                         defer atomic.AddInt64(&running, -1)
411                         for fn := range todo {
412                                 if len(errs) > 0 {
413                                         return
414                                 }
415                                 err := fn()
416                                 if err != nil {
417                                         select {
418                                         case errs <- err:
419                                         default:
420                                         }
421                                 }
422                                 remain := len(todo) + int(atomic.LoadInt64(&running)) - 1
423                                 if remain < cap(todo) {
424                                         ttl := time.Now().Sub(starttime) * time.Duration(remain) / time.Duration(cap(todo)-remain)
425                                         eta := time.Now().Add(ttl)
426                                         log.Printf("progress %d/%d, eta %v (%v)", cap(todo)-remain, cap(todo), eta, ttl)
427                                 }
428                         }
429                 }()
430         }
431         tileJobs.Wait()
432         encodeJobs.Wait()
433
434         go close(errs)
435         err := <-errs
436         if err != nil {
437                 return err
438         }
439
440         if cmd.outputStats != "" {
441                 f, err := os.OpenFile(cmd.outputStats, os.O_CREATE|os.O_WRONLY, 0666)
442                 if err != nil {
443                         return err
444                 }
445                 var flatstats []importStats
446                 for _, stats := range allstats {
447                         flatstats = append(flatstats, stats...)
448                 }
449                 err = json.NewEncoder(f).Encode(flatstats)
450                 if err != nil {
451                         return err
452                 }
453         }
454
455         return nil
456 }
457
458 func (cmd *importer) tileGVCF(tilelib *tileLibrary, infile string, phase int) (tileseq tileSeq, stats []importStats, err error) {
459         if cmd.refFile == "" {
460                 err = errors.New("cannot import vcf: reference data (-ref) not specified")
461                 return
462         }
463         args := []string{"bcftools", "consensus", "--fasta-ref", cmd.refFile, "-H", fmt.Sprint(phase + 1), infile}
464         indexsuffix := ".tbi"
465         if _, err := os.Stat(infile + ".csi"); err == nil {
466                 indexsuffix = ".csi"
467         }
468         if out, err := exec.Command("docker", "image", "ls", "-q", "lightning-runtime").Output(); err == nil && len(out) > 0 {
469                 args = append([]string{
470                         "docker", "run", "--rm",
471                         "--log-driver=none",
472                         "--volume=" + infile + ":" + infile + ":ro",
473                         "--volume=" + infile + indexsuffix + ":" + infile + indexsuffix + ":ro",
474                         "--volume=" + cmd.refFile + ":" + cmd.refFile + ":ro",
475                         "lightning-runtime",
476                 }, args...)
477         }
478         consensus := exec.Command(args[0], args[1:]...)
479         consensus.Stderr = os.Stderr
480         stdout, err := consensus.StdoutPipe()
481         defer stdout.Close()
482         if err != nil {
483                 return
484         }
485         err = consensus.Start()
486         if err != nil {
487                 return
488         }
489         defer consensus.Wait()
490         tileseq, stats, err = tilelib.TileFasta(fmt.Sprintf("%s phase %d", infile, phase+1), stdout, cmd.matchChromosome)
491         if err != nil {
492                 return
493         }
494         err = stdout.Close()
495         if err != nil {
496                 return
497         }
498         err = consensus.Wait()
499         if err != nil {
500                 err = fmt.Errorf("%s phase %d: bcftools: %s", infile, phase, err)
501                 return
502         }
503         return
504 }
505
506 func flatten(variants [][]tileVariantID) []tileVariantID {
507         ntags := 0
508         for _, v := range variants {
509                 if ntags < len(v) {
510                         ntags = len(v)
511                 }
512         }
513         flat := make([]tileVariantID, ntags*2)
514         for i := 0; i < ntags; i++ {
515                 for hap := 0; hap < 2; hap++ {
516                         if i < len(variants[hap]) {
517                                 flat[i*2+hap] = variants[hap][i]
518                         }
519                 }
520         }
521         return flat
522 }
523
524 type importstatsplot struct{}
525
526 func (cmd *importstatsplot) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
527         err := cmd.Plot(stdin, stdout)
528         if err != nil {
529                 log.Errorf("%s", err)
530                 return 1
531         }
532         return 0
533 }
534
535 func (cmd *importstatsplot) Plot(stdin io.Reader, stdout io.Writer) error {
536         var stats []importStats
537         err := json.NewDecoder(stdin).Decode(&stats)
538         if err != nil {
539                 return err
540         }
541
542         p, err := plot.New()
543         if err != nil {
544                 return err
545         }
546         p.Title.Text = "coverage preserved by import (excl X<0.65)"
547         p.X.Label.Text = "input base calls ÷ sequence length"
548         p.Y.Label.Text = "output base calls ÷ input base calls"
549         p.Add(plotter.NewGrid())
550
551         data := map[string]plotter.XYs{}
552         for _, stat := range stats {
553                 data[stat.InputLabel] = append(data[stat.InputLabel], plotter.XY{
554                         X: float64(stat.InputCoverage) / float64(stat.InputLength),
555                         Y: float64(stat.TileCoverage) / float64(stat.InputCoverage),
556                 })
557         }
558
559         labels := []string{}
560         for label := range data {
561                 labels = append(labels, label)
562         }
563         sort.Strings(labels)
564         palette, err := colorful.SoftPalette(len(labels))
565         if err != nil {
566                 return err
567         }
568         nextInPalette := 0
569         for idx, label := range labels {
570                 s, err := plotter.NewScatter(data[label])
571                 if err != nil {
572                         return err
573                 }
574                 s.GlyphStyle.Color = palette[idx]
575                 s.GlyphStyle.Radius = vg.Millimeter / 2
576                 s.GlyphStyle.Shape = draw.CrossGlyph{}
577                 nextInPalette += 7
578                 p.Add(s)
579                 if false {
580                         p.Legend.Add(label, s)
581                 }
582         }
583         p.X.Min = 0.65
584         p.X.Max = 1
585
586         w, err := p.WriterTo(8*vg.Inch, 6*vg.Inch, "svg")
587         if err != nil {
588                 return err
589         }
590         _, err = w.WriteTo(stdout)
591         return err
592 }