#!/usr/bin/env python

import arvados
import arvados_bwa
import arvados_samtools
import os
import re
import sys
import subprocess

arvados_bwa.one_task_per_pair_input_file(if_sequence=0, and_end_task=True)

this_job = arvados.current_job()
this_task = arvados.current_task()
ref_dir = arvados.util.collection_extract(
    collection = this_job['script_parameters']['reference_index'],
    path = 'reference',
    decompress = False)

ref_basename = None
for f in os.listdir(ref_dir):
    basename = re.sub(r'\.bwt$', '', f)
    if basename != f:
        ref_basename = os.path.join(ref_dir, basename)
if ref_basename == None:
    raise Exception("Could not find *.bwt in reference collection.")

tmp_dir = arvados.current_task().tmpdir

class Aligner:
    def input_filename(self):
        for s in arvados.CollectionReader(self.collection).all_streams():
            for f in s.all_files():
                return f.decompressed_name()
    def generate_input(self):
        for s in arvados.CollectionReader(self.collection).all_streams():
            for f in s.all_files():
                for s in f.readall_decompressed():
                    yield s
    def aln(self, input_param):
        self.collection = this_task['parameters'][input_param]
        reads_filename = os.path.join(tmp_dir, self.input_filename())
        aln_filename = os.path.join(tmp_dir, self.input_filename() + '.sai')
        reads_pipe_r, reads_pipe_w = os.pipe()
        if os.fork() == 0:
            os.close(reads_pipe_r)
            reads_file = open(reads_filename, 'wb')
            for s in self.generate_input():
                if len(s) != os.write(reads_pipe_w, s):
                    raise Exception("short write")
                reads_file.write(s)
            reads_file.close()
            os.close(reads_pipe_w)
            sys.exit(0)
        os.close(reads_pipe_w)

        aln_file = open(aln_filename, 'wb')
        bwa_proc = subprocess.Popen(
            [arvados_bwa.bwa_binary(),
             'aln', '-t', '16',
             ref_basename,
             '-'],
            stdin=os.fdopen(reads_pipe_r, 'rb', 2**20),
            stdout=aln_file)
        aln_file.close()
        return reads_filename, aln_filename

reads_1, alignments_1 = Aligner().aln('input_1')
reads_2, alignments_2 = Aligner().aln('input_2')
pid1, exit1 = os.wait()
pid2, exit2 = os.wait()
if exit1 != 0 or exit2 != 0:
    raise Exception("bwa aln exited non-zero (0x%x, 0x%x)" % (exit1, exit2))

# output alignments in sam format to pipe
sam_pipe_r, sam_pipe_w = os.pipe()
sam_pid = os.fork()
if sam_pid != 0:
    # parent
    os.close(sam_pipe_w)
else:
    # child
    os.close(sam_pipe_r)
    arvados_bwa.run('sampe',
                    [ref_basename,
                     alignments_1, alignments_2,
                     reads_1, reads_2],
                    stdout=os.fdopen(sam_pipe_w, 'wb', 2**20))
    sys.exit(0)

# convert sam (sam_pipe_r) to bam (bam_pipe_w)
bam_pipe_r, bam_pipe_w = os.pipe()
bam_pid = os.fork()
if bam_pid != 0:
    # parent
    os.close(bam_pipe_w)
    os.close(sam_pipe_r)
else:
    # child
    os.close(bam_pipe_r)
    arvados_samtools.run('view',
                         ['-S', '-b',
                          '-'],
                         stdin=os.fdopen(sam_pipe_r, 'rb', 2**20),
                         stdout=os.fdopen(bam_pipe_w, 'wb', 2**20))
    sys.exit(0)

# copy bam (bam_pipe_r) to Keep
out_bam_filename = os.path.split(reads_1)[-1] + '.bam'
out = arvados.CollectionWriter()
out.start_new_stream()
out.start_new_file(out_bam_filename)
out.write(os.fdopen(bam_pipe_r, 'rb', 2**20))

# make sure everyone exited nicely
pid3, exit3 = os.waitpid(sam_pid, 0)
if exit3 != 0:
    raise Exception("bwa sampe exited non-zero (0x%x)" % exit3)
pid4, exit4 = os.waitpid(bam_pid, 0)
if exit4 != 0:
    raise Exception("samtools view exited non-zero (0x%x)" % exit4)

# proclaim success
this_task.set_output(out.finish())