More memory + direct Keep access for merge and exportnumpy.
[lightning.git] / pca.go
1 package main
2
3 import (
4         "bufio"
5         "context"
6         "errors"
7         "flag"
8         "fmt"
9         "io"
10         "io/ioutil"
11         "net/http"
12         _ "net/http/pprof"
13         "os"
14         "strings"
15
16         "git.arvados.org/arvados.git/sdk/go/arvados"
17         "github.com/james-bowman/nlp"
18         "github.com/kshedden/gonpy"
19         log "github.com/sirupsen/logrus"
20         "gonum.org/v1/gonum/mat"
21 )
22
23 type pythonPCA struct{}
24
25 func (cmd *pythonPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
26         var err error
27         defer func() {
28                 if err != nil {
29                         fmt.Fprintf(stderr, "%s\n", err)
30                 }
31         }()
32         flags := flag.NewFlagSet("", flag.ContinueOnError)
33         flags.SetOutput(stderr)
34         projectUUID := flags.String("project", "", "project `UUID` for output data")
35         inputFilename := flags.String("i", "-", "input `file`")
36         priority := flags.Int("priority", 500, "container request priority")
37         err = flags.Parse(args)
38         if err == flag.ErrHelp {
39                 err = nil
40                 return 0
41         } else if err != nil {
42                 return 2
43         }
44
45         runner := arvadosContainerRunner{
46                 Name:        "lightning pca",
47                 Client:      arvados.NewClientFromEnv(),
48                 ProjectUUID: *projectUUID,
49                 RAM:         440000000000,
50                 VCPUs:       1,
51                 Priority:    *priority,
52         }
53         err = runner.TranslatePaths(inputFilename)
54         if err != nil {
55                 return 1
56         }
57         runner.Prog = "python3"
58         runner.Args = []string{"-c", `import sys
59 import scipy
60 from sklearn.decomposition import PCA
61 scipy.save(sys.argv[2], PCA(n_components=4).fit_transform(scipy.load(sys.argv[1])))`, *inputFilename, "/mnt/output/pca.npy"}
62         var output string
63         output, err = runner.Run()
64         if err != nil {
65                 return 1
66         }
67         fmt.Fprintln(stdout, output+"/pca.npy")
68         return 0
69 }
70
71 type goPCA struct {
72         filter filter
73 }
74
75 func (cmd *goPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
76         var err error
77         defer func() {
78                 if err != nil {
79                         fmt.Fprintf(stderr, "%s\n", err)
80                 }
81         }()
82         flags := flag.NewFlagSet("", flag.ContinueOnError)
83         flags.SetOutput(stderr)
84         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
85         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
86         projectUUID := flags.String("project", "", "project `UUID` for output data")
87         priority := flags.Int("priority", 500, "container request priority")
88         inputFilename := flags.String("i", "-", "input `file`")
89         outputFilename := flags.String("o", "-", "output `file`")
90         components := flags.Int("components", 4, "number of components")
91         onehot := flags.Bool("one-hot", false, "recode tile variants as one-hot")
92         cmd.filter.Flags(flags)
93         err = flags.Parse(args)
94         if err == flag.ErrHelp {
95                 err = nil
96                 return 0
97         } else if err != nil {
98                 return 2
99         }
100
101         if *pprof != "" {
102                 go func() {
103                         log.Println(http.ListenAndServe(*pprof, nil))
104                 }()
105         }
106
107         if !*runlocal {
108                 if *outputFilename != "-" {
109                         err = errors.New("cannot specify output file in container mode: not implemented")
110                         return 1
111                 }
112                 runner := arvadosContainerRunner{
113                         Name:        "lightning pca-go",
114                         Client:      arvados.NewClientFromEnv(),
115                         ProjectUUID: *projectUUID,
116                         RAM:         300000000000, // maybe 10x input size?
117                         VCPUs:       16,
118                         Priority:    *priority,
119                 }
120                 err = runner.TranslatePaths(inputFilename)
121                 if err != nil {
122                         return 1
123                 }
124                 runner.Args = []string{"pca-go", "-local=true", fmt.Sprintf("-one-hot=%v", *onehot), "-i", *inputFilename, "-o", "/mnt/output/pca.npy"}
125                 runner.Args = append(runner.Args, cmd.filter.Args()...)
126                 var output string
127                 output, err = runner.Run()
128                 if err != nil {
129                         return 1
130                 }
131                 fmt.Fprintln(stdout, output+"/pca.npy")
132                 return 0
133         }
134
135         var input io.ReadCloser
136         if *inputFilename == "-" {
137                 input = ioutil.NopCloser(stdin)
138         } else {
139                 input, err = os.Open(*inputFilename)
140                 if err != nil {
141                         return 1
142                 }
143                 defer input.Close()
144         }
145         log.Print("reading")
146         tilelib := &tileLibrary{
147                 retainNoCalls:  true,
148                 compactGenomes: map[string][]tileVariantID{},
149         }
150         err = tilelib.LoadGob(context.Background(), input, strings.HasSuffix(*inputFilename, ".gz"), nil)
151         if err != nil {
152                 return 1
153         }
154         err = input.Close()
155         if err != nil {
156                 return 1
157         }
158
159         log.Info("filtering")
160         cmd.filter.Apply(tilelib)
161         log.Info("tidying")
162         tilelib.Tidy()
163
164         log.Print("converting cgs to array")
165         data, rows, cols, _ := cgs2array(tilelib)
166         if *onehot {
167                 log.Printf("recode one-hot: %d rows, %d cols", rows, cols)
168                 data, _, cols = recodeOnehot(data, cols)
169         }
170         tilelib = nil
171
172         log.Printf("creating matrix backed by array: %d rows, %d cols", rows, cols)
173         mtx := array2matrix(rows, cols, data).T()
174
175         log.Print("fitting")
176         transformer := nlp.NewPCA(*components)
177         transformer.Fit(mtx)
178         log.Printf("transforming")
179         mtx, err = transformer.Transform(mtx)
180         if err != nil {
181                 return 1
182         }
183         mtx = mtx.T()
184
185         rows, cols = mtx.Dims()
186         log.Printf("copying result to numpy output array: %d rows, %d cols", rows, cols)
187         out := make([]float64, rows*cols)
188         for i := 0; i < rows; i++ {
189                 for j := 0; j < cols; j++ {
190                         out[i*cols+j] = mtx.At(i, j)
191                 }
192         }
193
194         var output io.WriteCloser
195         if *outputFilename == "-" {
196                 output = nopCloser{stdout}
197         } else {
198                 output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_WRONLY, 0777)
199                 if err != nil {
200                         return 1
201                 }
202                 defer output.Close()
203         }
204         bufw := bufio.NewWriter(output)
205         npw, err := gonpy.NewWriter(nopCloser{bufw})
206         if err != nil {
207                 return 1
208         }
209         npw.Shape = []int{rows, cols}
210         log.Printf("writing numpy: %d rows, %d cols", rows, cols)
211         npw.WriteFloat64(out)
212         err = bufw.Flush()
213         if err != nil {
214                 return 1
215         }
216         err = output.Close()
217         if err != nil {
218                 return 1
219         }
220         log.Print("done")
221         return 0
222 }
223
224 func array2matrix(rows, cols int, data []int16) mat.Matrix {
225         floatdata := make([]float64, rows*cols)
226         for i, v := range data {
227                 floatdata[i] = float64(v)
228         }
229         return mat.NewDense(rows, cols, floatdata)
230 }