19524: Flags choose which PCA components to plot.
authorTom Clegg <tom@curii.com>
Thu, 20 Oct 2022 17:06:35 +0000 (13:06 -0400)
committerTom Clegg <tom@curii.com>
Thu, 20 Oct 2022 17:06:35 +0000 (13:06 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

plot.go
plot.py

diff --git a/plot.go b/plot.go
index 9959b5067c09d42ee671162a8f04729702278ffe..98b7eece27ae48e5c15166ddae70c662f6c7cf3c 100644 (file)
--- a/plot.go
+++ b/plot.go
@@ -35,8 +35,10 @@ func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, s
        outputFilename := flags.String("o", "", "output `filename` (e.g., './plot.png')")
        sampleListFilename := flags.String("samples", "", "use second column of `samples.csv` as complete list of sample IDs")
        phenotypeFilename := flags.String("phenotype", "", "use `phenotype.csv` as id->phenotype mapping (column 0 is sample id)")
-       phenotypeCategoryColumn := flags.Int("phenotype-category-column", -1, "0-based column `index` of 2nd category in phenotype.csv file")
-       phenotypeColumn := flags.Int("phenotype-column", 1, "0-based column `index` of phenotype in phenotype.csv file")
+       cat1Column := flags.Int("phenotype-cat1-column", 1, "0-based column `index` of 1st category in phenotype.csv file")
+       cat2Column := flags.Int("phenotype-cat2-column", -1, "0-based column `index` of 2nd category in phenotype.csv file")
+       xComponent := flags.Int("x", 1, "1-based PCA component to plot on x axis")
+       yComponent := flags.Int("y", 2, "1-based PCA component to plot on y axis")
        priority := flags.Int("priority", 500, "container request priority")
        runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
        err = flags.Parse(args)
@@ -68,7 +70,16 @@ func (cmd *pythonPlot) RunCommand(prog string, args []string, stdin io.Reader, s
                }
                *outputFilename = "/mnt/output/plot.png"
        }
-       args = []string{*inputFilename, *sampleListFilename, *phenotypeFilename, fmt.Sprintf("%d", *phenotypeCategoryColumn), fmt.Sprintf("%d", *phenotypeColumn), *outputFilename}
+       args = []string{
+               *inputFilename,
+               fmt.Sprintf("%d", *xComponent),
+               fmt.Sprintf("%d", *yComponent),
+               *sampleListFilename,
+               *phenotypeFilename,
+               fmt.Sprintf("%d", *cat1Column),
+               fmt.Sprintf("%d", *cat2Column),
+               *outputFilename,
+       }
        if *runlocal {
                if *outputFilename == "" {
                        fmt.Fprintln(stderr, "error: must specify -o filename.png in local mode (or try -help)")
diff --git a/plot.py b/plot.py
index 670218afbdd0ee693e0ebd4e9eba85df2bdbd209..da84cc0df242cbc903031a6bae37fbf86c2978fe 100644 (file)
--- a/plot.py
+++ b/plot.py
@@ -9,35 +9,44 @@ import os.path
 import scipy
 import sys
 
-infile = sys.argv[1]
-X = numpy.load(infile)
+(_,
+ input_path,
+ x_component,
+ y_component,
+ samples_file,
+ phenotype_path,
+ phenotype_cat1_column,
+ phenotype_cat2_column,
+ output_path,
+ ) = sys.argv
+X = numpy.load(input_path)
 
 colors = None
 category = {}
 samples = []
-if sys.argv[2]:
+if samples_file:
     labels = {}
-    with open(sys.argv[2], 'rt', newline='') as samplelist:
+    with open(samples_file, 'rt', newline='') as samplelist:
         for row in csv.reader(samplelist):
             sampleid = row[1]
             samples.append(sampleid)
-    phenotype_category_column = int(sys.argv[4])
-    phenotype_column = int(sys.argv[5])
-    if os.path.isdir(sys.argv[3]):
-        phenotype_files = os.scandir(sys.argv[3])
+    phenotype_cat2_column = int(phenotype_cat2_column)
+    phenotype_cat1_column = int(phenotype_cat1_column)
+    if os.path.isdir(phenotype_path):
+        phenotype_files = os.scandir(phenotype_path)
     else:
-        phenotype_files = [sys.argv[3]]
+        phenotype_files = [phenotype_path]
     for phenotype_file in phenotype_files:
         with open(phenotype_file, 'rt', newline='') as phenotype:
             dialect = csv.Sniffer().sniff(phenotype.read(1024))
             phenotype.seek(0)
             for row in csv.reader(phenotype, dialect):
                 tag = row[0]
-                label = row[phenotype_column]
+                label = row[phenotype_cat1_column]
                 for sampleid in samples:
                     if tag in sampleid:
                         labels[sampleid] = label
-                        if phenotype_category_column >= 0 and row[phenotype_category_column] != '0':
+                        if phenotype_cat2_column >= 0 and row[phenotype_cat2_column] != '0':
                             category[sampleid] = True
     unknown_color = 'grey'
     colors = []
@@ -94,15 +103,15 @@ for marker in ['o', 'x']:
             for i, sampleid in enumerate(samples):
                 if ((colors[i] == unknown_color) == unknownfirst and
                     category.get(sampleid, False) == (marker == 'x')):
-                    x.append(X[i,0])
-                    y.append(X[i,1])
+                    x.append(X[i,int(x_component)-1])
+                    y.append(X[i,int(y_component)-1])
                     c.append(colors[i])
     elif marker == 'x':
         continue
     else:
-        x = X[:,0]
-        y = X[:,1]
+        x = X[:,int(x_component)-1]
+        y = X[:,int(y_component)-1]
         c = None
     ax.scatter(x, y, c=c, s=60, marker=marker, alpha=0.5)
 canvas = FigureCanvasAgg(fig)
-canvas.print_figure(sys.argv[6], dpi=80)
+canvas.print_figure(output_path, dpi=80)