#!/usr/bin/env python

import logging

logger = logging.getLogger('run-command')
log_handler = logging.StreamHandler()
log_handler.setFormatter(logging.Formatter("run-command: %(message)s"))
logger.addHandler(log_handler)
logger.setLevel(logging.INFO)

import arvados
import re
import os
import subprocess
import sys
import shutil
import crunchutil.subst as subst
import time
import arvados.commands.put as put
import signal
import stat
import copy
import traceback
import pprint
import multiprocessing
import crunchutil.robust_put as robust_put
import crunchutil.vwd as vwd
import argparse
import json
import tempfile

parser = argparse.ArgumentParser()
parser.add_argument('--dry-run', action='store_true')
parser.add_argument('--job-parameters', type=str, default="{}")
args = parser.parse_args()

os.umask(0077)

if not args.dry_run:
    api = arvados.api('v1')
    t = arvados.current_task().tmpdir
    os.chdir(arvados.current_task().tmpdir)
    os.mkdir("tmpdir")
    os.mkdir("output")

    os.chdir("output")

    outdir = os.getcwd()

    taskp = None
    jobp = arvados.current_job()['script_parameters']
    if len(arvados.current_task()['parameters']) > 0:
        taskp = arvados.current_task()['parameters']
else:
    outdir = "/tmp"
    jobp = json.loads(args.job_parameters)
    os.environ['JOB_UUID'] = 'zzzzz-8i9sb-1234567890abcde'
    os.environ['TASK_UUID'] = 'zzzzz-ot0gb-1234567890abcde'
    os.environ['CRUNCH_SRC'] = '/tmp/crunche-src'
    if 'TASK_KEEPMOUNT' not in os.environ:
        os.environ['TASK_KEEPMOUNT'] = '/keep'

links = []

def sub_tmpdir(v):
    return os.path.join(arvados.current_task().tmpdir, 'tmpdir')

def sub_outdir(v):
    return outdir

def sub_cores(v):
     return str(multiprocessing.cpu_count())

def sub_jobid(v):
     return os.environ['JOB_UUID']

def sub_taskid(v):
     return os.environ['TASK_UUID']

def sub_jobsrc(v):
     return os.environ['CRUNCH_SRC']

subst.default_subs["task.tmpdir"] = sub_tmpdir
subst.default_subs["task.outdir"] = sub_outdir
subst.default_subs["job.srcdir"] = sub_jobsrc
subst.default_subs["node.cores"] = sub_cores
subst.default_subs["job.uuid"] = sub_jobid
subst.default_subs["task.uuid"] = sub_taskid

class SigHandler(object):
    def __init__(self):
        self.sig = None

    def send_signal(self, sp, signum):
        sp.send_signal(signum)
        self.sig = signum

def add_to_group(gr, match):
    m = ('^_^').join(match.groups())
    if m not in gr:
        gr[m] = []
    gr[m].append(match.group(0))

def expand_item(p, c):
    if isinstance(c, dict):
        if "foreach" in c and "command" in c:
            var = c["foreach"]
            items = get_items(p, p[var])
            r = []
            for i in items:
                params = copy.copy(p)
                params[var] = i
                r.extend(expand_list(params, c["command"]))
            return r
        if "list" in c and "index" in c and "command" in c:
            var = c["list"]
            items = get_items(p, p[var])
            params = copy.copy(p)
            params[var] = items[int(c["index"])]
            return expand_list(params, c["command"])
        if "regex" in c:
            pattern = re.compile(c["regex"])
            if "filter" in c:
                items = get_items(p, p[c["filter"]])
                return [i for i in items if pattern.match(i)]
            elif "group" in c:
                items = get_items(p, p[c["group"]])
                groups = {}
                for i in items:
                    p = pattern.match(i)
                    if p:
                        add_to_group(groups, p)
                return [groups[k] for k in groups]
            elif "extract" in c:
                items = get_items(p, p[c["extract"]])
                r = []
                for i in items:
                    p = pattern.match(i)
                    if p:
                        r.append(list(p.groups()))
                return r
    elif isinstance(c, list):
        return expand_list(p, c)
    elif isinstance(c, basestring):
        return [subst.do_substitution(p, c)]

    return []

def expand_list(p, l):
    if isinstance(l, basestring):
        return expand_item(p, l)
    else:
        return [exp for arg in l for exp in expand_item(p, arg)]

def get_items(p, value):
    if isinstance(value, dict):
        return expand_item(p, value)

    if isinstance(value, list):
        return expand_list(p, value)

    fn = subst.do_substitution(p, value)
    mode = os.stat(fn).st_mode
    prefix = fn[len(os.environ['TASK_KEEPMOUNT'])+1:]
    if mode is not None:
        if stat.S_ISDIR(mode):
            items = [os.path.join(fn, l) for l in os.listdir(fn)]
        elif stat.S_ISREG(mode):
            with open(fn) as f:
                items = [line.rstrip("\r\n") for line in f]
        return items
    else:
        return None

stdoutname = None
stdoutfile = None
stdinname = None
stdinfile = None
rcode = 1

