Default to applying min-coverage filter based on training set only.
[lightning.git] / numpycomvar.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "flag"
9         "fmt"
10         "io"
11         _ "net/http/pprof"
12
13         "git.arvados.org/arvados.git/sdk/go/arvados"
14 )
15
16 type numpyComVar struct{}
17
18 func (cmd *numpyComVar) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
19         var err error
20         defer func() {
21                 if err != nil {
22                         fmt.Fprintf(stderr, "%s\n", err)
23                 }
24         }()
25         flags := flag.NewFlagSet("", flag.ContinueOnError)
26         flags.SetOutput(stderr)
27         projectUUID := flags.String("project", "", "project `UUID` for output data")
28         inputFilename := flags.String("i", "-", "numpy matrix `file`")
29         priority := flags.Int("priority", 500, "container request priority")
30         annotationsFilename := flags.String("annotations", "", "annotations tsv `file`")
31         maxResults := flags.Int("max-results", 256, "maximum number of tile variants to output")
32         minFrequency := flags.Float64("min-frequency", 0.4, "minimum allele frequency")
33         maxFrequency := flags.Float64("max-frequency", 0.6, "maximum allele frequency")
34         err = flags.Parse(args)
35         if err == flag.ErrHelp {
36                 err = nil
37                 return 0
38         } else if err != nil {
39                 return 2
40         } else if flags.NArg() > 0 {
41                 err = fmt.Errorf("errant command line arguments after parsed flags: %v", flags.Args())
42                 return 2
43         }
44
45         runner := arvadosContainerRunner{
46                 Name:        "lightning numpy-comvar",
47                 Client:      arvados.NewClientFromEnv(),
48                 ProjectUUID: *projectUUID,
49                 RAM:         120000000000,
50                 VCPUs:       2,
51                 Priority:    *priority,
52         }
53         err = runner.TranslatePaths(inputFilename, annotationsFilename)
54         if err != nil {
55                 return 1
56         }
57         runner.Prog = "python3"
58         runner.Args = []string{"-c", `import sys
59 import scipy
60 import sys
61 import csv
62
63 numpyFile = sys.argv[1]
64 annotationsFile = sys.argv[2]
65 outputFile = sys.argv[3]
66 maxResults = int(sys.argv[4])
67 minFrequency = float(sys.argv[5])
68 maxFrequency = float(sys.argv[6])
69
70 out = open(outputFile, 'w')
71
72 m = scipy.load(numpyFile)
73
74 commonvariants = {}
75 mincount = m.shape[0] * 2 * minFrequency
76 maxcount = m.shape[0] * 2 * maxFrequency
77 for tag in range(m.shape[1] // 2):
78   example = {}
79   counter = [0, 0, 0, 0, 0]
80   for genome in range(m.shape[0]):
81     for phase in range(2):
82       variant = m[genome][tag*2+phase]
83       if variant > 0 and variant < len(counter):
84         counter[variant] += 1
85         example[variant] = genome
86   for variant, count in enumerate(counter):
87     if count >= mincount and count <= maxcount:
88       commonvariants[tag,variant] = example[variant]
89       # sys.stderr.write('tag {} variant {} count {} example {} have {} commonvariants\n'.format(tag, variant, count, example[variant], len(commonvariants)))
90   if len(commonvariants) >= maxResults:
91     break
92
93 found = {}
94 with open(annotationsFile, newline='') as tsvfile:
95   rdr = csv.reader(tsvfile, delimiter='\t', quotechar='"')
96   for row in rdr:
97     tag = int(row[0])
98     variant = int(row[1])
99     if (tag, variant) in commonvariants:
100       found[tag, variant] = True
101       out.write(','.join(row + [str(commonvariants[tag, variant])]) + '\n')
102     elif len(found) >= len(commonvariants):
103       sys.stderr.write('done\n')
104       break
105
106 out.close()
107 `, *inputFilename, *annotationsFilename, "/mnt/output/commonvariants.csv", fmt.Sprintf("%d", *maxResults), fmt.Sprintf("%f", *minFrequency), fmt.Sprintf("%f", *maxFrequency)}
108         var output string
109         output, err = runner.Run()
110         if err != nil {
111                 return 1
112         }
113         fmt.Fprintln(stdout, output+"/commonvariants.csv")
114         return 0
115 }