3609: Add support for batch size, improve ability to pass lists of lists
[arvados.git] / crunch_scripts / run-command
1 #!/usr/bin/env python
2
3 import logging
4
5 logger = logging.getLogger('run-command')
6 log_handler = logging.StreamHandler()
7 log_handler.setFormatter(logging.Formatter("run-command: %(message)s"))
8 logger.addHandler(log_handler)
9 logger.setLevel(logging.INFO)
10
11 import arvados
12 import re
13 import os
14 import subprocess
15 import sys
16 import shutil
17 import crunchutil.subst as subst
18 import time
19 import arvados.commands.put as put
20 import signal
21 import stat
22 import copy
23 import traceback
24 import pprint
25 import multiprocessing
26 import crunchutil.robust_put as robust_put
27 import crunchutil.vwd as vwd
28 import argparse
29 import json
30 import tempfile
31 import errno
32
33 parser = argparse.ArgumentParser()
34 parser.add_argument('--dry-run', action='store_true')
35 parser.add_argument('--script-parameters', type=str, default="{}")
36 args = parser.parse_args()
37
38 os.umask(0077)
39
40 if not args.dry_run:
41     api = arvados.api('v1')
42     t = arvados.current_task().tmpdir
43     os.chdir(arvados.current_task().tmpdir)
44     os.mkdir("tmpdir")
45     os.mkdir("output")
46
47     os.chdir("output")
48
49     outdir = os.getcwd()
50
51     taskp = None
52     jobp = arvados.current_job()['script_parameters']
53     if len(arvados.current_task()['parameters']) > 0:
54         taskp = arvados.current_task()['parameters']
55 else:
56     outdir = "/tmp"
57     jobp = json.loads(args.script_parameters)
58     os.environ['JOB_UUID'] = 'zzzzz-8i9sb-1234567890abcde'
59     os.environ['TASK_UUID'] = 'zzzzz-ot0gb-1234567890abcde'
60     os.environ['CRUNCH_SRC'] = '/tmp/crunche-src'
61     if 'TASK_KEEPMOUNT' not in os.environ:
62         os.environ['TASK_KEEPMOUNT'] = '/keep'
63
64 links = []
65
66 def sub_tmpdir(v):
67     return os.path.join(arvados.current_task().tmpdir, 'tmpdir')
68
69 def sub_outdir(v):
70     return outdir
71
72 def sub_cores(v):
73      return str(multiprocessing.cpu_count())
74
75 def sub_jobid(v):
76      return os.environ['JOB_UUID']
77
78 def sub_taskid(v):
79      return os.environ['TASK_UUID']
80
81 def sub_jobsrc(v):
82      return os.environ['CRUNCH_SRC']
83
84 subst.default_subs["task.tmpdir"] = sub_tmpdir
85 subst.default_subs["task.outdir"] = sub_outdir
86 subst.default_subs["job.srcdir"] = sub_jobsrc
87 subst.default_subs["node.cores"] = sub_cores
88 subst.default_subs["job.uuid"] = sub_jobid
89 subst.default_subs["task.uuid"] = sub_taskid
90
91 class SigHandler(object):
92     def __init__(self):
93         self.sig = None
94
95     def send_signal(self, subprocesses, signum):
96         for sp in subprocesses:
97             sp.send_signal(signum)
98         self.sig = signum
99
100 def add_to_group(gr, match):
101     m = match.groups()
102     if m not in gr:
103         gr[m] = []
104     gr[m].append(match.group(0))
105
106 def expand_item(p, c, flatten=True):
107     if isinstance(c, dict):
108         if "foreach" in c and "command" in c:
109             var = c["foreach"]
110             items = get_items(p, p[var])
111             r = []
112             for i in items:
113                 params = copy.copy(p)
114                 params[var] = i
115                 r.extend(expand_item(params, c["command"]))
116             return r
117         if "list" in c and "index" in c and "command" in c:
118             var = c["list"]
119             items = get_items(p, p[var])
120             params = copy.copy(p)
121             params[var] = items[int(c["index"])]
122             return expand_list(params, c["command"])
123         if "regex" in c:
124             pattern = re.compile(c["regex"])
125             if "filter" in c:
126                 items = get_items(p, p[c["filter"]])
127                 return [i for i in items if pattern.match(i)]
128             elif "group" in c:
129                 items = get_items(p, p[c["group"]])
130                 groups = {}
131                 for i in items:
132                     match = pattern.match(i)
133                     if match:
134                         add_to_group(groups, match)
135                 return [groups[k] for k in groups]
136             elif "extract" in c:
137                 items = get_items(p, p[c["extract"]])
138                 r = []
139                 for i in items:
140                     match = pattern.match(i)
141                     if match:
142                         r.append(list(match.groups()))
143                 return r
144     elif isinstance(c, list):
145         return expand_list(p, c)
146     elif isinstance(c, basestring):
147         if flatten:
148             return [subst.do_substitution(p, c)]
149         else:
150             return subst.do_substitution(p, c)
151
152     return []
153
154 def expand_list(p, l, flatten=True):
155     if isinstance(l, basestring):
156         return expand_item(p, l)
157     elif flatten:
158         return [exp for arg in l for exp in expand_item(p, arg, flatten)]
159     else:
160         return [expand_item(p, arg, flatten) for arg in l]
161
162 def get_items(p, value, flatten=True):
163     if isinstance(value, dict):
164         return expand_item(p, value)
165
166     if isinstance(value, list):
167         return expand_list(p, value, flatten)
168
169     fn = subst.do_substitution(p, value)
170     mode = os.stat(fn).st_mode
171     prefix = fn[len(os.environ['TASK_KEEPMOUNT'])+1:]
172     if mode is not None:
173         if stat.S_ISDIR(mode):
174             items = [os.path.join(fn, l) for l in os.listdir(fn)]
175         elif stat.S_ISREG(mode):
176             with open(fn) as f:
177                 items = [line.rstrip("\r\n") for line in f]
178         return items
179     else:
180         return None
181
182 stdoutname = None
183 stdoutfile = None
184 stdinname = None
185 stdinfile = None
186
187 def recursive_foreach(params, fvars):
188     var = fvars[0]
189     fvars = fvars[1:]
190     items = get_items(params, params[var], False)
191     logger.info("parallelizing on %s with items %s" % (var, items))
192     if items is not None:
193         for i in items:
194             params = copy.copy(params)
195             params[var] = i
196             if len(fvars) > 0:
197                 recursive_foreach(params, fvars)
198             else:
199                 if not args.dry_run:
200                     arvados.api().job_tasks().create(body={
201                         'job_uuid': arvados.current_job()['uuid'],
202                         'created_by_job_task_uuid': arvados.current_task()['uuid'],
203                         'sequence': 1,
204                         'parameters': params
205                     }).execute()
206                 else:
207                     if isinstance(params["command"][0], list):
208                         logger.info(expand_list(params, params["command"], False))
209                     else:
210                         logger.info(expand_list(params, params["command"], True))
211     else:
212         logger.error("parameter %s with value %s in task.foreach yielded no items" % (var, params[var]))
213         sys.exit(1)
214
215 try:
216     if "task.foreach" in jobp:
217         if args.dry_run or arvados.current_task()['sequence'] == 0:
218             # This is the first task to start the other tasks and exit
219             fvars = jobp["task.foreach"]
220             if isinstance(fvars, basestring):
221                 fvars = [fvars]
222             if not isinstance(fvars, list) or len(fvars) == 0:
223                 logger.error("value of task.foreach must be a string or non-empty list")
224                 sys.exit(1)
225             recursive_foreach(jobp, jobp["task.foreach"])
226             if not args.dry_run:
227                 if "task.vwd" in jobp:
228                     # Set output of the first task to the base vwd collection so it
229                     # will be merged with output fragments from the other tasks by
230                     # crunch.
231                     arvados.current_task().set_output(subst.do_substitution(jobp, jobp["task.vwd"]))
232                 else:
233                     arvados.current_task().set_output(None)
234             sys.exit(0)
235     else:
236         # This is the only task so taskp/jobp are the same
237         taskp = jobp
238 except Exception as e:
239     logger.exception("caught exception")
240     logger.error("job parameters were:")
241     logger.error(pprint.pformat(jobp))
242     sys.exit(1)
243
244 try:
245     if not args.dry_run:
246         if "task.vwd" in taskp:
247             # Populate output directory with symlinks to files in collection
248             vwd.checkout(subst.do_substitution(taskp, taskp["task.vwd"]), outdir)
249
250         if "task.cwd" in taskp:
251             os.chdir(subst.do_substitution(taskp, taskp["task.cwd"]))
252
253     cmd = []
254     if isinstance(taskp["command"][0], list):
255         cmd.append(expand_list(taskp, taskp["command"], False))
256     else:
257         cmd.append(expand_list(taskp, taskp["command"], True))
258
259     if "task.stdin" in taskp:
260         stdinname = subst.do_substitution(taskp, taskp["task.stdin"])
261         if not args.dry_run:
262             stdinfile = open(stdinname, "rb")
263
264     if "task.stdout" in taskp:
265         stdoutname = subst.do_substitution(taskp, taskp["task.stdout"])
266         if not args.dry_run:
267             stdoutfile = open(stdoutname, "wb")
268
269     logger.info("{}{}{}".format(' | '.join([' '.join(c) for c in cmd]), (" < " + stdinname) if stdinname is not None else "", (" > " + stdoutname) if stdoutname is not None else ""))
270
271     if args.dry_run:
272         sys.exit(0)
273 except subst.SubstitutionError as e:
274     logger.error(str(e))
275     logger.error("task parameters were:")
276     logger.error(pprint.pformat(taskp))
277     sys.exit(1)
278 except Exception as e:
279     logger.exception("caught exception")
280     logger.error("task parameters were:")
281     logger.error(pprint.pformat(taskp))
282     sys.exit(1)
283
284 try:
285     subprocesses = []
286     close_streams = []
287     if stdinfile:
288         close_streams.append(stdinfile)
289     next_stdin = stdinfile
290
291     for i in xrange(len(cmd)):
292         if i == len(cmd)-1:
293             # this is the last command in the pipeline, so its stdout should go to stdoutfile
294             next_stdout = stdoutfile
295         else:
296             # this is an intermediate command in the pipeline, so its stdout should go to a pipe
297             next_stdout = subprocess.PIPE
298
299         sp = subprocess.Popen(cmd[i], shell=False, stdin=next_stdin, stdout=next_stdout)
300
301         # Need to close the FDs on our side so that subcommands will get SIGPIPE if the
302         # consuming process ends prematurely.
303         if sp.stdout:
304             close_streams.append(sp.stdout)
305
306         # Send this processes's stdout to to the next process's stdin
307         next_stdin = sp.stdout
308
309         subprocesses.append(sp)
310
311     # File descriptors have been handed off to the subprocesses, so close them here.
312     for s in close_streams:
313         s.close()
314
315     # Set up signal handling
316     sig = SigHandler()
317
318     # Forward terminate signals to the subprocesses.
319     signal.signal(signal.SIGINT, lambda signum, frame: sig.send_signal(subprocesses, signum))
320     signal.signal(signal.SIGTERM, lambda signum, frame: sig.send_signal(subprocesses, signum))
321     signal.signal(signal.SIGQUIT, lambda signum, frame: sig.send_signal(subprocesses, signum))
322
323     active = 1
324     pids = set([s.pid for s in subprocesses])
325     rcode = {}
326     while len(pids) > 0:
327         (pid, status) = os.wait()
328         pids.discard(pid)
329         rcode[pid] = (status >> 8)
330
331     if sig.sig is not None:
332         logger.critical("terminating on signal %s" % sig.sig)
333         sys.exit(2)
334     else:
335         for i in xrange(len(cmd)):
336             r = rcode[subprocesses[i].pid]
337             logger.info("%s completed with exit code %i (%s)" % (cmd[i][0], r, "success" if r == 0 else "failed"))
338
339 except Exception as e:
340     logger.exception("caught exception")
341
342 # restore default signal handlers.
343 signal.signal(signal.SIGINT, signal.SIG_DFL)
344 signal.signal(signal.SIGTERM, signal.SIG_DFL)
345 signal.signal(signal.SIGQUIT, signal.SIG_DFL)
346
347 for l in links:
348     os.unlink(l)
349
350 logger.info("the following output files will be saved to keep:")
351
352 subprocess.call(["find", ".", "-type", "f", "-printf", "run-command: %12.12s %h/%f\\n"], stdout=sys.stderr)
353
354 logger.info("start writing output to keep")
355
356 if "task.vwd" in taskp:
357     if "task.foreach" in jobp:
358         # This is a subtask, so don't merge with the original collection, that will happen at the end
359         outcollection = vwd.checkin(subst.do_substitution(taskp, taskp["task.vwd"]), outdir, merge=False).manifest_text()
360     else:
361         # Just a single task, so do merge with the original collection
362         outcollection = vwd.checkin(subst.do_substitution(taskp, taskp["task.vwd"]), outdir, merge=True).manifest_text()
363 else:
364     outcollection = robust_put.upload(outdir, logger)
365
366 success = reduce(lambda x, y: x & (y == 0), [True]+rcode.values())
367
368 api.job_tasks().update(uuid=arvados.current_task()['uuid'],
369                                      body={
370                                          'output': outcollection,
371                                          'success': success,
372                                          'progress':1.0
373                                      }).execute()
374
375 sys.exit(rcode)