1c233fac0ad98f4b0421a4e0856b00fd19d1422f
[arvados.git] / sdk / cwl / arvados_cwl / task_queue.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import Queue
6 import threading
7 import logging
8
9 logger = logging.getLogger('arvados.cwl-runner')
10
11 class TaskQueue(object):
12     def __init__(self, lock, thread_count):
13         self.thread_count = thread_count
14         self.task_queue = Queue.Queue(maxsize=self.thread_count)
15         self.task_queue_threads = []
16         self.lock = lock
17         self.in_flight = 0
18         self.error = None
19
20         for r in xrange(0, self.thread_count):
21             t = threading.Thread(target=self.task_queue_func)
22             self.task_queue_threads.append(t)
23             t.start()
24
25     def task_queue_func(self):
26         while True:
27             task = self.task_queue.get()
28             if task is None:
29                 return
30             try:
31                 task()
32             except Exception as e:
33                 logger.exception("Unhandled exception running task")
34                 self.error = e
35
36             with self.lock:
37                 self.in_flight -= 1
38
39     def add(self, task, unlock, check_done):
40         if self.thread_count > 1:
41             with self.lock:
42                 self.in_flight += 1
43         else:
44             task()
45             return
46
47         while True:
48             try:
49                 unlock.release()
50                 if check_done.is_set():
51                     return
52                 self.task_queue.put(task, block=True, timeout=3)
53                 return
54             except Queue.Full:
55                 pass
56             finally:
57                 unlock.acquire()
58
59
60     def drain(self):
61         try:
62             # Drain queue
63             while not self.task_queue.empty():
64                 self.task_queue.get(True, .1)
65         except Queue.Empty:
66             pass
67
68     def join(self):
69         for t in self.task_queue_threads:
70             self.task_queue.put(None)
71         for t in self.task_queue_threads:
72             t.join()