19524: Generalize plot colors a little.
[lightning.git] / plot.go
diff --git a/plot.go b/plot.go
index aa4335fec6889561b6a6adae7cb74caa7ded32bf..51f5c8236c27e62a2134f706999ee8752a4ce731 100644 (file)
--- a/plot.go
+++ b/plot.go
@@ -5,6 +5,7 @@
 package lightning
 
 import (
+       _ "embed"
        "flag"
        "fmt"
        "io"
@@ -15,6 +16,9 @@ import (
 
 type pythonPlot struct{}
 
+//go:embed plot.py
+var plotscript string
+
 func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
        var err error
        defer func() {
@@ -27,7 +31,8 @@ func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, s
        projectUUID := flags.String("project", "", "project `UUID` for output data")
        inputFilename := flags.String("i", "-", "input `file`")
        sampleListFilename := flags.String("samples", "", "use second column of `samples.csv` as complete list of sample IDs")
-       colormapFilename := flags.String("colormap", "", "use first two columns of `colormap.csv` as id->color mapping")
+       phenotypeFilename := flags.String("phenotype", "", "use `phenotype.csv` as id->phenotype mapping (column 0 is sample id)")
+       phenotypeColumn := flags.Int("phenotype-column", 1, "0-based column `index` of phenotype in phenotype.csv file")
        priority := flags.Int("priority", 500, "container request priority")
        err = flags.Parse(args)
        if err == flag.ErrHelp {
@@ -51,12 +56,12 @@ func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, s
                        },
                },
        }
-       err = runner.TranslatePaths(inputFilename, sampleListFilename, colormapFilename)
+       err = runner.TranslatePaths(inputFilename, sampleListFilename, phenotypeFilename)
        if err != nil {
                return 1
        }
        runner.Prog = "python3"
-       runner.Args = []string{"/plot.py", *inputFilename, *sampleListFilename, *colormapFilename, "/mnt/output/plot.png"}
+       runner.Args = []string{"/plot.py", *inputFilename, *sampleListFilename, *phenotypeFilename, fmt.Sprintf("%d", *phenotypeColumn), "/mnt/output/plot.png"}
        var output string
        output, err = runner.Run()
        if err != nil {
@@ -65,73 +70,3 @@ func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, s
        fmt.Fprintln(stdout, output+"/plot.png")
        return 0
 }
-
-var plotscript = `
-import csv
-import os
-import os.path
-import scipy
-import sys
-
-infile = sys.argv[1]
-X = scipy.load(infile)
-
-colors = None
-if sys.argv[2]:
-    samples = []
-    labels = {}
-    with open(sys.argv[2], 'rt') as samplelist:
-        for row in csv.reader(samplelist):
-            id = row[1]
-            samples.append(id)
-    with open(sys.argv[3], 'rt') as colormap:
-        for row in csv.reader(colormap):
-            tag = row[0]
-            label = row[1]
-            for id in samples:
-                if tag in id:
-                    labels[id] = label
-    colors = []
-    labelcolors = {
-        'PUR': 'firebrick',
-        'CLM': 'firebrick',
-        'MXL': 'firebrick',
-        'PEL': 'firebrick',
-        'TSI': 'green',
-        'IBS': 'green',
-        'CEU': 'green',
-        'GBR': 'green',
-        'FIN': 'green',
-        'LWK': 'coral',
-        'MSL': 'coral',
-        'GWD': 'coral',
-        'YRI': 'coral',
-        'ESN': 'coral',
-        'ACB': 'coral',
-        'ASW': 'coral',
-        'KHV': 'royalblue',
-        'CDX': 'royalblue',
-        'CHS': 'royalblue',
-        'CHB': 'royalblue',
-        'JPT': 'royalblue',
-        'STU': 'blueviolet',
-        'ITU': 'blueviolet',
-        'BEB': 'blueviolet',
-        'GIH': 'blueviolet',
-        'PJL': 'blueviolet',
-    }
-    for id in samples:
-        if (id in labels) and (labels[id] in labelcolors):
-            colors.append(labelcolors[labels[id]])
-        else:
-            colors.append('black')
-
-from matplotlib.figure import Figure
-from matplotlib.patches import Polygon
-from matplotlib.backends.backend_agg import FigureCanvasAgg
-fig = Figure()
-ax = fig.add_subplot(111)
-ax.scatter(X[:,0], X[:,1], c=colors, s=60, marker='o', alpha=0.5)
-canvas = FigureCanvasAgg(fig)
-canvas.print_figure(sys.argv[4], dpi=80)
-`