Merge branch '8484-sanity-check-collection-count' closes #8484
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
1 #!/usr/bin/env python
2
3 import argparse
4 import arvados
5 import arvados.events
6 import arvados.commands.keepdocker
7 import arvados.commands.run
8 import cwltool.draft2tool
9 import cwltool.workflow
10 import cwltool.main
11 from cwltool.process import shortname
12 import threading
13 import cwltool.docker
14 import fnmatch
15 import logging
16 import re
17 import os
18
19 from cwltool.process import get_feature
20
21 logger = logging.getLogger('arvados.cwl-runner')
22 logger.setLevel(logging.INFO)
23
24 def arv_docker_get_image(api_client, dockerRequirement, pull_image):
25     if "dockerImageId" not in dockerRequirement and "dockerPull" in dockerRequirement:
26         dockerRequirement["dockerImageId"] = dockerRequirement["dockerPull"]
27
28     sp = dockerRequirement["dockerImageId"].split(":")
29     image_name = sp[0]
30     image_tag = sp[1] if len(sp) > 1 else None
31
32     images = arvados.commands.keepdocker.list_images_in_arv(api_client, 3,
33                                                             image_name=image_name,
34                                                             image_tag=image_tag)
35
36     if not images:
37         imageId = cwltool.docker.get_image(dockerRequirement, pull_image)
38         args = [image_name]
39         if image_tag:
40             args.append(image_tag)
41         logger.info("Uploading Docker image %s", ":".join(args))
42         arvados.commands.keepdocker.main(args)
43
44     return dockerRequirement["dockerImageId"]
45
46
47 class CollectionFsAccess(cwltool.process.StdFsAccess):
48     def __init__(self, basedir):
49         self.collections = {}
50         self.basedir = basedir
51
52     def get_collection(self, path):
53         p = path.split("/")
54         if p[0].startswith("keep:") and arvados.util.keep_locator_pattern.match(p[0][5:]):
55             pdh = p[0][5:]
56             if pdh not in self.collections:
57                 self.collections[pdh] = arvados.collection.CollectionReader(pdh)
58             return (self.collections[pdh], "/".join(p[1:]))
59         else:
60             return (None, path)
61
62     def _match(self, collection, patternsegments, parent):
63         if not patternsegments:
64             return []
65
66         if not isinstance(collection, arvados.collection.RichCollectionBase):
67             return []
68
69         ret = []
70         # iterate over the files and subcollections in 'collection'
71         for filename in collection:
72             if patternsegments[0] == '.':
73                 # Pattern contains something like "./foo" so just shift
74                 # past the "./"
75                 ret.extend(self._match(collection, patternsegments[1:], parent))
76             elif fnmatch.fnmatch(filename, patternsegments[0]):
77                 cur = os.path.join(parent, filename)
78                 if len(patternsegments) == 1:
79                     ret.append(cur)
80                 else:
81                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
82         return ret
83
84     def glob(self, pattern):
85         collection, rest = self.get_collection(pattern)
86         patternsegments = rest.split("/")
87         return self._match(collection, patternsegments, "keep:" + collection.manifest_locator())
88
89     def open(self, fn, mode):
90         collection, rest = self.get_collection(fn)
91         if collection:
92             return collection.open(rest, mode)
93         else:
94             return open(self._abs(fn), mode)
95
96     def exists(self, fn):
97         collection, rest = self.get_collection(fn)
98         if collection:
99             return collection.exists(rest)
100         else:
101             return os.path.exists(self._abs(fn))
102
103 class ArvadosJob(object):
104     def __init__(self, runner):
105         self.arvrunner = runner
106         self.running = False
107
108     def run(self, dry_run=False, pull_image=True, **kwargs):
109         script_parameters = {
110             "command": self.command_line
111         }
112         runtime_constraints = {}
113
114         if self.generatefiles:
115             vwd = arvados.collection.Collection()
116             script_parameters["task.vwd"] = {}
117             for t in self.generatefiles:
118                 if isinstance(self.generatefiles[t], dict):
119                     src, rest = self.arvrunner.fs_access.get_collection(self.generatefiles[t]["path"].replace("$(task.keep)/", "keep:"))
120                     vwd.copy(rest, t, source_collection=src)
121                 else:
122                     with vwd.open(t, "w") as f:
123                         f.write(self.generatefiles[t])
124             vwd.save_new()
125             for t in self.generatefiles:
126                 script_parameters["task.vwd"][t] = "$(task.keep)/%s/%s" % (vwd.portable_data_hash(), t)
127
128         script_parameters["task.env"] = {"TMPDIR": "$(task.tmpdir)"}
129         if self.environment:
130             script_parameters["task.env"].update(self.environment)
131
132         if self.stdin:
133             script_parameters["task.stdin"] = self.pathmapper.mapper(self.stdin)[1]
134
135         if self.stdout:
136             script_parameters["task.stdout"] = self.stdout
137
138         (docker_req, docker_is_req) = get_feature(self, "DockerRequirement")
139         if docker_req and kwargs.get("use_container") is not False:
140             runtime_constraints["docker_image"] = arv_docker_get_image(self.arvrunner.api, docker_req, pull_image)
141
142         try:
143             response = self.arvrunner.api.jobs().create(body={
144                 "script": "crunchrunner",
145                 "repository": kwargs["repository"],
146                 "script_version": "master",
147                 "script_parameters": {"tasks": [script_parameters]},
148                 "runtime_constraints": runtime_constraints
149             }, find_or_create=kwargs.get("enable_reuse", True)).execute(num_retries=self.arvrunner.num_retries)
150
151             self.arvrunner.jobs[response["uuid"]] = self
152
153             self.arvrunner.pipeline["components"][self.name] = {"job": response}
154             self.arvrunner.pipeline = self.arvrunner.api.pipeline_instances().update(uuid=self.arvrunner.pipeline["uuid"],
155                                                                                      body={
156                                                                                          "components": self.arvrunner.pipeline["components"]
157                                                                                      }).execute(num_retries=self.arvrunner.num_retries)
158
159             logger.info("Job %s (%s) is %s", self.name, response["uuid"], response["state"])
160
161             if response["state"] in ("Complete", "Failed", "Cancelled"):
162                 self.done(response)
163         except Exception as e:
164             logger.error("Got error %s" % str(e))
165             self.output_callback({}, "permanentFail")
166
167     def update_pipeline_component(self, record):
168         self.arvrunner.pipeline["components"][self.name] = {"job": record}
169         self.arvrunner.pipeline = self.arvrunner.api.pipeline_instances().update(uuid=self.arvrunner.pipeline["uuid"],
170                                                                                  body={
171                                                                                     "components": self.arvrunner.pipeline["components"]
172                                                                                  }).execute(num_retries=self.arvrunner.num_retries)
173
174     def done(self, record):
175         try:
176             self.update_pipeline_component(record)
177         except:
178             pass
179
180         try:
181             if record["state"] == "Complete":
182                 processStatus = "success"
183             else:
184                 processStatus = "permanentFail"
185
186             try:
187                 outputs = {}
188                 if record["output"]:
189                     outputs = self.collect_outputs("keep:" + record["output"])
190             except Exception as e:
191                 logger.exception("Got exception while collecting job outputs:")
192                 processStatus = "permanentFail"
193
194             self.output_callback(outputs, processStatus)
195         finally:
196             del self.arvrunner.jobs[record["uuid"]]
197
198
199 class ArvPathMapper(cwltool.pathmapper.PathMapper):
200     def __init__(self, arvrunner, referenced_files, basedir, **kwargs):
201         self._pathmap = arvrunner.get_uploaded()
202         uploadfiles = []
203
204         pdh_path = re.compile(r'^keep:[0-9a-f]{32}\+\d+/.+')
205
206         for src in referenced_files:
207             if isinstance(src, basestring) and pdh_path.match(src):
208                 self._pathmap[src] = (src, "$(task.keep)/%s" % src[5:])
209             if src not in self._pathmap:
210                 ab = cwltool.pathmapper.abspath(src, basedir)
211                 st = arvados.commands.run.statfile("", ab, fnPattern="$(task.keep)/%s/%s")
212                 if kwargs.get("conformance_test"):
213                     self._pathmap[src] = (src, ab)
214                 elif isinstance(st, arvados.commands.run.UploadFile):
215                     uploadfiles.append((src, ab, st))
216                 elif isinstance(st, arvados.commands.run.ArvFile):
217                     self._pathmap[src] = (ab, st.fn)
218                 else:
219                     raise cwltool.workflow.WorkflowException("Input file path '%s' is invalid" % st)
220
221         if uploadfiles:
222             arvados.commands.run.uploadfiles([u[2] for u in uploadfiles],
223                                              arvrunner.api,
224                                              dry_run=kwargs.get("dry_run"),
225                                              num_retries=3,
226                                              fnPattern="$(task.keep)/%s/%s")
227
228         for src, ab, st in uploadfiles:
229             arvrunner.add_uploaded(src, (ab, st.fn))
230             self._pathmap[src] = (ab, st.fn)
231
232
233
234 class ArvadosCommandTool(cwltool.draft2tool.CommandLineTool):
235     def __init__(self, arvrunner, toolpath_object, **kwargs):
236         super(ArvadosCommandTool, self).__init__(toolpath_object, outdir="$(task.outdir)", tmpdir="$(task.tmpdir)", **kwargs)
237         self.arvrunner = arvrunner
238
239     def makeJobRunner(self):
240         return ArvadosJob(self.arvrunner)
241
242     def makePathMapper(self, reffiles, input_basedir, **kwargs):
243         return ArvPathMapper(self.arvrunner, reffiles, input_basedir, **kwargs)
244
245
246 class ArvCwlRunner(object):
247     def __init__(self, api_client):
248         self.api = api_client
249         self.jobs = {}
250         self.lock = threading.Lock()
251         self.cond = threading.Condition(self.lock)
252         self.final_output = None
253         self.uploaded = {}
254         self.num_retries = 4
255
256     def arvMakeTool(self, toolpath_object, **kwargs):
257         if "class" in toolpath_object and toolpath_object["class"] == "CommandLineTool":
258             return ArvadosCommandTool(self, toolpath_object, **kwargs)
259         else:
260             return cwltool.workflow.defaultMakeTool(toolpath_object, **kwargs)
261
262     def output_callback(self, out, processStatus):
263         if processStatus == "success":
264             logger.info("Overall job status is %s", processStatus)
265             self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
266                                                  body={"state": "Complete"}).execute(num_retries=self.num_retries)
267
268         else:
269             logger.warn("Overall job status is %s", processStatus)
270             self.api.pipeline_instances().update(uuid=self.pipeline["uuid"],
271                                                  body={"state": "Failed"}).execute(num_retries=self.num_retries)
272         self.final_output = out
273
274
275     def on_message(self, event):
276         if "object_uuid" in event:
277                 if event["object_uuid"] in self.jobs and event["event_type"] == "update":
278                     if event["properties"]["new_attributes"]["state"] == "Running" and self.jobs[event["object_uuid"]].running is False:
279                         uuid = event["object_uuid"]
280                         with self.lock:
281                             j = self.jobs[uuid]
282                             logger.info("Job %s (%s) is Running", j.name, uuid)
283                             j.running = True
284                             j.update_pipeline_component(event["properties"]["new_attributes"])
285                     elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled"):
286                         uuid = event["object_uuid"]
287                         try:
288                             self.cond.acquire()
289                             j = self.jobs[uuid]
290                             logger.info("Job %s (%s) is %s", j.name, uuid, event["properties"]["new_attributes"]["state"])
291                             j.done(event["properties"]["new_attributes"])
292                             self.cond.notify()
293                         finally:
294                             self.cond.release()
295
296     def get_uploaded(self):
297         return self.uploaded.copy()
298
299     def add_uploaded(self, src, pair):
300         self.uploaded[src] = pair
301
302     def arvExecutor(self, tool, job_order, input_basedir, args, **kwargs):
303         events = arvados.events.subscribe(arvados.api('v1'), [["object_uuid", "is_a", "arvados#job"]], self.on_message)
304
305         self.pipeline = self.api.pipeline_instances().create(body={"name": shortname(tool.tool["id"]),
306                                                                    "components": {},
307                                                                    "state": "RunningOnClient"}).execute(num_retries=self.num_retries)
308
309         self.fs_access = CollectionFsAccess(input_basedir)
310
311         kwargs["fs_access"] = self.fs_access
312         kwargs["enable_reuse"] = args.enable_reuse
313         kwargs["repository"] = args.repository
314
315         if kwargs.get("conformance_test"):
316             return cwltool.main.single_job_executor(tool, job_order, input_basedir, args, **kwargs)
317         else:
318             jobiter = tool.job(job_order,
319                             input_basedir,
320                             self.output_callback,
321                             **kwargs)
322
323             for runnable in jobiter:
324                 if runnable:
325                     with self.lock:
326                         runnable.run(**kwargs)
327                 else:
328                     if self.jobs:
329                         try:
330                             self.cond.acquire()
331                             self.cond.wait()
332                         finally:
333                             self.cond.release()
334                     else:
335                         logger.error("Workflow cannot make any more progress.")
336                         break
337
338             while self.jobs:
339                 try:
340                     self.cond.acquire()
341                     self.cond.wait()
342                 finally:
343                     self.cond.release()
344
345             events.close()
346
347             if self.final_output is None:
348                 raise cwltool.workflow.WorkflowException("Workflow did not return a result.")
349
350             return self.final_output
351
352
353 def main(args, stdout, stderr, api_client=None):
354     runner = ArvCwlRunner(api_client=arvados.api('v1'))
355     args.insert(0, "--leave-outputs")
356     parser = cwltool.main.arg_parser()
357     exgroup = parser.add_mutually_exclusive_group()
358     exgroup.add_argument("--enable-reuse", action="store_true",
359                         default=False, dest="enable_reuse",
360                         help="")
361     exgroup.add_argument("--disable-reuse", action="store_false",
362                         default=False, dest="enable_reuse",
363                         help="")
364
365     parser.add_argument('--repository', type=str, default="peter/crunchrunner", help="Repository containing the 'crunchrunner' program.")
366
367     return cwltool.main.main(args, executor=runner.arvExecutor, makeTool=runner.arvMakeTool, parser=parser)