#!/usr/bin/env python

import os
import re
import arvados
import arvados_gatk2
import arvados_picard
import arvados_samtools
from arvados_ipc import *

class InvalidArgumentError(Exception):
    pass

arvados_samtools.one_task_per_bam_file(if_sequence=0, and_end_task=True)

this_job = arvados.current_job()
this_task = arvados.current_task()
tmpdir = arvados.current_task().tmpdir
arvados.util.clear_tmpdir()

known_sites_files = arvados.getjobparam(
    'known_sites',
    ['dbsnp_137.b37.vcf',
     'Mills_and_1000G_gold_standard.indels.b37.vcf',
     ])
bundle_dir = arvados.util.collection_extract(
    collection = this_job['script_parameters']['gatk_bundle'],
    files = [
        'human_g1k_v37.dict',
        'human_g1k_v37.fasta',
        'human_g1k_v37.fasta.fai'
        ] + known_sites_files + [v + '.idx' for v in known_sites_files],
    path = 'gatk_bundle')
ref_fasta_files = [os.path.join(bundle_dir, f)
                   for f in os.listdir(bundle_dir)
                   if re.search(r'\.fasta(\.gz)?$', f)]
regions_args = []
if 'regions' in this_job['script_parameters']:
    regions_dir = arvados.util.collection_extract(
        collection = this_job['script_parameters']['regions'],
        path = 'regions')
    region_padding = int(this_job['script_parameters']['region_padding'])
    for f in os.listdir(regions_dir):
        if re.search(r'\.bed$', f):
            regions_args += [
                '--intervals', os.path.join(regions_dir, f),
                '--interval_padding', str(region_padding)
                ]

input_collection = this_task['parameters']['input']
input_dir = arvados.util.collection_extract(
    collection = input_collection,
    path = os.path.join(this_task.tmpdir, 'input'))
input_bam_files = []
for f in arvados.util.listdir_recursive(input_dir):
    if re.search(r'\.bam$', f):
        input_stream_name, input_file_name = os.path.split(f)
        input_bam_files += [os.path.join(input_dir, f)]
if len(input_bam_files) != 1:
    raise InvalidArgumentError("Expected exactly one bam file per task.")

known_sites_args = []
for f in known_sites_files:
    known_sites_args += ['-known', os.path.join(bundle_dir, f)]

children = {}
pipes = {}

arvados_gatk2.run(
    args=[
        '-nt', arvados_gatk2.cpus_per_task(),
        '-T', 'RealignerTargetCreator',
        '-R', ref_fasta_files[0],
        '-I', input_bam_files[0],
        '-o', os.path.join(tmpdir, 'intervals.list')
        ] + known_sites_args + regions_args)

pipe_setup(pipes, 'IndelRealigner')
if 0 == named_fork(children, 'IndelRealigner'):
    pipe_closeallbut(pipes, ('IndelRealigner', 'w'))
    arvados_gatk2.run(
        args=[
        '-T', 'IndelRealigner',
        '-R', ref_fasta_files[0],
        '-targetIntervals', os.path.join(tmpdir, 'intervals.list'),
        '-I', input_bam_files[0],
        '-o', '/dev/fd/' + str(pipes['IndelRealigner','w']),
        '--disable_bam_indexing',
        ] + known_sites_args + regions_args,
        close_fds=False)
    os._exit(0)
os.close(pipes.pop(('IndelRealigner','w'), None))

pipe_setup(pipes, 'bammanifest')
pipe_setup(pipes, 'bam')
if 0==named_fork(children, 'bammanifest'):
    pipe_closeallbut(pipes,
                     ('IndelRealigner', 'r'),
                     ('bammanifest', 'w'),
                     ('bam', 'w'))
    out = arvados.CollectionWriter()
    out.start_new_stream(input_stream_name)
    out.start_new_file(input_file_name)
    while True:
        buf = os.read(pipes['IndelRealigner','r'], 2**20)
        if len(buf) == 0:
            break
        os.write(pipes['bam','w'], buf)
        out.write(buf)
    os.write(pipes['bammanifest','w'], out.manifest_text())
    os.close(pipes['bammanifest','w'])
    os._exit(0)

pipe_setup(pipes, 'index')
if 0==named_fork(children, 'index'):
    pipe_closeallbut(pipes, ('bam', 'r'), ('index', 'w'))
    arvados_picard.run(
        'BuildBamIndex',
        params={
            'i': '/dev/fd/' + str(pipes['bam','r']),
            'o': '/dev/fd/' + str(pipes['index','w']),
            'quiet': 'true',
            'validation_stringency': 'LENIENT'
            },
        close_fds=False)
    os._exit(0)

pipe_setup(pipes, 'indexmanifest')
if 0==named_fork(children, 'indexmanifest'):
    pipe_closeallbut(pipes, ('index', 'r'), ('indexmanifest', 'w'))
    out = arvados.CollectionWriter()
    out.start_new_stream(input_stream_name)
    out.start_new_file(re.sub('\.bam$', '.bai', input_file_name))
    while True:
        buf = os.read(pipes['index','r'], 2**20)
        if len(buf) == 0:
            break
        out.write(buf)
    os.write(pipes['indexmanifest','w'], out.manifest_text())
    os.close(pipes['indexmanifest','w'])
    os._exit(0)

pipe_closeallbut(pipes, ('bammanifest', 'r'), ('indexmanifest', 'r'))
outmanifest = ''
for which in ['bammanifest', 'indexmanifest']:
    with os.fdopen(pipes[which,'r'], 'rb', 2**20) as f:
        while True:
            buf = f.read()
            if buf == '':
                break
            outmanifest += buf

all_ok = True
for (childname, pid) in children.items():
    all_ok = all_ok and waitpid_and_check_exit(pid, childname)

if all_ok:
    this_task.set_output(outmanifest)
else:
    sys.exit(1)