9701: Merge branch 'master' into 9701-collection-pack-small-files-alt
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
index ef4f78552989764a8a1e02234939021c1f44415c..c90f8902684304b400cc7ece97068a8e6b094000 100644 (file)
@@ -24,6 +24,7 @@ import arvados.config
 
 from .arvcontainer import ArvadosContainer, RunnerContainer
 from .arvjob import ArvadosJob, RunnerJob, RunnerTemplate
+from. runner import Runner
 from .arvtool import ArvadosCommandTool
 from .arvworkflow import ArvadosWorkflow, upload_workflow
 from .fsaccess import CollectionFsAccess
@@ -31,7 +32,7 @@ from .perf import Perf
 from .pathmapper import FinalOutputPathMapper
 
 from cwltool.pack import pack
-from cwltool.process import shortname, UnsupportedRequirement
+from cwltool.process import shortname, UnsupportedRequirement, getListing
 from cwltool.pathmapper import adjustFileObjs, adjustDirObjs
 from cwltool.draft2tool import compute_checksums
 from arvados.api import OrderedJsonModel
@@ -47,7 +48,7 @@ class ArvCwlRunner(object):
 
     """
 
-    def __init__(self, api_client, work_api=None):
+    def __init__(self, api_client, work_api=None, keep_client=None, output_name=None):
         self.api = api_client
         self.processes = {}
         self.lock = threading.Lock()
@@ -62,6 +63,11 @@ class ArvCwlRunner(object):
         self.poll_api = None
         self.pipeline = None
         self.final_output_collection = None
+        self.output_name = output_name
+        if keep_client is not None:
+            self.keep_client = keep_client
+        else:
+            self.keep_client = arvados.keep.KeepClient(api_client=self.api, num_retries=self.num_retries)
 
         if self.work_api is None:
             # todo: autodetect API to use.
@@ -178,17 +184,24 @@ class ArvCwlRunner(object):
 
         generatemapper = FinalOutputPathMapper(files, "", "", separateDirs=False)
 
-        final = arvados.collection.Collection()
+        final = arvados.collection.Collection(api_client=self.api,
+                                              keep_client=self.keep_client,
+                                              num_retries=self.num_retries)
 
         srccollections = {}
         for k,v in generatemapper.items():
             sp = k.split("/")
             srccollection = sp[0][5:]
             if srccollection not in srccollections:
-                srccollections[srccollection] = arvados.collection.CollectionReader(srccollection)
+                srccollections[srccollection] = arvados.collection.CollectionReader(
+                    srccollection,
+                    api_client=self.api,
+                    keep_client=self.keep_client,
+                    num_retries=self.num_retries)
             reader = srccollections[srccollection]
             try:
-                final.copy("/".join(sp[1:]), v.target, source_collection=reader, overwrite=False)
+                srcpath = "/".join(sp[1:]) if len(sp) > 1 else "."
+                final.copy(srcpath, v.target, source_collection=reader, overwrite=False)
             except IOError as e:
                 logger.warn("While preparing output collection: %s", e)
 
@@ -202,11 +215,13 @@ class ArvCwlRunner(object):
         adjustFileObjs(outputObj, rewrite)
 
         with final.open("cwl.output.json", "w") as f:
-            json.dump(outputObj, f, sort_keys=True, indent=4)
+            json.dump(outputObj, f, sort_keys=True, indent=4, separators=(',',': '))
 
         final.save_new(name=name, owner_uuid=self.project_uuid, ensure_unique_name=True)
 
-        logger.info("Final output collection %s (%s)", final.portable_data_hash(), final.manifest_locator())
+        logger.info("Final output collection %s \"%s\" (%s)", final.portable_data_hash(),
+                    final.api_response()["name"],
+                    final.manifest_locator())
 
         self.final_output_collection = final
 
@@ -218,7 +233,9 @@ class ArvCwlRunner(object):
         useruuid = self.api.users().current().execute()["uuid"]
         self.project_uuid = kwargs.get("project_uuid") if kwargs.get("project_uuid") else useruuid
         self.pipeline = None
-        make_fs_access = kwargs.get("make_fs_access") or partial(CollectionFsAccess, api_client=self.api)
+        make_fs_access = kwargs.get("make_fs_access") or partial(CollectionFsAccess,
+                                                                 api_client=self.api,
+                                                                 keep_client=self.keep_client)
         self.fs_access = make_fs_access(kwargs["basedir"])
 
         if kwargs.get("create_template"):
@@ -257,9 +274,9 @@ class ArvCwlRunner(object):
                                          self.output_callback,
                                          **kwargs).next()
                 else:
-                    runnerjob = RunnerContainer(self, tool, job_order, kwargs.get("enable_reuse"))
+                    runnerjob = RunnerContainer(self, tool, job_order, kwargs.get("enable_reuse"), self.output_name)
             else:
-                runnerjob = RunnerJob(self, tool, job_order, kwargs.get("enable_reuse"))
+                runnerjob = RunnerJob(self, tool, job_order, kwargs.get("enable_reuse"), self.output_name)
 
         if not kwargs.get("submit") and "cwl_runner_job" not in kwargs and not self.work_api == "containers":
             # Create pipeline for local run
@@ -340,13 +357,15 @@ class ArvCwlRunner(object):
         if self.final_output is None:
             raise WorkflowException("Workflow did not return a result.")
 
-        if kwargs.get("submit"):
+        if kwargs.get("submit") and isinstance(runnerjob, Runner):
             logger.info("Final output collection %s", runnerjob.final_output)
         else:
-            self.make_output_collection("Output of %s" % (shortname(tool.tool["id"])),
-                                        self.final_output)
+            if self.output_name is None:
+                self.output_name = "Output of %s" % (shortname(tool.tool["id"]))
+            self.make_output_collection(self.output_name, self.final_output)
 
         if kwargs.get("compute_checksum"):
+            adjustDirObjs(self.final_output, partial(getListing, self.fs_access))
             adjustFileObjs(self.final_output, partial(compute_checksums, self.fs_access))
 
         return self.final_output
@@ -396,6 +415,7 @@ def arg_parser():  # type: () -> argparse.ArgumentParser
                         help="")
 
     parser.add_argument("--project-uuid", type=str, metavar="UUID", help="Project that will own the workflow jobs, if not provided, will go to home project.")
+    parser.add_argument("--output-name", type=str, help="Name to use for collection that stores the final output.", default=None)
     parser.add_argument("--ignore-docker-for-reuse", action="store_true",
                         help="Ignore Docker image version when deciding whether to reuse past jobs.",
                         default=False)
@@ -433,13 +453,14 @@ def add_arv_hints():
     res = pkg_resources.resource_stream(__name__, 'arv-cwl-schema.yml')
     cache["http://arvados.org/cwl"] = res.read()
     res.close()
-    _, cwlnames, _, _ = cwltool.process.get_schema("v1.0")
+    document_loader, cwlnames, _, _ = cwltool.process.get_schema("v1.0")
     _, extnames, _, _ = schema_salad.schema.load_schema("http://arvados.org/cwl", cache=cache)
     for n in extnames.names:
         if not cwlnames.has_name("http://arvados.org/cwl#"+n, ""):
             cwlnames.add_name("http://arvados.org/cwl#"+n, "", extnames.get_name(n, ""))
+        document_loader.idx["http://arvados.org/cwl#"+n] = {}
 
-def main(args, stdout, stderr, api_client=None):
+def main(args, stdout, stderr, api_client=None, keep_client=None):
     parser = arg_parser()
 
     job_order_object = None
@@ -452,7 +473,7 @@ def main(args, stdout, stderr, api_client=None):
     try:
         if api_client is None:
             api_client=arvados.api('v1', model=OrderedJsonModel())
-        runner = ArvCwlRunner(api_client, work_api=arvargs.work_api)
+        runner = ArvCwlRunner(api_client, work_api=arvargs.work_api, keep_client=keep_client, output_name=arvargs.output_name)
     except Exception as e:
         logger.error(e)
         return 1