Merge branch '13164-fix-zero-priority-after-race'
[arvados.git] / sdk / cwl / arvados_cwl / task_queue.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import Queue
6 import threading
7 import logging
8
9 logger = logging.getLogger('arvados.cwl-runner')
10
11 class TaskQueue(object):
12     def __init__(self, lock, thread_count):
13         self.thread_count = thread_count
14         self.task_queue = Queue.Queue()
15         self.task_queue_threads = []
16         self.lock = lock
17         self.in_flight = 0
18         self.error = None
19
20         for r in xrange(0, self.thread_count):
21             t = threading.Thread(target=self.task_queue_func)
22             self.task_queue_threads.append(t)
23             t.start()
24
25     def task_queue_func(self):
26
27             while True:
28                 task = self.task_queue.get()
29                 if task is None:
30                     return
31                 try:
32                     task()
33                 except Exception as e:
34                     logger.exception("Unhandled exception running task")
35                     self.error = e
36
37                 with self.lock:
38                     self.in_flight -= 1
39
40     def add(self, task):
41         with self.lock:
42             if self.thread_count > 1:
43                 self.in_flight += 1
44                 self.task_queue.put(task)
45             else:
46                 task()
47
48     def drain(self):
49         try:
50             # Drain queue
51             while not self.task_queue.empty():
52                 self.task_queue.get(True, .1)
53         except Queue.Empty:
54             pass
55
56     def join(self):
57         for t in self.task_queue_threads:
58             self.task_queue.put(None)
59         for t in self.task_queue_threads:
60             t.join()