#!/usr/bin/env python

import logging

logger = logging.getLogger('run-command')
log_handler = logging.StreamHandler()
log_handler.setFormatter(logging.Formatter("run-command: %(message)s"))
logger.addHandler(log_handler)
logger.setLevel(logging.INFO)

import arvados
import re
import os
import subprocess
import sys
import shutil
import crunchutil.subst as subst
import time
import arvados.commands.put as put
import signal
import stat
import copy
import traceback
import pprint
import multiprocessing
import crunchutil.robust_put as robust_put
import crunchutil.vwd as vwd
import argparse
import json
import tempfile
import errno

parser = argparse.ArgumentParser()
parser.add_argument('--dry-run', action='store_true')
parser.add_argument('--script-parameters', type=str, default="{}")
args = parser.parse_args()

os.umask(0077)

if not args.dry_run:
    api = arvados.api('v1')
    t = arvados.current_task().tmpdir
    os.chdir(arvados.current_task().tmpdir)
    os.mkdir("tmpdir")
    os.mkdir("output")

    os.chdir("output")

    outdir = os.getcwd()

    taskp = None
    jobp = arvados.current_job()['script_parameters']
    if len(arvados.current_task()['parameters']) > 0:
        taskp = arvados.current_task()['parameters']
else:
    outdir = "/tmp"
    jobp = json.loads(args.script_parameters)
    os.environ['JOB_UUID'] = 'zzzzz-8i9sb-1234567890abcde'
    os.environ['TASK_UUID'] = 'zzzzz-ot0gb-1234567890abcde'
    os.environ['CRUNCH_SRC'] = '/tmp/crunch-src'
    if 'TASK_KEEPMOUNT' not in os.environ:
        os.environ['TASK_KEEPMOUNT'] = '/keep'

def sub_tmpdir(v):
    return os.path.join(arvados.current_task().tmpdir, 'tmpdir')

def sub_outdir(v):
    return outdir

def sub_cores(v):
     return str(multiprocessing.cpu_count())

def sub_jobid(v):
     return os.environ['JOB_UUID']

def sub_taskid(v):
     return os.environ['TASK_UUID']

def sub_jobsrc(v):
     return os.environ['CRUNCH_SRC']

subst.default_subs["task.tmpdir"] = sub_tmpdir
subst.default_subs["task.outdir"] = sub_outdir
subst.default_subs["job.srcdir"] = sub_jobsrc
subst.default_subs["node.cores"] = sub_cores
subst.default_subs["job.uuid"] = sub_jobid
subst.default_subs["task.uuid"] = sub_taskid

class SigHandler(object):
    def __init__(self):
        self.sig = None

    def send_signal(self, subprocesses, signum):
        for sp in subprocesses:
            sp.send_signal(signum)
        self.sig = signum

# http://rightfootin.blogspot.com/2006/09/more-on-python-flatten.html
def flatten(l, ltypes=(list, tuple)):
    ltype = type(l)
    l = list(l)
    i = 0
    while i < len(l):
        while isinstance(l[i], ltypes):
            if not l[i]:
                l.pop(i)
                i -= 1
                break
            else:
                l[i:i + 1] = l[i]
        i += 1
    return ltype(l)

def add_to_group(gr, match):
    m = match.groups()
    if m not in gr:
        gr[m] = []
    gr[m].append(match.group(0))

class EvaluationError(Exception):
    pass

# Return the name of variable ('var') that will take on each value in 'items'
# when performing an inner substitution
def var_items(p, c, key):
    if key not in c:
        raise EvaluationError("'%s' was expected in 'p' but is missing" % key)

    if "var" in c:
        if not isinstance(c["var"], basestring):
            raise EvaluationError("Value of 'var' must be a string")
        # Var specifies the variable name for inner parameter substitution
        return (c["var"], get_items(p, c[key]))
    else:
        # The component function ('key') value is a list, so return the list
        # directly with no parameter selected.
        if isinstance(c[key], list):
            return (None, get_items(p, c[key]))
        elif isinstance(c[key], basestring):
            # check if c[key] is a string that looks like a parameter
            m = re.match("^\$\((.*)\)$", c[key])
            if m and m.group(1) in p:
                return (m.group(1), get_items(p, c[key]))
            else:
                # backwards compatible, foreach specifies bare parameter name to use
                return (c[key], get_items(p, p[c[key]]))
        else:
            raise EvaluationError("Value of '%s' must be a string or list" % key)

