8766: Refactor logic to copy output collection and add tests.
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
1 #!/usr/bin/env python
2
3 import argparse
4 import arvados
5 import arvados.events
6 import arvados.commands.keepdocker
7 import arvados.commands.run
8 import arvados.collection
9 import arvados.util
10 import cwltool.draft2tool
11 import cwltool.workflow
12 import cwltool.main
13 from cwltool.process import shortname
14 from cwltool.errors import WorkflowException
15 import threading
16 import cwltool.docker
17 import fnmatch
18 import logging
19 import re
20 import os
21 import sys
22
23 from cwltool.process import get_feature
24 from arvados.api import OrderedJsonModel
25
26 logger = logging.getLogger('arvados.cwl-runner')
27 logger.setLevel(logging.INFO)
28
29 crunchrunner_pdh = "ff6fc71e593081ef9733afacaeee15ea+140"
30 crunchrunner_download = "https://cloud.curoverse.com/collections/download/qr1hi-4zz18-n3m1yxd0vx78jic/1i1u2qtq66k1atziv4ocfgsg5nu5tj11n4r6e0bhvjg03rix4m/crunchrunner"
31 certs_download = "https://cloud.curoverse.com/collections/download/qr1hi-4zz18-n3m1yxd0vx78jic/1i1u2qtq66k1atziv4ocfgsg5nu5tj11n4r6e0bhvjg03rix4m/ca-certificates.crt"
32
33 tmpdirre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.tmpdir\)=(.*)")
34 outdirre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.outdir\)=(.*)")
35 keepre = re.compile(r"^\S+ \S+ \d+ \d+ stderr \S+ \S+ crunchrunner: \$\(task\.keep\)=(.*)")
36
37
38 def arv_docker_get_image(api_client, dockerRequirement, pull_image, project_uuid):
39     if "dockerImageId" not in dockerRequirement and "dockerPull" in dockerRequirement:
40         dockerRequirement["dockerImageId"] = dockerRequirement["dockerPull"]
41
42     sp = dockerRequirement["dockerImageId"].split(":")
43     image_name = sp[0]
44     image_tag = sp[1] if len(sp) > 1 else None
45
46     images = arvados.commands.keepdocker.list_images_in_arv(api_client, 3,
47                                                             image_name=image_name,
48                                                             image_tag=image_tag)
49
50     if not images:
51         imageId = cwltool.docker.get_image(dockerRequirement, pull_image)
52         args = ["--project-uuid="+project_uuid, image_name]
53         if image_tag:
54             args.append(image_tag)
55         logger.info("Uploading Docker image %s", ":".join(args[1:]))
56         arvados.commands.keepdocker.main(args)
57
58     return dockerRequirement["dockerImageId"]
59
60
61 class CollectionFsAccess(cwltool.process.StdFsAccess):
62     def __init__(self, basedir):
63         self.collections = {}
64         self.basedir = basedir
65
66     def get_collection(self, path):
67         p = path.split("/")
68         if p[0].startswith("keep:") and arvados.util.keep_locator_pattern.match(p[0][5:]):
69             pdh = p[0][5:]
70             if pdh not in self.collections:
71                 self.collections[pdh] = arvados.collection.CollectionReader(pdh)
72             return (self.collections[pdh], "/".join(p[1:]))
73         else:
74             return (None, path)
75
76     def _match(self, collection, patternsegments, parent):
77         if not patternsegments:
78             return []
79
80         if not isinstance(collection, arvados.collection.RichCollectionBase):
81             return []
82
83         ret = []
84         # iterate over the files and subcollections in 'collection'
85         for filename in collection:
86             if patternsegments[0] == '.':
87                 # Pattern contains something like "./foo" so just shift
88                 # past the "./"
89                 ret.extend(self._match(collection, patternsegments[1:], parent))
90             elif fnmatch.fnmatch(filename, patternsegments[0]):
91                 cur = os.path.join(parent, filename)
92                 if len(patternsegments) == 1:
93                     ret.append(cur)
94                 else:
95                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
96         return ret
97
98     def glob(self, pattern):
99         collection, rest = self.get_collection(pattern)
100         patternsegments = rest.split("/")
101         return self._match(collection, patternsegments, "keep:" + collection.manifest_locator())
102
103     def open(self, fn, mode):
104         collection, rest = self.get_collection(fn)
105         if collection:
106             return collection.open(rest, mode)
107         else:
108             return open(self._abs(fn), mode)
109
110     def exists(self, fn):
111         collection, rest = self.get_collection(fn)
112         if collection:
113             return collection.exists(rest)
114         else:
115             return os.path.exists(self._abs(fn))
116
117 class ArvadosJob(object):
118     def __init__(self, runner):
119         self.arvrunner = runner
120         self.running = False
121
122     def run(self, dry_run=False, pull_image=True, **kwargs):
123         script_parameters = {
124             "command": self.command_line
125         }
126         runtime_constraints = {}
127
128         if self.generatefiles:
129             vwd = arvados.collection.Collection()
130             script_parameters["task.vwd"] = {}
131             for t in self.generatefiles:
132                 if isinstance(self.generatefiles[t], dict):
133                     src, rest = self.arvrunner.fs_access.get_collection(self.generatefiles[t]["path"].replace("$(task.keep)/", "keep:"))
134                     vwd.copy(rest, t, source_collection=src)
135                 else:
136                     with vwd.open(t, "w") as f:
137                         f.write(self.generatefiles[t])
138             vwd.save_new()
139             for t in self.generatefiles:
140                 script_parameters["task.vwd"][t] = "$(task.keep)/%s/%s" % (vwd.portable_data_hash(), t)
141
142         script_parameters["task.env"] = {"TMPDIR": "$(task.tmpdir)"}
143         if self.environment:
144             script_parameters["task.env"].update(self.environment)
145
146         if self.stdin:
147             script_parameters["task.stdin"] = self.pathmapper.mapper(self.stdin)[1]
148
149         if self.stdout:
150             script_parameters["task.stdout"] = self.stdout
151
152         (docker_req, docker_is_req) = get_feature(self, "DockerRequirement")
153         if docker_req and kwargs.get("use_container") is not False:
154             runtime_constraints["docker_image"] = arv_docker_get_image(self.arvrunner.api, docker_req, pull_image, self.arvrunner.project_uuid)
155
156         resources = self.builder.resources
157         if resources is not None:
158             runtime_constraints["min_cores_per_node"] = resources.get("cores", 1)
159             runtime_constraints["min_ram_mb_per_node"] = resources.get("ram")
160             runtime_constraints["min_scratch_mb_per_node"] = resources.get("tmpdirSize", 0) + resources.get("outdirSize", 0)
161
162         try:
163             response = self.arvrunner.api.jobs().create(body={
164                 "owner_uuid": self.arvrunner.project_uuid,
165                 "script": "crunchrunner",
166                 "repository": "arvados",
167                 "script_version": "master",
168                 "minimum_script_version": "9e5b98e8f5f4727856b53447191f9c06e3da2ba6",
169                 "script_parameters": {"tasks": [script_parameters], "crunchrunner": crunchrunner_pdh+"/crunchrunner"},
170                 "runtime_constraints": runtime_constraints
171             }, find_or_create=kwargs.get("enable_reuse", True)).execute(num_retries=self.arvrunner.num_retries)
172
173             self.arvrunner.jobs[response["uuid"]] = self
174
175             self.arvrunner.pipeline["components"][self.name] = {"job": response}
176             self.arvrunner.pipeline = self.arvrunner.api.pipeline_instances().update(uuid=self.arvrunner.pipeline["uuid"],
177                                                                                      body={
178                                                                                          "components": self.arvrunner.pipeline["components"]
179                                                                                      }).execute(num_retries=self.arvrunner.num_retries)
180
181             logger.info("Job %s (%s) is %s", self.name, response["uuid"], response["state"])
182
183             if response["state"] in ("Complete", "Failed", "Cancelled"):
184                 self.done(response)
185         except Exception as e:
186             logger.error("Got error %s" % str(e))
187             self.output_callback({}, "permanentFail")
188
189     def update_pipeline_component(self, record):
190         self.arvrunner.pipeline["components"][self.name] = {"job": record}
191         self.arvrunner.pipeline = self.arvrunner.api.pipeline_instances().update(uuid=self.arvrunner.pipeline["uuid"],
192                                                                                  body={
193                                                                                     "components": self.arvrunner.pipeline["components"]
194                                                                                  }).execute(num_retries=self.arvrunner.num_retries)
195
196     def done(self, record):
197         try:
198             self.update_pipeline_component(record)
199         except:
200             pass
201
202         try:
203             if record["state"] == "Complete":
204                 processStatus = "success"
205             else:
206                 processStatus = "permanentFail"
207
208             try:
209                 outputs = {}
210                 if record["output"]:
211                     logc = arvados.collection.Collection(record["log"])
212                     log = logc.open(logc.keys()[0])
213                     tmpdir = None
214                     outdir = None
215                     keepdir = None
216                     for l in log:
217                         # Determine the tmpdir, outdir and keepdir paths from
218                         # the job run.  Unfortunately, we can't take the first
219                         # values we find (which are expected to be near the
220                         # top) and stop scanning because if the node fails and
221                         # the job restarts on a different node these values
222                         # will different runs, and we need to know about the
223                         # final run that actually produced output.
224
225                         g = tmpdirre.match(l)
226                         if g:
227                             tmpdir = g.group(1)
228                         g = outdirre.match(l)
229                         if g:
230                             outdir = g.group(1)
231                         g = keepre.match(l)
232                         if g:
233                             keepdir = g.group(1)
234
235                     colname = "Output %s of %s" % (record["output"][0:7], self.name)
236
237                     # check if collection already exists with same owner, name and content
238                     collection_exists = self.arvrunner.api.collections().list(
239                         filters=[["owner_uuid", "=", self.arvrunner.project_uuid],
240                                  ['portable_data_hash', '=', record["output"]],
241                                  ["name", "=", colname]]
242                     ).execute(num_retries=self.arvrunner.num_retries)
243
244                     if not collection_exists["items"]:
245                         # Create a collection located in the same project as the
246                         # pipeline with the contents of the output.
247                         # First, get output record.
248                         collections = self.arvrunner.api.collections().list(
249                             limit=1,
250                             filters=[['portable_data_hash', '=', record["output"]]],
251                             select=["manifest_text"]
252                         ).execute(num_retries=self.arvrunner.num_retries)
253
254                         if not collections["items"]:
255                             raise WorkflowException(
256                                 "Job output '%s' cannot be found on API server" % (
257                                     record["output"]))
258
259                         # Create new collection in the parent project
260                         # with the output contents.
261                         self.arvrunner.api.collections().create(body={
262                             "owner_uuid": self.arvrunner.project_uuid,
263                             "name": colname,
264                             "portable_data_hash": record["output"],
265                             "manifest_text": collections["items"][0]["manifest_text"]
266                         }, ensure_unique_name=True).execute(
267                             num_retries=self.arvrunner.num_retries)
268
269                     self.builder.outdir = outdir
270                     self.builder.pathmapper.keepdir = keepdir
271                     outputs = self.collect_outputs("keep:" + record["output"])
272             except WorkflowException as e:
273                 logger.error("Error while collecting job outputs:\n%s", e, exc_info=(e if self.arvrunner.debug else False))
274                 processStatus = "permanentFail"
275             except Exception as e:
276                 logger.exception("Got unknown exception while collecting job outputs:")
277                 processStatus = "permanentFail"
278
279             self.output_callback(outputs, processStatus)
280         finally:
281             del self.arvrunner.jobs[record["uuid"]]
282
283
284 class ArvPathMapper(cwltool.pathmapper.PathMapper):
285     def __init__(self, arvrunner, referenced_files, basedir, **kwargs):
286         self._pathmap = arvrunner.get_uploaded()
287         uploadfiles = []
288
289         pdh_path = re.compile(r'^keep:[0-9a-f]{32}\+\d+/.+')
290
291         for src in referenced_files:
292             if isinstance(src, basestring) and pdh_path.match(src):
293                 self._pathmap[src] = (src, "$(task.keep)/%s" % src[5:])
294             if src not in self._pathmap:
295                 ab = cwltool.pathmapper.abspath(src, basedir)
296                 st = arvados.commands.run.statfile("", ab, fnPattern="$(task.keep)/%s/%s")
297                 if kwargs.get("conformance_test"):
298                     self._pathmap[src] = (src, ab)
299                 elif isinstance(st, arvados.commands.run.UploadFile):
300                     uploadfiles.append((src, ab, st))
301                 elif isinstance(st, arvados.commands.run.ArvFile):
302                     self._pathmap[src] = (ab, st.fn)
303                 else:
304                     raise cwltool.workflow.WorkflowException("Input file path '%s' is invalid" % st)
305
306         if uploadfiles:
307             arvados.commands.run.uploadfiles([u[2] for u in uploadfiles],
308                                              arvrunner.api,
309                                              dry_run=kwargs.get("dry_run"),
310                                              num_retries=3,
311                                              fnPattern="$(task.keep)/%s/%s",
312                                              project=arvrunner.project_uuid)
313
314         for src, ab, st in uploadfiles:
315             arvrunner.add_uploaded(src, (ab, st.fn))
316             self._pathmap[src] = (ab, st.fn)
317
318         self.keepdir = None
319
320     def reversemap(self, target):
321         if target.startswith("keep:"):
322             return (target, target)
323         elif self.keepdir and target.startswith(self.keepdir):
324             return (target, "keep:" + target[len(self.keepdir)+1:])
325         else:
326             return super(ArvPathMapper, self).reversemap(target)
327
328
329 class ArvadosCommandTool(cwltool.draft2tool.CommandLineTool):
330     def __init__(self, arvrunner, toolpath_object, **kwargs):
331         super(ArvadosCommandTool, self).__init__(toolpath_object, **kwargs)
332         self.arvrunner = arvrunner
333
334     def makeJobRunner(self):
335         return ArvadosJob(self.arvrunner)
336
337     def makePathMapper(self, reffiles, input_basedir, **kwargs):
338         return ArvPathMapper(self.arvrunner, reffiles, input_basedir, **kwargs)
339
340
341 class ArvCwlRunner(object):
342     def __init__(self, api_client):
343         self.api = api_client
344         self.jobs = {}
345         self.lock = threading.Lock()
346         self.cond = threading.Condition(self.lock)
347         self.final_output = None
348         self.uploaded = {}
349         self.num_retries = 4
350
351     def arvMakeTool(self, toolpath_object, **kwargs):
352         if "class" in toolpath_object and toolpath_object["class"] == "CommandLineTool":
353             return ArvadosCommandTool(self, toolpath_object, **kwargs)
354         else:
355             return cwltool.workflow.defaultMakeTool(toolpath_object, **kwargs)
356
357     def output_callback(self, out, processStatus):
358         if processStatus == "success":
359             logger.info("Overall job status is %s", processStatus)
360             self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
361                                                  body={"state": "Complete"}).execute(num_retries=self.num_retries)
362
363         else:
364             logger.warn("Overall job status is %s", processStatus)
365             self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
366                                                  body={"state": "Failed"}).execute(num_retries=self.num_retries)
367         self.final_output = out
368
369
370     def on_message(self, event):
371         if "object_uuid" in event:
372             if event["object_uuid"] in self.jobs and event["event_type"] == "update":
373                 if event["properties"]["new_attributes"]["state"] == "Running" and self.jobs[event["object_uuid"]].running is False:
374                     uuid = event["object_uuid"]
375                     with self.lock:
376                         j = self.jobs[uuid]
377                         logger.info("Job %s (%s) is Running", j.name, uuid)
378                         j.running = True
379                         j.update_pipeline_component(event["properties"]["new_attributes"])
380                 elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled"):
381                     uuid = event["object_uuid"]
382                     try:
383                         self.cond.acquire()
384                         j = self.jobs[uuid]
385                         logger.info("Job %s (%s) is %s", j.name, uuid, event["properties"]["new_attributes"]["state"])
386                         j.done(event["properties"]["new_attributes"])
387                         self.cond.notify()
388                     finally:
389                         self.cond.release()
390
391     def get_uploaded(self):
392         return self.uploaded.copy()
393
394     def add_uploaded(self, src, pair):
395         self.uploaded[src] = pair
396
397     def arvExecutor(self, tool, job_order, input_basedir, args, **kwargs):
398         events = arvados.events.subscribe(arvados.api('v1'), [["object_uuid", "is_a", "arvados#job"]], self.on_message)
399
400         self.debug = args.debug
401
402         try:
403             self.api.collections().get(uuid=crunchrunner_pdh).execute()
404         except arvados.errors.ApiError as e:
405             import httplib2
406             h = httplib2.Http(ca_certs=arvados.util.ca_certs_path())
407             resp, content = h.request(crunchrunner_download, "GET")
408             resp2, content2 = h.request(certs_download, "GET")
409             with arvados.collection.Collection() as col:
410                 with col.open("crunchrunner", "w") as f:
411                     f.write(content)
412                 with col.open("ca-certificates.crt", "w") as f:
413                     f.write(content2)
414
415                 col.save_new("crunchrunner binary", ensure_unique_name=True)
416
417         self.fs_access = CollectionFsAccess(input_basedir)
418
419         kwargs["fs_access"] = self.fs_access
420         kwargs["enable_reuse"] = args.enable_reuse
421
422         kwargs["outdir"] = "$(task.outdir)"
423         kwargs["tmpdir"] = "$(task.tmpdir)"
424
425         useruuid = self.api.users().current().execute()["uuid"]
426         self.project_uuid = args.project_uuid if args.project_uuid else useruuid
427
428         if kwargs.get("conformance_test"):
429             return cwltool.main.single_job_executor(tool, job_order, input_basedir, args, **kwargs)
430         else:
431             self.pipeline = self.api.pipeline_instances().create(
432                 body={
433                     "owner_uuid": self.project_uuid,
434                     "name": shortname(tool.tool["id"]),
435                     "components": {},
436                     "state": "RunningOnClient"}).execute(num_retries=self.num_retries)
437
438             logger.info("Pipeline instance %s", self.pipeline["uuid"])
439
440             jobiter = tool.job(job_order,
441                                input_basedir,
442                                self.output_callback,
443                                docker_outdir="$(task.outdir)",
444                                **kwargs)
445
446             try:
447                 self.cond.acquire()
448                 # Will continue to hold the lock for the duration of this code
449                 # except when in cond.wait(), at which point on_message can update
450                 # job state and process output callbacks.
451
452                 for runnable in jobiter:
453                     if runnable:
454                         runnable.run(**kwargs)
455                     else:
456                         if self.jobs:
457                             self.cond.wait(1)
458                         else:
459                             logger.error("Workflow is deadlocked, no runnable jobs and not waiting on any pending jobs.")
460                             break
461
462                 while self.jobs:
463                     self.cond.wait(1)
464
465                 events.close()
466
467                 if self.final_output is None:
468                     raise cwltool.workflow.WorkflowException("Workflow did not return a result.")
469
470                 # create final output collection
471             except:
472                 if sys.exc_info()[0] is KeyboardInterrupt:
473                     logger.error("Interrupted, marking pipeline as failed")
474                 else:
475                     logger.error("Caught unhandled exception, marking pipeline as failed.  Error was: %s", sys.exc_info()[0], exc_info=(sys.exc_info()[1] if self.debug else False))
476                 self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
477                                                      body={"state": "Failed"}).execute(num_retries=self.num_retries)
478             finally:
479                 self.cond.release()
480
481             return self.final_output
482
483
484 def main(args, stdout, stderr, api_client=None):
485     args.insert(0, "--leave-outputs")
486     parser = cwltool.main.arg_parser()
487     exgroup = parser.add_mutually_exclusive_group()
488     exgroup.add_argument("--enable-reuse", action="store_true",
489                         default=True, dest="enable_reuse",
490                         help="")
491     exgroup.add_argument("--disable-reuse", action="store_false",
492                         default=True, dest="enable_reuse",
493                         help="")
494     parser.add_argument("--project-uuid", type=str, help="Project that will own the workflow jobs")
495
496     try:
497         runner = ArvCwlRunner(api_client=arvados.api('v1', model=OrderedJsonModel()))
498     except Exception as e:
499         logger.error(e)
500         return 1
501
502     return cwltool.main.main(args, executor=runner.arvExecutor, makeTool=runner.arvMakeTool, parser=parser)