19566: Option to limit pca components used in glm. Fix onehot use.
[lightning.git] / plot.go
diff --git a/plot.go b/plot.go
index 4339484235eb11a50d57c02b4e18bef47b44f67b..6deaff6fc7d01057b0b4cfd1c49537cbacc7edb0 100644 (file)
--- a/plot.go
+++ b/plot.go
@@ -1,16 +1,26 @@
+// Copyright (C) The Lightning Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package lightning
 
 import (
+       _ "embed"
        "flag"
        "fmt"
        "io"
        _ "net/http/pprof"
+       "os/exec"
+       "strings"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
 )
 
 type pythonPlot struct{}
 
+//go:embed plot.py
+var plotscript string
+
 func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
        var err error
        defer func() {
@@ -22,15 +32,24 @@ func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, s
        flags.SetOutput(stderr)
        projectUUID := flags.String("project", "", "project `UUID` for output data")
        inputFilename := flags.String("i", "-", "input `file`")
-       sampleCSVFilename := flags.String("labels-csv", "", "use first two columns of `labels.csv` as id->color mapping")
-       sampleFastaDirname := flags.String("sample-fasta-dir", "", "`directory` containing fasta input files")
+       outputFilename := flags.String("o", "", "output `filename` (e.g., './plot.png')")
+       sampleListFilename := flags.String("samples", "", "use second column of `samples.csv` as complete list of sample IDs")
+       phenotypeFilename := flags.String("phenotype", "", "use `phenotype.csv` as id->phenotype mapping (column 0 is sample id)")
+       cat1Column := flags.Int("phenotype-cat1-column", 1, "0-based column `index` of 1st category in phenotype.csv file")
+       cat2Column := flags.Int("phenotype-cat2-column", -1, "0-based column `index` of 2nd category in phenotype.csv file")
+       xComponent := flags.Int("x", 1, "1-based PCA component to plot on x axis")
+       yComponent := flags.Int("y", 2, "1-based PCA component to plot on y axis")
        priority := flags.Int("priority", 500, "container request priority")
+       runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
        err = flags.Parse(args)
        if err == flag.ErrHelp {
                err = nil
                return 0
        } else if err != nil {
                return 2
+       } else if flags.NArg() > 0 {
+               err = fmt.Errorf("errant command line arguments after parsed flags: %v", flags.Args())
+               return 2
        }
 
        runner := arvadosContainerRunner{
@@ -47,12 +66,40 @@ func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, s
                        },
                },
        }
-       err = runner.TranslatePaths(inputFilename, sampleCSVFilename, sampleFastaDirname)
-       if err != nil {
-               return 1
+       if !*runlocal {
+               err = runner.TranslatePaths(inputFilename, sampleListFilename, phenotypeFilename)
+               if err != nil {
+                       return 1
+               }
+               *outputFilename = "/mnt/output/plot.png"
+       }
+       args = []string{
+               *inputFilename,
+               fmt.Sprintf("%d", *xComponent),
+               fmt.Sprintf("%d", *yComponent),
+               *sampleListFilename,
+               *phenotypeFilename,
+               fmt.Sprintf("%d", *cat1Column),
+               fmt.Sprintf("%d", *cat2Column),
+               *outputFilename,
+       }
+       if *runlocal {
+               if *outputFilename == "" {
+                       fmt.Fprintln(stderr, "error: must specify -o filename.png in local mode (or try -help)")
+                       return 1
+               }
+               cmd := exec.Command("python3", append([]string{"-"}, args...)...)
+               cmd.Stdin = strings.NewReader(plotscript)
+               cmd.Stdout = stdout
+               cmd.Stderr = stderr
+               err = cmd.Run()
+               if err != nil {
+                       return 1
+               }
+               return 0
        }
        runner.Prog = "python3"
-       runner.Args = []string{"/plot.py", *inputFilename, *sampleCSVFilename, *sampleFastaDirname, "/mnt/output/plot.png"}
+       runner.Args = append([]string{"/plot.py"}, args...)
        var output string
        output, err = runner.Run()
        if err != nil {
@@ -61,72 +108,3 @@ func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, s
        fmt.Fprintln(stdout, output+"/plot.png")
        return 0
 }
-
-var plotscript = `
-import csv
-import os
-import scipy
-import sys
-
-infile = sys.argv[1]
-X = scipy.load(infile)
-
-colors = None
-if sys.argv[2]:
-    labels = {}
-    for fnm in os.listdir(sys.argv[3]):
-        if '.2.fasta' not in fnm:
-            labels[fnm] = '---'
-    if len(labels) != len(X):
-        raise "len(inputdir) != len(inputarray)"
-    with open(sys.argv[2], 'rt') as csvfile:
-        for row in csv.reader(csvfile):
-            ident=row[0]
-            label=row[1]
-            for fnm in labels:
-                if row[0] in fnm:
-                    labels[fnm] = row[1]
-    colors = []
-    labelcolors = {
-        'PUR': 'firebrick',
-        'CLM': 'firebrick',
-        'MXL': 'firebrick',
-        'PEL': 'firebrick',
-        'TSI': 'green',
-        'IBS': 'green',
-        'CEU': 'green',
-        'GBR': 'green',
-        'FIN': 'green',
-        'LWK': 'coral',
-        'MSL': 'coral',
-        'GWD': 'coral',
-        'YRI': 'coral',
-        'ESN': 'coral',
-        'ACB': 'coral',
-        'ASW': 'coral',
-        'KHV': 'royalblue',
-        'CDX': 'royalblue',
-        'CHS': 'royalblue',
-        'CHB': 'royalblue',
-        'JPT': 'royalblue',
-        'STU': 'blueviolet',
-        'ITU': 'blueviolet',
-        'BEB': 'blueviolet',
-        'GIH': 'blueviolet',
-        'PJL': 'blueviolet',
-    }
-    for fnm in sorted(labels.keys()):
-        if labels[fnm] in labelcolors:
-            colors.append(labelcolors[labels[fnm]])
-        else:
-            colors.append('black')
-
-from matplotlib.figure import Figure
-from matplotlib.patches import Polygon
-from matplotlib.backends.backend_agg import FigureCanvasAgg
-fig = Figure()
-ax = fig.add_subplot(111)
-ax.scatter(X[:,0], X[:,1], c=colors, s=60, marker='o', alpha=0.5)
-canvas = FigureCanvasAgg(fig)
-canvas.print_figure(sys.argv[4], dpi=80)
-`