65c69eb6e3a495657fad6284230fc015058e47e4
[lightning.git] / pca.go
1 package main
2
3 import (
4         "bufio"
5         "context"
6         "errors"
7         "flag"
8         "fmt"
9         "io"
10         "io/ioutil"
11         "net/http"
12         _ "net/http/pprof"
13         "os"
14
15         "git.arvados.org/arvados.git/sdk/go/arvados"
16         "github.com/james-bowman/nlp"
17         "github.com/kshedden/gonpy"
18         log "github.com/sirupsen/logrus"
19         "gonum.org/v1/gonum/mat"
20 )
21
22 type pythonPCA struct{}
23
24 func (cmd *pythonPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
25         var err error
26         defer func() {
27                 if err != nil {
28                         fmt.Fprintf(stderr, "%s\n", err)
29                 }
30         }()
31         flags := flag.NewFlagSet("", flag.ContinueOnError)
32         flags.SetOutput(stderr)
33         projectUUID := flags.String("project", "", "project `UUID` for output data")
34         inputFilename := flags.String("i", "-", "input `file`")
35         priority := flags.Int("priority", 500, "container request priority")
36         err = flags.Parse(args)
37         if err == flag.ErrHelp {
38                 err = nil
39                 return 0
40         } else if err != nil {
41                 return 2
42         }
43
44         runner := arvadosContainerRunner{
45                 Name:        "lightning pca",
46                 Client:      arvados.NewClientFromEnv(),
47                 ProjectUUID: *projectUUID,
48                 RAM:         440000000000,
49                 VCPUs:       1,
50                 Priority:    *priority,
51         }
52         err = runner.TranslatePaths(inputFilename)
53         if err != nil {
54                 return 1
55         }
56         runner.Prog = "python3"
57         runner.Args = []string{"-c", `import sys
58 import scipy
59 from sklearn.decomposition import PCA
60 scipy.save(sys.argv[2], PCA(n_components=4).fit_transform(scipy.load(sys.argv[1])))`, *inputFilename, "/mnt/output/pca.npy"}
61         var output string
62         output, err = runner.Run()
63         if err != nil {
64                 return 1
65         }
66         fmt.Fprintln(stdout, output+"/pca.npy")
67         return 0
68 }
69
70 type goPCA struct{}
71
72 func (cmd *goPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
73         var err error
74         defer func() {
75                 if err != nil {
76                         fmt.Fprintf(stderr, "%s\n", err)
77                 }
78         }()
79         flags := flag.NewFlagSet("", flag.ContinueOnError)
80         flags.SetOutput(stderr)
81         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
82         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
83         projectUUID := flags.String("project", "", "project `UUID` for output data")
84         priority := flags.Int("priority", 500, "container request priority")
85         inputFilename := flags.String("i", "-", "input `file`")
86         outputFilename := flags.String("o", "-", "output `file`")
87         components := flags.Int("components", 4, "number of components")
88         onehot := flags.Bool("one-hot", false, "recode tile variants as one-hot")
89         err = flags.Parse(args)
90         if err == flag.ErrHelp {
91                 err = nil
92                 return 0
93         } else if err != nil {
94                 return 2
95         }
96
97         if *pprof != "" {
98                 go func() {
99                         log.Println(http.ListenAndServe(*pprof, nil))
100                 }()
101         }
102
103         if !*runlocal {
104                 if *outputFilename != "-" {
105                         err = errors.New("cannot specify output file in container mode: not implemented")
106                         return 1
107                 }
108                 runner := arvadosContainerRunner{
109                         Name:        "lightning pca-go",
110                         Client:      arvados.NewClientFromEnv(),
111                         ProjectUUID: *projectUUID,
112                         RAM:         100000000000, // maybe 10x input size?
113                         VCPUs:       2,
114                         Priority:    *priority,
115                 }
116                 err = runner.TranslatePaths(inputFilename)
117                 if err != nil {
118                         return 1
119                 }
120                 runner.Args = []string{"pca-go", "-local=true", fmt.Sprintf("-one-hot=%v", *onehot), "-i", *inputFilename, "-o", "/mnt/output/pca.npy"}
121                 var output string
122                 output, err = runner.Run()
123                 if err != nil {
124                         return 1
125                 }
126                 fmt.Fprintln(stdout, output+"/pca.npy")
127                 return 0
128         }
129
130         var input io.ReadCloser
131         if *inputFilename == "-" {
132                 input = ioutil.NopCloser(stdin)
133         } else {
134                 input, err = os.Open(*inputFilename)
135                 if err != nil {
136                         return 1
137                 }
138                 defer input.Close()
139         }
140         log.Print("reading")
141         tilelib := tileLibrary{
142                 retainNoCalls:  true,
143                 compactGenomes: map[string][]tileVariantID{},
144         }
145         err = tilelib.LoadGob(context.Background(), input, nil)
146         if err != nil {
147                 return 1
148         }
149         err = input.Close()
150         if err != nil {
151                 return 1
152         }
153
154         log.Print("converting cgs to array")
155         data, rows, cols := cgs2array(tilelib.compactGenomes)
156         if *onehot {
157                 log.Printf("recode one-hot: %d rows, %d cols", rows, cols)
158                 data, _, cols = recodeOnehot(data, cols)
159         }
160
161         log.Printf("creating matrix backed by array: %d rows, %d cols", rows, cols)
162         mtx := array2matrix(rows, cols, data).T()
163
164         log.Print("fitting")
165         transformer := nlp.NewPCA(*components)
166         transformer.Fit(mtx)
167         log.Printf("transforming")
168         mtx, err = transformer.Transform(mtx)
169         if err != nil {
170                 return 1
171         }
172         mtx = mtx.T()
173
174         rows, cols = mtx.Dims()
175         log.Printf("copying result to numpy output array: %d rows, %d cols", rows, cols)
176         out := make([]float64, rows*cols)
177         for i := 0; i < rows; i++ {
178                 for j := 0; j < cols; j++ {
179                         out[i*cols+j] = mtx.At(i, j)
180                 }
181         }
182
183         var output io.WriteCloser
184         if *outputFilename == "-" {
185                 output = nopCloser{stdout}
186         } else {
187                 output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_WRONLY, 0777)
188                 if err != nil {
189                         return 1
190                 }
191                 defer output.Close()
192         }
193         bufw := bufio.NewWriter(output)
194         npw, err := gonpy.NewWriter(nopCloser{bufw})
195         if err != nil {
196                 return 1
197         }
198         npw.Shape = []int{rows, cols}
199         log.Printf("writing numpy: %d rows, %d cols", rows, cols)
200         npw.WriteFloat64(out)
201         err = bufw.Flush()
202         if err != nil {
203                 return 1
204         }
205         err = output.Close()
206         if err != nil {
207                 return 1
208         }
209         log.Print("done")
210         return 0
211 }
212
213 func array2matrix(rows, cols int, data []uint16) mat.Matrix {
214         floatdata := make([]float64, rows*cols)
215         for i, v := range data {
216                 floatdata[i] = float64(v)
217         }
218         return mat.NewDense(rows, cols, floatdata)
219 }