1ba9537884a170fb65aba651521a91521f59909e
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
1 #!/usr/bin/env python
2
3 import argparse
4 import arvados
5 import arvados.events
6 import arvados.commands.keepdocker
7 import arvados.commands.run
8 import arvados.collection
9 import arvados.util
10 import cwltool.draft2tool
11 import cwltool.workflow
12 import cwltool.main
13 from cwltool.process import shortname
14 import threading
15 import cwltool.docker
16 import fnmatch
17 import logging
18 import re
19 import os
20 import sys
21
22 from cwltool.process import get_feature
23 from arvados.api import OrderedJsonModel
24
25 logger = logging.getLogger('arvados.cwl-runner')
26 logger.setLevel(logging.INFO)
27
28 crunchrunner_pdh = "83db29f08544e1c319572a6bd971088a+140"
29 crunchrunner_download = "https://cloud.curoverse.com/collections/download/qr1hi-4zz18-n3m1yxd0vx78jic/1i1u2qtq66k1atziv4ocfgsg5nu5tj11n4r6e0bhvjg03rix4m/crunchrunner"
30 certs_download = "https://cloud.curoverse.com/collections/download/qr1hi-4zz18-n3m1yxd0vx78jic/1i1u2qtq66k1atziv4ocfgsg5nu5tj11n4r6e0bhvjg03rix4m/ca-certificates.crt"
31
32 tmpdirre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.tmpdir\)=(.*)")
33 outdirre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.outdir\)=(.*)")
34 keepre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.keep\)=(.*)")
35
36
37 def arv_docker_get_image(api_client, dockerRequirement, pull_image, project_uuid):
38     if "dockerImageId" not in dockerRequirement and "dockerPull" in dockerRequirement:
39         dockerRequirement["dockerImageId"] = dockerRequirement["dockerPull"]
40
41     sp = dockerRequirement["dockerImageId"].split(":")
42     image_name = sp[0]
43     image_tag = sp[1] if len(sp) > 1 else None
44
45     images = arvados.commands.keepdocker.list_images_in_arv(api_client, 3,
46                                                             image_name=image_name,
47                                                             image_tag=image_tag)
48
49     if not images:
50         imageId = cwltool.docker.get_image(dockerRequirement, pull_image)
51         args = ["--project-uuid="+project_uuid, image_name]
52         if image_tag:
53             args.append(image_tag)
54         logger.info("Uploading Docker image %s", ":".join(args[1:]))
55         arvados.commands.keepdocker.main(args)
56
57     return dockerRequirement["dockerImageId"]
58
59
60 class CollectionFsAccess(cwltool.process.StdFsAccess):
61     def __init__(self, basedir):
62         self.collections = {}
63         self.basedir = basedir
64
65     def get_collection(self, path):
66         p = path.split("/")
67         if p[0].startswith("keep:") and arvados.util.keep_locator_pattern.match(p[0][5:]):
68             pdh = p[0][5:]
69             if pdh not in self.collections:
70                 self.collections[pdh] = arvados.collection.CollectionReader(pdh)
71             return (self.collections[pdh], "/".join(p[1:]))
72         else:
73             return (None, path)
74
75     def _match(self, collection, patternsegments, parent):
76         if not patternsegments:
77             return []
78
79         if not isinstance(collection, arvados.collection.RichCollectionBase):
80             return []
81
82         ret = []
83         # iterate over the files and subcollections in 'collection'
84         for filename in collection:
85             if patternsegments[0] == '.':
86                 # Pattern contains something like "./foo" so just shift
87                 # past the "./"
88                 ret.extend(self._match(collection, patternsegments[1:], parent))
89             elif fnmatch.fnmatch(filename, patternsegments[0]):
90                 cur = os.path.join(parent, filename)
91                 if len(patternsegments) == 1:
92                     ret.append(cur)
93                 else:
94                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
95         return ret
96
97     def glob(self, pattern):
98         collection, rest = self.get_collection(pattern)
99         patternsegments = rest.split("/")
100         return self._match(collection, patternsegments, "keep:" + collection.manifest_locator())
101
102     def open(self, fn, mode):
103         collection, rest = self.get_collection(fn)
104         if collection:
105             return collection.open(rest, mode)
106         else:
107             return open(self._abs(fn), mode)
108
109     def exists(self, fn):
110         collection, rest = self.get_collection(fn)
111         if collection:
112             return collection.exists(rest)
113         else:
114             return os.path.exists(self._abs(fn))
115
116 class ArvadosJob(object):
117     def __init__(self, runner):
118         self.arvrunner = runner
119         self.running = False
120
121     def run(self, dry_run=False, pull_image=True, **kwargs):
122         script_parameters = {
123             "command": self.command_line
124         }
125         runtime_constraints = {}
126
127         if self.generatefiles:
128             vwd = arvados.collection.Collection()
129             script_parameters["task.vwd"] = {}
130             for t in self.generatefiles:
131                 if isinstance(self.generatefiles[t], dict):
132                     src, rest = self.arvrunner.fs_access.get_collection(self.generatefiles[t]["path"].replace("$(task.keep)/", "keep:"))
133                     vwd.copy(rest, t, source_collection=src)
134                 else:
135                     with vwd.open(t, "w") as f:
136                         f.write(self.generatefiles[t])
137             vwd.save_new()
138             for t in self.generatefiles:
139                 script_parameters["task.vwd"][t] = "$(task.keep)/%s/%s" % (vwd.portable_data_hash(), t)
140
141         script_parameters["task.env"] = {"TMPDIR": "$(task.tmpdir)"}
142         if self.environment:
143             script_parameters["task.env"].update(self.environment)
144
145         if self.stdin:
146             script_parameters["task.stdin"] = self.pathmapper.mapper(self.stdin)[1]
147
148         if self.stdout:
149             script_parameters["task.stdout"] = self.stdout
150
151         (docker_req, docker_is_req) = get_feature(self, "DockerRequirement")
152         if docker_req and kwargs.get("use_container") is not False:
153             runtime_constraints["docker_image"] = arv_docker_get_image(self.arvrunner.api, docker_req, pull_image, self.arvrunner.project_uuid)
154
155         try:
156             response = self.arvrunner.api.jobs().create(body={
157                 "owner_uuid": self.arvrunner.project_uuid,
158                 "script": "crunchrunner",
159                 "repository": "arvados",
160                 "script_version": "master",
161                 "minimum_script_version": "9e5b98e8f5f4727856b53447191f9c06e3da2ba6",
162                 "script_parameters": {"tasks": [script_parameters], "crunchrunner": crunchrunner_pdh+"/crunchrunner"},
163                 "runtime_constraints": runtime_constraints
164             }, find_or_create=kwargs.get("enable_reuse", True)).execute(num_retries=self.arvrunner.num_retries)
165
166             self.arvrunner.jobs[response["uuid"]] = self
167
168             self.arvrunner.pipeline["components"][self.name] = {"job": response}
169             self.arvrunner.pipeline = self.arvrunner.api.pipeline_instances().update(uuid=self.arvrunner.pipeline["uuid"],
170                                                                                      body={
171                                                                                          "components": self.arvrunner.pipeline["components"]
172                                                                                      }).execute(num_retries=self.arvrunner.num_retries)
173
174             logger.info("Job %s (%s) is %s", self.name, response["uuid"], response["state"])
175
176             if response["state"] in ("Complete", "Failed", "Cancelled"):
177                 self.done(response)
178         except Exception as e:
179             logger.error("Got error %s" % str(e))
180             self.output_callback({}, "permanentFail")
181
182     def update_pipeline_component(self, record):
183         self.arvrunner.pipeline["components"][self.name] = {"job": record}
184         self.arvrunner.pipeline = self.arvrunner.api.pipeline_instances().update(uuid=self.arvrunner.pipeline["uuid"],
185                                                                                  body={
186                                                                                     "components": self.arvrunner.pipeline["components"]
187                                                                                  }).execute(num_retries=self.arvrunner.num_retries)
188
189     def done(self, record):
190         try:
191             self.update_pipeline_component(record)
192         except:
193             pass
194
195         try:
196             if record["state"] == "Complete":
197                 processStatus = "success"
198             else:
199                 processStatus = "permanentFail"
200
201             try:
202                 outputs = {}
203                 if record["output"]:
204                     logc = arvados.collection.Collection(record["log"])
205                     log = logc.open(logc.keys()[0])
206                     tmpdir = None
207                     outdir = None
208                     keepdir = None
209                     for l in log.readlines():
210                         g = tmpdirre.match(l)
211                         if g:
212                             tmpdir = g.group(1)
213                         g = outdirre.match(l)
214                         if g:
215                             outdir = g.group(1)
216                         g = keepre.match(l)
217                         if g:
218                             keepdir = g.group(1)
219                         if tmpdir and outdir and keepdir:
220                             break
221
222                     self.builder.outdir = outdir
223                     self.builder.pathmapper.keepdir = keepdir
224                     outputs = self.collect_outputs("keep:" + record["output"])
225             except Exception as e:
226                 logger.exception("Got exception while collecting job outputs:")
227                 processStatus = "permanentFail"
228
229             self.output_callback(outputs, processStatus)
230         finally:
231             del self.arvrunner.jobs[record["uuid"]]
232
233
234 class ArvPathMapper(cwltool.pathmapper.PathMapper):
235     def __init__(self, arvrunner, referenced_files, basedir, **kwargs):
236         self._pathmap = arvrunner.get_uploaded()
237         uploadfiles = []
238
239         pdh_path = re.compile(r'^keep:[0-9a-f]{32}\+\d+/.+')
240
241         for src in referenced_files:
242             if isinstance(src, basestring) and pdh_path.match(src):
243                 self._pathmap[src] = (src, "$(task.keep)/%s" % src[5:])
244             if src not in self._pathmap:
245                 ab = cwltool.pathmapper.abspath(src, basedir)
246                 st = arvados.commands.run.statfile("", ab, fnPattern="$(task.keep)/%s/%s")
247                 if kwargs.get("conformance_test"):
248                     self._pathmap[src] = (src, ab)
249                 elif isinstance(st, arvados.commands.run.UploadFile):
250                     uploadfiles.append((src, ab, st))
251                 elif isinstance(st, arvados.commands.run.ArvFile):
252                     self._pathmap[src] = (ab, st.fn)
253                 else:
254                     raise cwltool.workflow.WorkflowException("Input file path '%s' is invalid" % st)
255
256         if uploadfiles:
257             arvados.commands.run.uploadfiles([u[2] for u in uploadfiles],
258                                              arvrunner.api,
259                                              dry_run=kwargs.get("dry_run"),
260                                              num_retries=3,
261                                              fnPattern="$(task.keep)/%s/%s",
262                                              project=arvrunner.project_uuid)
263
264         for src, ab, st in uploadfiles:
265             arvrunner.add_uploaded(src, (ab, st.fn))
266             self._pathmap[src] = (ab, st.fn)
267
268         self.keepdir = None
269
270     def reversemap(self, target):
271         if target.startswith("keep:"):
272             return (target, target)
273         elif self.keepdir and target.startswith(self.keepdir):
274             return (target, "keep:" + target[len(self.keepdir)+1:])
275         else:
276             return super(ArvPathMapper, self).reversemap(target)
277
278
279 class ArvadosCommandTool(cwltool.draft2tool.CommandLineTool):
280     def __init__(self, arvrunner, toolpath_object, **kwargs):
281         super(ArvadosCommandTool, self).__init__(toolpath_object, **kwargs)
282         self.arvrunner = arvrunner
283
284     def makeJobRunner(self):
285         return ArvadosJob(self.arvrunner)
286
287     def makePathMapper(self, reffiles, input_basedir, **kwargs):
288         return ArvPathMapper(self.arvrunner, reffiles, input_basedir, **kwargs)
289
290
291 class ArvCwlRunner(object):
292     def __init__(self, api_client):
293         self.api = api_client
294         self.jobs = {}
295         self.lock = threading.Lock()
296         self.cond = threading.Condition(self.lock)
297         self.final_output = None
298         self.uploaded = {}
299         self.num_retries = 4
300
301     def arvMakeTool(self, toolpath_object, **kwargs):
302         if "class" in toolpath_object and toolpath_object["class"] == "CommandLineTool":
303             return ArvadosCommandTool(self, toolpath_object, **kwargs)
304         else:
305             return cwltool.workflow.defaultMakeTool(toolpath_object, **kwargs)
306
307     def output_callback(self, out, processStatus):
308         if processStatus == "success":
309             logger.info("Overall job status is %s", processStatus)
310             self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
311                                                  body={"state": "Complete"}).execute(num_retries=self.num_retries)
312
313         else:
314             logger.warn("Overall job status is %s", processStatus)
315             self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
316                                                  body={"state": "Failed"}).execute(num_retries=self.num_retries)
317         self.final_output = out
318
319
320     def on_message(self, event):
321         if "object_uuid" in event:
322             if event["object_uuid"] in self.jobs and event["event_type"] == "update":
323                 if event["properties"]["new_attributes"]["state"] == "Running" and self.jobs[event["object_uuid"]].running is False:
324                     uuid = event["object_uuid"]
325                     with self.lock:
326                         j = self.jobs[uuid]
327                         logger.info("Job %s (%s) is Running", j.name, uuid)
328                         j.running = True
329                         j.update_pipeline_component(event["properties"]["new_attributes"])
330                 elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled"):
331                     uuid = event["object_uuid"]
332                     try:
333                         self.cond.acquire()
334                         j = self.jobs[uuid]
335                         logger.info("Job %s (%s) is %s", j.name, uuid, event["properties"]["new_attributes"]["state"])
336                         j.done(event["properties"]["new_attributes"])
337                         self.cond.notify()
338                     finally:
339                         self.cond.release()
340
341     def get_uploaded(self):
342         return self.uploaded.copy()
343
344     def add_uploaded(self, src, pair):
345         self.uploaded[src] = pair
346
347     def arvExecutor(self, tool, job_order, input_basedir, args, **kwargs):
348         events = arvados.events.subscribe(arvados.api('v1'), [["object_uuid", "is_a", "arvados#job"]], self.on_message)
349
350         try:
351             self.api.collections().get(uuid=crunchrunner_pdh).execute()
352         except arvados.errors.ApiError as e:
353             import httplib2
354             h = httplib2.Http(ca_certs=arvados.util.ca_certs_path())
355             resp, content = h.request(crunchrunner_download, "GET")
356             resp2, content2 = h.request(certs_download, "GET")
357             with arvados.collection.Collection() as col:
358                 with col.open("crunchrunner", "w") as f:
359                     f.write(content)
360                 with col.open("ca-certificates.crt", "w") as f:
361                     f.write(content2)
362
363                 col.save_new("crunchrunner binary", ensure_unique_name=True)
364
365         self.fs_access = CollectionFsAccess(input_basedir)
366
367         kwargs["fs_access"] = self.fs_access
368         if args:
369             kwargs["enable_reuse"] = args.enable_reuse
370
371         kwargs["outdir"] = "$(task.outdir)"
372         kwargs["tmpdir"] = "$(task.tmpdir)"
373
374         useruuid = self.api.users().current().execute()["uuid"]
375         self.project_uuid = args.project_uuid if args.project_uuid else useruuid
376
377         if kwargs.get("conformance_test"):
378             return cwltool.main.single_job_executor(tool, job_order, input_basedir, args, **kwargs)
379         else:
380             self.pipeline = self.api.pipeline_instances().create(
381                 body={
382                     "owner_uuid": self.project_uuid,
383                     "name": shortname(tool.tool["id"]),
384                     "components": {},
385                     "state": "RunningOnClient"}).execute(num_retries=self.num_retries)
386
387             logger.info("Pipeline instance %s", self.pipeline["uuid"])
388
389             jobiter = tool.job(job_order,
390                                input_basedir,
391                                self.output_callback,
392                                docker_outdir="$(task.outdir)",
393                                **kwargs)
394
395             try:
396                 self.cond.acquire()
397                 # Will continue to hold the lock for the duration of this code
398                 # except when in cond.wait(), at which point on_message can update
399                 # job state and process output callbacks.
400
401                 for runnable in jobiter:
402                     if runnable:
403                         runnable.run(**kwargs)
404                     else:
405                         if self.jobs:
406                             self.cond.wait(1)
407                         else:
408                             logger.error("Workflow is deadlocked, no runnable jobs and not waiting on any pending jobs.")
409                             break
410
411                 while self.jobs:
412                     self.cond.wait(1)
413
414                 events.close()
415
416                 if self.final_output is None:
417                     raise cwltool.workflow.WorkflowException("Workflow did not return a result.")
418
419                 # create final output collection
420             except:
421                 if sys.exc_info()[0] is KeyboardInterrupt:
422                     logger.error("Interrupted, marking pipeline as failed")
423                 else:
424                     logger.exception("Caught unhandled exception, marking pipeline as failed")
425                 self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
426                                                      body={"state": "Failed"}).execute(num_retries=self.num_retries)
427             finally:
428                 self.cond.release()
429
430             return self.final_output
431
432
433 def main(args, stdout, stderr, api_client=None):
434     args.insert(0, "--leave-outputs")
435     parser = cwltool.main.arg_parser()
436     exgroup = parser.add_mutually_exclusive_group()
437     exgroup.add_argument("--enable-reuse", action="store_true",
438                         default=True, dest="enable_reuse",
439                         help="")
440     exgroup.add_argument("--disable-reuse", action="store_false",
441                         default=True, dest="enable_reuse",
442                         help="")
443     parser.add_argument("--project-uuid", type=str, help="Project that will own the workflow jobs")
444
445     try:
446         runner = ArvCwlRunner(api_client=arvados.api('v1', model=OrderedJsonModel()))
447     except Exception as e:
448         logger.error(e)
449         return 1
450
451     return cwltool.main.main(args, executor=runner.arvExecutor, makeTool=runner.arvMakeTool, parser=parser)