#!/usr/bin/env python
# Copyright (C) The Arvados Authors. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

import arvados
import os
import re
import sys
import subprocess
import arvados_picard
from arvados_ipc import *

arvados.job_setup.one_task_per_input_file(if_sequence=0, and_end_task=True)

this_job = arvados.current_job()
this_task = arvados.current_task()
ref_dir = arvados.util.collection_extract(
    collection = this_job['script_parameters']['reference'],
    path = 'reference',
    decompress = True)
ref_fasta_files = [os.path.join(ref_dir, f)
                   for f in os.listdir(ref_dir)
                   if re.search(r'\.fasta(\.gz)?$', f)]
input_collection = this_task['parameters']['input']

for s in arvados.CollectionReader(input_collection).all_streams():
    for f in s.all_files():
        input_stream_name = s.name()
        input_file_name = f.name()
        break

# Unfortunately, picard FixMateInformation cannot read from a pipe. We
# must copy the input to a temporary file before running picard.
input_bam_path = os.path.join(this_task.tmpdir, input_file_name)
with open(input_bam_path, 'wb') as bam:
    for s in arvados.CollectionReader(input_collection).all_streams():
        for f in s.all_files():
            for s in f.readall():
                bam.write(s)

children = {}
pipes = {}

pipe_setup(pipes, 'fixmate')
if 0==named_fork(children, 'fixmate'):
    pipe_closeallbut(pipes, ('fixmate', 'w'))
    arvados_picard.run(
        'FixMateInformation',
        params={
            'i': input_bam_path,
            'o': '/dev/stdout',
            'quiet': 'true',
            'so': 'coordinate',
            'validation_stringency': 'LENIENT',
            'compression_level': 0
            },
        stdout=os.fdopen(pipes['fixmate','w'], 'wb', 2**20))
    os._exit(0)
os.close(pipes.pop(('fixmate','w'), None))

pipe_setup(pipes, 'sortsam')
if 0==named_fork(children, 'sortsam'):
    pipe_closeallbut(pipes, ('fixmate', 'r'), ('sortsam', 'w'))
    arvados_picard.run(
        'SortSam',
        params={
            'i': '/dev/stdin',
            'o': '/dev/stdout',
            'quiet': 'true',
            'so': 'coordinate',
            'validation_stringency': 'LENIENT',
            'compression_level': 0
            },
        stdin=os.fdopen(pipes['fixmate','r'], 'rb', 2**20),
        stdout=os.fdopen(pipes['sortsam','w'], 'wb', 2**20))
    os._exit(0)

pipe_setup(pipes, 'reordersam')
if 0==named_fork(children, 'reordersam'):
    pipe_closeallbut(pipes, ('sortsam', 'r'), ('reordersam', 'w'))
    arvados_picard.run(
        'ReorderSam',
        params={
            'i': '/dev/stdin',
            'o': '/dev/stdout',
            'reference': ref_fasta_files[0],
            'quiet': 'true',
            'validation_stringency': 'LENIENT',
            'compression_level': 0
            },
        stdin=os.fdopen(pipes['sortsam','r'], 'rb', 2**20),
        stdout=os.fdopen(pipes['reordersam','w'], 'wb', 2**20))
    os._exit(0)

pipe_setup(pipes, 'addrg')
if 0==named_fork(children, 'addrg'):
    pipe_closeallbut(pipes, ('reordersam', 'r'), ('addrg', 'w'))
    arvados_picard.run(
        'AddOrReplaceReadGroups',
        params={
            'i': '/dev/stdin',
            'o': '/dev/stdout',
            'quiet': 'true',
            'rglb': this_job['script_parameters'].get('rglb', 0),
            'rgpl': this_job['script_parameters'].get('rgpl', 'illumina'),
            'rgpu': this_job['script_parameters'].get('rgpu', 0),
            'rgsm': this_job['script_parameters'].get('rgsm', 0),
            'validation_stringency': 'LENIENT'
            },
        stdin=os.fdopen(pipes['reordersam','r'], 'rb', 2**20),
        stdout=os.fdopen(pipes['addrg','w'], 'wb', 2**20))
    os._exit(0)

pipe_setup(pipes, 'bammanifest')
pipe_setup(pipes, 'bam')
pipe_setup(pipes, 'casm_in')
if 0==named_fork(children, 'bammanifest'):
    pipe_closeallbut(pipes,
                     ('addrg', 'r'),
                     ('bammanifest', 'w'),
                     ('bam', 'w'),
                     ('casm_in', 'w'))
    out = arvados.CollectionWriter()
    out.start_new_stream(input_stream_name)
    out.start_new_file(input_file_name)
    while True:
        buf = os.read(pipes['addrg','r'], 2**20)
        if len(buf) == 0:
            break
        os.write(pipes['bam','w'], buf)
        os.write(pipes['casm_in','w'], buf)
        out.write(buf)
    os.write(pipes['bammanifest','w'], out.manifest_text())
    os.close(pipes['bammanifest','w'])
    os._exit(0)

pipe_setup(pipes, 'casm')
if 0 == named_fork(children, 'casm'):
    pipe_closeallbut(pipes, ('casm_in', 'r'), ('casm', 'w'))
    arvados_picard.run(
        'CollectAlignmentSummaryMetrics',
        params={
            'input': '/dev/fd/' + str(pipes['casm_in','r']),
            'output': '/dev/fd/' + str(pipes['casm','w']),
            'reference_sequence': ref_fasta_files[0],
            'validation_stringency': 'LENIENT',
            },
        close_fds=False)
    os._exit(0)

pipe_setup(pipes, 'index')
if 0==named_fork(children, 'index'):
    pipe_closeallbut(pipes, ('bam', 'r'), ('index', 'w'))
    arvados_picard.run(
        'BuildBamIndex',
        params={
            'i': '/dev/stdin',
            'o': '/dev/stdout',
            'quiet': 'true',
            'validation_stringency': 'LENIENT'
            },
        stdin=os.fdopen(pipes['bam','r'], 'rb', 2**20),
        stdout=os.fdopen(pipes['index','w'], 'wb', 2**20))
    os._exit(0)

pipe_setup(pipes, 'indexmanifest')
if 0==named_fork(children, 'indexmanifest'):
    pipe_closeallbut(pipes, ('index', 'r'), ('indexmanifest', 'w'))
    out = arvados.CollectionWriter()
    out.start_new_stream(input_stream_name)
    out.start_new_file(re.sub('\.bam$', '.bai', input_file_name))
    while True:
        buf = os.read(pipes['index','r'], 2**20)
        if len(buf) == 0:
            break
        out.write(buf)
    os.write(pipes['indexmanifest','w'], out.manifest_text())
    os.close(pipes['indexmanifest','w'])
    os._exit(0)

pipe_closeallbut(pipes,
                 ('bammanifest', 'r'),
                 ('indexmanifest', 'r'),
                 ('casm', 'r'))

outmanifest = ''

for which in ['bammanifest', 'indexmanifest']:
    with os.fdopen(pipes[which,'r'], 'rb', 2**20) as f:
        while True:
            buf = f.read()
            if buf == '':
                break
            outmanifest += buf

casm_out = arvados.CollectionWriter()
casm_out.start_new_stream(input_stream_name)
casm_out.start_new_file(input_file_name + '.casm.tsv')
casm_out.write(os.fdopen(pipes.pop(('casm','r'))))

outmanifest += casm_out.manifest_text()

all_ok = True
for (childname, pid) in children.items():
    all_ok = all_ok and waitpid_and_check_exit(pid, childname)

if all_ok:
    this_task.set_output(outmanifest)
else:
    sys.exit(1)