13108: Refactor task queue into its own class.
authorPeter Amstutz <pamstutz@veritasgenetics.com>
Fri, 6 Apr 2018 15:32:03 +0000 (11:32 -0400)
committerPeter Amstutz <pamstutz@veritasgenetics.com>
Fri, 6 Apr 2018 17:04:14 +0000 (13:04 -0400)
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz@veritasgenetics.com>

sdk/cwl/arvados_cwl/__init__.py
sdk/cwl/arvados_cwl/task_queue.py [new file with mode: 0644]

index 16f1bf473a34a09530636b2ab7f957a3add36672..7affade0734536fb3e7ee241bed9357a995eb949 100644 (file)
@@ -40,6 +40,7 @@ from .arvworkflow import ArvadosWorkflow, upload_workflow
 from .fsaccess import CollectionFsAccess, CollectionFetcher, collectionResolver, CollectionCache
 from .perf import Perf
 from .pathmapper import NoFollowPathMapper
+from .task_queue import TaskQueue
 from ._version import __version__
 
 from cwltool.pack import pack
@@ -65,11 +66,9 @@ class ArvCwlRunner(object):
     """
 
     def __init__(self, api_client, work_api=None, keep_client=None,
-                 output_name=None, output_tags=None, num_retries=4,
-                 thread_count=4):
+                 output_name=None, output_tags=None, num_retries=4):
         self.api = api_client
         self.processes = {}
-        self.in_flight = 0
         self.workflow_eval_lock = threading.Condition(threading.RLock())
         self.final_output = None
         self.final_status = None
@@ -86,9 +85,7 @@ class ArvCwlRunner(object):
         self.intermediate_output_ttl = 0
         self.intermediate_output_collections = []
         self.trash_intermediate = False
-        self.task_queue = Queue.Queue()
-        self.task_queue_threads = []
-        self.thread_count = thread_count
+        self.thread_count = 4
         self.poll_interval = 12
 
         if keep_client is not None:
@@ -146,25 +143,9 @@ class ArvCwlRunner(object):
             self.final_output = out
             self.workflow_eval_lock.notifyAll()
 
-    def task_queue_func(self):
-        while True:
-            task = self.task_queue.get()
-            if task is None:
-                return
-            task()
-            with self.workflow_eval_lock:
-                self.in_flight -= 1
-
-    def task_queue_add(self, task):
-        with self.workflow_eval_lock:
-            if self.thread_count > 1:
-                self.in_flight += 1
-                self.task_queue.put(task)
-            else:
-                task()
 
     def start_run(self, runnable, kwargs):
-        self.task_queue_add(partial(runnable.run, **kwargs))
+        self.task_queue.add(partial(runnable.run, **kwargs))
 
     def process_submitted(self, container):
         with self.workflow_eval_lock:
@@ -197,7 +178,7 @@ class ArvCwlRunner(object):
                 elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled", "Final"):
                     with self.workflow_eval_lock:
                         j = self.processes[uuid]
-                    self.task_queue_add(partial(j.done, event["properties"]["new_attributes"]))
+                    self.task_queue.add(partial(j.done, event["properties"]["new_attributes"]))
                     logger.info("%s %s is %s", self.label(j), uuid, event["properties"]["new_attributes"]["state"])
 
     def label(self, obj):
@@ -405,6 +386,7 @@ class ArvCwlRunner(object):
                                                                  collection_cache=self.collection_cache)
         self.fs_access = make_fs_access(kwargs["basedir"])
         self.secret_store = kwargs.get("secret_store")
+        self.thread_count = kwargs.get("thread_count", 4)
 
         self.trash_intermediate = kwargs["trash_intermediate"]
         if self.trash_intermediate and self.work_api != "containers":
@@ -533,10 +515,7 @@ class ArvCwlRunner(object):
         self.polling_thread = threading.Thread(target=self.poll_states)
         self.polling_thread.start()
 
-        for r in xrange(0, self.thread_count):
-            t = threading.Thread(target=self.task_queue_func)
-            self.task_queue_threads.append(t)
-            t.start()
+        self.task_queue = TaskQueue(self.workflow_eval_lock, self.thread_count)
 
         if runnerjob:
             jobiter = iter((runnerjob,))
@@ -562,11 +541,14 @@ class ArvCwlRunner(object):
                 if self.stop_polling.is_set():
                     break
 
+                if self.task_queue.error is not None:
+                    raise self.task_queue.error
+
                 if runnable:
                     with Perf(metrics, "run"):
                         self.start_run(runnable, kwargs)
                 else:
-                    if (self.in_flight + len(self.processes)) > 0:
+                    if (self.task_queue.in_flight + len(self.processes)) > 0:
                         self.workflow_eval_lock.wait(3)
                     else:
                         logger.error("Workflow is deadlocked, no runnable jobs and not waiting on any pendingjobs.")
@@ -574,7 +556,9 @@ class ArvCwlRunner(object):
                 loopperf.__enter__()
             loopperf.__exit__()
 
-            while (self.in_flight + len(self.processes)) > 0:
+            while (self.task_queue.in_flight + len(self.processes)) > 0:
+                if self.task_queue.error is not None:
+                    raise self.task_queue.error
                 self.workflow_eval_lock.wait(3)
 
         except UnsupportedRequirement:
@@ -592,18 +576,10 @@ class ArvCwlRunner(object):
                                                      body={"priority": "0"}).execute(num_retries=self.num_retries)
         finally:
             self.workflow_eval_lock.release()
-            try:
-                # Drain queue
-                while not self.task_queue.empty():
-                    self.task_queue.get()
-            except Queue.Empty:
-                pass
+            self.task_queue.drain()
             self.stop_polling.set()
             self.polling_thread.join()
-            for t in self.task_queue_threads:
-                self.task_queue.put(None)
-            for t in self.task_queue_threads:
-                t.join()
+            self.task_queue.join()
 
         if self.final_status == "UnsupportedRequirement":
             raise UnsupportedRequirement("Check log for details.")
diff --git a/sdk/cwl/arvados_cwl/task_queue.py b/sdk/cwl/arvados_cwl/task_queue.py
new file mode 100644 (file)
index 0000000..cc3e86e
--- /dev/null
@@ -0,0 +1,56 @@
+import Queue
+import threading
+import logging
+
+logger = logging.getLogger('arvados.cwl-runner')
+
+class TaskQueue(object):
+    def __init__(self, lock, thread_count):
+        self.thread_count = thread_count
+        self.task_queue = Queue.Queue()
+        self.task_queue_threads = []
+        self.lock = lock
+        self.in_flight = 0
+        self.error = None
+
+        for r in xrange(0, self.thread_count):
+            t = threading.Thread(target=self.task_queue_func)
+            self.task_queue_threads.append(t)
+            t.start()
+
+    def task_queue_func(self):
+
+            while True:
+                task = self.task_queue.get()
+                if task is None:
+                    return
+                try:
+                    task()
+                except Exception as e:
+                    logger.exception("Unexpected error running task")
+                    self.error = e
+
+                with self.lock:
+                    self.in_flight -= 1
+
+    def add(self, task):
+        with self.lock:
+            if self.thread_count > 1:
+                self.in_flight += 1
+                self.task_queue.put(task)
+            else:
+                task()
+
+    def drain(self):
+        try:
+            # Drain queue
+            while not self.task_queue.empty():
+                self.task_queue.get()
+        except Queue.Empty:
+            pass
+
+    def join(self):
+        for t in self.task_queue_threads:
+            self.task_queue.put(None)
+        for t in self.task_queue_threads:
+            t.join()