11167: Refactored tests to use new helper function.
[arvados.git] / crunch_scripts / run-command
1 #!/usr/bin/env python
2 # Copyright (C) The Arvados Authors. All rights reserved.
3 #
4 # SPDX-License-Identifier: Apache-2.0
5
6 import logging
7
8 logger = logging.getLogger('run-command')
9 log_handler = logging.StreamHandler()
10 log_handler.setFormatter(logging.Formatter("run-command: %(message)s"))
11 logger.addHandler(log_handler)
12 logger.setLevel(logging.INFO)
13
14 import arvados
15 import re
16 import os
17 import subprocess
18 import sys
19 import shutil
20 import crunchutil.subst as subst
21 import time
22 import arvados.commands.put as put
23 import signal
24 import stat
25 import copy
26 import traceback
27 import pprint
28 import multiprocessing
29 import crunchutil.robust_put as robust_put
30 import crunchutil.vwd as vwd
31 import argparse
32 import json
33 import tempfile
34 import errno
35
36 parser = argparse.ArgumentParser()
37 parser.add_argument('--dry-run', action='store_true')
38 parser.add_argument('--script-parameters', type=str, default="{}")
39 args = parser.parse_args()
40
41 os.umask(0077)
42
43 if not args.dry_run:
44     api = arvados.api('v1')
45     t = arvados.current_task().tmpdir
46     os.chdir(arvados.current_task().tmpdir)
47     os.mkdir("tmpdir")
48     os.mkdir("output")
49
50     os.chdir("output")
51
52     outdir = os.getcwd()
53
54     taskp = None
55     jobp = arvados.current_job()['script_parameters']
56     if len(arvados.current_task()['parameters']) > 0:
57         taskp = arvados.current_task()['parameters']
58 else:
59     outdir = "/tmp"
60     jobp = json.loads(args.script_parameters)
61     os.environ['JOB_UUID'] = 'zzzzz-8i9sb-1234567890abcde'
62     os.environ['TASK_UUID'] = 'zzzzz-ot0gb-1234567890abcde'
63     os.environ['CRUNCH_SRC'] = '/tmp/crunch-src'
64     if 'TASK_KEEPMOUNT' not in os.environ:
65         os.environ['TASK_KEEPMOUNT'] = '/keep'
66
67 def sub_tmpdir(v):
68     return os.path.join(arvados.current_task().tmpdir, 'tmpdir')
69
70 def sub_outdir(v):
71     return outdir
72
73 def sub_cores(v):
74      return str(multiprocessing.cpu_count())
75
76 def sub_jobid(v):
77      return os.environ['JOB_UUID']
78
79 def sub_taskid(v):
80      return os.environ['TASK_UUID']
81
82 def sub_jobsrc(v):
83      return os.environ['CRUNCH_SRC']
84
85 subst.default_subs["task.tmpdir"] = sub_tmpdir
86 subst.default_subs["task.outdir"] = sub_outdir
87 subst.default_subs["job.srcdir"] = sub_jobsrc
88 subst.default_subs["node.cores"] = sub_cores
89 subst.default_subs["job.uuid"] = sub_jobid
90 subst.default_subs["task.uuid"] = sub_taskid
91
92 class SigHandler(object):
93     def __init__(self):
94         self.sig = None
95
96     def send_signal(self, subprocesses, signum):
97         for sp in subprocesses:
98             sp.send_signal(signum)
99         self.sig = signum
100
101 # http://rightfootin.blogspot.com/2006/09/more-on-python-flatten.html
102 def flatten(l, ltypes=(list, tuple)):
103     ltype = type(l)
104     l = list(l)
105     i = 0
106     while i < len(l):
107         while isinstance(l[i], ltypes):
108             if not l[i]:
109                 l.pop(i)
110                 i -= 1
111                 break
112             else:
113                 l[i:i + 1] = l[i]
114         i += 1
115     return ltype(l)
116
117 def add_to_group(gr, match):
118     m = match.groups()
119     if m not in gr:
120         gr[m] = []
121     gr[m].append(match.group(0))
122
123 class EvaluationError(Exception):
124     pass
125
126 # Return the name of variable ('var') that will take on each value in 'items'
127 # when performing an inner substitution
128 def var_items(p, c, key):
129     if key not in c:
130         raise EvaluationError("'%s' was expected in 'p' but is missing" % key)
131
132     if "var" in c:
133         if not isinstance(c["var"], basestring):
134             raise EvaluationError("Value of 'var' must be a string")
135         # Var specifies the variable name for inner parameter substitution
136         return (c["var"], get_items(p, c[key]))
137     else:
138         # The component function ('key') value is a list, so return the list
139         # directly with no parameter selected.
140         if isinstance(c[key], list):
141             return (None, get_items(p, c[key]))
142         elif isinstance(c[key], basestring):
143             # check if c[key] is a string that looks like a parameter
144             m = re.match("^\$\((.*)\)$", c[key])
145             if m and m.group(1) in p:
146                 return (m.group(1), get_items(p, c[key]))
147             else:
148                 # backwards compatible, foreach specifies bare parameter name to use
149                 return (c[key], get_items(p, p[c[key]]))
150         else:
151             raise EvaluationError("Value of '%s' must be a string or list" % key)
152
153 # "p" is the parameter scope, "c" is the item to be expanded.
154 # If "c" is a dict, apply function expansion.
155 # If "c" is a list, recursively expand each item and return a new list.
156 # If "c" is a string, apply parameter substitution
157 def expand_item(p, c):
158     if isinstance(c, dict):
159         if "foreach" in c and "command" in c:
160             # Expand a command template for each item in the specified user
161             # parameter
162             var, items = var_items(p, c, "foreach")
163             if var is None:
164                 raise EvaluationError("Must specify 'var' in foreach")
165             r = []
166             for i in items:
167                 params = copy.copy(p)
168                 params[var] = i
169                 r.append(expand_item(params, c["command"]))
170             return r
171         elif "list" in c and "index" in c and "command" in c:
172             # extract a single item from a list
173             var, items = var_items(p, c, "list")
174             if var is None:
175                 raise EvaluationError("Must specify 'var' in list")
176             params = copy.copy(p)
177             params[var] = items[int(c["index"])]
178             return expand_item(params, c["command"])
179         elif "regex" in c:
180             pattern = re.compile(c["regex"])
181             if "filter" in c:
182                 # filter list so that it only includes items that match a
183                 # regular expression
184                 _, items = var_items(p, c, "filter")
185                 return [i for i in items if pattern.match(i)]
186             elif "group" in c:
187                 # generate a list of lists, where items are grouped on common
188                 # subexpression match
189                 _, items = var_items(p, c, "group")
190                 groups = {}
191                 for i in items:
192                     match = pattern.match(i)
193                     if match:
194                         add_to_group(groups, match)
195                 return [groups[k] for k in groups]
196             elif "extract" in c:
197                 # generate a list of lists, where items are split by
198                 # subexpression match
199                 _, items = var_items(p, c, "extract")
200                 r = []
201                 for i in items:
202                     match = pattern.match(i)
203                     if match:
204                         r.append(list(match.groups()))
205                 return r
206         elif "batch" in c and "size" in c:
207             # generate a list of lists, where items are split into a batch size
208             _, items = var_items(p, c, "batch")
209             sz = int(c["size"])
210             r = []
211             for j in xrange(0, len(items), sz):
212                 r.append(items[j:j+sz])
213             return r
214         raise EvaluationError("Missing valid list context function")
215     elif isinstance(c, list):
216         return [expand_item(p, arg) for arg in c]
217     elif isinstance(c, basestring):
218         m = re.match("^\$\((.*)\)$", c)
219         if m and m.group(1) in p:
220             return expand_item(p, p[m.group(1)])
221         else:
222             return subst.do_substitution(p, c)
223     else:
224         raise EvaluationError("expand_item() unexpected parameter type %s" % type(c))
225
226 # Evaluate in a list context
227 # "p" is the parameter scope, "value" will be evaluated
228 # if "value" is a list after expansion, return that
229 # if "value" is a path to a directory, return a list consisting of each entry in the directory
230 # if "value" is a path to a file, return a list consisting of each line of the file
231 def get_items(p, value):
232     value = expand_item(p, value)
233     if isinstance(value, list):
234         return value
235     elif isinstance(value, basestring):
236         mode = os.stat(value).st_mode
237         prefix = value[len(os.environ['TASK_KEEPMOUNT'])+1:]
238         if mode is not None:
239             if stat.S_ISDIR(mode):
240                 items = [os.path.join(value, l) for l in os.listdir(value)]
241             elif stat.S_ISREG(mode):
242                 with open(value) as f:
243                     items = [line.rstrip("\r\n") for line in f]
244             return items
245     raise EvaluationError("get_items did not yield a list")
246
247 stdoutname = None
248 stdoutfile = None
249 stdinname = None
250 stdinfile = None
251
252 # Construct the cross product of all values of each variable listed in fvars
253 def recursive_foreach(params, fvars):
254     var = fvars[0]
255     fvars = fvars[1:]
256     items = get_items(params, params[var])
257     logger.info("parallelizing on %s with items %s" % (var, items))
258     if items is not None:
259         for i in items:
260             params = copy.copy(params)
261             params[var] = i
262             if len(fvars) > 0:
263                 recursive_foreach(params, fvars)
264             else:
265                 if not args.dry_run:
266                     arvados.api().job_tasks().create(body={
267                         'job_uuid': arvados.current_job()['uuid'],
268                         'created_by_job_task_uuid': arvados.current_task()['uuid'],
269                         'sequence': 1,
270                         'parameters': params
271                     }).execute()
272                 else:
273                     if isinstance(params["command"][0], list):
274                         for c in params["command"]:
275                             logger.info(flatten(expand_item(params, c)))
276                     else:
277                         logger.info(flatten(expand_item(params, params["command"])))
278     else:
279         logger.error("parameter %s with value %s in task.foreach yielded no items" % (var, params[var]))
280         sys.exit(1)
281
282 try:
283     if "task.foreach" in jobp:
284         if args.dry_run or arvados.current_task()['sequence'] == 0:
285             # This is the first task to start the other tasks and exit
286             fvars = jobp["task.foreach"]
287             if isinstance(fvars, basestring):
288                 fvars = [fvars]
289             if not isinstance(fvars, list) or len(fvars) == 0:
290                 logger.error("value of task.foreach must be a string or non-empty list")
291                 sys.exit(1)
292             recursive_foreach(jobp, jobp["task.foreach"])
293             if not args.dry_run:
294                 if "task.vwd" in jobp:
295                     # Set output of the first task to the base vwd collection so it
296                     # will be merged with output fragments from the other tasks by
297                     # crunch.
298                     arvados.current_task().set_output(subst.do_substitution(jobp, jobp["task.vwd"]))
299                 else:
300                     arvados.current_task().set_output(None)
301             sys.exit(0)
302     else:
303         # This is the only task so taskp/jobp are the same
304         taskp = jobp
305 except Exception as e:
306     logger.exception("caught exception")
307     logger.error("job parameters were:")
308     logger.error(pprint.pformat(jobp))
309     sys.exit(1)
310
311 try:
312     if not args.dry_run:
313         if "task.vwd" in taskp:
314             # Populate output directory with symlinks to files in collection
315             vwd.checkout(subst.do_substitution(taskp, taskp["task.vwd"]), outdir)
316
317         if "task.cwd" in taskp:
318             os.chdir(subst.do_substitution(taskp, taskp["task.cwd"]))
319
320     cmd = []
321     if isinstance(taskp["command"][0], list):
322         for c in taskp["command"]:
323             cmd.append(flatten(expand_item(taskp, c)))
324     else:
325         cmd.append(flatten(expand_item(taskp, taskp["command"])))
326
327     if "task.stdin" in taskp:
328         stdinname = subst.do_substitution(taskp, taskp["task.stdin"])
329         if not args.dry_run:
330             stdinfile = open(stdinname, "rb")
331
332     if "task.stdout" in taskp:
333         stdoutname = subst.do_substitution(taskp, taskp["task.stdout"])
334         if not args.dry_run:
335             stdoutfile = open(stdoutname, "wb")
336
337     if "task.env" in taskp:
338         env = copy.copy(os.environ)
339         for k,v in taskp["task.env"].items():
340             env[k] = subst.do_substitution(taskp, v)
341     else:
342         env = None
343
344     logger.info("{}{}{}".format(' | '.join([' '.join(c) for c in cmd]), (" < " + stdinname) if stdinname is not None else "", (" > " + stdoutname) if stdoutname is not None else ""))
345
346     if args.dry_run:
347         sys.exit(0)
348 except subst.SubstitutionError as e:
349     logger.error(str(e))
350     logger.error("task parameters were:")
351     logger.error(pprint.pformat(taskp))
352     sys.exit(1)
353 except Exception as e:
354     logger.exception("caught exception")
355     logger.error("task parameters were:")
356     logger.error(pprint.pformat(taskp))
357     sys.exit(1)
358
359 # rcode holds the return codes produced by each subprocess
360 rcode = {}
361 try:
362     subprocesses = []
363     close_streams = []
364     if stdinfile:
365         close_streams.append(stdinfile)
366     next_stdin = stdinfile
367
368     for i in xrange(len(cmd)):
369         if i == len(cmd)-1:
370             # this is the last command in the pipeline, so its stdout should go to stdoutfile
371             next_stdout = stdoutfile
372         else:
373             # this is an intermediate command in the pipeline, so its stdout should go to a pipe
374             next_stdout = subprocess.PIPE
375
376         sp = subprocess.Popen(cmd[i], shell=False, stdin=next_stdin, stdout=next_stdout, env=env)
377
378         # Need to close the FDs on our side so that subcommands will get SIGPIPE if the
379         # consuming process ends prematurely.
380         if sp.stdout:
381             close_streams.append(sp.stdout)
382
383         # Send this processes's stdout to to the next process's stdin
384         next_stdin = sp.stdout
385
386         subprocesses.append(sp)
387
388     # File descriptors have been handed off to the subprocesses, so close them here.
389     for s in close_streams:
390         s.close()
391
392     # Set up signal handling
393     sig = SigHandler()
394
395     # Forward terminate signals to the subprocesses.
396     signal.signal(signal.SIGINT, lambda signum, frame: sig.send_signal(subprocesses, signum))
397     signal.signal(signal.SIGTERM, lambda signum, frame: sig.send_signal(subprocesses, signum))
398     signal.signal(signal.SIGQUIT, lambda signum, frame: sig.send_signal(subprocesses, signum))
399
400     active = 1
401     pids = set([s.pid for s in subprocesses])
402     while len(pids) > 0:
403         try:
404             (pid, status) = os.wait()
405         except OSError as e:
406             if e.errno == errno.EINTR:
407                 pass
408             else:
409                 raise
410         else:
411             pids.discard(pid)
412             if not taskp.get("task.ignore_rcode"):
413                 rcode[pid] = (status >> 8)
414             else:
415                 rcode[pid] = 0
416
417     if sig.sig is not None:
418         logger.critical("terminating on signal %s" % sig.sig)
419         sys.exit(2)
420     else:
421         for i in xrange(len(cmd)):
422             r = rcode[subprocesses[i].pid]
423             logger.info("%s completed with exit code %i (%s)" % (cmd[i][0], r, "success" if r == 0 else "failed"))
424
425 except Exception as e:
426     logger.exception("caught exception")
427
428 # restore default signal handlers.
429 signal.signal(signal.SIGINT, signal.SIG_DFL)
430 signal.signal(signal.SIGTERM, signal.SIG_DFL)
431 signal.signal(signal.SIGQUIT, signal.SIG_DFL)
432
433 logger.info("the following output files will be saved to keep:")
434
435 subprocess.call(["find", "-L", ".", "-type", "f", "-printf", "run-command: %12.12s %h/%f\\n"], stdout=sys.stderr, cwd=outdir)
436
437 logger.info("start writing output to keep")
438
439 if "task.vwd" in taskp and "task.foreach" in jobp:
440     for root, dirs, files in os.walk(outdir):
441         for f in files:
442             s = os.lstat(os.path.join(root, f))
443             if stat.S_ISLNK(s.st_mode):
444                 os.unlink(os.path.join(root, f))
445
446 (outcollection, checkin_error) = vwd.checkin(outdir)
447
448 # Success if we ran any subprocess, and they all exited 0.
449 success = rcode and all(status == 0 for status in rcode.itervalues()) and not checkin_error
450
451 api.job_tasks().update(uuid=arvados.current_task()['uuid'],
452                                      body={
453                                          'output': outcollection.manifest_text(),
454                                          'success': success,
455                                          'progress':1.0
456                                      }).execute()
457
458 sys.exit(0 if success else 1)