def recursive_foreach(params, fvars):
    var = fvars[0]
    fvars = fvars[1:]
    items = get_items(params, params[var])
    logger.info("parallelizing on %s with items %s" % (var, items))
    if items is not None:
        for i in items:
            params = copy.copy(params)
            params[var] = i
            if len(fvars) > 0:
                recursive_foreach(params, fvars)
            else:
                if not args.dry_run:
                    arvados.api().job_tasks().create(body={
                        'job_uuid': arvados.current_job()['uuid'],
                        'created_by_job_task_uuid': arvados.current_task()['uuid'],
                        'sequence': 1,
                        'parameters': params
                    }).execute()
                else:
                    logger.info(expand_list(params, params["command"]))
    else:
        logger.error("parameter %s with value %s in task.foreach yielded no items" % (var, params[var]))
        sys.exit(1)

try:
    if "task.foreach" in jobp:
        if args.dry_run or arvados.current_task()['sequence'] == 0:
            # This is the first task to start the other tasks and exit
            fvars = jobp["task.foreach"]
            if isinstance(fvars, basestring):
                fvars = [fvars]
            if not isinstance(fvars, list) or len(fvars) == 0:
                logger.error("value of task.foreach must be a string or non-empty list")
                sys.exit(1)
            recursive_foreach(jobp, jobp["task.foreach"])
            if not args.dry_run:
                if "task.vwd" in jobp:
                    # Set output of the first task to the base vwd collection so it
                    # will be merged with output fragments from the other tasks by
                    # crunch.
                    arvados.current_task().set_output(subst.do_substitution(jobp, jobp["task.vwd"]))
                else:
                    arvados.current_task().set_output(None)
            sys.exit(0)
    else:
        # This is the only task so taskp/jobp are the same
        taskp = jobp

    if not args.dry_run:
        if "task.vwd" in taskp:
            # Populate output directory with symlinks to files in collection
            vwd.checkout(subst.do_substitution(taskp, taskp["task.vwd"]), outdir)

        if "task.cwd" in taskp:
            os.chdir(subst.do_substitution(taskp, taskp["task.cwd"]))

    cmd = expand_list(taskp, taskp["command"])

    if not args.dry_run:
        if "task.stdin" in taskp:
            stdinname = subst.do_substitution(taskp, taskp["task.stdin"])
            stdinfile = open(stdinname, "rb")

        if "task.stdout" in taskp:
            stdoutname = subst.do_substitution(taskp, taskp["task.stdout"])
            stdoutfile = open(stdoutname, "wb")

    logger.info("{}{}{}".format(' '.join(cmd), (" < " + stdinname) if stdinname is not None else "", (" > " + stdoutname) if stdoutname is not None else ""))

    if args.dry_run:
        sys.exit(0)
except subst.SubstitutionError as e:
    logger.error(str(e))
    logger.error("task parameters were:")
    logger.error(pprint.pformat(taskp))
    sys.exit(1)
except Exception as e:
    logger.exception("caught exception")
    logger.error("task parameters were:")
    logger.error(pprint.pformat(taskp))
    sys.exit(1)

try:
    sp = subprocess.Popen(cmd, shell=False, stdin=stdinfile, stdout=stdoutfile)
    sig = SigHandler()

    # forward signals to the process.
    signal.signal(signal.SIGINT, lambda signum, frame: sig.send_signal(sp, signum))
    signal.signal(signal.SIGTERM, lambda signum, frame: sig.send_signal(sp, signum))
    signal.signal(signal.SIGQUIT, lambda signum, frame: sig.send_signal(sp, signum))

    # wait for process to complete.
    rcode = sp.wait()

    if sig.sig is not None:
        logger.critical("terminating on signal %s" % sig.sig)
        sys.exit(2)
    else:
        logger.info("completed with exit code %i (%s)" % (rcode, "success" if rcode == 0 else "failed"))

except Exception as e:
    logger.exception("caught exception")

# restore default signal handlers.
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGQUIT, signal.SIG_DFL)

for l in links:
    os.unlink(l)

logger.info("the following output files will be saved to keep:")

subprocess.call(["find", ".", "-type", "f", "-printf", "run-command: %12.12s %h/%f\\n"], stdout=sys.stderr)

logger.info("start writing output to keep")

if "task.vwd" in taskp:
    if "task.foreach" in jobp:
        # This is a subtask, so don't merge with the original collection, that will happen at the end
        outcollection = vwd.checkin(subst.do_substitution(taskp, taskp["task.vwd"]), outdir, merge=False).manifest_text()
    else:
        # Just a single task, so do merge with the original collection
        outcollection = vwd.checkin(subst.do_substitution(taskp, taskp["task.vwd"]), outdir, merge=True).manifest_text()
else:
    outcollection = robust_put.upload(outdir, logger)

api.job_tasks().update(uuid=arvados.current_task()['uuid'],
                                     body={
                                         'output': outcollection,
                                         'success': (rcode == 0),
                                         'progress':1.0
                                     }).execute()

sys.exit(rcode)