cd5f0707a9d2334f6791d35f6ae237a997ef084a
[lightning.git] / plot.py
1 # Copyright (C) The Lightning Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: AGPL-3.0
4
5 import csv
6 import numpy
7 import os
8 import os.path
9 import scipy
10 import sys
11
12 infile = sys.argv[1]
13 X = numpy.load(infile)
14
15 colors = None
16 category = {}
17 samples = []
18 if sys.argv[2]:
19     labels = {}
20     with open(sys.argv[2], 'rt', newline='') as samplelist:
21         for row in csv.reader(samplelist):
22             sampleid = row[1]
23             samples.append(sampleid)
24     phenotype_category_column = int(sys.argv[4])
25     phenotype_column = int(sys.argv[5])
26     if os.path.isdir(sys.argv[3]):
27         phenotype_files = os.scandir(sys.argv[3])
28     else:
29         phenotype_files = [sys.argv[3]]
30     for phenotype_file in phenotype_files:
31         with open(phenotype_file, 'rt', newline='') as phenotype:
32             dialect = csv.Sniffer().sniff(phenotype.read(1024))
33             phenotype.seek(0)
34             for row in csv.reader(phenotype, dialect):
35                 tag = row[0]
36                 label = row[phenotype_column]
37                 for sampleid in samples:
38                     if tag in sampleid:
39                         labels[sampleid] = label
40                         if phenotype_category_column >= 0 and row[phenotype_category_column] != '0':
41                             category[sampleid] = True
42     colors = []
43     labelcolors = {
44         'PUR': 'firebrick',
45         'CLM': 'firebrick',
46         'MXL': 'firebrick',
47         'PEL': 'firebrick',
48         '1': 'firebrick',
49         'TSI': 'green',
50         'IBS': 'green',
51         'CEU': 'green',
52         'GBR': 'green',
53         'FIN': 'green',
54         '2': 'green',
55         'LWK': 'coral',
56         'MSL': 'coral',
57         'GWD': 'coral',
58         'YRI': 'coral',
59         'ESN': 'coral',
60         'ACB': 'coral',
61         'ASW': 'coral',
62         '3': 'coral',
63         'KHV': 'royalblue',
64         'CDX': 'royalblue',
65         'CHS': 'royalblue',
66         'CHB': 'royalblue',
67         'JPT': 'royalblue',
68         '4': 'royalblue',
69         'STU': 'blueviolet',
70         'ITU': 'blueviolet',
71         'BEB': 'blueviolet',
72         'GIH': 'blueviolet',
73         'PJL': 'blueviolet',
74         '5': 'blueviolet',
75         '6': 'black',           # unknown?
76     }
77     for sampleid in samples:
78         if (sampleid in labels) and (labels[sampleid] in labelcolors):
79             colors.append(labelcolors[labels[sampleid]])
80         else:
81             colors.append('black')
82
83 from matplotlib.figure import Figure
84 from matplotlib.patches import Polygon
85 from matplotlib.backends.backend_agg import FigureCanvasAgg
86 fig = Figure()
87 ax = fig.add_subplot(111)
88 for marker in ['o', 'x']:
89     x = []
90     y = []
91     if samples:
92         c = []
93         for i, sampleid in enumerate(samples):
94             if category.get(sampleid, False) == (marker == 'x'):
95                 x.append(X[i,0])
96                 y.append(X[i,1])
97                 c.append(colors[i])
98     elif marker == 'x':
99         continue
100     else:
101         x = X[:,0]
102         y = X[:,1]
103         c = None
104     ax.scatter(x, y, c=c, s=60, marker=marker, alpha=0.5)
105 canvas = FigureCanvasAgg(fig)
106 canvas.print_figure(sys.argv[6], dpi=80)