Fix diff case
[lightning.git] / pca.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bufio"
9         "context"
10         "errors"
11         "flag"
12         "fmt"
13         "io"
14         "io/ioutil"
15         "net/http"
16         _ "net/http/pprof"
17         "os"
18         "strings"
19
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "github.com/james-bowman/nlp"
22         "github.com/kshedden/gonpy"
23         log "github.com/sirupsen/logrus"
24         "gonum.org/v1/gonum/mat"
25 )
26
27 type pythonPCA struct{}
28
29 func (cmd *pythonPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
30         var err error
31         defer func() {
32                 if err != nil {
33                         fmt.Fprintf(stderr, "%s\n", err)
34                 }
35         }()
36         flags := flag.NewFlagSet("", flag.ContinueOnError)
37         flags.SetOutput(stderr)
38         projectUUID := flags.String("project", "", "project `UUID` for output data")
39         inputFilename := flags.String("i", "-", "input `file`")
40         priority := flags.Int("priority", 500, "container request priority")
41         err = flags.Parse(args)
42         if err == flag.ErrHelp {
43                 err = nil
44                 return 0
45         } else if err != nil {
46                 return 2
47         }
48
49         runner := arvadosContainerRunner{
50                 Name:        "lightning pca",
51                 Client:      arvados.NewClientFromEnv(),
52                 ProjectUUID: *projectUUID,
53                 RAM:         440000000000,
54                 VCPUs:       1,
55                 Priority:    *priority,
56         }
57         err = runner.TranslatePaths(inputFilename)
58         if err != nil {
59                 return 1
60         }
61         runner.Prog = "python3"
62         runner.Args = []string{"-c", `import sys
63 import scipy
64 from sklearn.decomposition import PCA
65 scipy.save(sys.argv[2], PCA(n_components=4).fit_transform(scipy.load(sys.argv[1])))`, *inputFilename, "/mnt/output/pca.npy"}
66         var output string
67         output, err = runner.Run()
68         if err != nil {
69                 return 1
70         }
71         fmt.Fprintln(stdout, output+"/pca.npy")
72         return 0
73 }
74
75 type goPCA struct {
76         filter filter
77 }
78
79 func (cmd *goPCA) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
80         var err error
81         defer func() {
82                 if err != nil {
83                         fmt.Fprintf(stderr, "%s\n", err)
84                 }
85         }()
86         flags := flag.NewFlagSet("", flag.ContinueOnError)
87         flags.SetOutput(stderr)
88         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
89         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
90         projectUUID := flags.String("project", "", "project `UUID` for output data")
91         priority := flags.Int("priority", 500, "container request priority")
92         inputFilename := flags.String("i", "-", "input `file`")
93         outputFilename := flags.String("o", "-", "output `file`")
94         components := flags.Int("components", 4, "number of components")
95         onehot := flags.Bool("one-hot", false, "recode tile variants as one-hot")
96         cmd.filter.Flags(flags)
97         err = flags.Parse(args)
98         if err == flag.ErrHelp {
99                 err = nil
100                 return 0
101         } else if err != nil {
102                 return 2
103         }
104
105         if *pprof != "" {
106                 go func() {
107                         log.Println(http.ListenAndServe(*pprof, nil))
108                 }()
109         }
110
111         if !*runlocal {
112                 if *outputFilename != "-" {
113                         err = errors.New("cannot specify output file in container mode: not implemented")
114                         return 1
115                 }
116                 runner := arvadosContainerRunner{
117                         Name:        "lightning pca-go",
118                         Client:      arvados.NewClientFromEnv(),
119                         ProjectUUID: *projectUUID,
120                         RAM:         300000000000, // maybe 10x input size?
121                         VCPUs:       16,
122                         Priority:    *priority,
123                 }
124                 err = runner.TranslatePaths(inputFilename)
125                 if err != nil {
126                         return 1
127                 }
128                 runner.Args = []string{"pca-go", "-local=true", fmt.Sprintf("-one-hot=%v", *onehot), "-i", *inputFilename, "-o", "/mnt/output/pca.npy"}
129                 runner.Args = append(runner.Args, cmd.filter.Args()...)
130                 var output string
131                 output, err = runner.Run()
132                 if err != nil {
133                         return 1
134                 }
135                 fmt.Fprintln(stdout, output+"/pca.npy")
136                 return 0
137         }
138
139         var input io.ReadCloser
140         if *inputFilename == "-" {
141                 input = ioutil.NopCloser(stdin)
142         } else {
143                 input, err = os.Open(*inputFilename)
144                 if err != nil {
145                         return 1
146                 }
147                 defer input.Close()
148         }
149         log.Print("reading")
150         tilelib := &tileLibrary{
151                 retainNoCalls:  true,
152                 compactGenomes: map[string][]tileVariantID{},
153         }
154         err = tilelib.LoadGob(context.Background(), input, strings.HasSuffix(*inputFilename, ".gz"))
155         if err != nil {
156                 return 1
157         }
158         err = input.Close()
159         if err != nil {
160                 return 1
161         }
162
163         log.Info("filtering")
164         cmd.filter.Apply(tilelib)
165         log.Info("tidying")
166         tilelib.Tidy()
167
168         log.Print("converting cgs to array")
169         data, rows, cols := cgs2array(tilelib, cgnames(tilelib), lowqual(tilelib), nil, 0, len(tilelib.variant))
170         if *onehot {
171                 log.Printf("recode one-hot: %d rows, %d cols", rows, cols)
172                 data, _, cols = recodeOnehot(data, cols)
173         }
174         tilelib = nil
175
176         log.Printf("creating matrix backed by array: %d rows, %d cols", rows, cols)
177         mtx := array2matrix(rows, cols, data).T()
178
179         log.Print("fitting")
180         transformer := nlp.NewPCA(*components)
181         transformer.Fit(mtx)
182         log.Printf("transforming")
183         mtx, err = transformer.Transform(mtx)
184         if err != nil {
185                 return 1
186         }
187         mtx = mtx.T()
188
189         rows, cols = mtx.Dims()
190         log.Printf("copying result to numpy output array: %d rows, %d cols", rows, cols)
191         out := make([]float64, rows*cols)
192         for i := 0; i < rows; i++ {
193                 for j := 0; j < cols; j++ {
194                         out[i*cols+j] = mtx.At(i, j)
195                 }
196         }
197
198         var output io.WriteCloser
199         if *outputFilename == "-" {
200                 output = nopCloser{stdout}
201         } else {
202                 output, err = os.OpenFile(*outputFilename, os.O_CREATE|os.O_WRONLY, 0777)
203                 if err != nil {
204                         return 1
205                 }
206                 defer output.Close()
207         }
208         bufw := bufio.NewWriter(output)
209         npw, err := gonpy.NewWriter(nopCloser{bufw})
210         if err != nil {
211                 return 1
212         }
213         npw.Shape = []int{rows, cols}
214         log.Printf("writing numpy: %d rows, %d cols", rows, cols)
215         npw.WriteFloat64(out)
216         err = bufw.Flush()
217         if err != nil {
218                 return 1
219         }
220         err = output.Close()
221         if err != nil {
222                 return 1
223         }
224         log.Print("done")
225         return 0
226 }
227
228 func array2matrix(rows, cols int, data []int16) mat.Matrix {
229         floatdata := make([]float64, rows*cols)
230         for i, v := range data {
231                 floatdata[i] = float64(v)
232         }
233         return mat.NewDense(rows, cols, floatdata)
234 }