#!/usr/bin/env python

import arvados
import re
import os
import subprocess
import sys
import shutil
import subst
import time
import arvados.commands.put as put
import signal
import stat
import copy
import traceback
import pprint
import multiprocessing

os.umask(0077)

t = arvados.current_task().tmpdir

api = arvados.api('v1')

os.chdir(arvados.current_task().tmpdir)
os.mkdir("tmpdir")
os.mkdir("output")

os.chdir("output")

outdir = os.getcwd()

taskp = None
jobp = arvados.current_job()['script_parameters']
if len(arvados.current_task()['parameters']) > 0:
    p = arvados.current_task()['parameters']

links = []

def sub_link(v):
    r = os.path.join(outdir, os.path.basename(v))
    os.symlink(v, r)
    links.append(r)
    return r

def sub_tmpdir(v):
    return os.path.join(arvados.current_task().tmpdir, 'tmpdir')

def sub_outdir(v):
    return outdir

def sub_cores(v):
     return str(multiprocessing.cpu_count())

def sub_jobid(v):
     return os.environ['JOB_UUID']

def sub_taskid(v):
     return os.environ['TASK_UUID']

def sub_jobsrc(v):
     return os.environ['CRUNCH_SRC']

subst.default_subs["link "] = sub_link
subst.default_subs["task.tmpdir"] = sub_tmpdir
subst.default_subs["task.outdir"] = sub_outdir
subst.default_subs["job.srcdir"] = sub_jobsrc
subst.default_subs["node.cores"] = sub_cores
subst.default_subs["job.uuid"] = sub_jobid
subst.default_subs["task.uuid"] = sub_taskid

rcode = 1

def machine_progress(bytes_written, bytes_expected):
    return "run-command: wrote {} total {}\n".format(
        bytes_written, -1 if (bytes_expected is None) else bytes_expected)

class SigHandler(object):
    def __init__(self):
        self.sig = None

    def send_signal(self, sp, signum):
        sp.send_signal(signum)
        self.sig = signum

def expand_item(p, c):
    if isinstance(c, dict):
        if "foreach" in c and "command" in c:
            var = c["foreach"]
            items = get_items(p, p[var])
            r = []
            for i in items:
                params = copy.copy(p)
                params[var] = i
                r.extend(expand_list(params, c["command"]))
            return r
    elif isinstance(c, list):
        return expand_list(p, c)
    elif isinstance(c, str) or isinstance(c, unicode):
        return [subst.do_substitution(p, c)]

    return []

def expand_list(p, l):
    return [exp for arg in l for exp in expand_item(p, arg)]

def get_items(p, value):
    if isinstance(value, list):
        return expand_list(p, value)

    fn = subst.do_substitution(p, value)
    mode = os.stat(fn).st_mode
    prefix = fn[len(os.environ['TASK_KEEPMOUNT'])+1:]
    if mode != None:
        if stat.S_ISDIR(mode):
            items = ["$(dir %s/%s/)" % (prefix, l) for l in os.listdir(fn)]
        elif stat.S_ISREG(mode):
            with open(fn) as f:
                items = [line for line in f]
        return items
    else:
        return None

stdoutname = None
stdoutfile = None

try:
    if "task.foreach" in jobp:
        if arvados.current_task()['sequence'] == 0:
            var = jobp["task.foreach"]
            items = get_items(jobp, jobp[var])
            print("run-command: parallelizing on %s with items %s" % (var, items))
            if items != None:
                for i in items:
                    params = copy.copy(jobp)
                    params[var] = i
                    arvados.api().job_tasks().create(body={
                        'job_uuid': arvados.current_job()['uuid'],
                        'created_by_job_task_uuid': arvados.current_task()['uuid'],
                        'sequence': 1,
                        'parameters': params
                        }
                    ).execute()
                arvados.current_task().set_output(None)
                sys.exit(0)
            else:
                sys.exit(1)
    else:
        p = jobp

    cmd = expand_list(p, p["command"])

    if "save.stdout" in p:
        stdoutname = subst.do_substitution(p, p["save.stdout"])
        stdoutfile = open(stdoutname, "wb")

    print("run-command: {}{}".format(' '.join(cmd), (" > " + stdoutname) if stdoutname != None else ""))

except Exception as e:
    print("run-command: caught exception:")
    traceback.print_exc(file=sys.stdout)
    print("run-command: task parameters was:")
    pprint.pprint(p)
    sys.exit(1)

try:
    sp = subprocess.Popen(cmd, shell=False, stdout=stdoutfile)
    sig = SigHandler()

    # forward signals to the process.
    signal.signal(signal.SIGINT, lambda signum, frame: sig.send_signal(sp, signum))
    signal.signal(signal.SIGTERM, lambda signum, frame: sig.send_signal(sp, signum))
    signal.signal(signal.SIGQUIT, lambda signum, frame: sig.send_signal(sp, signum))

    # wait for process to complete.
    rcode = sp.wait()

    if sig.sig != None:
        print("run-command: terminating on signal %s" % sig.sig)
        sys.exit(2)
    else:
        print("run-command: completed with exit code %i (%s)" % (rcode, "success" if rcode == 0 else "failed"))

except Exception as e:
    print("run-command: caught exception:")
    traceback.print_exc(file=sys.stdout)

# restore default signal handlers.
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGQUIT, signal.SIG_DFL)

for l in links:
    os.unlink(l)

print("run-command: the following output files will be saved to keep:")

subprocess.call(["find", ".", "-type", "f", "-printf", "run-command: %12.12s %h/%f\\n"])

print("run-command: start writing output to keep")

done = False
resume_cache = put.ResumeCache(os.path.join(arvados.current_task().tmpdir, "upload-output-checkpoint"))
reporter = put.progress_writer(machine_progress)
bytes_expected = put.expected_bytes_for(".")
while not done:
    try:
        out = put.ArvPutCollectionWriter.from_cache(resume_cache, reporter, bytes_expected)
        out.do_queued_work()
        out.write_directory_tree(".", max_manifest_depth=0)
        outuuid = out.finish()
        api.job_tasks().update(uuid=arvados.current_task()['uuid'],
                                             body={
                                                 'output':outuuid,
                                                 'success': (rcode == 0),
                                                 'progress':1.0
                                             }).execute()
        done = True
    except KeyboardInterrupt:
        print("run-command: terminating on signal 2")
        sys.exit(2)
    except Exception as e:
        print("run-command: caught exception:")
        traceback.print_exc(file=sys.stdout)
        time.sleep(5)

sys.exit(rcode)