Hotfix: use a recursive lock for closed_lock so that EventClient.close() can be
[arvados.git] / sdk / python / arvados / events.py
1 import arvados
2 import config
3 import errors
4
5 import logging
6 import json
7 import threading
8 import time
9 import os
10 import re
11 import ssl
12 from ws4py.client.threadedclient import WebSocketClient
13
14 _logger = logging.getLogger('arvados.events')
15
16 class EventClient(WebSocketClient):
17     def __init__(self, url, filters, on_event, last_log_id):
18         ssl_options = {'ca_certs': arvados.util.ca_certs_path()}
19         if config.flag_is_true('ARVADOS_API_HOST_INSECURE'):
20             ssl_options['cert_reqs'] = ssl.CERT_NONE
21         else:
22             ssl_options['cert_reqs'] = ssl.CERT_REQUIRED
23
24         # Warning: If the host part of url resolves to both IPv6 and
25         # IPv4 addresses (common with "localhost"), only one of them
26         # will be attempted -- and it might not be the right one. See
27         # ws4py's WebSocketBaseClient.__init__.
28         super(EventClient, self).__init__(url, ssl_options=ssl_options)
29         self.filters = filters
30         self.on_event = on_event
31         self.last_log_id = last_log_id
32         self.closed_lock = threading.RLock()
33         self.closed = False
34
35     def opened(self):
36         self.subscribe(self.filters, self.last_log_id)
37
38     def received_message(self, m):
39         with self.closed_lock:
40             if not self.closed:
41                 self.on_event(json.loads(str(m)))
42
43     def close(self, code=1000, reason=''):
44         """Close event client and wait for it to finish."""
45         super(EventClient, self).close(code, reason)
46         with self.closed_lock:
47             # make sure we don't process any more messages.
48             self.closed = True
49
50     def subscribe(self, filters, last_log_id=None):
51         m = {"method": "subscribe", "filters": filters}
52         if last_log_id is not None:
53             m["last_log_id"] = last_log_id
54         self.send(json.dumps(m))
55
56     def unsubscribe(self, filters):
57         self.send(json.dumps({"method": "unsubscribe", "filters": filters}))
58
59 class PollClient(threading.Thread):
60     def __init__(self, api, filters, on_event, poll_time, last_log_id):
61         super(PollClient, self).__init__()
62         self.api = api
63         if filters:
64             self.filters = [filters]
65         else:
66             self.filters = [[]]
67         self.on_event = on_event
68         self.poll_time = poll_time
69         self.daemon = True
70         self.stop = threading.Event()
71         self.last_log_id = last_log_id
72
73     def run(self):
74         self.id = 0
75         if self.last_log_id != None:
76             self.id = self.last_log_id
77         else:
78             for f in self.filters:
79                 items = self.api.logs().list(limit=1, order="id desc", filters=f).execute()['items']
80                 if items:
81                     if items[0]['id'] > self.id:
82                         self.id = items[0]['id']
83
84         self.on_event({'status': 200})
85
86         while not self.stop.isSet():
87             max_id = self.id
88             moreitems = False
89             for f in self.filters:
90                 items = self.api.logs().list(order="id asc", filters=f+[["id", ">", str(self.id)]]).execute()
91                 for i in items["items"]:
92                     if i['id'] > max_id:
93                         max_id = i['id']
94                     self.on_event(i)
95                 if items["items_available"] > len(items["items"]):
96                     moreitems = True
97             self.id = max_id
98             if not moreitems:
99                 self.stop.wait(self.poll_time)
100
101     def run_forever(self):
102         # Have to poll here, otherwise KeyboardInterrupt will never get processed.
103         while not self.stop.is_set():
104             self.stop.wait(1)
105
106     def close(self):
107         """Close poll client and wait for it to finish."""
108
109         self.stop.set()
110         try:
111             self.join()
112         except RuntimeError:
113             # "join() raises a RuntimeError if an attempt is made to join the
114             # current thread as that would cause a deadlock. It is also an
115             # error to join() a thread before it has been started and attempts
116             # to do so raises the same exception."
117             pass
118
119     def subscribe(self, filters):
120         self.on_event({'status': 200})
121         self.filters.append(filters)
122
123     def unsubscribe(self, filters):
124         del self.filters[self.filters.index(filters)]
125
126
127 def _subscribe_websocket(api, filters, on_event, last_log_id=None):
128     endpoint = api._rootDesc.get('websocketUrl', None)
129     if not endpoint:
130         raise errors.FeatureNotEnabledError(
131             "Server does not advertise a websocket endpoint")
132     try:
133         uri_with_token = "{}?api_token={}".format(endpoint, api.api_token)
134         client = EventClient(uri_with_token, filters, on_event, last_log_id)
135         ok = False
136         try:
137             client.connect()
138             ok = True
139             return client
140         finally:
141             if not ok:
142                 client.close_connection()
143     except:
144         _logger.warn("Failed to connect to websockets on %s" % endpoint)
145         raise
146
147
148 def subscribe(api, filters, on_event, poll_fallback=15, last_log_id=None):
149     """
150     :api:
151       a client object retrieved from arvados.api(). The caller should not use this client object for anything else after calling subscribe().
152     :filters:
153       Initial subscription filters.
154     :on_event:
155       The callback when a message is received.
156     :poll_fallback:
157       If websockets are not available, fall back to polling every N seconds.  If poll_fallback=False, this will return None if websockets are not available.
158     :last_log_id:
159       Log rows that are newer than the log id
160     """
161
162     if not poll_fallback:
163         return _subscribe_websocket(api, filters, on_event, last_log_id)
164
165     try:
166         return _subscribe_websocket(api, filters, on_event, last_log_id)
167     except Exception as e:
168         _logger.warn("Falling back to polling after websocket error: %s" % e)
169     p = PollClient(api, filters, on_event, poll_fallback, last_log_id)
170     p.start()
171     return p