19524: Fix matrix alloc size.
[lightning.git] / plot.py
1 # Copyright (C) The Lightning Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: AGPL-3.0
4
5 import csv
6 import numpy
7 import os
8 import os.path
9 import scipy
10 import sys
11
12 (_,
13  input_path,
14  x_component,
15  y_component,
16  samples_file,
17  phenotype_path,
18  phenotype_cat1_column,
19  phenotype_cat2_column,
20  output_path,
21  ) = sys.argv
22 X = numpy.load(input_path)
23
24 colors = None
25 category = {}
26 samples = []
27 if samples_file:
28     labels = {}
29     with open(samples_file, 'rt', newline='') as samplelist:
30         for row in csv.reader(samplelist):
31             sampleid = row[1]
32             samples.append(sampleid)
33     phenotype_cat2_column = int(phenotype_cat2_column)
34     phenotype_cat1_column = int(phenotype_cat1_column)
35     if os.path.isdir(phenotype_path):
36         phenotype_files = os.scandir(phenotype_path)
37     else:
38         phenotype_files = [phenotype_path]
39     for phenotype_file in phenotype_files:
40         with open(phenotype_file, 'rt', newline='') as phenotype:
41             dialect = csv.Sniffer().sniff(phenotype.read(1024))
42             phenotype.seek(0)
43             for row in csv.reader(phenotype, dialect):
44                 tag = row[0]
45                 label = row[phenotype_cat1_column]
46                 for sampleid in samples:
47                     if tag in sampleid:
48                         labels[sampleid] = label
49                         if phenotype_cat2_column >= 0 and row[phenotype_cat2_column] != '0':
50                             category[sampleid] = True
51     unknown_color = 'grey'
52     colors = []
53     labelcolors = {
54         'PUR': 'firebrick',
55         'CLM': 'firebrick',
56         'MXL': 'firebrick',
57         'PEL': 'firebrick',
58         '1': 'firebrick',
59         'TSI': 'green',
60         'IBS': 'green',
61         'CEU': 'green',
62         'GBR': 'green',
63         'FIN': 'green',
64         '5': 'green',
65         'LWK': 'coral',
66         'MSL': 'coral',
67         'GWD': 'coral',
68         'YRI': 'coral',
69         'ESN': 'coral',
70         'ACB': 'coral',
71         'ASW': 'coral',
72         '4': 'coral',
73         'KHV': 'royalblue',
74         'CDX': 'royalblue',
75         'CHS': 'royalblue',
76         'CHB': 'royalblue',
77         'JPT': 'royalblue',
78         '2': 'royalblue',
79         'STU': 'blueviolet',
80         'ITU': 'blueviolet',
81         'BEB': 'blueviolet',
82         'GIH': 'blueviolet',
83         'PJL': 'blueviolet',
84         '3': 'navy',
85     }
86     for sampleid in samples:
87         if (sampleid in labels) and (labels[sampleid] in labelcolors):
88             colors.append(labelcolors[labels[sampleid]])
89         else:
90             colors.append(unknown_color)
91
92 from matplotlib.figure import Figure
93 from matplotlib.patches import Polygon
94 from matplotlib.backends.backend_agg import FigureCanvasAgg
95 fig = Figure()
96 ax = fig.add_subplot(111)
97 for marker in ['o', 'x']:
98     x = []
99     y = []
100     if samples:
101         c = []
102         for unknownfirst in [True, False]:
103             for i, sampleid in enumerate(samples):
104                 if ((colors[i] == unknown_color) == unknownfirst and
105                     category.get(sampleid, False) == (marker == 'x')):
106                     x.append(X[i,int(x_component)-1])
107                     y.append(X[i,int(y_component)-1])
108                     c.append(colors[i])
109     elif marker == 'x':
110         continue
111     else:
112         x = X[:,int(x_component)-1]
113         y = X[:,int(y_component)-1]
114         c = None
115     ax.scatter(x, y, c=c, s=60, marker=marker, alpha=0.5)
116 canvas = FigureCanvasAgg(fig)
117 canvas.print_figure(output_path, dpi=80)