Update storage_classes support for arvados_cwl_runner to work correctly
authorFuad Muhic <fmuhic@capeannenterprises.com>
Tue, 12 Jun 2018 13:10:23 +0000 (15:10 +0200)
committerFuad Muhic <fmuhic@capeannenterprises.com>
Tue, 12 Jun 2018 13:10:23 +0000 (15:10 +0200)
when arvados_cwl_runner is run in submit mode.

Arvados-DCO-1.1-Signed-off-by: Fuad Muhic <fmuhic@capeannenterprises.com>

sdk/cwl/arvados_cwl/__init__.py
sdk/cwl/arvados_cwl/arvcontainer.py
sdk/cwl/arvados_cwl/runner.py
sdk/cwl/tests/test_submit.py

index 4a6ccc0b2a331626f055d75eb1c54ac9d9a45b6a..b2b93bf9e7870e553eab1ec3086f16e2f795b29b 100644 (file)
@@ -69,8 +69,8 @@ class ArvCwlRunner(object):
     """
 
     def __init__(self, api_client, work_api=None, keep_client=None,
-                 output_name=None, output_tags=None, num_retries=4,
-                 thread_count=4):
+                 output_name=None, output_tags=None, default_storage_classes="default",
+                 num_retries=4, thread_count=4):
         self.api = api_client
         self.processes = {}
         self.workflow_eval_lock = threading.Condition(threading.RLock())
@@ -90,6 +90,7 @@ class ArvCwlRunner(object):
         self.trash_intermediate = False
         self.thread_count = thread_count
         self.poll_interval = 12
+        self.default_storage_classes = default_storage_classes
 
         if keep_client is not None:
             self.keep_client = keep_client
@@ -485,6 +486,7 @@ class ArvCwlRunner(object):
                                                 submit_runner_image=kwargs.get("submit_runner_image"),
                                                 intermediate_output_ttl=kwargs.get("intermediate_output_ttl"),
                                                 merged_map=merged_map,
+                                                default_storage_classes=self.default_storage_classes,
                                                 priority=kwargs.get("priority"),
                                                 secret_store=self.secret_store)
             elif self.work_api == "jobs":
@@ -597,7 +599,7 @@ class ArvCwlRunner(object):
             if self.output_tags is None:
                 self.output_tags = ""
 
-            storage_classes = kwargs.get("storage_classes")
+            storage_classes = kwargs.get("storage_classes").strip().split(",")
             self.final_output, self.final_output_collection = self.make_output_collection(self.output_name, storage_classes, self.output_tags, self.final_output)
             self.set_crunch_output()
 
@@ -720,7 +722,7 @@ def arg_parser():  # type: () -> argparse.ArgumentParser
     parser.add_argument("--enable-dev", action="store_true",
                         help="Enable loading and running development versions "
                              "of CWL spec.", default=False)
-    parser.add_argument('--storage-classes', default="default",
+    parser.add_argument('--storage-classes', default="default", type=str,
                         help="Specify comma separated list of storage classes to be used when saving workflow output to Keep.")
 
     parser.add_argument("--intermediate-output-ttl", type=int, metavar="N",
@@ -783,8 +785,7 @@ def main(args, stdout, stderr, api_client=None, keep_client=None,
     job_order_object = None
     arvargs = parser.parse_args(args)
 
-    arvargs.storage_classes = arvargs.storage_classes.strip().split(',')
-    if len(arvargs.storage_classes) > 1:
+    if len(arvargs.storage_classes.strip().split(',')) > 1:
         logger.error("Multiple storage classes are not supported currently.")
         return 1
 
@@ -817,7 +818,7 @@ def main(args, stdout, stderr, api_client=None, keep_client=None,
             keep_client = arvados.keep.KeepClient(api_client=api_client, num_retries=4)
         runner = ArvCwlRunner(api_client, work_api=arvargs.work_api, keep_client=keep_client,
                               num_retries=4, output_name=arvargs.output_name,
-                              output_tags=arvargs.output_tags,
+                              output_tags=arvargs.output_tags, default_storage_classes=parser.get_default("storage_classes"),
                               thread_count=arvargs.thread_count)
     except Exception as e:
         logger.error(e)
index 0bec692643ad805c02d6b8358fae8a65841c1367..05c0212fc6eeae41b2a10279f04ebd71000f1eaf 100644 (file)
@@ -427,6 +427,9 @@ class RunnerContainer(Runner):
         if kwargs.get("debug"):
             command.append("--debug")
 
+        if kwargs.get("storage_classes") and kwargs.get("storage_classes") != self.default_storage_classes:
+            command.append("--storage-classes=" + kwargs.get("storage_classes"))
+
         if self.on_error:
             command.append("--on-error=" + self.on_error)
 
index cf91f69f818cd51e721c658cd05d5a81e9df6e05..f907d33951c45a5d707bb15dd18c9154ae1b5bad 100644 (file)
@@ -353,8 +353,8 @@ class Runner(object):
     def __init__(self, runner, tool, job_order, enable_reuse,
                  output_name, output_tags, submit_runner_ram=0,
                  name=None, on_error=None, submit_runner_image=None,
-                 intermediate_output_ttl=0, merged_map=None, priority=None,
-                 secret_store=None):
+                 intermediate_output_ttl=0, merged_map=None, default_storage_classes="default",
+                 priority=None, secret_store=None):
         self.arvrunner = runner
         self.tool = tool
         self.job_order = job_order
@@ -376,6 +376,7 @@ class Runner(object):
         self.intermediate_output_ttl = intermediate_output_ttl
         self.priority = priority
         self.secret_store = secret_store
+        self.default_storage_classes = default_storage_classes
 
         if submit_runner_ram:
             self.submit_runner_ram = submit_runner_ram
index c1e9629fc8cc28e2a25efe1e81c6040e6e102506..c7ededd79639b8c435171b093c2e70ba3492de1d 100644 (file)
@@ -612,6 +612,30 @@ class TestSubmit(unittest.TestCase):
         self.assertEqual(capture_stdout.getvalue(),
                          stubs.expect_container_request_uuid + '\n')
 
+    @stubs
+    def test_submit_storage_classes(self, stubs):
+        capture_stdout = cStringIO.StringIO()
+        try:
+            exited = arvados_cwl.main(
+                ["--debug", "--submit", "--no-wait", "--api=containers", "--storage-classes=foo",
+                 "tests/wf/submit_wf.cwl", "tests/submit_test_job.json"],
+                capture_stdout, sys.stderr, api_client=stubs.api, keep_client=stubs.keep_client)
+            self.assertEqual(exited, 0)
+        except:
+            logging.exception("")
+
+        expect_container = copy.deepcopy(stubs.expect_container_spec)
+        expect_container["command"] = ['arvados-cwl-runner', '--local', '--api=containers',
+                                       '--no-log-timestamps', '--disable-validate',
+                                       '--eval-timeout=20', '--thread-count=4',
+                                       '--enable-reuse', "--debug", 
+                                       "--storage-classes=foo", '--on-error=continue',
+                                       '/var/lib/cwl/workflow.json#main', '/var/lib/cwl/cwl.input.json']
+
+        stubs.api.container_requests().create.assert_called_with(
+            body=JsonDiffMatcher(expect_container))
+        self.assertEqual(capture_stdout.getvalue(),
+                         stubs.expect_container_request_uuid + '\n')
 
     @stubs
     def test_submit_container_output_ttl(self, stubs):