# "p" is the parameter scope, "c" is the item to be expanded.
# If "c" is a dict, apply function expansion.
# If "c" is a list, recursively expand each item and return a new list.
# If "c" is a string, apply parameter substitution
def expand_item(p, c):
    if isinstance(c, dict):
        if "foreach" in c and "command" in c:
            # Expand a command template for each item in the specified user
            # parameter
            var, items = var_items(p, c, "foreach")
            if var is None:
                raise EvaluationError("Must specify 'var' in foreach")
            r = []
            for i in items:
                params = copy.copy(p)
                params[var] = i
                r.append(expand_item(params, c["command"]))
            return r
        elif "list" in c and "index" in c and "command" in c:
            # extract a single item from a list
            var, items = var_items(p, c, "list")
            if var is None:
                raise EvaluationError("Must specify 'var' in list")
            params = copy.copy(p)
            params[var] = items[int(c["index"])]
            return expand_item(params, c["command"])
        elif "regex" in c:
            pattern = re.compile(c["regex"])
            if "filter" in c:
                # filter list so that it only includes items that match a
                # regular expression
                _, items = var_items(p, c, "filter")
                return [i for i in items if pattern.match(i)]
            elif "group" in c:
                # generate a list of lists, where items are grouped on common
                # subexpression match
                _, items = var_items(p, c, "group")
                groups = {}
                for i in items:
                    match = pattern.match(i)
                    if match:
                        add_to_group(groups, match)
                return [groups[k] for k in groups]
            elif "extract" in c:
                # generate a list of lists, where items are split by
                # subexpression match
                _, items = var_items(p, c, "extract")
                r = []
                for i in items:
                    match = pattern.match(i)
                    if match:
                        r.append(list(match.groups()))
                return r
        elif "batch" in c and "size" in c:
            # generate a list of lists, where items are split into a batch size
            _, items = var_items(p, c, "batch")
            sz = int(c["size"])
            r = []
            for j in xrange(0, len(items), sz):
                r.append(items[j:j+sz])
            return r
        raise EvaluationError("Missing valid list context function")
    elif isinstance(c, list):
        return [expand_item(p, arg) for arg in c]
    elif isinstance(c, basestring):
        m = re.match("^\$\((.*)\)$", c)
        if m and m.group(1) in p:
            return expand_item(p, p[m.group(1)])
        else:
            return subst.do_substitution(p, c)
    else:
        raise EvaluationError("expand_item() unexpected parameter type %s" % type(c))

# Evaluate in a list context
# "p" is the parameter scope, "value" will be evaluated
# if "value" is a list after expansion, return that
# if "value" is a path to a directory, return a list consisting of each entry in the directory
# if "value" is a path to a file, return a list consisting of each line of the file
def get_items(p, value):
    value = expand_item(p, value)
    if isinstance(value, list):
        return value
    elif isinstance(value, basestring):
        mode = os.stat(value).st_mode
        prefix = value[len(os.environ['TASK_KEEPMOUNT'])+1:]
        if mode is not None:
            if stat.S_ISDIR(mode):
                items = [os.path.join(value, l) for l in os.listdir(value)]
            elif stat.S_ISREG(mode):
                with open(value) as f:
                    items = [line.rstrip("\r\n") for line in f]
            return items
    raise EvaluationError("get_items did not yield a list")

stdoutname = None
stdoutfile = None
stdinname = None
stdinfile = None

# Construct the cross product of all values of each variable listed in fvars
def recursive_foreach(params, fvars):
    var = fvars[0]
    fvars = fvars[1:]
    items = get_items(params, params[var])
    logger.info("parallelizing on %s with items %s" % (var, items))
    if items is not None:
        for i in items:
            params = copy.copy(params)
            params[var] = i
            if len(fvars) > 0:
                recursive_foreach(params, fvars)
            else:
                if not args.dry_run:
                    arvados.api().job_tasks().create(body={
                        'job_uuid': arvados.current_job()['uuid'],
                        'created_by_job_task_uuid': arvados.current_task()['uuid'],
                        'sequence': 1,
                        'parameters': params
                    }).execute()
                else:
                    if isinstance(params["command"][0], list):
                        for c in params["command"]:
                            logger.info(flatten(expand_item(params, c)))
                    else:
                        logger.info(flatten(expand_item(params, params["command"])))
    else:
        logger.error("parameter %s with value %s in task.foreach yielded no items" % (var, params[var]))
        sys.exit(1)

try:
    if "task.foreach" in jobp:
        if args.dry_run or arvados.current_task()['sequence'] == 0:
            # This is the first task to start the other tasks and exit
            fvars = jobp["task.foreach"]
            if isinstance(fvars, basestring):
                fvars = [fvars]
            if not isinstance(fvars, list) or len(fvars) == 0:
                logger.error("value of task.foreach must be a string or non-empty list")
                sys.exit(1)
            recursive_foreach(jobp, jobp["task.foreach"])
            if not args.dry_run:
                if "task.vwd" in jobp:
                    # Set output of the first task to the base vwd collection so it
                    # will be merged with output fragments from the other tasks by
                    # crunch.
                    arvados.current_task().set_output(subst.do_substitution(jobp, jobp["task.vwd"]))
                else:
                    arvados.current_task().set_output(None)
            sys.exit(0)
    else:
        # This is the only task so taskp/jobp are the same
        taskp = jobp
except Exception as e:
    logger.exception("caught exception")
    logger.error("job parameters were:")
    logger.error(pprint.pformat(jobp))
    sys.exit(1)

