13108: Add test for taskqueue
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
index 7affade0734536fb3e7ee241bed9357a995eb949..c2f43fe368b47c1c75447195feb2ed96d27b39d6 100644 (file)
@@ -66,7 +66,8 @@ class ArvCwlRunner(object):
     """
 
     def __init__(self, api_client, work_api=None, keep_client=None,
-                 output_name=None, output_tags=None, num_retries=4):
+                 output_name=None, output_tags=None, num_retries=4,
+                 thread_count=4):
         self.api = api_client
         self.processes = {}
         self.workflow_eval_lock = threading.Condition(threading.RLock())
@@ -85,7 +86,7 @@ class ArvCwlRunner(object):
         self.intermediate_output_ttl = 0
         self.intermediate_output_collections = []
         self.trash_intermediate = False
-        self.thread_count = 4
+        self.thread_count = thread_count
         self.poll_interval = 12
 
         if keep_client is not None:
@@ -165,21 +166,20 @@ class ArvCwlRunner(object):
         return partial(self.wrapped_callback, cb)
 
     def on_message(self, event):
-        if "object_uuid" in event:
-            if event["object_uuid"] in self.processes and event["event_type"] == "update":
-                uuid = event["object_uuid"]
-                if event["properties"]["new_attributes"]["state"] == "Running":
-                    with self.workflow_eval_lock:
-                        j = self.processes[uuid]
-                        if j.running is False:
-                            j.running = True
-                            j.update_pipeline_component(event["properties"]["new_attributes"])
-                            logger.info("%s %s is Running", self.label(j), uuid)
-                elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled", "Final"):
-                    with self.workflow_eval_lock:
-                        j = self.processes[uuid]
-                    self.task_queue.add(partial(j.done, event["properties"]["new_attributes"]))
-                    logger.info("%s %s is %s", self.label(j), uuid, event["properties"]["new_attributes"]["state"])
+        if event.get("object_uuid") in self.processes and event["event_type"] == "update":
+            uuid = event["object_uuid"]
+            if event["properties"]["new_attributes"]["state"] == "Running":
+                with self.workflow_eval_lock:
+                    j = self.processes[uuid]
+                    if j.running is False:
+                        j.running = True
+                        j.update_pipeline_component(event["properties"]["new_attributes"])
+                        logger.info("%s %s is Running", self.label(j), uuid)
+            elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled", "Final"):
+                with self.workflow_eval_lock:
+                    j = self.processes[uuid]
+                self.task_queue.add(partial(j.done, event["properties"]["new_attributes"]))
+                logger.info("%s %s is %s", self.label(j), uuid, event["properties"]["new_attributes"]["state"])
 
     def label(self, obj):
         return "[%s %s]" % (self.work_api[0:-1], obj.name)
@@ -386,7 +386,6 @@ class ArvCwlRunner(object):
                                                                  collection_cache=self.collection_cache)
         self.fs_access = make_fs_access(kwargs["basedir"])
         self.secret_store = kwargs.get("secret_store")
-        self.thread_count = kwargs.get("thread_count", 4)
 
         self.trash_intermediate = kwargs["trash_intermediate"]
         if self.trash_intermediate and self.work_api != "containers":
@@ -551,7 +550,7 @@ class ArvCwlRunner(object):
                     if (self.task_queue.in_flight + len(self.processes)) > 0:
                         self.workflow_eval_lock.wait(3)
                     else:
-                        logger.error("Workflow is deadlocked, no runnable jobs and not waiting on any pendingjobs.")
+                        logger.error("Workflow is deadlocked, no runnable processes and not waiting on any pending processes.")
                         break
                 loopperf.__enter__()
             loopperf.__exit__()
@@ -794,7 +793,8 @@ def main(args, stdout, stderr, api_client=None, keep_client=None):
             keep_client = arvados.keep.KeepClient(api_client=api_client, num_retries=4)
         runner = ArvCwlRunner(api_client, work_api=arvargs.work_api, keep_client=keep_client,
                               num_retries=4, output_name=arvargs.output_name,
-                              output_tags=arvargs.output_tags)
+                              output_tags=arvargs.output_tags,
+                              thread_count=arvargs.thread_count)
     except Exception as e:
         logger.error(e)
         return 1