13108: Refactor task queue into its own class.
[arvados.git] / sdk / cwl / arvados_cwl / __init__.py
index 73901555acc4a3cc115fe9187abca80c34cbc75d..7affade0734536fb3e7ee241bed9357a995eb949 100644 (file)
@@ -18,6 +18,7 @@ import re
 from functools import partial
 import pkg_resources  # part of setuptools
 import Queue
 from functools import partial
 import pkg_resources  # part of setuptools
 import Queue
+import time
 
 from cwltool.errors import WorkflowException
 import cwltool.main
 
 from cwltool.errors import WorkflowException
 import cwltool.main
@@ -39,6 +40,7 @@ from .arvworkflow import ArvadosWorkflow, upload_workflow
 from .fsaccess import CollectionFsAccess, CollectionFetcher, collectionResolver, CollectionCache
 from .perf import Perf
 from .pathmapper import NoFollowPathMapper
 from .fsaccess import CollectionFsAccess, CollectionFetcher, collectionResolver, CollectionCache
 from .perf import Perf
 from .pathmapper import NoFollowPathMapper
+from .task_queue import TaskQueue
 from ._version import __version__
 
 from cwltool.pack import pack
 from ._version import __version__
 
 from cwltool.pack import pack
@@ -63,10 +65,10 @@ class ArvCwlRunner(object):
 
     """
 
 
     """
 
-    def __init__(self, api_client, work_api=None, keep_client=None, output_name=None, output_tags=None, num_retries=4):
+    def __init__(self, api_client, work_api=None, keep_client=None,
+                 output_name=None, output_tags=None, num_retries=4):
         self.api = api_client
         self.processes = {}
         self.api = api_client
         self.processes = {}
-        self.in_flight = 0
         self.workflow_eval_lock = threading.Condition(threading.RLock())
         self.final_output = None
         self.final_status = None
         self.workflow_eval_lock = threading.Condition(threading.RLock())
         self.final_output = None
         self.final_status = None
@@ -83,7 +85,8 @@ class ArvCwlRunner(object):
         self.intermediate_output_ttl = 0
         self.intermediate_output_collections = []
         self.trash_intermediate = False
         self.intermediate_output_ttl = 0
         self.intermediate_output_collections = []
         self.trash_intermediate = False
-        self.runnable_queue = Queue.Queue()
+        self.thread_count = 4
+        self.poll_interval = 12
 
         if keep_client is not None:
             self.keep_client = keep_client
 
         if keep_client is not None:
             self.keep_client = keep_client
@@ -138,21 +141,15 @@ class ArvCwlRunner(object):
                                                          body={"state": "Failed"}).execute(num_retries=self.num_retries)
             self.final_status = processStatus
             self.final_output = out
                                                          body={"state": "Failed"}).execute(num_retries=self.num_retries)
             self.final_status = processStatus
             self.final_output = out
+            self.workflow_eval_lock.notifyAll()
 
 
-    def runnable_queue_thread(self):
-        while True:
-            runnable, kwargs = self.runnable_queue.get()
-            runnable.run(**kwargs)
 
     def start_run(self, runnable, kwargs):
 
     def start_run(self, runnable, kwargs):
-        with self.workflow_eval_lock:
-            self.in_flight += 1
-        self.runnable_queue.put((runnable, kwargs))
+        self.task_queue.add(partial(runnable.run, **kwargs))
 
     def process_submitted(self, container):
         with self.workflow_eval_lock:
             self.processes[container.uuid] = container
 
     def process_submitted(self, container):
         with self.workflow_eval_lock:
             self.processes[container.uuid] = container
-            self.in_flight -= 1
 
     def process_done(self, uuid):
         with self.workflow_eval_lock:
 
     def process_done(self, uuid):
         with self.workflow_eval_lock:
@@ -162,6 +159,7 @@ class ArvCwlRunner(object):
     def wrapped_callback(self, cb, obj, st):
         with self.workflow_eval_lock:
             cb(obj, st)
     def wrapped_callback(self, cb, obj, st):
         with self.workflow_eval_lock:
             cb(obj, st)
+            self.workflow_eval_lock.notifyAll()
 
     def get_wrapped_callback(self, cb):
         return partial(self.wrapped_callback, cb)
 
     def get_wrapped_callback(self, cb):
         return partial(self.wrapped_callback, cb)
@@ -169,21 +167,19 @@ class ArvCwlRunner(object):
     def on_message(self, event):
         if "object_uuid" in event:
             if event["object_uuid"] in self.processes and event["event_type"] == "update":
     def on_message(self, event):
         if "object_uuid" in event:
             if event["object_uuid"] in self.processes and event["event_type"] == "update":