try:
    if not args.dry_run:
        if "task.vwd" in taskp:
            # Populate output directory with symlinks to files in collection
            vwd.checkout(subst.do_substitution(taskp, taskp["task.vwd"]), outdir)

        if "task.cwd" in taskp:
            os.chdir(subst.do_substitution(taskp, taskp["task.cwd"]))

    cmd = []
    if isinstance(taskp["command"][0], list):
        for c in taskp["command"]:
            cmd.append(flatten(expand_item(taskp, c)))
    else:
        cmd.append(flatten(expand_item(taskp, taskp["command"])))

    if "task.stdin" in taskp:
        stdinname = subst.do_substitution(taskp, taskp["task.stdin"])
        if not args.dry_run:
            stdinfile = open(stdinname, "rb")

    if "task.stdout" in taskp:
        stdoutname = subst.do_substitution(taskp, taskp["task.stdout"])
        if not args.dry_run:
            stdoutfile = open(stdoutname, "wb")

    if "task.env" in taskp:
        env = copy.copy(os.environ)
        for k,v in taskp["task.env"].items():
            env[k] = subst.do_substitution(taskp, v)
    else:
        env = None

    logger.info("{}{}{}".format(' | '.join([' '.join(c) for c in cmd]), (" < " + stdinname) if stdinname is not None else "", (" > " + stdoutname) if stdoutname is not None else ""))

    if args.dry_run:
        sys.exit(0)
except subst.SubstitutionError as e:
    logger.error(str(e))
    logger.error("task parameters were:")
    logger.error(pprint.pformat(taskp))
    sys.exit(1)
except Exception as e:
    logger.exception("caught exception")
    logger.error("task parameters were:")
    logger.error(pprint.pformat(taskp))
    sys.exit(1)

# rcode holds the return codes produced by each subprocess
rcode = {}
try:
    subprocesses = []
    close_streams = []
    if stdinfile:
        close_streams.append(stdinfile)
    next_stdin = stdinfile

    for i in xrange(len(cmd)):
        if i == len(cmd)-1:
            # this is the last command in the pipeline, so its stdout should go to stdoutfile
            next_stdout = stdoutfile
        else:
            # this is an intermediate command in the pipeline, so its stdout should go to a pipe
            next_stdout = subprocess.PIPE

        sp = subprocess.Popen(cmd[i], shell=False, stdin=next_stdin, stdout=next_stdout, env=env)

        # Need to close the FDs on our side so that subcommands will get SIGPIPE if the
        # consuming process ends prematurely.
        if sp.stdout:
            close_streams.append(sp.stdout)

        # Send this processes's stdout to to the next process's stdin
        next_stdin = sp.stdout

        subprocesses.append(sp)

    # File descriptors have been handed off to the subprocesses, so close them here.
    for s in close_streams:
        s.close()

    # Set up signal handling
    sig = SigHandler()

    # Forward terminate signals to the subprocesses.
    signal.signal(signal.SIGINT, lambda signum, frame: sig.send_signal(subprocesses, signum))
    signal.signal(signal.SIGTERM, lambda signum, frame: sig.send_signal(subprocesses, signum))
    signal.signal(signal.SIGQUIT, lambda signum, frame: sig.send_signal(subprocesses, signum))

    active = 1
    pids = set([s.pid for s in subprocesses])
    while len(pids) > 0:
        (pid, status) = os.wait()
        pids.discard(pid)
        if not taskp.get("task.ignore_rcode"):
            rcode[pid] = (status >> 8)
        else:
            rcode[pid] = 0

    if sig.sig is not None:
        logger.critical("terminating on signal %s" % sig.sig)
        sys.exit(2)
    else:
        for i in xrange(len(cmd)):
            r = rcode[subprocesses[i].pid]
            logger.info("%s completed with exit code %i (%s)" % (cmd[i][0], r, "success" if r == 0 else "failed"))

except Exception as e:
    logger.exception("caught exception")

# restore default signal handlers.
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGQUIT, signal.SIG_DFL)

logger.info("the following output files will be saved to keep:")

subprocess.call(["find", "-L", ".", "-type", "f", "-printf", "run-command: %12.12s %h/%f\\n"], stdout=sys.stderr, cwd=outdir)

logger.info("start writing output to keep")

if "task.vwd" in taskp and "task.foreach" in jobp:
    for root, dirs, files in os.walk(outdir):
        for f in files:
            s = os.lstat(os.path.join(root, f))
            if stat.S_ISLNK(s.st_mode):
                os.unlink(os.path.join(root, f))

(outcollection, checkin_error) = vwd.checkin(outdir)

# Success if we ran any subprocess, and they all exited 0.
success = rcode and all(status == 0 for status in rcode.itervalues()) and not checkin_error

api.job_tasks().update(uuid=arvados.current_task()['uuid'],
                                     body={
                                         'output': outcollection.manifest_text(),
                                         'success': success,
                                         'progress':1.0
                                     }).execute()

sys.exit(0 if success else 1)