Fix deadlock at container finish.
[lightning.git] / anno2vcf.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bufio"
9         "bytes"
10         "flag"
11         "fmt"
12         "io"
13         "io/ioutil"
14         "net/http"
15         _ "net/http/pprof"
16         "os"
17         "runtime"
18         "sort"
19         "strconv"
20         "strings"
21         "sync"
22
23         "git.arvados.org/arvados.git/sdk/go/arvados"
24         log "github.com/sirupsen/logrus"
25 )
26
27 type anno2vcf struct {
28 }
29
30 func (cmd *anno2vcf) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
31         var err error
32         defer func() {
33                 if err != nil {
34                         fmt.Fprintf(stderr, "%s\n", err)
35                 }
36         }()
37         flags := flag.NewFlagSet("", flag.ContinueOnError)
38         flags.SetOutput(stderr)
39         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
40         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
41         projectUUID := flags.String("project", "", "project `UUID` for output data")
42         priority := flags.Int("priority", 500, "container request priority")
43         inputDir := flags.String("input-dir", "./in", "input `directory`")
44         outputDir := flags.String("output-dir", "./out", "output `directory`")
45         err = flags.Parse(args)
46         if err == flag.ErrHelp {
47                 err = nil
48                 return 0
49         } else if err != nil {
50                 return 2
51         }
52
53         if *pprof != "" {
54                 go func() {
55                         log.Println(http.ListenAndServe(*pprof, nil))
56                 }()
57         }
58
59         if !*runlocal {
60                 runner := arvadosContainerRunner{
61                         Name:        "lightning anno2vcf",
62                         Client:      arvados.NewClientFromEnv(),
63                         ProjectUUID: *projectUUID,
64                         RAM:         500000000000,
65                         VCPUs:       64,
66                         Priority:    *priority,
67                         KeepCache:   2,
68                         APIAccess:   true,
69                 }
70                 err = runner.TranslatePaths(inputDir)
71                 if err != nil {
72                         return 1
73                 }
74                 runner.Args = []string{"anno2vcf", "-local=true",
75                         "-pprof", ":6060",
76                         "-input-dir", *inputDir,
77                         "-output-dir", "/mnt/output",
78                 }
79                 var output string
80                 output, err = runner.Run()
81                 if err != nil {
82                         return 1
83                 }
84                 fmt.Fprintln(stdout, output)
85                 return 0
86         }
87
88         d, err := open(*inputDir)
89         if err != nil {
90                 log.Print(err)
91                 return 1
92         }
93         defer d.Close()
94         fis, err := d.Readdir(-1)
95         if err != nil {
96                 log.Print(err)
97                 return 1
98         }
99         d.Close()
100         sort.Slice(fis, func(i, j int) bool { return fis[i].Name() < fis[j].Name() })
101
102         type call struct {
103                 tile      int
104                 variant   int
105                 position  int
106                 deletion  []byte
107                 insertion []byte
108                 hgvsID    []byte
109         }
110         allcalls := map[string][]*call{}
111         var mtx sync.Mutex
112         thr := throttle{Max: runtime.GOMAXPROCS(0)}
113         log.Print("reading input files")
114         for _, fi := range fis {
115                 if !strings.HasSuffix(fi.Name(), "annotations.csv") {
116                         continue
117                 }
118                 filename := *inputDir + "/" + fi.Name()
119                 thr.Go(func() error {
120                         log.Printf("reading %s", filename)
121                         buf, err := ioutil.ReadFile(filename)
122                         if err != nil {
123                                 return fmt.Errorf("%s: %s", filename, err)
124                         }
125                         lines := bytes.Split(buf, []byte{'\n'})
126                         calls := map[string][]*call{}
127                         for lineIdx, line := range lines {
128                                 if len(line) == 0 {
129                                         continue
130                                 }
131                                 if lineIdx&0xff == 0 && thr.Err() != nil {
132                                         return nil
133                                 }
134                                 fields := bytes.Split(line, []byte{','})
135                                 if len(fields) < 8 {
136                                         return fmt.Errorf("%s line %d: wrong number of fields (%d < %d): %q", fi.Name(), lineIdx+1, len(fields), 8, line)
137                                 }
138                                 hgvsID := fields[3]
139                                 if len(hgvsID) < 2 {
140                                         // "=" reference or ""
141                                         // non-diffable tile variant
142                                         continue
143                                 }
144                                 tile, _ := strconv.ParseInt(string(fields[0]), 10, 64)
145                                 variant, _ := strconv.ParseInt(string(fields[2]), 10, 64)
146                                 position, _ := strconv.ParseInt(string(fields[5]), 10, 64)
147                                 seq := string(fields[4])
148                                 if calls[seq] == nil {
149                                         calls[seq] = make([]*call, 0, len(lines)/50)
150                                 }
151                                 del := fields[6]
152                                 ins := fields[7]
153                                 if (len(del) == 0 || len(ins) == 0) && len(fields) >= 9 {
154                                         // "123,,AA,T" means 123insAA
155                                         // preceded by T. We record it
156                                         // here as "122 T TAA" to
157                                         // avoid writing an empty
158                                         // "ref" field in our
159                                         // VCF. Similarly, we record
160                                         // deletions as "122 TAA T"
161                                         // rather than "123 AA .".
162                                         del = append(append(make([]byte, 0, len(fields[8])+len(del)), fields[8]...), del...)
163                                         ins = append(append(make([]byte, 0, len(fields[8])+len(ins)), fields[8]...), ins...)
164                                         position -= int64(len(fields[8]))
165                                 } else {
166                                         del = append([]byte(nil), del...)
167                                         ins = append([]byte(nil), ins...)
168                                 }
169                                 calls[seq] = append(calls[seq], &call{
170                                         tile:      int(tile),
171                                         variant:   int(variant),
172                                         position:  int(position),
173                                         deletion:  del,
174                                         insertion: ins,
175                                         hgvsID:    hgvsID,
176                                 })
177                         }
178                         mtx.Lock()
179                         for seq, seqcalls := range calls {
180                                 allcalls[seq] = append(allcalls[seq], seqcalls...)
181                         }
182                         mtx.Unlock()
183                         return nil
184                 })
185         }
186         err = thr.Wait()
187         if err != nil {
188                 return 1
189         }
190         thr = throttle{Max: len(allcalls)}
191         for seq, seqcalls := range allcalls {
192                 seq, seqcalls := seq, seqcalls
193                 thr.Go(func() error {
194                         log.Printf("%s: sorting", seq)
195                         sort.Slice(seqcalls, func(i, j int) bool {
196                                 ii, jj := seqcalls[i], seqcalls[j]
197                                 if cmp := ii.position - jj.position; cmp != 0 {
198                                         return cmp < 0
199                                 }
200                                 if cmp := len(ii.deletion) - len(jj.deletion); cmp != 0 {
201                                         return cmp < 0
202                                 }
203                                 if cmp := bytes.Compare(ii.insertion, jj.insertion); cmp != 0 {
204                                         return cmp < 0
205                                 }
206                                 if cmp := ii.tile - jj.tile; cmp != 0 {
207                                         return cmp < 0
208                                 }
209                                 return ii.variant < jj.variant
210                         })
211
212                         vcfFilename := fmt.Sprintf("%s/annotations.%s.vcf", *outputDir, seq)
213                         log.Printf("%s: writing %s", seq, vcfFilename)
214
215                         f, err := os.Create(vcfFilename)
216                         if err != nil {
217                                 return err
218                         }
219                         defer f.Close()
220                         bufw := bufio.NewWriterSize(f, 1<<20)
221                         _, err = fmt.Fprintf(bufw, `##fileformat=VCFv4.0
222 ##INFO=<ID=TV,Number=.,Type=String,Description="tile-variant">
223 #CHROM  POS     ID      REF     ALT     QUAL    FILTER  INFO
224 `)
225                         if err != nil {
226                                 return err
227                         }
228                         placeholder := []byte{'.'}
229                         for i := 0; i < len(seqcalls); {
230                                 call := seqcalls[i]
231                                 i++
232                                 info := fmt.Sprintf("TV=,%d-%d,", call.tile, call.variant)
233                                 for i < len(seqcalls) &&
234                                         call.position == seqcalls[i].position &&
235                                         len(call.deletion) == len(seqcalls[i].deletion) &&
236                                         bytes.Equal(call.insertion, seqcalls[i].insertion) {
237                                         call = seqcalls[i]
238                                         i++
239                                         info += fmt.Sprintf("%d-%d,", call.tile, call.variant)
240                                 }
241                                 deletion := call.deletion
242                                 if len(deletion) == 0 {
243                                         deletion = placeholder
244                                 }
245                                 insertion := call.insertion
246                                 if len(insertion) == 0 {
247                                         insertion = placeholder
248                                 }
249                                 _, err = fmt.Fprintf(bufw, "%s\t%d\t%s\t%s\t%s\t.\t.\t%s\n", seq, call.position, call.hgvsID, deletion, insertion, info)
250                                 if err != nil {
251                                         return err
252                                 }
253                         }
254                         err = bufw.Flush()
255                         if err != nil {
256                                 return err
257                         }
258                         err = f.Close()
259                         if err != nil {
260                                 return err
261                         }
262                         log.Printf("%s: done", seq)
263                         return nil
264                 })
265         }
266         err = thr.Wait()
267         if err != nil {
268                 return 1
269         }
270         return 0
271 }