Let caller set container output name.
[lightning.git] / plot.go
1 package main
2
3 import (
4         "flag"
5         "fmt"
6         "io"
7         _ "net/http/pprof"
8
9         "git.arvados.org/arvados.git/sdk/go/arvados"
10 )
11
12 type pythonPlot struct{}
13
14 func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
15         var err error
16         defer func() {
17                 if err != nil {
18                         fmt.Fprintf(stderr, "%s\n", err)
19                 }
20         }()
21         flags := flag.NewFlagSet("", flag.ContinueOnError)
22         flags.SetOutput(stderr)
23         projectUUID := flags.String("project", "", "project `UUID` for output data")
24         inputFilename := flags.String("i", "-", "input `file`")
25         sampleCSVFilename := flags.String("labels-csv", "", "use first two columns of `labels.csv` as id->color mapping")
26         sampleFastaDirname := flags.String("sample-fasta-dir", "", "`directory` containing fasta input files")
27         priority := flags.Int("priority", 500, "container request priority")
28         err = flags.Parse(args)
29         if err == flag.ErrHelp {
30                 err = nil
31                 return 0
32         } else if err != nil {
33                 return 2
34         }
35
36         runner := arvadosContainerRunner{
37                 Name:        "lightning plot",
38                 Client:      arvados.NewClientFromEnv(),
39                 ProjectUUID: *projectUUID,
40                 RAM:         4 << 30,
41                 VCPUs:       1,
42                 Priority:    *priority,
43                 Mounts: map[string]map[string]interface{}{
44                         "/plot.py": map[string]interface{}{
45                                 "kind":    "text",
46                                 "content": plotscript,
47                         },
48                 },
49         }
50         err = runner.TranslatePaths(inputFilename, sampleCSVFilename, sampleFastaDirname)
51         if err != nil {
52                 return 1
53         }
54         runner.Prog = "python3"
55         runner.Args = []string{"/plot.py", *inputFilename, *sampleCSVFilename, *sampleFastaDirname, "/mnt/output/plot.png"}
56         var output string
57         output, err = runner.Run()
58         if err != nil {
59                 return 1
60         }
61         fmt.Fprintln(stdout, output+"/plot.png")
62         return 0
63 }
64
65 var plotscript = `
66 import csv
67 import os
68 import scipy
69 import sys
70
71 infile = sys.argv[1]
72 X = scipy.load(infile)
73
74 colors = None
75 if sys.argv[2]:
76     labels = {}
77     for fnm in os.listdir(sys.argv[3]):
78         if '.2.fasta' not in fnm:
79             labels[fnm] = '---'
80     if len(labels) != len(X):
81         raise "len(inputdir) != len(inputarray)"
82     with open(sys.argv[2], 'rt') as csvfile:
83         for row in csv.reader(csvfile):
84             ident=row[0]
85             label=row[1]
86             for fnm in labels:
87                 if row[0] in fnm:
88                     labels[fnm] = row[1]
89     colors = []
90     labelcolors = {
91         'PUR': 'firebrick',
92         'CLM': 'firebrick',
93         'MXL': 'firebrick',
94         'PEL': 'firebrick',
95         'TSI': 'green',
96         'IBS': 'green',
97         'CEU': 'green',
98         'GBR': 'green',
99         'FIN': 'green',
100         'LWK': 'coral',
101         'MSL': 'coral',
102         'GWD': 'coral',
103         'YRI': 'coral',
104         'ESN': 'coral',
105         'ACB': 'coral',
106         'ASW': 'coral',
107         'KHV': 'royalblue',
108         'CDX': 'royalblue',
109         'CHS': 'royalblue',
110         'CHB': 'royalblue',
111         'JPT': 'royalblue',
112         'STU': 'blueviolet',
113         'ITU': 'blueviolet',
114         'BEB': 'blueviolet',
115         'GIH': 'blueviolet',
116         'PJL': 'blueviolet',
117     }
118     for fnm in sorted(labels.keys()):
119         if labels[fnm] in labelcolors:
120             colors.append(labelcolors[labels[fnm]])
121         else:
122             colors.append('black')
123
124 from matplotlib.figure import Figure
125 from matplotlib.patches import Polygon
126 from matplotlib.backends.backend_agg import FigureCanvasAgg
127 fig = Figure()
128 ax = fig.add_subplot(111)
129 ax.scatter(X[:,0], X[:,1], c=colors, s=60, marker='o', alpha=0.5)
130 canvas = FigureCanvasAgg(fig)
131 canvas.print_figure(sys.argv[4], dpi=80)
132 `