13306: Changes to arvados-cwl-runner code after running futurize --stage2
[arvados.git] / sdk / cwl / arvados_cwl / task_queue.py
index cc3e86e63633b0041399defc86070998d7e32f25..f7d7c071ce11cb1355debd4442a1586175aa7dc5 100644 (file)
@@ -1,4 +1,12 @@
-import Queue
+from future import standard_library
+standard_library.install_aliases()
+from builtins import range
+from builtins import object
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import queue
 import threading
 import logging
 
@@ -7,46 +15,58 @@ logger = logging.getLogger('arvados.cwl-runner')
 class TaskQueue(object):
     def __init__(self, lock, thread_count):
         self.thread_count = thread_count
-        self.task_queue = Queue.Queue()
+        self.task_queue = queue.Queue(maxsize=self.thread_count)
         self.task_queue_threads = []
         self.lock = lock
         self.in_flight = 0
         self.error = None
 
-        for r in xrange(0, self.thread_count):
+        for r in range(0, self.thread_count):
             t = threading.Thread(target=self.task_queue_func)
             self.task_queue_threads.append(t)
             t.start()
 
     def task_queue_func(self):
+        while True:
+            task = self.task_queue.get()
+            if task is None:
+                return
+            try:
+                task()
+            except Exception as e:
+                logger.exception("Unhandled exception running task")
+                self.error = e
 
-            while True:
-                task = self.task_queue.get()
-                if task is None:
-                    return
-                try:
-                    task()
-                except Exception as e:
-                    logger.exception("Unexpected error running task")
-                    self.error = e
-
-                with self.lock:
-                    self.in_flight -= 1
-
-    def add(self, task):
-        with self.lock:
-            if self.thread_count > 1:
+            with self.lock:
+                self.in_flight -= 1
+
+    def add(self, task, unlock, check_done):
+        if self.thread_count > 1:
+            with self.lock:
                 self.in_flight += 1
-                self.task_queue.put(task)
-            else:
-                task()
+        else:
+            task()
+            return
+
+        while True:
+            try:
+                unlock.release()
+                if check_done.is_set():
+                    return
+                self.task_queue.put(task, block=True, timeout=3)
+                return
+            except queue.Full:
+                pass
+            finally:
+                unlock.acquire()
+
 
     def drain(self):
         try:
             # Drain queue
             while not self.task_queue.empty():
-                self.task_queue.get()
-        except Queue.Empty:
+                self.task_queue.get(True, .1)
+        except queue.Empty:
             pass
 
     def join(self):