Merge branch '10534-go-systemd-sdnotify-v14' of https://github.com/wtsi-hgi/arvados
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
index e11f6a12a679f0d5cf2b6bf6e6b1726d512637a8..92be92d6e0469fba63a1504a5f0c834fb4a9b2b7 100644 (file)
@@ -124,7 +124,8 @@ class ArvCwlRunner(object):
                     try:
                         self.cond.acquire()
                         j = self.processes[uuid]
-                        logger.info("Job %s (%s) is %s", j.name, uuid, event["properties"]["new_attributes"]["state"])
+                        txt = self.work_api[0].upper() + self.work_api[1:-1]
+                        logger.info("%s %s (%s) is %s", txt, j.name, uuid, event["properties"]["new_attributes"]["state"])
                         with Perf(metrics, "done %s" % j.name):
                             j.done(event["properties"]["new_attributes"])
                         self.cond.notify()
@@ -200,14 +201,28 @@ class ArvCwlRunner(object):
 
         srccollections = {}
         for k,v in generatemapper.items():
+            if k.startswith("_:"):
+                if v.type == "Directory":
+                    continue
+                if v.type == "CreateFile":
+                    with final.open(v.target, "wb") as f:
+                        f.write(v.resolved.encode("utf-8"))
+                    continue
+
+            if not k.startswith("keep:"):
+                raise Exception("Output source is not in keep or a literal")
             sp = k.split("/")
             srccollection = sp[0][5:]
             if srccollection not in srccollections:
-                srccollections[srccollection] = arvados.collection.CollectionReader(
-                    srccollection,
-                    api_client=self.api,
-                    keep_client=self.keep_client,
-                    num_retries=self.num_retries)
+                try:
+                    srccollections[srccollection] = arvados.collection.CollectionReader(
+                        srccollection,
+                        api_client=self.api,
+                        keep_client=self.keep_client,
+                        num_retries=self.num_retries)
+                except arvados.errors.ArgumentError as e:
+                    logger.error("Creating CollectionReader for '%s' '%s': %s", k, v, e)
+                    raise
             reader = srccollections[srccollection]
             try:
                 srcpath = "/".join(sp[1:]) if len(sp) > 1 else "."
@@ -217,7 +232,7 @@ class ArvCwlRunner(object):
 
         def rewrite(fileobj):
             fileobj["location"] = generatemapper.mapper(fileobj["location"]).target
-            for k in ("basename", "size", "listing"):
+            for k in ("basename", "listing", "contents"):
                 if k in fileobj:
                     del fileobj[k]
 
@@ -233,7 +248,31 @@ class ArvCwlRunner(object):
                     final.api_response()["name"],
                     final.manifest_locator())
 
-        self.final_output_collection = final
+        def finalcollection(fileobj):
+            fileobj["location"] = "keep:%s/%s" % (final.portable_data_hash(), fileobj["location"])
+
+        adjustDirObjs(outputObj, finalcollection)
+        adjustFileObjs(outputObj, finalcollection)
+
+        return (outputObj, final)
+
+    def set_crunch_output(self):
+        if self.work_api == "containers":
+            try:
+                current = self.api.containers().current().execute(num_retries=self.num_retries)
+                self.api.containers().update(uuid=current['uuid'],
+                                             body={
+                                                 'output': self.final_output_collection.portable_data_hash(),
+                                             }).execute(num_retries=self.num_retries)
+            except Exception as e:
+                logger.info("Setting container output: %s", e)
+        elif self.work_api == "jobs" and "TASK_UUID" in os.environ:
+            self.api.job_tasks().update(uuid=os.environ["TASK_UUID"],
+                                   body={
+                                       'output': self.final_output_collection.portable_data_hash(),
+                                       'success': self.final_status == "success",
+                                       'progress':1.0
+                                   }).execute(num_retries=self.num_retries)
 
     def arv_executor(self, tool, job_order, **kwargs):
         self.debug = kwargs.get("debug")
@@ -363,9 +402,6 @@ class ArvCwlRunner(object):
         if self.final_status == "UnsupportedRequirement":
             raise UnsupportedRequirement("Check log for details.")
 
-        if self.final_status != "success":
-            raise WorkflowException("Workflow failed.")
-
         if self.final_output is None:
             raise WorkflowException("Workflow did not return a result.")
 
@@ -374,7 +410,11 @@ class ArvCwlRunner(object):
         else:
             if self.output_name is None:
                 self.output_name = "Output of %s" % (shortname(tool.tool["id"]))
-            self.make_output_collection(self.output_name, self.final_output)
+            self.final_output, self.final_output_collection = self.make_output_collection(self.output_name, self.final_output)
+            self.set_crunch_output()
+
+        if self.final_status != "success":
+            raise WorkflowException("Workflow failed.")
 
         if kwargs.get("compute_checksum"):
             adjustDirObjs(self.final_output, partial(getListing, self.fs_access))
@@ -449,7 +489,7 @@ def arg_parser():  # type: () -> argparse.ArgumentParser
 
     parser.add_argument("--api", type=str,
                         default=None, dest="work_api",
-                        help="Select work submission API, one of 'jobs' or 'containers'.")
+                        help="Select work submission API, one of 'jobs' or 'containers'. Default is 'jobs' if that API is available, otherwise 'containers'.")
 
     parser.add_argument("--compute-checksum", action="store_true", default=False,
                         help="Compute checksum of contents while collecting outputs",