Gzip gob files.
[lightning.git] / pca.go
1 package main
2
3 import (
4         "bufio"
5         "context"
6         "errors"
7         "flag"
8         "fmt"
9         "io"
10         "io/ioutil"
11         "net/http"
12         _ "net/http/pprof"
13         "os"
14         "strings"
15
16         "git.arvados.org/arvados.git/sdk/go/arvados"
17         "github.com/james-bowman/nlp"
18         "github.com/kshedden/gonpy"
19         log "github.com/sirupsen/logrus"
20         "gonum.org/v1/gonum/mat"
21 )
22
23 type pythonPCA struct{}
24
25 func (cmd *pythonPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
26         var err error
27         defer func() {
28                 if err != nil {
29                         fmt.Fprintf(stderr, "%s\n", err)
30                 }
31         }()
32         flags := flag.NewFlagSet("", flag.ContinueOnError)
33         flags.SetOutput(stderr)
34         projectUUID := flags.String("project", "", "project `UUID` for output data")
35         inputFilename := flags.String("i", "-", "input `file`")
36         priority := flags.Int("priority", 500, "container request priority")
37         err = flags.Parse(args)
38         if err == flag.ErrHelp {
39                 err = nil
40                 return 0
41         } else if err != nil {
42                 return 2
43         }
44
45         runner := arvadosContainerRunner{
46                 Name:        "lightning pca",
47                 Client:      arvados.NewClientFromEnv(),
48                 ProjectUUID: *projectUUID,
49                 RAM:         440000000000,
50                 VCPUs:       1,
51                 Priority:    *priority,
52         }
53         err = runner.TranslatePaths(inputFilename)
54         if err != nil {
55                 return 1
56         }
57         runner.Prog = "python3"
58         runner.Args = []string{"-c", `import sys
59 import scipy
60 from sklearn.decomposition import PCA
61 scipy.save(sys.argv[2], PCA(n_components=4).fit_transform(scipy.load(sys.argv[1])))`, *inputFilename, "/mnt/output/pca.npy"}
62         var output string
63         output, err = runner.Run()
64         if err != nil {
65                 return 1
66         }
67         fmt.Fprintln(stdout, output+"/pca.npy")
68         return 0
69 }
70
71 type goPCA struct{}
72
73 func (cmd *goPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
74         var err error
75         defer func() {
76                 if err != nil {
77                         fmt.Fprintf(stderr, "%s\n", err)
78                 }
79         }()
80         flags := flag.NewFlagSet("", flag.ContinueOnError)
81         flags.SetOutput(stderr)
82         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
83         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
84         projectUUID := flags.String("project", "", "project `UUID` for output data")
85         priority := flags.Int("priority", 500, "container request priority")
86         inputFilename := flags.String("i", "-", "input `file`")
87         outputFilename := flags.String("o", "-", "output `file`")
88         components := flags.Int("components", 4, "number of components")
89         onehot := flags.Bool("one-hot", false, "recode tile variants as one-hot")
90         err = flags.Parse(args)
91         if err == flag.ErrHelp {
92                 err = nil
93                 return 0
94         } else if err != nil {
95                 return 2
96         }
97
98         if *pprof != "" {
99                 go func() {
100                         log.Println(http.ListenAndServe(*pprof, nil))
101                 }()
102         }
103
104         if !*runlocal {
105                 if *outputFilename != "-" {
106                         err = errors.New("cannot specify output file in container mode: not implemented")
107                         return 1
108                 }
109                 runner := arvadosContainerRunner{
110                         Name:        "lightning pca-go",
111                         Client:      arvados.NewClientFromEnv(),
112                         ProjectUUID: *projectUUID,
113                         RAM:         100000000000, // maybe 10x input size?
114                         VCPUs:       2,
115                         Priority:    *priority,
116                 }
117                 err = runner.TranslatePaths(inputFilename)
118                 if err != nil {
119                         return 1
120                 }
121                 runner.Args = []string{"pca-go", "-local=true", fmt.Sprintf("-one-hot=%v", *onehot), "-i", *inputFilename, "-o", "/mnt/output/pca.npy"}
122                 var output string
123                 output, err = runner.Run()
124                 if err != nil {
125                         return 1
126                 }
127                 fmt.Fprintln(stdout, output+"/pca.npy")
128                 return 0
129         }
130
131         var input io.ReadCloser
132         if *inputFilename == "-" {
133                 input = ioutil.NopCloser(stdin)
134         } else {
135                 input, err = os.Open(*inputFilename)
136                 if err != nil {
137                         return 1
138                 }
139                 defer input.Close()
140         }
141         log.Print("reading")
142         tilelib := tileLibrary{
143                 retainNoCalls:  true,
144                 compactGenomes: map[string][]tileVariantID{},
145         }
146         err = tilelib.LoadGob(context.Background(), input, strings.HasSuffix(*inputFilename, ".gz"), nil)
147         if err != nil {
148                 return 1
149         }
150         err = input.Close()
151         if err != nil {
152                 return 1
153         }
154
155         log.Print("converting cgs to array")
156         data, rows, cols := cgs2array(tilelib.compactGenomes)
157         if *onehot {
158                 log.Printf("recode one-hot: %d rows, %d cols", rows, cols)
159                 data, _, cols = recodeOnehot(data, cols)
160         }
161
162         log.Printf("creating matrix backed by array: %d rows, %d cols", rows, cols)
163         mtx := array2matrix(rows, cols, data).T()
164
165         log.Print("fitting")
166         transformer := nlp.NewPCA(*components)
167         transformer.Fit(mtx)
168         log.Printf("transforming")
169         mtx, err = transformer.Transform(mtx)
170         if err != nil {
171                 return 1
172         }
173         mtx = mtx.T()
174
175         rows, cols = mtx.Dims()
176         log.Printf("copying result to numpy output array: %d rows, %d cols", rows, cols)
177         out := make([]float64, rows*cols)
178         for i := 0; i < rows; i++ {
179                 for j := 0; j < cols; j++ {
180                         out[i*cols+j] = mtx.At(i, j)
181                 }
182         }
183
184         var output io.WriteCloser
185         if *outputFilename == "-" {
186                 output = nopCloser{stdout}
187         } else {
188                 output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_WRONLY, 0777)
189                 if err != nil {
190                         return 1
191                 }
192                 defer output.Close()
193         }
194         bufw := bufio.NewWriter(output)
195         npw, err := gonpy.NewWriter(nopCloser{bufw})
196         if err != nil {
197                 return 1
198         }
199         npw.Shape = []int{rows, cols}
200         log.Printf("writing numpy: %d rows, %d cols", rows, cols)
201         npw.WriteFloat64(out)
202         err = bufw.Flush()
203         if err != nil {
204                 return 1
205         }
206         err = output.Close()
207         if err != nil {
208                 return 1
209         }
210         log.Print("done")
211         return 0
212 }
213
214 func array2matrix(rows, cols int, data []uint16) mat.Matrix {
215         floatdata := make([]float64, rows*cols)
216         for i, v := range data {
217                 floatdata[i] = float64(v)
218         }
219         return mat.NewDense(rows, cols, floatdata)
220 }