Fix some tests.
[lightning.git] / anno2vcf.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bufio"
9         "bytes"
10         "flag"
11         "fmt"
12         "io"
13         "net/http"
14         _ "net/http/pprof"
15         "os"
16         "runtime"
17         "sort"
18         "strconv"
19         "strings"
20         "sync"
21
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         log "github.com/sirupsen/logrus"
24 )
25
26 type anno2vcf struct {
27 }
28
29 func (cmd *anno2vcf) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
30         var err error
31         defer func() {
32                 if err != nil {
33                         fmt.Fprintf(stderr, "%s\n", err)
34                 }
35         }()
36         flags := flag.NewFlagSet("", flag.ContinueOnError)
37         flags.SetOutput(stderr)
38         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
39         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
40         projectUUID := flags.String("project", "", "project `UUID` for output data")
41         priority := flags.Int("priority", 500, "container request priority")
42         inputDir := flags.String("input-dir", "./in", "input `directory`")
43         outputDir := flags.String("output-dir", "./out", "output `directory`")
44         err = flags.Parse(args)
45         if err == flag.ErrHelp {
46                 err = nil
47                 return 0
48         } else if err != nil {
49                 return 2
50         } else if flags.NArg() > 0 {
51                 err = fmt.Errorf("errant command line arguments after parsed flags: %v", flags.Args())
52                 return 2
53         }
54
55         if *pprof != "" {
56                 go func() {
57                         log.Println(http.ListenAndServe(*pprof, nil))
58                 }()
59         }
60
61         if !*runlocal {
62                 runner := arvadosContainerRunner{
63                         Name:        "lightning anno2vcf",
64                         Client:      arvados.NewClientFromEnv(),
65                         ProjectUUID: *projectUUID,
66                         RAM:         500000000000,
67                         VCPUs:       64,
68                         Priority:    *priority,
69                         KeepCache:   2,
70                         APIAccess:   true,
71                 }
72                 err = runner.TranslatePaths(inputDir)
73                 if err != nil {
74                         return 1
75                 }
76                 runner.Args = []string{"anno2vcf", "-local=true",
77                         "-pprof", ":6060",
78                         "-input-dir", *inputDir,
79                         "-output-dir", "/mnt/output",
80                 }
81                 var output string
82                 output, err = runner.Run()
83                 if err != nil {
84                         return 1
85                 }
86                 fmt.Fprintln(stdout, output)
87                 return 0
88         }
89
90         d, err := open(*inputDir)
91         if err != nil {
92                 log.Print(err)
93                 return 1
94         }
95         defer d.Close()
96         fis, err := d.Readdir(-1)
97         if err != nil {
98                 log.Print(err)
99                 return 1
100         }
101         d.Close()
102         sort.Slice(fis, func(i, j int) bool { return fis[i].Name() < fis[j].Name() })
103
104         type call struct {
105                 tile      int
106                 variant   int
107                 position  int
108                 deletion  []byte
109                 insertion []byte
110                 hgvsID    []byte
111         }
112         allcalls := map[string][]*call{}
113         var mtx sync.Mutex
114         thr := throttle{Max: runtime.GOMAXPROCS(0)}
115         log.Print("reading input files")
116         for _, fi := range fis {
117                 if !strings.HasSuffix(fi.Name(), "annotations.csv") {
118                         continue
119                 }
120                 filename := *inputDir + "/" + fi.Name()
121                 thr.Go(func() error {
122                         log.Printf("reading %s", filename)
123                         f, err := open(filename)
124                         if err != nil {
125                                 return err
126                         }
127                         defer f.Close()
128                         buf, err := io.ReadAll(f)
129                         if err != nil {
130                                 return fmt.Errorf("%s: %s", filename, err)
131                         }
132                         f.Close()
133                         lines := bytes.Split(buf, []byte{'\n'})
134                         calls := map[string][]*call{}
135                         for lineIdx, line := range lines {
136                                 if len(line) == 0 {
137                                         continue
138                                 }
139                                 if lineIdx&0xff == 0 && thr.Err() != nil {
140                                         return nil
141                                 }
142                                 fields := bytes.Split(line, []byte{','})
143                                 if len(fields) < 8 {
144                                         return fmt.Errorf("%s line %d: wrong number of fields (%d < %d): %q", fi.Name(), lineIdx+1, len(fields), 8, line)
145                                 }
146                                 hgvsID := fields[3]
147                                 if len(hgvsID) < 2 {
148                                         // "=" reference or ""
149                                         // non-diffable tile variant
150                                         continue
151                                 }
152                                 tile, _ := strconv.ParseInt(string(fields[0]), 10, 64)
153                                 variant, _ := strconv.ParseInt(string(fields[2]), 10, 64)
154                                 position, _ := strconv.ParseInt(string(fields[5]), 10, 64)
155                                 seq := string(fields[4])
156                                 if calls[seq] == nil {
157                                         calls[seq] = make([]*call, 0, len(lines)/50)
158                                 }
159                                 del := fields[6]
160                                 ins := fields[7]
161                                 if (len(del) == 0 || len(ins) == 0) && len(fields) >= 9 {
162                                         // "123,,AA,T" means 123insAA
163                                         // preceded by T. We record it
164                                         // here as "122 T TAA" to
165                                         // avoid writing an empty
166                                         // "ref" field in our
167                                         // VCF. Similarly, we record
168                                         // deletions as "122 TAA T"
169                                         // rather than "123 AA .".
170                                         del = append(append(make([]byte, 0, len(fields[8])+len(del)), fields[8]...), del...)
171                                         ins = append(append(make([]byte, 0, len(fields[8])+len(ins)), fields[8]...), ins...)
172                                         position -= int64(len(fields[8]))
173                                 } else {
174                                         del = append([]byte(nil), del...)
175                                         ins = append([]byte(nil), ins...)
176                                 }
177                                 calls[seq] = append(calls[seq], &call{
178                                         tile:      int(tile),
179                                         variant:   int(variant),
180                                         position:  int(position),
181                                         deletion:  del,
182                                         insertion: ins,
183                                         hgvsID:    hgvsID,
184                                 })
185                         }
186                         mtx.Lock()
187                         for seq, seqcalls := range calls {
188                                 allcalls[seq] = append(allcalls[seq], seqcalls...)
189                         }
190                         mtx.Unlock()
191                         return nil
192                 })
193         }
194         err = thr.Wait()
195         if err != nil {
196                 return 1
197         }
198         thr = throttle{Max: len(allcalls)}
199         for seq, seqcalls := range allcalls {
200                 seq, seqcalls := seq, seqcalls
201                 thr.Go(func() error {
202                         log.Printf("%s: sorting", seq)
203                         sort.Slice(seqcalls, func(i, j int) bool {
204                                 ii, jj := seqcalls[i], seqcalls[j]
205                                 if cmp := ii.position - jj.position; cmp != 0 {
206                                         return cmp < 0
207                                 }
208                                 if cmp := len(ii.deletion) - len(jj.deletion); cmp != 0 {
209                                         return cmp < 0
210                                 }
211                                 if cmp := bytes.Compare(ii.insertion, jj.insertion); cmp != 0 {
212                                         return cmp < 0
213                                 }
214                                 if cmp := ii.tile - jj.tile; cmp != 0 {
215                                         return cmp < 0
216                                 }
217                                 return ii.variant < jj.variant
218                         })
219
220                         vcfFilename := fmt.Sprintf("%s/annotations.%s.vcf", *outputDir, seq)
221                         log.Printf("%s: writing %s", seq, vcfFilename)
222
223                         f, err := os.Create(vcfFilename)
224                         if err != nil {
225                                 return err
226                         }
227                         defer f.Close()
228                         bufw := bufio.NewWriterSize(f, 1<<20)
229                         _, err = fmt.Fprintf(bufw, `##fileformat=VCFv4.0
230 ##INFO=<ID=TV,Number=.,Type=String,Description="tile-variant">
231 #CHROM  POS     ID      REF     ALT     QUAL    FILTER  INFO
232 `)
233                         if err != nil {
234                                 return err
235                         }
236                         placeholder := []byte{'.'}
237                         for i := 0; i < len(seqcalls); {
238                                 call := seqcalls[i]
239                                 i++
240                                 info := fmt.Sprintf("TV=,%d-%d,", call.tile, call.variant)
241                                 for i < len(seqcalls) &&
242                                         call.position == seqcalls[i].position &&
243                                         len(call.deletion) == len(seqcalls[i].deletion) &&
244                                         bytes.Equal(call.insertion, seqcalls[i].insertion) {
245                                         call = seqcalls[i]
246                                         i++
247                                         info += fmt.Sprintf("%d-%d,", call.tile, call.variant)
248                                 }
249                                 deletion := call.deletion
250                                 if len(deletion) == 0 {
251                                         deletion = placeholder
252                                 }
253                                 insertion := call.insertion
254                                 if len(insertion) == 0 {
255                                         insertion = placeholder
256                                 }
257                                 _, err = fmt.Fprintf(bufw, "%s\t%d\t%s\t%s\t%s\t.\t.\t%s\n", seq, call.position, call.hgvsID, deletion, insertion, info)
258                                 if err != nil {
259                                         return err
260                                 }
261                         }
262                         err = bufw.Flush()
263                         if err != nil {
264                                 return err
265                         }
266                         err = f.Close()
267                         if err != nil {
268                                 return err
269                         }
270                         log.Printf("%s: done", seq)
271                         return nil
272                 })
273         }
274         err = thr.Wait()
275         if err != nil {
276                 return 1
277         }
278         return 0
279 }