Fix deadlock from unclosed chan.
[lightning.git] / gvcf2numpy.go
1 package main
2
3 import (
4         "bufio"
5         "compress/gzip"
6         "flag"
7         "fmt"
8         "io"
9         "log"
10         "os"
11         "os/exec"
12         "path/filepath"
13         "runtime"
14         "sort"
15         "strings"
16         "sync"
17         "time"
18
19         "github.com/kshedden/gonpy"
20 )
21
22 type gvcf2numpy struct {
23         tagLibraryFile string
24         refFile        string
25         output         io.Writer
26 }
27
28 func (cmd *gvcf2numpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
29         var err error
30         defer func() {
31                 if err != nil {
32                         fmt.Fprintf(stderr, "%s\n", err)
33                 }
34         }()
35         flags := flag.NewFlagSet("", flag.ContinueOnError)
36         flags.SetOutput(stderr)
37         flags.StringVar(&cmd.tagLibraryFile, "tag-library", "", "tag library fasta `file`")
38         flags.StringVar(&cmd.refFile, "ref", "", "reference fasta `file`")
39         err = flags.Parse(args)
40         if err == flag.ErrHelp {
41                 err = nil
42                 return 0
43         } else if err != nil {
44                 return 2
45         } else if cmd.refFile == "" || cmd.tagLibraryFile == "" {
46                 fmt.Fprintln(os.Stderr, "cannot run without -tag-library and -ref arguments")
47                 return 2
48         } else if flags.NArg() == 0 {
49                 flags.Usage()
50                 return 2
51         }
52         cmd.output = stdout
53
54         infiles, err := listVCFFiles(flags.Args())
55         if err != nil {
56                 return 1
57         }
58
59         log.Printf("tag library %s load starting", cmd.tagLibraryFile)
60         f, err := os.Open(cmd.tagLibraryFile)
61         if err != nil {
62                 return 1
63         }
64         var rdr io.ReadCloser = f
65         if strings.HasSuffix(cmd.tagLibraryFile, ".gz") {
66                 rdr, err = gzip.NewReader(f)
67                 if err != nil {
68                         err = fmt.Errorf("%s: gzip: %s", cmd.tagLibraryFile, err)
69                         return 1
70                 }
71         }
72         var taglib tagLibrary
73         err = taglib.Load(rdr)
74         if err != nil {
75                 return 1
76         }
77         if taglib.Len() < 1 {
78                 err = fmt.Errorf("cannot tile: tag library is empty")
79                 return 1
80         }
81         log.Printf("tag library %s load done", cmd.tagLibraryFile)
82
83         tilelib := tileLibrary{taglib: &taglib}
84         go func() {
85                 for range time.Tick(10 * time.Second) {
86                         log.Printf("tilelib.Len() == %d", tilelib.Len())
87                 }
88         }()
89         tseqs, err := cmd.tileGVCFs(&tilelib, infiles)
90         if err != nil {
91                 return 1
92         }
93         err = cmd.printVariants(tseqs)
94         if err != nil {
95                 return 1
96         }
97         return 0
98 }
99
100 func listVCFFiles(paths []string) (files []string, err error) {
101         for _, path := range paths {
102                 if fi, err := os.Stat(path); err != nil {
103                         return nil, fmt.Errorf("%s: stat failed: %s", path, err)
104                 } else if !fi.IsDir() {
105                         files = append(files, path)
106                         continue
107                 }
108                 d, err := os.Open(path)
109                 if err != nil {
110                         return nil, fmt.Errorf("%s: open failed: %s", path, err)
111                 }
112                 defer d.Close()
113                 names, err := d.Readdirnames(0)
114                 if err != nil {
115                         return nil, fmt.Errorf("%s: readdir failed: %s", path, err)
116                 }
117                 sort.Strings(names)
118                 for _, name := range names {
119                         if strings.HasSuffix(name, ".vcf") || strings.HasSuffix(name, ".vcf.gz") {
120                                 files = append(files, filepath.Join(path, name))
121                         }
122                 }
123                 d.Close()
124         }
125         for _, file := range files {
126                 if _, err := os.Stat(file + ".csi"); err == nil {
127                         continue
128                 } else if _, err = os.Stat(file + ".tbi"); err == nil {
129                         continue
130                 } else {
131                         return nil, fmt.Errorf("%s: cannot read without .tbi or .csi index file", file)
132                 }
133         }
134         return
135 }
136
137 func (cmd *gvcf2numpy) tileGVCFs(tilelib *tileLibrary, infiles []string) ([]tileSeq, error) {
138         starttime := time.Now()
139         errs := make(chan error, 1)
140         tseqs := make([]tileSeq, len(infiles)*2)
141         todo := make(chan func() error, len(infiles)*2)
142         var wg sync.WaitGroup
143         for i, infile := range infiles {
144                 for phase := 0; phase < 2; phase++ {
145                         i, infile, phase := i, infile, phase
146                         todo <- func() (err error) {
147                                 log.Printf("%s phase %d starting", infile, phase+1)
148                                 defer log.Printf("%s phase %d done", infile, phase+1)
149                                 tseqs[i*2+phase], err = cmd.tileGVCF(tilelib, infile, phase)
150                                 return
151                         }
152                 }
153         }
154         go close(todo)
155         for i := 0; i < runtime.NumCPU(); i++ {
156                 wg.Add(1)
157                 go func() {
158                         defer wg.Done()
159                         for fn := range todo {
160                                 if len(errs) > 0 {
161                                         return
162                                 }
163                                 err := fn()
164                                 if err != nil {
165                                         select {
166                                         case errs <- err:
167                                         default:
168                                         }
169                                 }
170                                 remain := len(todo)
171                                 ttl := time.Now().Sub(starttime) * time.Duration(remain) / time.Duration(cap(todo)-remain)
172                                 eta := time.Now().Add(ttl)
173                                 log.Printf("progress %d/%d, eta %v (%v)", cap(todo)-remain, cap(todo), eta, ttl)
174                         }
175                 }()
176         }
177         wg.Wait()
178         go close(errs)
179         return tseqs, <-errs
180 }
181
182 func (cmd *gvcf2numpy) printVariants(tseqs []tileSeq) error {
183         maxtag := tagID(-1)
184         for _, tseq := range tseqs {
185                 for _, path := range tseq {
186                         for _, tvar := range path {
187                                 if maxtag < tvar.tag {
188                                         maxtag = tvar.tag
189                                 }
190                         }
191                 }
192         }
193         out := make([]uint16, len(tseqs)*int(maxtag+1))
194         for i := 0; i < len(tseqs)/2; i++ {
195                 for phase := 0; phase < 2; phase++ {
196                         for _, path := range tseqs[i*2+phase] {
197                                 for _, tvar := range path {
198                                         out[2*int(tvar.tag)+phase] = uint16(tvar.variant)
199                                 }
200                         }
201                 }
202         }
203         w := bufio.NewWriter(cmd.output)
204         npw, err := gonpy.NewWriter(nopCloser{w})
205         if err != nil {
206                 return err
207         }
208         npw.Shape = []int{len(tseqs) / 2, 2 * (int(maxtag) + 1)}
209         npw.WriteUint16(out)
210         return w.Flush()
211 }
212
213 type nopCloser struct {
214         io.Writer
215 }
216
217 func (nopCloser) Close() error { return nil }
218
219 func (cmd *gvcf2numpy) tileGVCF(tilelib *tileLibrary, infile string, phase int) (tileseq tileSeq, err error) {
220         args := []string{"bcftools", "consensus", "--fasta-ref", cmd.refFile, "-H", fmt.Sprint(phase + 1), infile}
221         indexsuffix := ".tbi"
222         if _, err := os.Stat(infile + ".csi"); err == nil {
223                 indexsuffix = ".csi"
224         }
225         if out, err := exec.Command("docker", "image", "ls", "-q", "lightning-runtime").Output(); err == nil && len(out) > 0 {
226                 args = append([]string{
227                         "docker", "run", "--rm",
228                         "--log-driver=none",
229                         "--volume=" + infile + ":" + infile + ":ro",
230                         "--volume=" + infile + indexsuffix + ":" + infile + indexsuffix + ":ro",
231                         "--volume=" + cmd.refFile + ":" + cmd.refFile + ":ro",
232                         "lightning-runtime",
233                 }, args...)
234         }
235         consensus := exec.Command(args[0], args[1:]...)
236         consensus.Stderr = os.Stderr
237         stdout, err := consensus.StdoutPipe()
238         if err != nil {
239                 return
240         }
241         err = consensus.Start()
242         if err != nil {
243                 return
244         }
245         tileseq, err = tilelib.TileFasta(fmt.Sprintf("%s phase %d", infile, phase+1), stdout)
246         if err != nil {
247                 return
248         }
249         err = stdout.Close()
250         if err != nil {
251                 return
252         }
253         err = consensus.Wait()
254         if err != nil {
255                 err = fmt.Errorf("%s phase %d: bcftools: %s", infile, phase, err)
256                 return
257         }
258         return
259 }