Fix wrong index in chunk>0 case.
[lightning.git] / plot.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "flag"
9         "fmt"
10         "io"
11         _ "net/http/pprof"
12
13         "git.arvados.org/arvados.git/sdk/go/arvados"
14 )
15
16 type pythonPlot struct{}
17
18 func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
19         var err error
20         defer func() {
21                 if err != nil {
22                         fmt.Fprintf(stderr, "%s\n", err)
23                 }
24         }()
25         flags := flag.NewFlagSet("", flag.ContinueOnError)
26         flags.SetOutput(stderr)
27         projectUUID := flags.String("project", "", "project `UUID` for output data")
28         inputFilename := flags.String("i", "-", "input `file`")
29         sampleCSVFilename := flags.String("labels-csv", "", "use first two columns of `labels.csv` as id->color mapping")
30         sampleFastaDirname := flags.String("sample-fasta-dir", "", "`directory` containing fasta input files")
31         priority := flags.Int("priority", 500, "container request priority")
32         err = flags.Parse(args)
33         if err == flag.ErrHelp {
34                 err = nil
35                 return 0
36         } else if err != nil {
37                 return 2
38         }
39
40         runner := arvadosContainerRunner{
41                 Name:        "lightning plot",
42                 Client:      arvados.NewClientFromEnv(),
43                 ProjectUUID: *projectUUID,
44                 RAM:         4 << 30,
45                 VCPUs:       1,
46                 Priority:    *priority,
47                 Mounts: map[string]map[string]interface{}{
48                         "/plot.py": map[string]interface{}{
49                                 "kind":    "text",
50                                 "content": plotscript,
51                         },
52                 },
53         }
54         err = runner.TranslatePaths(inputFilename, sampleCSVFilename, sampleFastaDirname)
55         if err != nil {
56                 return 1
57         }
58         runner.Prog = "python3"
59         runner.Args = []string{"/plot.py", *inputFilename, *sampleCSVFilename, *sampleFastaDirname, "/mnt/output/plot.png"}
60         var output string
61         output, err = runner.Run()
62         if err != nil {
63                 return 1
64         }
65         fmt.Fprintln(stdout, output+"/plot.png")
66         return 0
67 }
68
69 var plotscript = `
70 import csv
71 import os
72 import scipy
73 import sys
74
75 infile = sys.argv[1]
76 X = scipy.load(infile)
77
78 colors = None
79 if sys.argv[2]:
80     labels = {}
81     for fnm in os.listdir(sys.argv[3]):
82         if '.2.fasta' not in fnm:
83             labels[fnm] = '---'
84     if len(labels) != len(X):
85         raise "len(inputdir) != len(inputarray)"
86     with open(sys.argv[2], 'rt') as csvfile:
87         for row in csv.reader(csvfile):
88             ident=row[0]
89             label=row[1]
90             for fnm in labels:
91                 if row[0] in fnm:
92                     labels[fnm] = row[1]
93     colors = []
94     labelcolors = {
95         'PUR': 'firebrick',
96         'CLM': 'firebrick',
97         'MXL': 'firebrick',
98         'PEL': 'firebrick',
99         'TSI': 'green',
100         'IBS': 'green',
101         'CEU': 'green',
102         'GBR': 'green',
103         'FIN': 'green',
104         'LWK': 'coral',
105         'MSL': 'coral',
106         'GWD': 'coral',
107         'YRI': 'coral',
108         'ESN': 'coral',
109         'ACB': 'coral',
110         'ASW': 'coral',
111         'KHV': 'royalblue',
112         'CDX': 'royalblue',
113         'CHS': 'royalblue',
114         'CHB': 'royalblue',
115         'JPT': 'royalblue',
116         'STU': 'blueviolet',
117         'ITU': 'blueviolet',
118         'BEB': 'blueviolet',
119         'GIH': 'blueviolet',
120         'PJL': 'blueviolet',
121     }
122     for fnm in sorted(labels.keys()):
123         if labels[fnm] in labelcolors:
124             colors.append(labelcolors[labels[fnm]])
125         else:
126             colors.append('black')
127
128 from matplotlib.figure import Figure
129 from matplotlib.patches import Polygon
130 from matplotlib.backends.backend_agg import FigureCanvasAgg
131 fig = Figure()
132 ax = fig.add_subplot(111)
133 ax.scatter(X[:,0], X[:,1], c=colors, s=60, marker='o', alpha=0.5)
134 canvas = FigureCanvasAgg(fig)
135 canvas.print_figure(sys.argv[4], dpi=80)
136 `