Propagate -match-chromosome arg to container.
[lightning.git] / import.go
1 package main
2
3 import (
4         "bufio"
5         "compress/gzip"
6         "encoding/gob"
7         "encoding/json"
8         "errors"
9         "flag"
10         "fmt"
11         "io"
12         "net/http"
13         _ "net/http/pprof"
14         "os"
15         "os/exec"
16         "path/filepath"
17         "regexp"
18         "runtime"
19         "sort"
20         "strings"
21         "sync"
22         "sync/atomic"
23         "time"
24
25         "git.arvados.org/arvados.git/sdk/go/arvados"
26         "github.com/lucasb-eyer/go-colorful"
27         log "github.com/sirupsen/logrus"
28         "gonum.org/v1/plot"
29         "gonum.org/v1/plot/plotter"
30         "gonum.org/v1/plot/vg"
31         "gonum.org/v1/plot/vg/draw"
32 )
33
34 type importer struct {
35         tagLibraryFile      string
36         refFile             string
37         outputFile          string
38         projectUUID         string
39         runLocal            bool
40         skipOOO             bool
41         outputTiles         bool
42         saveIncompleteTiles bool
43         outputStats         string
44         matchChromosome     *regexp.Regexp
45         encoder             *gob.Encoder
46 }
47
48 func (cmd *importer) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
49         var err error
50         defer func() {
51                 if err != nil {
52                         fmt.Fprintf(stderr, "%s\n", err)
53                 }
54         }()
55         flags := flag.NewFlagSet("", flag.ContinueOnError)
56         flags.SetOutput(stderr)
57         flags.StringVar(&cmd.tagLibraryFile, "tag-library", "", "tag library fasta `file`")
58         flags.StringVar(&cmd.refFile, "ref", "", "reference fasta `file`")
59         flags.StringVar(&cmd.outputFile, "o", "-", "output `file`")
60         flags.StringVar(&cmd.projectUUID, "project", "", "project `UUID` for output data")
61         flags.BoolVar(&cmd.runLocal, "local", false, "run on local host (default: run in an arvados container)")
62         flags.BoolVar(&cmd.skipOOO, "skip-ooo", false, "skip out-of-order tags")
63         flags.BoolVar(&cmd.outputTiles, "output-tiles", false, "include tile variant sequences in output file")
64         flags.BoolVar(&cmd.saveIncompleteTiles, "save-incomplete-tiles", false, "treat tiles with no-calls as regular tiles")
65         flags.StringVar(&cmd.outputStats, "output-stats", "", "output stats to `file` (json)")
66         matchChromosome := flags.String("match-chromosome", "^(chr)?([0-9]+|X|Y|MT?)$", "import chromosomes that match the given `regexp`")
67         priority := flags.Int("priority", 500, "container request priority")
68         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
69         loglevel := flags.String("loglevel", "info", "logging threshold (trace, debug, info, warn, error, fatal, or panic)")
70         err = flags.Parse(args)
71         if err == flag.ErrHelp {
72                 err = nil
73                 return 0
74         } else if err != nil {
75                 return 2
76         } else if cmd.tagLibraryFile == "" {
77                 fmt.Fprintln(os.Stderr, "cannot import without -tag-library argument")
78                 return 2
79         } else if flags.NArg() == 0 {
80                 flags.Usage()
81                 return 2
82         }
83
84         if *pprof != "" {
85                 go func() {
86                         log.Println(http.ListenAndServe(*pprof, nil))
87                 }()
88         }
89
90         lvl, err := log.ParseLevel(*loglevel)
91         if err != nil {
92                 return 2
93         }
94         log.SetLevel(lvl)
95
96         cmd.matchChromosome, err = regexp.Compile(*matchChromosome)
97         if err != nil {
98                 return 1
99         }
100
101         if !cmd.runLocal {
102                 runner := arvadosContainerRunner{
103                         Name:        "lightning import",
104                         Client:      arvados.NewClientFromEnv(),
105                         ProjectUUID: cmd.projectUUID,
106                         RAM:         80000000000,
107                         VCPUs:       32,
108                         Priority:    *priority,
109                 }
110                 err = runner.TranslatePaths(&cmd.tagLibraryFile, &cmd.refFile, &cmd.outputFile)
111                 if err != nil {
112                         return 1
113                 }
114                 inputs := flags.Args()
115                 for i := range inputs {
116                         err = runner.TranslatePaths(&inputs[i])
117                         if err != nil {
118                                 return 1
119                         }
120                 }
121                 if cmd.outputFile == "-" {
122                         cmd.outputFile = "/mnt/output/library.gob.gz"
123                 } else {
124                         // Not yet implemented, but this should write
125                         // the collection to an existing collection,
126                         // possibly even an in-place update.
127                         err = errors.New("cannot specify output file in container mode: not implemented")
128                         return 1
129                 }
130                 runner.Args = append([]string{"import",
131                         "-local=true",
132                         "-loglevel=" + *loglevel,
133                         fmt.Sprintf("-skip-ooo=%v", cmd.skipOOO),
134                         fmt.Sprintf("-output-tiles=%v", cmd.outputTiles),
135                         fmt.Sprintf("-save-incomplete-tiles=%v", cmd.saveIncompleteTiles),
136                         "-match-chromosome", cmd.matchChromosome.String(),
137                         "-output-stats", "/mnt/output/stats.json",
138                         "-tag-library", cmd.tagLibraryFile,
139                         "-ref", cmd.refFile,
140                         "-o", cmd.outputFile,
141                 }, inputs...)
142                 var output string
143                 output, err = runner.Run()
144                 if err != nil {
145                         return 1
146                 }
147                 fmt.Fprintln(stdout, output+"/library.gob.gz")
148                 return 0
149         }
150
151         infiles, err := listInputFiles(flags.Args())
152         if err != nil {
153                 return 1
154         }
155
156         taglib, err := cmd.loadTagLibrary()
157         if err != nil {
158                 return 1
159         }
160
161         var outw, outf io.WriteCloser
162         if cmd.outputFile == "-" {
163                 outw = nopCloser{stdout}
164         } else {
165                 outf, err = os.OpenFile(cmd.outputFile, os.O_CREATE|os.O_WRONLY, 0777)
166                 if err != nil {
167                         return 1
168                 }
169                 defer outf.Close()
170                 if strings.HasSuffix(cmd.outputFile, ".gz") {
171                         outw = gzip.NewWriter(outf)
172                 } else {
173                         outw = outf
174                 }
175         }
176         bufw := bufio.NewWriter(outw)
177         cmd.encoder = gob.NewEncoder(bufw)
178
179         tilelib := &tileLibrary{taglib: taglib, retainNoCalls: cmd.saveIncompleteTiles, skipOOO: cmd.skipOOO}
180         if cmd.outputTiles {
181                 cmd.encoder.Encode(LibraryEntry{TagSet: taglib.Tags()})
182                 tilelib.encoder = cmd.encoder
183         }
184         go func() {
185                 for range time.Tick(10 * time.Minute) {
186                         log.Printf("tilelib.Len() == %d", tilelib.Len())
187                 }
188         }()
189
190         err = cmd.tileInputs(tilelib, infiles)
191         if err != nil {
192                 return 1
193         }
194         err = bufw.Flush()
195         if err != nil {
196                 return 1
197         }
198         err = outw.Close()
199         if err != nil {
200                 return 1
201         }
202         if outf != nil && outf != outw {
203                 err = outf.Close()
204                 if err != nil {
205                         return 1
206                 }
207         }
208         return 0
209 }
210
211 func (cmd *importer) tileFasta(tilelib *tileLibrary, infile string) (tileSeq, []importStats, error) {
212         var input io.ReadCloser
213         input, err := os.Open(infile)
214         if err != nil {
215                 return nil, nil, err
216         }
217         defer input.Close()
218         if strings.HasSuffix(infile, ".gz") {
219                 input, err = gzip.NewReader(input)
220                 if err != nil {
221                         return nil, nil, err
222                 }
223                 defer input.Close()
224         }
225         return tilelib.TileFasta(infile, input, cmd.matchChromosome)
226 }
227
228 func (cmd *importer) loadTagLibrary() (*tagLibrary, error) {
229         log.Printf("tag library %s load starting", cmd.tagLibraryFile)
230         f, err := os.Open(cmd.tagLibraryFile)
231         if err != nil {
232                 return nil, err
233         }
234         defer f.Close()
235         var rdr io.ReadCloser = f
236         if strings.HasSuffix(cmd.tagLibraryFile, ".gz") {
237                 rdr, err = gzip.NewReader(f)
238                 if err != nil {
239                         return nil, fmt.Errorf("%s: gzip: %s", cmd.tagLibraryFile, err)
240                 }
241                 defer rdr.Close()
242         }
243         var taglib tagLibrary
244         err = taglib.Load(rdr)
245         if err != nil {
246                 return nil, err
247         }
248         if taglib.Len() < 1 {
249                 return nil, fmt.Errorf("cannot tile: tag library is empty")
250         }
251         log.Printf("tag library %s load done", cmd.tagLibraryFile)
252         return &taglib, nil
253 }
254
255 var (
256         vcfFilenameRe    = regexp.MustCompile(`\.vcf(\.gz)?$`)
257         fasta1FilenameRe = regexp.MustCompile(`\.1\.fa(sta)?(\.gz)?$`)
258         fasta2FilenameRe = regexp.MustCompile(`\.2\.fa(sta)?(\.gz)?$`)
259         fastaFilenameRe  = regexp.MustCompile(`\.fa(sta)?(\.gz)?$`)
260 )
261
262 func listInputFiles(paths []string) (files []string, err error) {
263         for _, path := range paths {
264                 if fi, err := os.Stat(path); err != nil {
265                         return nil, fmt.Errorf("%s: stat failed: %s", path, err)
266                 } else if !fi.IsDir() {
267                         if !fasta2FilenameRe.MatchString(path) {
268                                 files = append(files, path)
269                         }
270                         continue
271                 }
272                 d, err := os.Open(path)
273                 if err != nil {
274                         return nil, fmt.Errorf("%s: open failed: %s", path, err)
275                 }
276                 defer d.Close()
277                 names, err := d.Readdirnames(0)
278                 if err != nil {
279                         return nil, fmt.Errorf("%s: readdir failed: %s", path, err)
280                 }
281                 sort.Strings(names)
282                 for _, name := range names {
283                         if vcfFilenameRe.MatchString(name) {
284                                 files = append(files, filepath.Join(path, name))
285                         } else if fastaFilenameRe.MatchString(name) && !fasta2FilenameRe.MatchString(name) {
286                                 files = append(files, filepath.Join(path, name))
287                         }
288                 }
289                 d.Close()
290         }
291         for _, file := range files {
292                 if fastaFilenameRe.MatchString(file) {
293                         continue
294                 } else if vcfFilenameRe.MatchString(file) {
295                         if _, err := os.Stat(file + ".csi"); err == nil {
296                                 continue
297                         } else if _, err = os.Stat(file + ".tbi"); err == nil {
298                                 continue
299                         } else {
300                                 return nil, fmt.Errorf("%s: cannot read without .tbi or .csi index file", file)
301                         }
302                 } else {
303                         return nil, fmt.Errorf("don't know how to handle filename %s", file)
304                 }
305         }
306         return
307 }
308
309 func (cmd *importer) tileInputs(tilelib *tileLibrary, infiles []string) error {
310         starttime := time.Now()
311         errs := make(chan error, 1)
312         todo := make(chan func() error, len(infiles)*2)
313         allstats := make([][]importStats, len(infiles)*2)
314         var encodeJobs sync.WaitGroup
315         for idx, infile := range infiles {
316                 idx, infile := idx, infile
317                 var phases sync.WaitGroup
318                 phases.Add(2)
319                 variants := make([][]tileVariantID, 2)
320                 if fasta1FilenameRe.MatchString(infile) {
321                         todo <- func() error {
322                                 defer phases.Done()
323                                 log.Printf("%s starting", infile)
324                                 defer log.Printf("%s done", infile)
325                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile)
326                                 allstats[idx*2] = stats
327                                 var kept, dropped int
328                                 variants[0], kept, dropped = tseqs.Variants()
329                                 log.Printf("%s found %d unique tags plus %d repeats", infile, kept, dropped)
330                                 return err
331                         }
332                         infile2 := fasta1FilenameRe.ReplaceAllString(infile, `.2.fa$1$2`)
333                         todo <- func() error {
334                                 defer phases.Done()
335                                 log.Printf("%s starting", infile2)
336                                 defer log.Printf("%s done", infile2)
337                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile2)
338                                 allstats[idx*2+1] = stats
339                                 var kept, dropped int
340                                 variants[1], kept, dropped = tseqs.Variants()
341                                 log.Printf("%s found %d unique tags plus %d repeats", infile2, kept, dropped)
342
343                                 return err
344                         }
345                 } else if fastaFilenameRe.MatchString(infile) {
346                         todo <- func() error {
347                                 defer phases.Done()
348                                 defer phases.Done()
349                                 log.Printf("%s starting", infile)
350                                 defer log.Printf("%s done", infile)
351                                 tseqs, stats, err := cmd.tileFasta(tilelib, infile)
352                                 allstats[idx*2] = stats
353                                 if err != nil {
354                                         return err
355                                 }
356                                 totlen := 0
357                                 for _, tseq := range tseqs {
358                                         totlen += len(tseq)
359                                 }
360                                 log.Printf("%s tiled %d seqs, total len %d", infile, len(tseqs), totlen)
361                                 return cmd.encoder.Encode(LibraryEntry{
362                                         CompactSequences: []CompactSequence{{Name: infile, TileSequences: tseqs}},
363                                 })
364                         }
365                         // Don't write out a CompactGenomes entry
366                         continue
367                 } else if vcfFilenameRe.MatchString(infile) {
368                         for phase := 0; phase < 2; phase++ {
369                                 phase := phase
370                                 todo <- func() error {
371                                         defer phases.Done()
372                                         log.Printf("%s phase %d starting", infile, phase+1)
373                                         defer log.Printf("%s phase %d done", infile, phase+1)
374                                         tseqs, stats, err := cmd.tileGVCF(tilelib, infile, phase)
375                                         allstats[idx*2] = stats
376                                         var kept, dropped int
377                                         variants[phase], kept, dropped = tseqs.Variants()
378                                         log.Printf("%s phase %d found %d unique tags plus %d repeats", infile, phase+1, kept, dropped)
379                                         return err
380                                 }
381                         }
382                 } else {
383                         panic(fmt.Sprintf("bug: unhandled filename %q", infile))
384                 }
385                 encodeJobs.Add(1)
386                 go func() {
387                         defer encodeJobs.Done()
388                         phases.Wait()
389                         if len(errs) > 0 {
390                                 return
391                         }
392                         err := cmd.encoder.Encode(LibraryEntry{
393                                 CompactGenomes: []CompactGenome{{Name: infile, Variants: flatten(variants)}},
394                         })
395                         if err != nil {
396                                 select {
397                                 case errs <- err:
398                                 default:
399                                 }
400                         }
401                 }()
402         }
403         go close(todo)
404         var tileJobs sync.WaitGroup
405         var running int64
406         for i := 0; i < runtime.NumCPU()*9/8+1; i++ {
407                 tileJobs.Add(1)
408                 atomic.AddInt64(&running, 1)
409                 go func() {
410                         defer tileJobs.Done()
411                         defer atomic.AddInt64(&running, -1)
412                         for fn := range todo {
413                                 if len(errs) > 0 {
414                                         return
415                                 }
416                                 err := fn()
417                                 if err != nil {
418                                         select {
419                                         case errs <- err:
420                                         default:
421                                         }
422                                 }
423                                 remain := len(todo) + int(atomic.LoadInt64(&running)) - 1
424                                 if remain < cap(todo) {
425                                         ttl := time.Now().Sub(starttime) * time.Duration(remain) / time.Duration(cap(todo)-remain)
426                                         eta := time.Now().Add(ttl)
427                                         log.Printf("progress %d/%d, eta %v (%v)", cap(todo)-remain, cap(todo), eta, ttl)
428                                 }
429                         }
430                 }()
431         }
432         tileJobs.Wait()
433         encodeJobs.Wait()
434
435         go close(errs)
436         err := <-errs
437         if err != nil {
438                 return err
439         }
440
441         if cmd.outputStats != "" {
442                 f, err := os.OpenFile(cmd.outputStats, os.O_CREATE|os.O_WRONLY, 0666)
443                 if err != nil {
444                         return err
445                 }
446                 var flatstats []importStats
447                 for _, stats := range allstats {
448                         flatstats = append(flatstats, stats...)
449                 }
450                 err = json.NewEncoder(f).Encode(flatstats)
451                 if err != nil {
452                         return err
453                 }
454         }
455
456         return nil
457 }
458
459 func (cmd *importer) tileGVCF(tilelib *tileLibrary, infile string, phase int) (tileseq tileSeq, stats []importStats, err error) {
460         if cmd.refFile == "" {
461                 err = errors.New("cannot import vcf: reference data (-ref) not specified")
462                 return
463         }
464         args := []string{"bcftools", "consensus", "--fasta-ref", cmd.refFile, "-H", fmt.Sprint(phase + 1), infile}
465         indexsuffix := ".tbi"
466         if _, err := os.Stat(infile + ".csi"); err == nil {
467                 indexsuffix = ".csi"
468         }
469         if out, err := exec.Command("docker", "image", "ls", "-q", "lightning-runtime").Output(); err == nil && len(out) > 0 {
470                 args = append([]string{
471                         "docker", "run", "--rm",
472                         "--log-driver=none",
473                         "--volume=" + infile + ":" + infile + ":ro",
474                         "--volume=" + infile + indexsuffix + ":" + infile + indexsuffix + ":ro",
475                         "--volume=" + cmd.refFile + ":" + cmd.refFile + ":ro",
476                         "lightning-runtime",
477                 }, args...)
478         }
479         consensus := exec.Command(args[0], args[1:]...)
480         consensus.Stderr = os.Stderr
481         stdout, err := consensus.StdoutPipe()
482         defer stdout.Close()
483         if err != nil {
484                 return
485         }
486         err = consensus.Start()
487         if err != nil {
488                 return
489         }
490         defer consensus.Wait()
491         tileseq, stats, err = tilelib.TileFasta(fmt.Sprintf("%s phase %d", infile, phase+1), stdout, cmd.matchChromosome)
492         if err != nil {
493                 return
494         }
495         err = stdout.Close()
496         if err != nil {
497                 return
498         }
499         err = consensus.Wait()
500         if err != nil {
501                 err = fmt.Errorf("%s phase %d: bcftools: %s", infile, phase, err)
502                 return
503         }
504         return
505 }
506
507 func flatten(variants [][]tileVariantID) []tileVariantID {
508         ntags := 0
509         for _, v := range variants {
510                 if ntags < len(v) {
511                         ntags = len(v)
512                 }
513         }
514         flat := make([]tileVariantID, ntags*2)
515         for i := 0; i < ntags; i++ {
516                 for hap := 0; hap < 2; hap++ {
517                         if i < len(variants[hap]) {
518                                 flat[i*2+hap] = variants[hap][i]
519                         }
520                 }
521         }
522         return flat
523 }
524
525 type importstatsplot struct{}
526
527 func (cmd *importstatsplot) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
528         err := cmd.Plot(stdin, stdout)
529         if err != nil {
530                 log.Errorf("%s", err)
531                 return 1
532         }
533         return 0
534 }
535
536 func (cmd *importstatsplot) Plot(stdin io.Reader, stdout io.Writer) error {
537         var stats []importStats
538         err := json.NewDecoder(stdin).Decode(&stats)
539         if err != nil {
540                 return err
541         }
542
543         p, err := plot.New()
544         if err != nil {
545                 return err
546         }
547         p.Title.Text = "coverage preserved by import (excl X<0.65)"
548         p.X.Label.Text = "input base calls ÷ sequence length"
549         p.Y.Label.Text = "output base calls ÷ input base calls"
550         p.Add(plotter.NewGrid())
551
552         data := map[string]plotter.XYs{}
553         for _, stat := range stats {
554                 data[stat.InputLabel] = append(data[stat.InputLabel], plotter.XY{
555                         X: float64(stat.InputCoverage) / float64(stat.InputLength),
556                         Y: float64(stat.TileCoverage) / float64(stat.InputCoverage),
557                 })
558         }
559
560         labels := []string{}
561         for label := range data {
562                 labels = append(labels, label)
563         }
564         sort.Strings(labels)
565         palette, err := colorful.SoftPalette(len(labels))
566         if err != nil {
567                 return err
568         }
569         nextInPalette := 0
570         for idx, label := range labels {
571                 s, err := plotter.NewScatter(data[label])
572                 if err != nil {
573                         return err
574                 }
575                 s.GlyphStyle.Color = palette[idx]
576                 s.GlyphStyle.Radius = vg.Millimeter / 2
577                 s.GlyphStyle.Shape = draw.CrossGlyph{}
578                 nextInPalette += 7
579                 p.Add(s)
580                 if false {
581                         p.Legend.Add(label, s)
582                 }
583         }
584         p.X.Min = 0.65
585         p.X.Max = 1
586
587         w, err := p.WriterTo(8*vg.Inch, 6*vg.Inch, "svg")
588         if err != nil {
589                 return err
590         }
591         _, err = w.WriteTo(stdout)
592         return err
593 }