-                if event["properties"]["new_attributes"]["state"] == "Running" and self.processes[event["object_uuid"]].running is False:
-                    uuid = event["object_uuid"]
+                uuid = event["object_uuid"]
+                if event["properties"]["new_attributes"]["state"] == "Running":
                     with self.workflow_eval_lock:
                         j = self.processes[uuid]
                     with self.workflow_eval_lock:
                         j = self.processes[uuid]
-                        logger.info("%s %s is Running", self.label(j), uuid)
-                        j.running = True
-                        j.update_pipeline_component(event["properties"]["new_attributes"])
+                        if j.running is False:
+                            j.running = True
+                            j.update_pipeline_component(event["properties"]["new_attributes"])
+                            logger.info("%s %s is Running", self.label(j), uuid)
                 elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled", "Final"):
                 elif event["properties"]["new_attributes"]["state"] in ("Complete", "Failed", "Cancelled", "Final"):
-                    uuid = event["object_uuid"]
                     with self.workflow_eval_lock:
                         j = self.processes[uuid]
                     with self.workflow_eval_lock:
                         j = self.processes[uuid]
-                        logger.info("%s %s is %s", self.label(j), uuid, event["properties"]["new_attributes"]["state"])
-                        with Perf(metrics, "done %s" % j.name):
-                            j.done(event["properties"]["new_attributes"])
-                        self.workflow_eval_lock.notify()
+                    self.task_queue.add(partial(j.done, event["properties"]["new_attributes"]))
+                    logger.info("%s %s is %s", self.label(j), uuid, event["properties"]["new_attributes"]["state"])
 
     def label(self, obj):
         return "[%s %s]" % (self.work_api[0:-1], obj.name)
 
     def label(self, obj):
         return "[%s %s]" % (self.work_api[0:-1], obj.name)
@@ -195,15 +191,19 @@ class ArvCwlRunner(object):
         """
 
         try:
         """
 
         try:
+            remain_wait = self.poll_interval
             while True:
             while True:
-                self.stop_polling.wait(15)
+                if remain_wait > 0:
+                    self.stop_polling.wait(remain_wait)
                 if self.stop_polling.is_set():
                     break
                 with self.workflow_eval_lock:
                     keys = list(self.processes.keys())
                 if not keys:
                 if self.stop_polling.is_set():
                     break
                 with self.workflow_eval_lock:
                     keys = list(self.processes.keys())
                 if not keys:
+                    remain_wait = self.poll_interval
                     continue
 
                     continue
 
+                begin_poll = time.time()
                 if self.work_api == "containers":
                     table = self.poll_api.container_requests()
                 elif self.work_api == "jobs":
                 if self.work_api == "containers":
                     table = self.poll_api.container_requests()
                 elif self.work_api == "jobs":
@@ -213,6 +213,7 @@ class ArvCwlRunner(object):
                     proc_states = table.list(filters=[["uuid", "in", keys]]).execute(num_retries=self.num_retries)
                 except Exception as e:
                     logger.warn("Error checking states on API server: %s", e)
                     proc_states = table.list(filters=[["uuid", "in", keys]]).execute(num_retries=self.num_retries)
                 except Exception as e:
                     logger.warn("Error checking states on API server: %s", e)
+                    remain_wait = self.poll_interval
                     continue
 
                 for p in proc_states["items"]:
                     continue
 
                 for p in proc_states["items"]:
@@ -223,11 +224,13 @@ class ArvCwlRunner(object):
                             "new_attributes": p
                         }
                     })
                             "new_attributes": p
                         }
                     })
+                finish_poll = time.time()
+                remain_wait = self.poll_interval - (finish_poll - begin_poll)
         except:
         except:
-            logger.error("Fatal error in state polling thread.", exc_info=(sys.exc_info()[1] if self.debug else False))
-            with workflow_eval_lock:
+            logger.exception("Fatal error in state polling thread.")
+            with self.workflow_eval_lock:
                 self.processes.clear()
                 self.processes.clear()
-                self.workflow_eval_lock.notify()
+                self.workflow_eval_lock.notifyAll()
         finally:
             self.stop_polling.set()
 
         finally:
             self.stop_polling.set()
 
@@ -383,6 +386,7 @@ class ArvCwlRunner(object):
                                                                  collection_cache=self.collection_cache)
         self.fs_access = make_fs_access(kwargs["basedir"])
         self.secret_store = kwargs.get("secret_store")
                                                                  collection_cache=self.collection_cache)
         self.fs_access = make_fs_access(kwargs["basedir"])
         self.secret_store = kwargs.get("secret_store")
