#!/usr/bin/env python

import os
import re
import string
import threading
import arvados
import arvados_gatk2
import arvados_picard
from arvados_ipc import *

class InvalidArgumentError(Exception):
    pass

this_job = arvados.current_job()
this_task = arvados.current_task()
tmpdir = arvados.current_task().tmpdir
arvados.util.clear_tmpdir()

bundle_dir = arvados.util.collection_extract(
    collection = this_job['script_parameters']['gatk_bundle'],
    files = [
        'human_g1k_v37.dict',
        'human_g1k_v37.fasta',
        'human_g1k_v37.fasta.fai',
        'dbsnp_137.b37.vcf',
        'dbsnp_137.b37.vcf.idx',
        ],
    path = 'gatk_bundle')
ref_fasta_files = [os.path.join(bundle_dir, f)
                   for f in os.listdir(bundle_dir)
                   if re.search(r'\.fasta(\.gz)?$', f)]
regions_args = []
if 'regions' in this_job['script_parameters']:
    regions_dir = arvados.util.collection_extract(
        collection = this_job['script_parameters']['regions'],
        path = 'regions')
    region_padding = int(this_job['script_parameters']['region_padding'])
    for f in os.listdir(regions_dir):
        if re.search(r'\.bed$', f):
            regions_args += [
                '--intervals', os.path.join(regions_dir, f),
                '--interval_padding', str(region_padding)
                ]


# Start a child process for each input file, feeding data to picard.

input_child_names = []
children = {}
pipes = {}

input_collection = this_job['script_parameters']['input']
input_index = 0
for s in arvados.CollectionReader(input_collection).all_streams():
    for f in s.all_files():
        if not re.search(r'\.bam$', f.name()):
            continue
        input_index += 1
        childname = 'input-' + str(input_index)
        input_child_names += [childname]
        pipe_setup(pipes, childname)
        childpid = named_fork(children, childname)
        if childpid == 0:
            pipe_closeallbut(pipes, (childname, 'w'))
            for s in f.readall():
                os.write(pipes[childname, 'w'], s)
            os.close(pipes[childname, 'w'])
            os._exit(0)
        sys.stderr.write("pid %d writing %s to fd %d->%d\n" %
                         (childpid,
                          s.name()+'/'+f.name(),
                          pipes[childname, 'w'],
                          pipes[childname, 'r']))
        pipe_closeallbut(pipes, *[(childname, 'r')
                                  for childname in input_child_names])


# Merge-sort the input files to merge.bam

arvados_picard.run(
    'MergeSamFiles',
    args=[
        'I=/dev/fd/' + str(pipes[childname, 'r'])
        for childname in input_child_names
        ],
    params={
        'o': 'merge.bam',
        'quiet': 'true',
        'so': 'coordinate',
        'use_threading': 'true',
        'create_index': 'true',
        'validation_stringency': 'LENIENT',
        },
    close_fds=False,
    )
pipe_closeallbut(pipes)


# Run CoverageBySample on merge.bam

pipe_setup(pipes, 'stats_log')
pipe_setup(pipes, 'stats_out')
if 0 == named_fork(children, 'GATK'):
    pipe_closeallbut(pipes,
                     ('stats_log', 'w'),
                     ('stats_out', 'w'))
    arvados_gatk2.run(
        args=[
            '-T', 'CoverageBySample',
            '-R', ref_fasta_files[0],
            '-I', 'merge.bam',
            '-o', '/dev/fd/' + str(pipes['stats_out', 'w']),
            '--log_to_file', '/dev/fd/' + str(pipes['stats_log', 'w']),
            ]
        + regions_args,
        close_fds=False)
    pipe_closeallbut(pipes)
    os._exit(0)
pipe_closeallbut(pipes, ('stats_log', 'r'), ('stats_out', 'r'))


# Start two threads to read from CoverageBySample pipes

class ExceptionPropagatingThread(threading.Thread):
    """
    If a subclassed thread calls _raise(e) in run(), running join() on
    the thread will raise e in the thread that calls join().
    """
    def __init__(self, *args, **kwargs):
        super(ExceptionPropagatingThread, self).__init__(*args, **kwargs)
        self.__exception = None
    def join(self, *args, **kwargs):
        ret = super(ExceptionPropagatingThread, self).join(*args, **kwargs)
        if self.__exception:
            raise self.__exception
        return ret
    def _raise(self, exception):
        self.__exception = exception

class StatsLogReader(ExceptionPropagatingThread):
    def __init__(self, **kwargs):
        super(StatsLogReader, self).__init__()
        self.args = kwargs
    def run(self):
        try:
            for logline in self.args['infile']:
                x = re.search('Processing (\d+) bp from intervals', logline)
                if x:
                    self._total_bp = int(x.group(1))
        except Exception as e:
            self._raise(e)
    def total_bp(self):
        self.join()
        return self._total_bp
stats_log_thr = StatsLogReader(infile=os.fdopen(pipes.pop(('stats_log', 'r'))))
stats_log_thr.start()

class StatsOutReader(ExceptionPropagatingThread):
    """
    Read output of CoverageBySample and collect a histogram of
    coverage (last column) -> number of loci (number of rows).
    """
    def __init__(self, **kwargs):
        super(StatsOutReader, self).__init__()
        self.args = kwargs
    def run(self):
        try:
            hist = [0]
            histtot = 0
            for line in self.args['infile']:
                try:
                    i = int(string.split(line)[-1])
                except ValueError:
                    continue
                if i >= 1:
                    if len(hist) <= i:
                        hist.extend([0 for x in range(1+i-len(hist))])
                    hist[i] += 1
                    histtot += 1
            hist[0] = stats_log_thr.total_bp() - histtot
            self._histogram = hist
        except Exception as e:
            self._raise(e)
    def histogram(self):
        self.join()
        return self._histogram
stats_out_thr = StatsOutReader(infile=os.fdopen(pipes.pop(('stats_out', 'r'))))
stats_out_thr.start()


# Run UnifiedGenotyper on merge.bam

arvados_gatk2.run(
    args=[
        '-nt', arvados_gatk2.cpus_on_this_node(),
        '-T', 'UnifiedGenotyper',
        '-R', ref_fasta_files[0],
        '-I', 'merge.bam',
        '-o', os.path.join(tmpdir, 'out.vcf'),
        '--dbsnp', os.path.join(bundle_dir, 'dbsnp_137.b37.vcf'),
        '-metrics', 'UniGenMetrics',
        '-A', 'DepthOfCoverage',
        '-A', 'AlleleBalance',
        '-A', 'QualByDepth',
        '-A', 'HaplotypeScore',
        '-A', 'MappingQualityRankSumTest',
        '-A', 'ReadPosRankSumTest',
        '-A', 'FisherStrand',
        '-glm', 'both',
        ]
    + regions_args
    + arvados.getjobparam('GATK2_UnifiedGenotyper_args',[]))

# Copy the output VCF file to Keep

out = arvados.CollectionWriter()
out.start_new_stream()
out.start_new_file('out.vcf')
out.write(open(os.path.join(tmpdir, 'out.vcf'), 'rb'))


# Write statistics to Keep

out.start_new_file('mincoverage_nlocus.csv')
sofar = 0
hist = stats_out_thr.histogram()
total_bp = stats_log_thr.total_bp()
for i in range(len(hist)):
    out.write("%d,%d,%f\n" %
              (i,
               total_bp - sofar,
               100.0 * (total_bp - sofar) / total_bp))
    sofar += hist[i]

if waitpid_and_check_children(children):
    this_task.set_output(out.finish())
else:
    sys.exit(1)