8558: include min_scratch_mb_per_node in the keys propagated
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
1 #!/usr/bin/env python
2
3 import argparse
4 import arvados
5 import arvados.events
6 import arvados.commands.keepdocker
7 import arvados.commands.run
8 import arvados.collection
9 import arvados.util
10 import cwltool.draft2tool
11 import cwltool.workflow
12 import cwltool.main
13 from cwltool.process import shortname
14 import threading
15 import cwltool.docker
16 import fnmatch
17 import logging
18 import re
19 import os
20 import sys
21
22 from cwltool.process import get_feature
23 from arvados.api import OrderedJsonModel
24
25 logger = logging.getLogger('arvados.cwl-runner')
26 logger.setLevel(logging.INFO)
27
28 crunchrunner_pdh = "83db29f08544e1c319572a6bd971088a+140"
29 crunchrunner_download = "https://cloud.curoverse.com/collections/download/qr1hi-4zz18-n3m1yxd0vx78jic/1i1u2qtq66k1atziv4ocfgsg5nu5tj11n4r6e0bhvjg03rix4m/crunchrunner"
30 certs_download = "https://cloud.curoverse.com/collections/download/qr1hi-4zz18-n3m1yxd0vx78jic/1i1u2qtq66k1atziv4ocfgsg5nu5tj11n4r6e0bhvjg03rix4m/ca-certificates.crt"
31
32 tmpdirre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.tmpdir\)=(.*)")
33 outdirre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.outdir\)=(.*)")
34 keepre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.keep\)=(.*)")
35
36
37 def arv_docker_get_image(api_client, dockerRequirement, pull_image):
38     if "dockerImageId" not in dockerRequirement and "dockerPull" in dockerRequirement:
39         dockerRequirement["dockerImageId"] = dockerRequirement["dockerPull"]
40
41     sp = dockerRequirement["dockerImageId"].split(":")
42     image_name = sp[0]
43     image_tag = sp[1] if len(sp) > 1 else None
44
45     images = arvados.commands.keepdocker.list_images_in_arv(api_client, 3,
46                                                             image_name=image_name,
47                                                             image_tag=image_tag)
48
49     if not images:
50         imageId = cwltool.docker.get_image(dockerRequirement, pull_image)
51         args = [image_name]
52         if image_tag:
53             args.append(image_tag)
54         logger.info("Uploading Docker image %s", ":".join(args))
55         arvados.commands.keepdocker.main(args)
56
57     return dockerRequirement["dockerImageId"]
58
59
60 class CollectionFsAccess(cwltool.process.StdFsAccess):
61     def __init__(self, basedir):
62         self.collections = {}
63         self.basedir = basedir
64
65     def get_collection(self, path):
66         p = path.split("/")
67         if p[0].startswith("keep:") and arvados.util.keep_locator_pattern.match(p[0][5:]):
68             pdh = p[0][5:]
69             if pdh not in self.collections:
70                 self.collections[pdh] = arvados.collection.CollectionReader(pdh)
71             return (self.collections[pdh], "/".join(p[1:]))
72         else:
73             return (None, path)
74
75     def _match(self, collection, patternsegments, parent):
76         if not patternsegments:
77             return []
78
79         if not isinstance(collection, arvados.collection.RichCollectionBase):
80             return []
81
82         ret = []
83         # iterate over the files and subcollections in 'collection'
84         for filename in collection:
85             if patternsegments[0] == '.':
86                 # Pattern contains something like "./foo" so just shift
87                 # past the "./"
88                 ret.extend(self._match(collection, patternsegments[1:], parent))
89             elif fnmatch.fnmatch(filename, patternsegments[0]):
90                 cur = os.path.join(parent, filename)
91                 if len(patternsegments) == 1:
92                     ret.append(cur)
93                 else:
94                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
95         return ret
96
97     def glob(self, pattern):
98         collection, rest = self.get_collection(pattern)
99         patternsegments = rest.split("/")
100         return self._match(collection, patternsegments, "keep:" + collection.manifest_locator())
101
102     def open(self, fn, mode):
103         collection, rest = self.get_collection(fn)
104         if collection:
105             return collection.open(rest, mode)
106         else:
107             return open(self._abs(fn), mode)
108
109     def exists(self, fn):
110         collection, rest = self.get_collection(fn)
111         if collection:
112             return collection.exists(rest)
113         else:
114             return os.path.exists(self._abs(fn))
115
116 class ArvadosJob(object):
117     def __init__(self, runner):
118         self.arvrunner = runner
119         self.running = False
120
121     def run(self, dry_run=False, pull_image=True, **kwargs):
122         script_parameters = {
123             "command": self.command_line
124         }
125         runtime_constraints = {}
126
127         if self.generatefiles:
128             vwd = arvados.collection.Collection()
129             script_parameters["task.vwd"] = {}
130             for t in self.generatefiles:
131                 if isinstance(self.generatefiles[t], dict):
132                     src, rest = self.arvrunner.fs_access.get_collection(self.generatefiles[t]["path"].replace("$(task.keep)/", "keep:"))
133                     vwd.copy(rest, t, source_collection=src)
134                 else:
135                     with vwd.open(t, "w") as f:
136                         f.write(self.generatefiles[t])
137             vwd.save_new()
138             for t in self.generatefiles:
139                 script_parameters["task.vwd"][t] = "$(task.keep)/%s/%s" % (vwd.portable_data_hash(), t)
140
141         script_parameters["task.env"] = {"TMPDIR": "$(task.tmpdir)"}
142         if self.environment:
143             script_parameters["task.env"].update(self.environment)
144
145         if self.stdin:
146             script_parameters["task.stdin"] = self.pathmapper.mapper(self.stdin)[1]
147
148         if self.stdout:
149             script_parameters["task.stdout"] = self.stdout
150
151         (docker_req, docker_is_req) = get_feature(self, "DockerRequirement")
152         if docker_req and kwargs.get("use_container") is not False:
153             runtime_constraints["docker_image"] = arv_docker_get_image(self.arvrunner.api, docker_req, pull_image)
154
155         resources = self.builder.resources
156         if resources is not None:
157             keys = resources.keys()
158             if "coresMin" in keys:
159                 try:
160                     runtime_constraints["min_cores_per_node"] = int(resources["coresMin"])
161                 except:
162                     runtime_constraints["min_cores_per_node"] = None
163             if "ramMin" in keys:
164                 try:
165                     runtime_constraints["min_ram_mb_per_node"] = int(resources["ramMin"])
166                 except:
167                     runtime_constraints["min_ram_mb_per_node"] = None
168             if "tmpdirMin" in keys:
169                 try:
170                     runtime_constraints["min_scratch_mb_per_node"] = int(resources["tmpdirMin"])
171                 except:
172                     runtime_constraints["min_scratch_mb_per_node"] = None
173
174         try:
175             response = self.arvrunner.api.jobs().create(body={
176                 "script": "crunchrunner",
177                 "repository": "arvados",
178                 "script_version": "8488-cwl-crunchrunner-collection",
179                 "script_parameters": {"tasks": [script_parameters], "crunchrunner": crunchrunner_pdh+"/crunchrunner"},
180                 "runtime_constraints": runtime_constraints
181             }, find_or_create=kwargs.get("enable_reuse", True)).execute(num_retries=self.arvrunner.num_retries)
182
183             self.arvrunner.jobs[response["uuid"]] = self
184
185             self.arvrunner.pipeline["components"][self.name] = {"job": response}
186             self.arvrunner.pipeline = self.arvrunner.api.pipeline_instances().update(uuid=self.arvrunner.pipeline["uuid"],
187                                                                                      body={
188                                                                                          "components": self.arvrunner.pipeline["components"]
189                                                                                      }).execute(num_retries=self.arvrunner.num_retries)
190
191             logger.info("Job %s (%s) is %s", self.name, response["uuid"], response["state"])
192
193             if response["state"] in ("Complete", "Failed", "Cancelled"):
194                 self.done(response)
195         except Exception as e:
196             logger.error("Got error %s" % str(e))
197             self.output_callback({}, "permanentFail")
198
199     def update_pipeline_component(self, record):
200         self.arvrunner.pipeline["components"][self.name] = {"job": record}
201         self.arvrunner.pipeline = self.arvrunner.api.pipeline_instances().update(uuid=self.arvrunner.pipeline["uuid"],
202                                                                                  body={
203                                                                                     "components": self.arvrunner.pipeline["components"]
204                                                                                  }).execute(num_retries=self.arvrunner.num_retries)
205
206     def done(self, record):
207         try:
208             self.update_pipeline_component(record)
209         except:
210             pass
211
212         try:
213             if record["state"] == "Complete":
214                 processStatus = "success"
215             else:
216                 processStatus = "permanentFail"
217
218             try:
219                 outputs = {}
220                 if record["output"]:
221                     logc = arvados.collection.Collection(record["log"])
222                     log = logc.open(logc.keys()[0])
223                     tmpdir = None
224                     outdir = None
225                     keepdir = None
226                     for l in log.readlines():
227                         g = tmpdirre.match(l)
228                         if g:
229                             tmpdir = g.group(1)
230                         g = outdirre.match(l)
231                         if g:
232                             outdir = g.group(1)
233                         g = keepre.match(l)
234                         if g:
235                             keepdir = g.group(1)
236                         if tmpdir and outdir and keepdir:
237                             break
238
239                     self.builder.outdir = outdir
240                     self.builder.pathmapper.keepdir = keepdir
241                     outputs = self.collect_outputs("keep:" + record["output"])
242             except Exception as e:
243                 logger.exception("Got exception while collecting job outputs:")
244                 processStatus = "permanentFail"
245
246             self.output_callback(outputs, processStatus)
247         finally:
248             del self.arvrunner.jobs[record["uuid"]]
249
250
251 class ArvPathMapper(cwltool.pathmapper.PathMapper):
252     def __init__(self, arvrunner, referenced_files, basedir, **kwargs):
253         self._pathmap = arvrunner.get_uploaded()
254         uploadfiles = []
255
256         pdh_path = re.compile(r'^keep:[0-9a-f]{32}\+\d+/.+')
257
258         for src in referenced_files:
259             if isinstance(src, basestring) and pdh_path.match(src):
260                 self._pathmap[src] = (src, "$(task.keep)/%s" % src[5:])
261             if src not in self._pathmap:
262                 ab = cwltool.pathmapper.abspath(src, basedir)
263                 st = arvados.commands.run.statfile("", ab, fnPattern="$(task.keep)/%s/%s")
264                 if kwargs.get("conformance_test"):
265                     self._pathmap[src] = (src, ab)
266                 elif isinstance(st, arvados.commands.run.UploadFile):
267                     uploadfiles.append((src, ab, st))
268                 elif isinstance(st, arvados.commands.run.ArvFile):
269                     self._pathmap[src] = (ab, st.fn)
270                 else:
271                     raise cwltool.workflow.WorkflowException("Input file path '%s' is invalid" % st)
272
273         if uploadfiles:
274             arvados.commands.run.uploadfiles([u[2] for u in uploadfiles],
275                                              arvrunner.api,
276                                              dry_run=kwargs.get("dry_run"),
277                                              num_retries=3,
278                                              fnPattern="$(task.keep)/%s/%s")
279
280         for src, ab, st in uploadfiles:
281             arvrunner.add_uploaded(src, (ab, st.fn))
282             self._pathmap[src] = (ab, st.fn)
283
284         self.keepdir = None
285
286     def reversemap(self, target):
287         if target.startswith("keep:"):
288             return target
289         elif self.keepdir and target.startswith(self.keepdir):
290             return "keep:" + target[len(self.keepdir)+1:]
291         else:
292             return super(ArvPathMapper, self).reversemap(target)
293
294
295 class ArvadosCommandTool(cwltool.draft2tool.CommandLineTool):
296     def __init__(self, arvrunner, toolpath_object, **kwargs):
297         super(ArvadosCommandTool, self).__init__(toolpath_object, **kwargs)
298         self.arvrunner = arvrunner
299
300     def makeJobRunner(self):
301         return ArvadosJob(self.arvrunner)
302
303     def makePathMapper(self, reffiles, input_basedir, **kwargs):
304         return ArvPathMapper(self.arvrunner, reffiles, input_basedir, **kwargs)
305
306
307 class ArvCwlRunner(object):
308     def __init__(self, api_client):
309         self.api = api_client
310         self.jobs = {}
311         self.lock = threading.Lock()
312         self.cond = threading.Condition(self.lock)
313         self.final_output = None
314         self.uploaded = {}
315         self.num_retries = 4
316
317     def arvMakeTool(self, toolpath_object, **kwargs):
318         if "class" in toolpath_object and toolpath_object["class"] == "CommandLineTool":
319             return ArvadosCommandTool(self, toolpath_object, **kwargs)
320         else:
321             return cwltool.workflow.defaultMakeTool(toolpath_object, **kwargs)
322
323     def output_callback(self, out, processStatus):
324         if processStatus == "success":
325             logger.info("Overall job status is %s", processStatus)
326             self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
327                                                  body={"state": "Complete"}).execute(num_retries=self.num_retries)
328
329         else:
330             logger.warn("Overall job status is %s", processStatus)
331             self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
332                                                  body={"state": "Failed"}).execute(num_retries=self.num_retries)
333         self.final_output = out
334
335
336     def on_message(self, event):
337         if "object_uuid" in event:
338                 if event["object_uuid"] in self.jobs and event["event_type"] == "update":
339                     if event["properties"]["new_attributes"]["state"] == "Running" and self.jobs[event["object_uuid"]].running is False:
340                         uuid = event["object_uuid"]
341                         with self.lock:
342                             j = self.jobs[uuid]
343                             logger.info("Job %s (%s) is Running", j.name, uuid)
344                             j.running = True
345                             j.update_pipeline_component(event["properties"]["new_attributes"])
346                     elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled"):
347                         uuid = event["object_uuid"]
348                         try:
349                             self.cond.acquire()
350                             j = self.jobs[uuid]
351                             logger.info("Job %s (%s) is %s", j.name, uuid, event["properties"]["new_attributes"]["state"])
352                             j.done(event["properties"]["new_attributes"])
353                             self.cond.notify()
354                         finally:
355                             self.cond.release()
356
357     def get_uploaded(self):
358         return self.uploaded.copy()
359
360     def add_uploaded(self, src, pair):
361         self.uploaded[src] = pair
362
363     def arvExecutor(self, tool, job_order, input_basedir, args, **kwargs):
364         events = arvados.events.subscribe(arvados.api('v1'), [["object_uuid", "is_a", "arvados#job"]], self.on_message)
365
366         try:
367             self.api.collections().get(uuid=crunchrunner_pdh).execute()
368         except arvados.errors.ApiError as e:
369             import httplib2
370             h = httplib2.Http(ca_certs=arvados.util.ca_certs_path())
371             resp, content = h.request(crunchrunner_download, "GET")
372             resp2, content2 = h.request(certs_download, "GET")
373             with arvados.collection.Collection() as col:
374                 with col.open("crunchrunner", "w") as f:
375                     f.write(content)
376                 with col.open("ca-certificates.crt", "w") as f:
377                     f.write(content2)
378
379                 col.save_new("crunchrunner binary", ensure_unique_name=True)
380
381         self.fs_access = CollectionFsAccess(input_basedir)
382
383         kwargs["fs_access"] = self.fs_access
384         kwargs["enable_reuse"] = args.enable_reuse
385
386         kwargs["outdir"] = "$(task.outdir)"
387         kwargs["tmpdir"] = "$(task.tmpdir)"
388
389         if kwargs.get("conformance_test"):
390             return cwltool.main.single_job_executor(tool, job_order, input_basedir, args, **kwargs)
391         else:
392             self.pipeline = self.api.pipeline_instances().create(body={"name": shortname(tool.tool["id"]),
393                                                                    "components": {},
394                                                                    "state": "RunningOnClient"}).execute(num_retries=self.num_retries)
395
396             jobiter = tool.job(job_order,
397                                input_basedir,
398                                self.output_callback,
399                                docker_outdir="$(task.outdir)",
400                                **kwargs)
401
402             try:
403                 for runnable in jobiter:
404                     if runnable:
405                         with self.lock:
406                             runnable.run(**kwargs)
407                     else:
408                         if self.jobs:
409                             try:
410                                 self.cond.acquire()
411                                 self.cond.wait(1)
412                             except RuntimeError:
413                                 pass
414                             finally:
415                                 self.cond.release()
416                         else:
417                             logger.error("Workflow cannot make any more progress.")
418                             break
419
420                 while self.jobs:
421                     try:
422                         self.cond.acquire()
423                         self.cond.wait(1)
424                     except RuntimeError:
425                         pass
426                     finally:
427                         self.cond.release()
428
429                 events.close()
430
431                 if self.final_output is None:
432                     raise cwltool.workflow.WorkflowException("Workflow did not return a result.")
433
434             except:
435                 if sys.exc_info()[0] is not KeyboardInterrupt:
436                     logger.exception("Caught unhandled exception, marking pipeline as failed")
437                 self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
438                                                      body={"state": "Failed"}).execute(num_retries=self.num_retries)
439
440             return self.final_output
441
442
443 def main(args, stdout, stderr, api_client=None):
444     args.insert(0, "--leave-outputs")
445     parser = cwltool.main.arg_parser()
446     exgroup = parser.add_mutually_exclusive_group()
447     exgroup.add_argument("--enable-reuse", action="store_true",
448                         default=False, dest="enable_reuse",
449                         help="")
450     exgroup.add_argument("--disable-reuse", action="store_false",
451                         default=False, dest="enable_reuse",
452                         help="")
453
454     try:
455         runner = ArvCwlRunner(api_client=arvados.api('v1', model=OrderedJsonModel()))
456     except Exception as e:
457         logger.error(e)
458         return 1
459
460     return cwltool.main.main(args, executor=runner.arvExecutor, makeTool=runner.arvMakeTool, parser=parser)