+        self.thread_count = kwargs.get("thread_count", 4)
 
         self.trash_intermediate = kwargs["trash_intermediate"]
         if self.trash_intermediate and self.work_api != "containers":
 
         self.trash_intermediate = kwargs["trash_intermediate"]
         if self.trash_intermediate and self.work_api != "containers":
@@ -504,14 +508,14 @@ class ArvCwlRunner(object):
             logger.info("Pipeline instance %s", self.pipeline["uuid"])
 
         if runnerjob and not kwargs.get("wait"):
             logger.info("Pipeline instance %s", self.pipeline["uuid"])
 
         if runnerjob and not kwargs.get("wait"):
-            runnerjob.run(wait=kwargs.get("wait"))
+            runnerjob.run(**kwargs)
             return (runnerjob.uuid, "success")
 
         self.poll_api = arvados.api('v1')
         self.polling_thread = threading.Thread(target=self.poll_states)
         self.polling_thread.start()
 
             return (runnerjob.uuid, "success")
 
         self.poll_api = arvados.api('v1')
         self.polling_thread = threading.Thread(target=self.poll_states)
         self.polling_thread.start()
 
-        threading.Thread(target=self.runnable_queue_thread).start()
+        self.task_queue = TaskQueue(self.workflow_eval_lock, self.thread_count)
 
         if runnerjob:
             jobiter = iter((runnerjob,))
 
         if runnerjob:
             jobiter = iter((runnerjob,))
@@ -537,20 +541,25 @@ class ArvCwlRunner(object):
                 if self.stop_polling.is_set():
                     break
 
                 if self.stop_polling.is_set():
                     break
 
+                if self.task_queue.error is not None:
+                    raise self.task_queue.error
+
                 if runnable:
                     with Perf(metrics, "run"):
                         self.start_run(runnable, kwargs)
                 else:
                 if runnable:
                     with Perf(metrics, "run"):
                         self.start_run(runnable, kwargs)
                 else:
-                    if (self.in_flight + len(self.processes)) > 0:
-                        self.workflow_eval_lock.wait(1)
+                    if (self.task_queue.in_flight + len(self.processes)) > 0:
+                        self.workflow_eval_lock.wait(3)
                     else:
                     else:
-                        logger.error("Workflow is deadlocked, no runnable jobs and not waiting on any pending jobs.")
+                        logger.error("Workflow is deadlocked, no runnable jobs and not waiting on any pendingjobs.")
                         break
                 loopperf.__enter__()
             loopperf.__exit__()
 
                         break
                 loopperf.__enter__()
             loopperf.__exit__()
 
-            while self.processes:
-                self.workflow_eval_lock.wait(1)
+            while (self.task_queue.in_flight + len(self.processes)) > 0:
+                if self.task_queue.error is not None:
+                    raise self.task_queue.error
+                self.workflow_eval_lock.wait(3)
 
         except UnsupportedRequirement:
             raise
 
         except UnsupportedRequirement:
             raise
@@ -567,8 +576,10 @@ class ArvCwlRunner(object):
                                                      body={"priority": "0"}).execute(num_retries=self.num_retries)
         finally:
             self.workflow_eval_lock.release()
                                                      body={"priority": "0"}).execute(num_retries=self.num_retries)
         finally:
             self.workflow_eval_lock.release()
+            self.task_queue.drain()
             self.stop_polling.set()
             self.polling_thread.join()
             self.stop_polling.set()
             self.polling_thread.join()
+            self.task_queue.join()
 
         if self.final_status == "UnsupportedRequirement":
             raise UnsupportedRequirement("Check log for details.")
 
         if self.final_status == "UnsupportedRequirement":
             raise UnsupportedRequirement("Check log for details.")
@@ -718,6 +729,9 @@ def arg_parser():  # type: () -> argparse.ArgumentParser
                         action="store_true", default=False,
                         help=argparse.SUPPRESS)
 
                         action="store_true", default=False,
                         help=argparse.SUPPRESS)
 
+    parser.add_argument("--thread-count", type=int,
+                        default=4, help="Number of threads to use for job submit and output collection.")
+
     exgroup = parser.add_mutually_exclusive_group()
     exgroup.add_argument("--trash-intermediate", action="store_true",
                         default=False, dest="trash_intermediate",
     exgroup = parser.add_mutually_exclusive_group()
     exgroup.add_argument("--trash-intermediate", action="store_true",
                         default=False, dest="trash_intermediate",