#!/usr/bin/env python

import os
import re
import arvados
import arvados_gatk2
import arvados_samtools
from arvados_ipc import *

class InvalidArgumentError(Exception):
    pass

arvados_samtools.one_task_per_bam_file(if_sequence=0, and_end_task=True)

this_job = arvados.current_job()
this_task = arvados.current_task()
tmpdir = arvados.current_task().tmpdir
arvados.util.clear_tmpdir()

known_sites_files = arvados.getjobparam(
    'known_sites',
    ['dbsnp_137.b37.vcf',
     'Mills_and_1000G_gold_standard.indels.b37.vcf',
     ])
bundle_dir = arvados.util.collection_extract(
    collection = this_job['script_parameters']['gatk_bundle'],
    files = [
        'human_g1k_v37.dict',
        'human_g1k_v37.fasta',
        'human_g1k_v37.fasta.fai'
        ] + known_sites_files + [v + '.idx' for v in known_sites_files],
    path = 'gatk_bundle')
ref_fasta_files = [os.path.join(bundle_dir, f)
                   for f in os.listdir(bundle_dir)
                   if re.search(r'\.fasta(\.gz)?$', f)]

input_collection = this_task['parameters']['input']
input_dir = arvados.util.collection_extract(
    collection = input_collection,
    path = os.path.join(this_task.tmpdir, 'input'))
input_bam_files = []
for f in arvados.util.listdir_recursive(input_dir):
    if re.search(r'\.bam$', f):
        input_stream_name, input_file_name = os.path.split(f)
        input_bam_files += [os.path.join(input_dir, f)]
if len(input_bam_files) != 1:
    raise InvalidArgumentError("Expected exactly one bam file per task.")

known_sites_args = []
for f in known_sites_files:
    known_sites_args += ['-knownSites', os.path.join(bundle_dir, f)]

recal_file = os.path.join(tmpdir, 'recal.csv')

children = {}
pipes = {}

arvados_gatk2.run(
    args=[
        '-nct', arvados_gatk2.cpus_on_this_node(),
        '-T', 'BaseRecalibrator',
        '-R', ref_fasta_files[0],
        '-I', input_bam_files[0],
        '-o', recal_file,
        ] + known_sites_args)

pipe_setup(pipes, 'BQSR')
if 0 == named_fork(children, 'BQSR'):
    pipe_closeallbut(pipes, ('BQSR', 'w'))
    arvados_gatk2.run(
        args=[
        '-T', 'PrintReads',
        '-R', ref_fasta_files[0],
        '-I', input_bam_files[0],
        '-o', '/dev/fd/' + str(pipes['BQSR','w']),
        '-BQSR', recal_file,
        '--disable_bam_indexing',
        ],
        close_fds=False)
    os._exit(0)
os.close(pipes.pop(('BQSR','w'), None))

out = arvados.CollectionWriter()
out.start_new_stream(input_stream_name)

out.start_new_file(input_file_name + '.recal.csv')
out.write(open(recal_file, 'rb'))

out.start_new_file(input_file_name)
while True:
    buf = os.read(pipes['BQSR','r'], 2**20)
    if len(buf) == 0:
        break
    out.write(buf)
pipe_closeallbut(pipes)

if waitpid_and_check_children(children):
    this_task.set_output(out.finish())
else:
    sys.exit(1)