16535: Merge branch 'master'
[arvados.git] / sdk / cwl / arvados_cwl / task_queue.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import range
8 from builtins import object
9
10 import queue
11 import threading
12 import logging
13
14 logger = logging.getLogger('arvados.cwl-runner')
15
16 class TaskQueue(object):
17     def __init__(self, lock, thread_count):
18         self.thread_count = thread_count
19         self.task_queue = queue.Queue(maxsize=self.thread_count)
20         self.task_queue_threads = []
21         self.lock = lock
22         self.in_flight = 0
23         self.error = None
24
25         for r in range(0, self.thread_count):
26             t = threading.Thread(target=self.task_queue_func)
27             self.task_queue_threads.append(t)
28             t.start()
29
30     def task_queue_func(self):
31         while True:
32             task = self.task_queue.get()
33             if task is None:
34                 return
35             try:
36                 task()
37             except Exception as e:
38                 logger.exception("Unhandled exception running task")
39                 self.error = e
40
41             with self.lock:
42                 self.in_flight -= 1
43
44     def add(self, task, unlock, check_done):
45         if self.thread_count > 1:
46             with self.lock:
47                 self.in_flight += 1
48         else:
49             task()
50             return
51
52         while True:
53             try:
54                 unlock.release()
55                 if check_done.is_set():
56                     return
57                 self.task_queue.put(task, block=True, timeout=3)
58                 return
59             except queue.Full:
60                 pass
61             finally:
62                 unlock.acquire()
63
64
65     def drain(self):
66         try:
67             # Drain queue
68             while not self.task_queue.empty():
69                 self.task_queue.get(True, .1)
70         except queue.Empty:
71             pass
72
73     def join(self):
74         for t in self.task_queue_threads:
75             self.task_queue.put(None)
76         for t in self.task_queue_threads:
77             t.join()