19566: Merge branch 'main'
[lightning.git] / pca_plot.py
1 # Copyright (C) The Lightning Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: AGPL-3.0
4
5 import csv
6 import numpy
7 import os
8 import os.path
9 import scipy
10 import sys
11
12 (_,
13  input_path,
14  x_component,
15  y_component,
16  samples_file,
17  phenotype_path,
18  phenotype_cat1_column,
19  phenotype_cat2_column,
20  output_path,
21  ) = sys.argv
22 X = numpy.load(input_path)
23
24 colors = None
25 category = {}
26 samples = []
27 if samples_file:
28     labels = {}
29     with open(samples_file, 'rt', newline='') as samplelist:
30         for row in csv.reader(samplelist):
31             if row[0] == "Index":
32                 continue
33             sampleid = row[1]
34             samples.append(sampleid)
35     phenotype_cat2_column = int(phenotype_cat2_column)
36     phenotype_cat1_column = int(phenotype_cat1_column)
37     if os.path.isdir(phenotype_path):
38         phenotype_files = os.scandir(phenotype_path)
39     else:
40         phenotype_files = [phenotype_path]
41     for phenotype_file in phenotype_files:
42         with open(phenotype_file, 'rt', newline='') as phenotype:
43             dialect = csv.Sniffer().sniff(phenotype.read(1024))
44             phenotype.seek(0)
45             for row in csv.reader(phenotype, dialect):
46                 tag = row[0]
47                 label = row[phenotype_cat1_column]
48                 for sampleid in samples:
49                     if tag in sampleid:
50                         labels[sampleid] = label
51                         if phenotype_cat2_column >= 0 and row[phenotype_cat2_column] != '0':
52                             category[sampleid] = True
53     unknown_color = 'grey'
54     colors = []
55     labelcolors = {
56         'PUR': 'firebrick',
57         'CLM': 'firebrick',
58         'MXL': 'firebrick',
59         'PEL': 'firebrick',
60         '1': 'firebrick',
61         'TSI': 'green',
62         'IBS': 'green',
63         'CEU': 'green',
64         'GBR': 'green',
65         'FIN': 'green',
66         '5': 'green',
67         'LWK': 'coral',
68         'MSL': 'coral',
69         'GWD': 'coral',
70         'YRI': 'coral',
71         'ESN': 'coral',
72         'ACB': 'coral',
73         'ASW': 'coral',
74         '4': 'coral',
75         'KHV': 'royalblue',
76         'CDX': 'royalblue',
77         'CHS': 'royalblue',
78         'CHB': 'royalblue',
79         'JPT': 'royalblue',
80         '2': 'royalblue',
81         'STU': 'blueviolet',
82         'ITU': 'blueviolet',
83         'BEB': 'blueviolet',
84         'GIH': 'blueviolet',
85         'PJL': 'blueviolet',
86         '3': 'navy',
87     }
88     for sampleid in samples:
89         if (sampleid in labels) and (labels[sampleid] in labelcolors):
90             colors.append(labelcolors[labels[sampleid]])
91         else:
92             colors.append(unknown_color)
93
94 from matplotlib.figure import Figure
95 from matplotlib.patches import Polygon
96 from matplotlib.backends.backend_agg import FigureCanvasAgg
97 fig = Figure()
98 ax = fig.add_subplot(111)
99 for marker in ['o', 'x']:
100     x = []
101     y = []
102     if samples:
103         c = []
104         for unknownfirst in [True, False]:
105             for i, sampleid in enumerate(samples):
106                 if ((colors[i] == unknown_color) == unknownfirst and
107                     category.get(sampleid, False) == (marker == 'x')):
108                     x.append(X[i,int(x_component)-1])
109                     y.append(X[i,int(y_component)-1])
110                     c.append(colors[i])
111     elif marker == 'x':
112         continue
113     else:
114         x = X[:,int(x_component)-1]
115         y = X[:,int(y_component)-1]
116         c = None
117     ax.scatter(x, y, c=c, s=60, marker=marker, alpha=0.5)
118 canvas = FigureCanvasAgg(fig)
119 canvas.print_figure(output_path, dpi=80)