5a0258461883d2b6c627beaf11abd5f20987c235
[lightning.git] / gvcf2numpy.go
1 package main
2
3 import (
4         "bufio"
5         "compress/gzip"
6         "flag"
7         "fmt"
8         "io"
9         "log"
10         "os"
11         "os/exec"
12         "path/filepath"
13         "regexp"
14         "runtime"
15         "sort"
16         "strings"
17         "sync"
18         "time"
19
20         "github.com/kshedden/gonpy"
21 )
22
23 type gvcf2numpy struct {
24         tagLibraryFile string
25         refFile        string
26         output         io.Writer
27 }
28
29 func (cmd *gvcf2numpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
30         var err error
31         defer func() {
32                 if err != nil {
33                         fmt.Fprintf(stderr, "%s\n", err)
34                 }
35         }()
36         flags := flag.NewFlagSet("", flag.ContinueOnError)
37         flags.SetOutput(stderr)
38         flags.StringVar(&cmd.tagLibraryFile, "tag-library", "", "tag library fasta `file`")
39         flags.StringVar(&cmd.refFile, "ref", "", "reference fasta `file`")
40         err = flags.Parse(args)
41         if err == flag.ErrHelp {
42                 err = nil
43                 return 0
44         } else if err != nil {
45                 return 2
46         } else if cmd.refFile == "" || cmd.tagLibraryFile == "" {
47                 fmt.Fprintln(os.Stderr, "cannot run without -tag-library and -ref arguments")
48                 return 2
49         } else if flags.NArg() == 0 {
50                 flags.Usage()
51                 return 2
52         }
53         cmd.output = stdout
54
55         infiles, err := listInputFiles(flags.Args())
56         if err != nil {
57                 return 1
58         }
59
60         tilelib, err := cmd.loadTileLibrary()
61         if err != nil {
62                 return 1
63         }
64         go func() {
65                 for range time.Tick(10 * time.Second) {
66                         log.Printf("tilelib.Len() == %d", tilelib.Len())
67                 }
68         }()
69         tseqs, err := cmd.tileGVCFs(tilelib, infiles)
70         if err != nil {
71                 return 1
72         }
73         err = cmd.printVariants(tseqs)
74         if err != nil {
75                 return 1
76         }
77         return 0
78 }
79
80 func (cmd *gvcf2numpy) tileFasta(tilelib *tileLibrary, infile string) (tileSeq, error) {
81         var input io.ReadCloser
82         input, err := os.Open(infile)
83         if err != nil {
84                 return nil, err
85         }
86         defer input.Close()
87         if strings.HasSuffix(infile, ".gz") {
88                 input, err = gzip.NewReader(input)
89                 if err != nil {
90                         return nil, err
91                 }
92                 defer input.Close()
93         }
94         return tilelib.TileFasta(infile, input)
95 }
96
97 func (cmd *gvcf2numpy) loadTileLibrary() (*tileLibrary, error) {
98         log.Printf("tag library %s load starting", cmd.tagLibraryFile)
99         f, err := os.Open(cmd.tagLibraryFile)
100         if err != nil {
101                 return nil, err
102         }
103         defer f.Close()
104         var rdr io.ReadCloser = f
105         if strings.HasSuffix(cmd.tagLibraryFile, ".gz") {
106                 rdr, err = gzip.NewReader(f)
107                 if err != nil {
108                         return nil, fmt.Errorf("%s: gzip: %s", cmd.tagLibraryFile, err)
109                 }
110                 defer rdr.Close()
111         }
112         var taglib tagLibrary
113         err = taglib.Load(rdr)
114         if err != nil {
115                 return nil, err
116         }
117         if taglib.Len() < 1 {
118                 return nil, fmt.Errorf("cannot tile: tag library is empty")
119         }
120         log.Printf("tag library %s load done", cmd.tagLibraryFile)
121         return &tileLibrary{taglib: &taglib}, nil
122 }
123
124 func listInputFiles(paths []string) (files []string, err error) {
125         for _, path := range paths {
126                 if fi, err := os.Stat(path); err != nil {
127                         return nil, fmt.Errorf("%s: stat failed: %s", path, err)
128                 } else if !fi.IsDir() {
129                         if !strings.HasSuffix(path, ".2.fasta") || strings.HasSuffix(path, ".2.fasta.gz") {
130                                 files = append(files, path)
131                         }
132                         continue
133                 }
134                 d, err := os.Open(path)
135                 if err != nil {
136                         return nil, fmt.Errorf("%s: open failed: %s", path, err)
137                 }
138                 defer d.Close()
139                 names, err := d.Readdirnames(0)
140                 if err != nil {
141                         return nil, fmt.Errorf("%s: readdir failed: %s", path, err)
142                 }
143                 sort.Strings(names)
144                 for _, name := range names {
145                         if strings.HasSuffix(name, ".vcf") || strings.HasSuffix(name, ".vcf.gz") {
146                                 files = append(files, filepath.Join(path, name))
147                         } else if strings.HasSuffix(name, ".1.fasta") || strings.HasSuffix(name, ".1.fasta.gz") {
148                                 files = append(files, filepath.Join(path, name))
149                         }
150                 }
151                 d.Close()
152         }
153         for _, file := range files {
154                 if strings.HasSuffix(file, ".1.fasta") || strings.HasSuffix(file, ".1.fasta.gz") {
155                         continue
156                 } else if _, err := os.Stat(file + ".csi"); err == nil {
157                         continue
158                 } else if _, err = os.Stat(file + ".tbi"); err == nil {
159                         continue
160                 } else {
161                         return nil, fmt.Errorf("%s: cannot read without .tbi or .csi index file", file)
162                 }
163         }
164         return
165 }
166
167 func (cmd *gvcf2numpy) tileGVCFs(tilelib *tileLibrary, infiles []string) ([]tileSeq, error) {
168         starttime := time.Now()
169         errs := make(chan error, 1)
170         tseqs := make([]tileSeq, len(infiles)*2)
171         todo := make(chan func() error, len(infiles)*2)
172         var wg sync.WaitGroup
173         for i, infile := range infiles {
174                 i, infile := i, infile
175                 if strings.HasSuffix(infile, ".1.fasta") || strings.HasSuffix(infile, ".1.fasta.gz") {
176                         todo <- func() (err error) {
177                                 log.Printf("%s starting", infile)
178                                 defer log.Printf("%s done", infile)
179                                 tseqs[i*2], err = cmd.tileFasta(tilelib, infile)
180                                 return
181                         }
182                         infile2 := regexp.MustCompile(`\.1\.fasta(\.gz)?$`).ReplaceAllString(infile, `.2.fasta$1`)
183                         todo <- func() (err error) {
184                                 log.Printf("%s starting", infile2)
185                                 defer log.Printf("%s done", infile2)
186                                 tseqs[i*2+1], err = cmd.tileFasta(tilelib, infile2)
187                                 return
188                         }
189                 } else {
190                         for phase := 0; phase < 2; phase++ {
191                                 phase := phase
192                                 todo <- func() (err error) {
193                                         log.Printf("%s phase %d starting", infile, phase+1)
194                                         defer log.Printf("%s phase %d done", infile, phase+1)
195                                         tseqs[i*2+phase], err = cmd.tileGVCF(tilelib, infile, phase)
196                                         return
197                                 }
198                         }
199                 }
200         }
201         go close(todo)
202         for i := 0; i < runtime.NumCPU(); i++ {
203                 wg.Add(1)
204                 go func() {
205                         defer wg.Done()
206                         for fn := range todo {
207                                 if len(errs) > 0 {
208                                         return
209                                 }
210                                 err := fn()
211                                 if err != nil {
212                                         select {
213                                         case errs <- err:
214                                         default:
215                                         }
216                                 }
217                                 remain := len(todo)
218                                 ttl := time.Now().Sub(starttime) * time.Duration(remain) / time.Duration(cap(todo)-remain)
219                                 eta := time.Now().Add(ttl)
220                                 log.Printf("progress %d/%d, eta %v (%v)", cap(todo)-remain, cap(todo), eta, ttl)
221                         }
222                 }()
223         }
224         wg.Wait()
225         go close(errs)
226         return tseqs, <-errs
227 }
228
229 func (cmd *gvcf2numpy) printVariants(tseqs []tileSeq) error {
230         maxtag := tagID(-1)
231         for _, tseq := range tseqs {
232                 for _, path := range tseq {
233                         for _, tvar := range path {
234                                 if maxtag < tvar.tag {
235                                         maxtag = tvar.tag
236                                 }
237                         }
238                 }
239         }
240         rows := len(tseqs) / 2
241         cols := 2 * (int(maxtag) + 1)
242         out := make([]uint16, rows*cols)
243         for row := 0; row < len(tseqs)/2; row++ {
244                 for phase := 0; phase < 2; phase++ {
245                         for _, path := range tseqs[row*2+phase] {
246                                 for _, tvar := range path {
247                                         out[row*cols+2*int(tvar.tag)+phase] = uint16(tvar.variant)
248                                 }
249                         }
250                 }
251         }
252         w := bufio.NewWriter(cmd.output)
253         npw, err := gonpy.NewWriter(nopCloser{w})
254         if err != nil {
255                 return err
256         }
257         npw.Shape = []int{rows, cols}
258         npw.WriteUint16(out)
259         return w.Flush()
260 }
261
262 type nopCloser struct {
263         io.Writer
264 }
265
266 func (nopCloser) Close() error { return nil }
267
268 func (cmd *gvcf2numpy) tileGVCF(tilelib *tileLibrary, infile string, phase int) (tileseq tileSeq, err error) {
269         args := []string{"bcftools", "consensus", "--fasta-ref", cmd.refFile, "-H", fmt.Sprint(phase + 1), infile}
270         indexsuffix := ".tbi"
271         if _, err := os.Stat(infile + ".csi"); err == nil {
272                 indexsuffix = ".csi"
273         }
274         if out, err := exec.Command("docker", "image", "ls", "-q", "lightning-runtime").Output(); err == nil && len(out) > 0 {
275                 args = append([]string{
276                         "docker", "run", "--rm",
277                         "--log-driver=none",
278                         "--volume=" + infile + ":" + infile + ":ro",
279                         "--volume=" + infile + indexsuffix + ":" + infile + indexsuffix + ":ro",
280                         "--volume=" + cmd.refFile + ":" + cmd.refFile + ":ro",
281                         "lightning-runtime",
282                 }, args...)
283         }
284         consensus := exec.Command(args[0], args[1:]...)
285         consensus.Stderr = os.Stderr
286         stdout, err := consensus.StdoutPipe()
287         defer stdout.Close()
288         if err != nil {
289                 return
290         }
291         err = consensus.Start()
292         if err != nil {
293                 return
294         }
295         defer consensus.Wait()
296         tileseq, err = tilelib.TileFasta(fmt.Sprintf("%s phase %d", infile, phase+1), stdout)
297         if err != nil {
298                 return
299         }
300         err = stdout.Close()
301         if err != nil {
302                 return
303         }
304         err = consensus.Wait()
305         if err != nil {
306                 err = fmt.Errorf("%s phase %d: bcftools: %s", infile, phase, err)
307                 return
308         }
309         return
310 }