Update storage_classes support for arvados_cwl_runner to work correctly
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
index 6d4eda4f4790717dde489d7839973a8a0d7fa6fe..b2b93bf9e7870e553eab1ec3086f16e2f795b29b 100644 (file)
@@ -37,7 +37,7 @@ import arvados.commands._util as arv_cmd
 
 from .arvcontainer import ArvadosContainer, RunnerContainer
 from .arvjob import ArvadosJob, RunnerJob, RunnerTemplate
-from. runner import Runner, upload_docker, upload_job_order, upload_workflow_deps, upload_dependencies
+from. runner import Runner, upload_docker, upload_job_order, upload_workflow_deps
 from .arvtool import ArvadosCommandTool
 from .arvworkflow import ArvadosWorkflow, upload_workflow
 from .fsaccess import CollectionFsAccess, CollectionFetcher, collectionResolver, CollectionCache
@@ -69,14 +69,13 @@ class ArvCwlRunner(object):
     """
 
     def __init__(self, api_client, work_api=None, keep_client=None,
-                 output_name=None, output_tags=None, num_retries=4,
-                 thread_count=4):
+                 output_name=None, output_tags=None, default_storage_classes="default",
+                 num_retries=4, thread_count=4):
         self.api = api_client
         self.processes = {}
         self.workflow_eval_lock = threading.Condition(threading.RLock())
         self.final_output = None
         self.final_status = None
-        self.uploaded = {}
         self.num_retries = num_retries
         self.uuid = None
         self.stop_polling = threading.Event()
@@ -91,6 +90,7 @@ class ArvCwlRunner(object):
         self.trash_intermediate = False
         self.thread_count = thread_count
         self.poll_interval = 12
+        self.default_storage_classes = default_storage_classes
 
         if keep_client is not None:
             self.keep_client = keep_client
@@ -238,12 +238,6 @@ class ArvCwlRunner(object):
         finally:
             self.stop_polling.set()
 
-    def get_uploaded(self):
-        return self.uploaded.copy()
-
-    def add_uploaded(self, src, pair):
-        self.uploaded[src] = pair
-
     def add_intermediate_output(self, uuid):
         if uuid:
             self.intermediate_output_collections.append(uuid)
@@ -401,6 +395,9 @@ class ArvCwlRunner(object):
         if self.intermediate_output_ttl < 0:
             raise Exception("Invalid value %d for --intermediate-output-ttl, cannot be less than zero" % self.intermediate_output_ttl)
 
+        if kwargs.get("submit_request_uuid") and self.work_api != "containers":
+            raise Exception("--submit-request-uuid requires containers API, but using '{}' api".format(self.work_api))
+
         if not kwargs.get("name"):
             kwargs["name"] = self.name = tool.tool.get("label") or tool.metadata.get("label") or os.path.basename(tool.tool["id"])
 
@@ -489,6 +486,7 @@ class ArvCwlRunner(object):
                                                 submit_runner_image=kwargs.get("submit_runner_image"),
                                                 intermediate_output_ttl=kwargs.get("intermediate_output_ttl"),
                                                 merged_map=merged_map,
+                                                default_storage_classes=self.default_storage_classes,
                                                 priority=kwargs.get("priority"),
                                                 secret_store=self.secret_store)
             elif self.work_api == "jobs":
@@ -601,10 +599,7 @@ class ArvCwlRunner(object):
             if self.output_tags is None:
                 self.output_tags = ""
 
-            storage_classes = ["default"]
-            if kwargs.get("storage_classes"):
-                storage_classes = kwargs.get("storage_classes").strip().split(",")
-
+            storage_classes = kwargs.get("storage_classes").strip().split(",")
             self.final_output, self.final_output_collection = self.make_output_collection(self.output_name, storage_classes, self.output_tags, self.final_output)
             self.set_crunch_output()
 
@@ -712,6 +707,10 @@ def arg_parser():  # type: () -> argparse.ArgumentParser
                         help="Docker image for workflow runner job, default arvados/jobs:%s" % __version__,
                         default=None)
 
+    parser.add_argument("--submit-request-uuid", type=str,
+                        default=None,
+                        help="Update and commit supplied container request instead of creating a new one (containers API only).")
+
     parser.add_argument("--name", type=str,
                         help="Name to use for workflow execution instance.",
                         default=None)
@@ -723,8 +722,8 @@ def arg_parser():  # type: () -> argparse.ArgumentParser
     parser.add_argument("--enable-dev", action="store_true",
                         help="Enable loading and running development versions "
                              "of CWL spec.", default=False)
-    parser.add_argument('--storage-classes', 
-                        help="Specify comma separated list of storage classes to be used when saving wortkflow output to Keep.")
+    parser.add_argument('--storage-classes', default="default", type=str,
+                        help="Specify comma separated list of storage classes to be used when saving workflow output to Keep.")
 
     parser.add_argument("--intermediate-output-ttl", type=int, metavar="N",
                         help="If N > 0, intermediate output collections will be trashed N seconds after creation.  Default is 0 (don't trash).",
@@ -786,6 +785,10 @@ def main(args, stdout, stderr, api_client=None, keep_client=None,
     job_order_object = None
     arvargs = parser.parse_args(args)
 
+    if len(arvargs.storage_classes.strip().split(',')) > 1:
+        logger.error("Multiple storage classes are not supported currently.")
+        return 1
+
     if install_sig_handlers:
         arv_cmd.install_signal_handlers()
 
@@ -815,7 +818,7 @@ def main(args, stdout, stderr, api_client=None, keep_client=None,
             keep_client = arvados.keep.KeepClient(api_client=api_client, num_retries=4)
         runner = ArvCwlRunner(api_client, work_api=arvargs.work_api, keep_client=keep_client,
                               num_retries=4, output_name=arvargs.output_name,
-                              output_tags=arvargs.output_tags,
+                              output_tags=arvargs.output_tags, default_storage_classes=parser.get_default("storage_classes"),
                               thread_count=arvargs.thread_count)
     except Exception as e:
         logger.error(e)