21230: Add license header
[arvados.git] / sdk / python / tests / test_events.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import json
6 import logging
7 import mock
8 import queue
9 import sys
10 import threading
11 import time
12 import unittest
13
14 import websockets.exceptions as ws_exc
15
16 import arvados
17 from . import arvados_testutil as tutil
18 from . import run_test_server
19
20 class FakeWebsocketClient:
21     """Fake self-contained version of websockets.sync.client.ClientConnection
22
23     This provides enough of the API to test EventClient. It loosely mimics
24     the Arvados WebSocket API by acknowledging subscribe messages. You can use
25     `mock_wrapper` to test calls. You can set `_check_lock` to test that the
26     given lock is acquired before `send` is called.
27     """
28
29     def __init__(self):
30         self._check_lock = None
31         self._closed = threading.Event()
32         self._messages = queue.Queue()
33
34     def mock_wrapper(self):
35         wrapper = mock.Mock(wraps=self)
36         wrapper.__iter__ = lambda _: self.__iter__()
37         return wrapper
38
39     def __iter__(self):
40         while True:
41             msg = self._messages.get()
42             self._messages.task_done()
43             if isinstance(msg, Exception):
44                 raise msg
45             else:
46                 yield msg
47
48     def close(self, code=1000, reason=''):
49         if not self._closed.is_set():
50             self._closed.set()
51             self.force_disconnect()
52
53     def force_disconnect(self):
54         self._messages.put(ws_exc.ConnectionClosed(None, None))
55
56     def send(self, msg):
57         if self._check_lock is not None and self._check_lock.acquire(blocking=False):
58             self._check_lock.release()
59             raise AssertionError(f"called ws_client.send() without lock")
60         elif self._closed.is_set():
61             raise ws_exc.ConnectionClosed(None, None)
62         try:
63             msg = json.loads(msg)
64         except ValueError:
65             status = 400
66         else:
67             status = 200
68         self._messages.put(json.dumps({'status': status}))
69
70
71 class WebsocketTest(run_test_server.TestCaseWithServers):
72     MAIN_SERVER = {}
73
74     TIME_PAST = time.time()-3600
75     TIME_FUTURE = time.time()+3600
76     MOCK_WS_URL = 'wss://[{}]/'.format(tutil.TEST_HOST)
77
78     TEST_TIMEOUT = 10.0
79
80     def setUp(self):
81         self.ws = None
82
83     def tearDown(self):
84         try:
85             if self.ws:
86                 self.ws.close()
87         except Exception as e:
88             print("Error in teardown: ", e)
89         super(WebsocketTest, self).tearDown()
90         run_test_server.reset()
91
92     def _test_subscribe(self, poll_fallback, expect_type, start_time=None, expected=1):
93         run_test_server.authorize_with('active')
94         events = queue.Queue(100)
95
96         # Create ancestor before subscribing.
97         # When listening with start_time in the past, this should also be retrieved.
98         # However, when start_time is omitted in subscribe, this should not be fetched.
99         ancestor = arvados.api('v1').humans().create(body={}).execute()
100
101         filters = [['object_uuid', 'is_a', 'arvados#human']]
102         if start_time:
103             filters.append(['created_at', '>=', start_time])
104
105         self.ws = arvados.events.subscribe(
106             arvados.api('v1'), filters,
107             events.put_nowait,
108             poll_fallback=poll_fallback,
109             last_log_id=(1 if start_time else None))
110         self.assertIsInstance(self.ws, expect_type)
111         self.assertEqual(200, events.get(True, 5)['status'])
112
113         if hasattr(self.ws, '_skip_old_events'):
114             # Avoid race by waiting for the first "find ID threshold"
115             # poll to finish.
116             deadline = time.time() + 10
117             while not self.ws._skip_old_events:
118                 self.assertLess(time.time(), deadline)
119                 time.sleep(0.1)
120         human = arvados.api('v1').humans().create(body={}).execute()
121
122         want_uuids = []
123         if expected > 0:
124             want_uuids.append(human['uuid'])
125         if expected > 1:
126             want_uuids.append(ancestor['uuid'])
127         log_object_uuids = []
128         while set(want_uuids) - set(log_object_uuids):
129             log_object_uuids.append(events.get(True, 5)['object_uuid'])
130
131         if expected < 2:
132             with self.assertRaises(queue.Empty):
133                 # assertEqual just serves to show us what unexpected
134                 # thing comes out of the queue when the assertRaises
135                 # fails; when the test passes, this assertEqual
136                 # doesn't get called.
137                 self.assertEqual(events.get(True, 2), None)
138
139     def test_subscribe_websocket(self):
140         self._test_subscribe(
141             poll_fallback=False, expect_type=arvados.events.EventClient, expected=1)
142
143     @mock.patch('arvados.events.EventClient.__init__')
144     def test_subscribe_poll(self, event_client_constr):
145         event_client_constr.side_effect = Exception('All is well')
146         self._test_subscribe(
147             poll_fallback=0.25, expect_type=arvados.events.PollClient, expected=1)
148
149     def test_subscribe_poll_retry(self):
150         api_mock = mock.MagicMock()
151         n = []
152         def on_ev(ev):
153             n.append(ev)
154
155         error_mock = mock.MagicMock()
156         error_mock.resp.status = 0
157         error_mock._get_reason.return_value = "testing"
158         api_mock.logs().list().execute.side_effect = (
159             arvados.errors.ApiError(error_mock, b""),
160             {"items": [{"id": 1}], "items_available": 1},
161             arvados.errors.ApiError(error_mock, b""),
162             {"items": [{"id": 1}], "items_available": 1},
163         )
164         pc = arvados.events.PollClient(api_mock, [], on_ev, 15, None)
165         pc.start()
166         while len(n) < 2:
167             time.sleep(.1)
168         pc.close()
169
170     def test_subscribe_websocket_with_start_time_past(self):
171         self._test_subscribe(
172             poll_fallback=False, expect_type=arvados.events.EventClient,
173             start_time=self.localiso(self.TIME_PAST),
174             expected=2)
175
176     @mock.patch('arvados.events.EventClient.__init__')
177     def test_subscribe_poll_with_start_time_past(self, event_client_constr):
178         event_client_constr.side_effect = Exception('All is well')
179         self._test_subscribe(
180             poll_fallback=0.25, expect_type=arvados.events.PollClient,
181             start_time=self.localiso(self.TIME_PAST),
182             expected=2)
183
184     def test_subscribe_websocket_with_start_time_future(self):
185         self._test_subscribe(
186             poll_fallback=False, expect_type=arvados.events.EventClient,
187             start_time=self.localiso(self.TIME_FUTURE),
188             expected=0)
189
190     @mock.patch('arvados.events.EventClient.__init__')
191     def test_subscribe_poll_with_start_time_future(self, event_client_constr):
192         event_client_constr.side_effect = Exception('All is well')
193         self._test_subscribe(
194             poll_fallback=0.25, expect_type=arvados.events.PollClient,
195             start_time=self.localiso(self.TIME_FUTURE),
196             expected=0)
197
198     def test_subscribe_websocket_with_start_time_past_utc(self):
199         self._test_subscribe(
200             poll_fallback=False, expect_type=arvados.events.EventClient,
201             start_time=self.utciso(self.TIME_PAST),
202             expected=2)
203
204     def test_subscribe_websocket_with_start_time_future_utc(self):
205         self._test_subscribe(
206             poll_fallback=False, expect_type=arvados.events.EventClient,
207             start_time=self.utciso(self.TIME_FUTURE),
208             expected=0)
209
210     def utciso(self, t):
211         return time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime(t))
212
213     def localiso(self, t):
214         return time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime(t)) + self.isotz(-time.timezone//60)
215
216     def isotz(self, offset):
217         """Convert minutes-east-of-UTC to RFC3339- and ISO-compatible time zone designator"""
218         return '{:+03d}:{:02d}'.format(offset//60, offset%60)
219
220     # Test websocket reconnection on (un)expected close
221     def _test_websocket_reconnect(self, close_unexpected):
222         run_test_server.authorize_with('active')
223         events = queue.Queue(100)
224
225         logstream = tutil.StringIO()
226         rootLogger = logging.getLogger()
227         streamHandler = logging.StreamHandler(logstream)
228         rootLogger.addHandler(streamHandler)
229
230         filters = [['object_uuid', 'is_a', 'arvados#human']]
231         filters.append(['created_at', '>=', self.localiso(self.TIME_PAST)])
232         self.ws = arvados.events.subscribe(
233             arvados.api('v1'), filters,
234             events.put_nowait,
235             poll_fallback=False,
236             last_log_id=None)
237         self.assertIsInstance(self.ws, arvados.events.EventClient)
238         self.assertEqual(200, events.get(True, 5)['status'])
239
240         # create obj
241         human = arvados.api('v1').humans().create(body={}).execute()
242
243         # expect an event
244         self.assertIn(human['uuid'], events.get(True, 5)['object_uuid'])
245         with self.assertRaises(queue.Empty):
246             self.assertEqual(events.get(True, 2), None)
247
248         # close (im)properly
249         if close_unexpected:
250             self.ws._client.close()
251         else:
252             self.ws.close()
253
254         # create one more obj
255         human2 = arvados.api('v1').humans().create(body={}).execute()
256
257         # (un)expect the object creation event
258         if close_unexpected:
259             log_object_uuids = []
260             for i in range(0, 2):
261                 event = events.get(True, 5)
262                 if event.get('object_uuid') != None:
263                     log_object_uuids.append(event['object_uuid'])
264             with self.assertRaises(queue.Empty):
265                 self.assertEqual(events.get(True, 2), None)
266             self.assertNotIn(human['uuid'], log_object_uuids)
267             self.assertIn(human2['uuid'], log_object_uuids)
268         else:
269             with self.assertRaises(queue.Empty):
270                 self.assertEqual(events.get(True, 2), None)
271
272         # verify log message to ensure that an (un)expected close
273         log_messages = logstream.getvalue()
274         closeLogFound = log_messages.find("Unexpected close. Reconnecting.")
275         retryLogFound = log_messages.find("Error during websocket reconnect. Will retry")
276         if close_unexpected:
277             self.assertNotEqual(closeLogFound, -1)
278         else:
279             self.assertEqual(closeLogFound, -1)
280         rootLogger.removeHandler(streamHandler)
281
282     def test_websocket_reconnect_on_unexpected_close(self):
283         self._test_websocket_reconnect(True)
284
285     def test_websocket_no_reconnect_on_close_by_user(self):
286         self._test_websocket_reconnect(False)
287
288     # Test websocket reconnection retry
289     @mock.patch('arvados.events.ws_client.connect')
290     def test_websocket_reconnect_retry(self, ws_conn):
291         logstream = tutil.StringIO()
292         rootLogger = logging.getLogger()
293         streamHandler = logging.StreamHandler(logstream)
294         rootLogger.addHandler(streamHandler)
295         try:
296             msg_event, wss_client, self.ws = self.fake_client(ws_conn)
297             self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
298             msg_event.clear()
299             ws_conn.side_effect = [Exception('EventClient.connect error'), wss_client]
300             wss_client.force_disconnect()
301             self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for reconnect callback")
302             # verify log messages to ensure retry happened
303             self.assertIn("Error 'EventClient.connect error' during websocket reconnect.", logstream.getvalue())
304             self.assertEqual(ws_conn.call_count, 3)
305         finally:
306             rootLogger.removeHandler(streamHandler)
307
308     @mock.patch('arvados.events.ws_client.connect')
309     def test_run_forever_survives_reconnects(self, websocket_client):
310         client = arvados.events.EventClient(
311             self.MOCK_WS_URL, [], lambda event: None, None)
312         forever_thread = threading.Thread(target=client.run_forever)
313         forever_thread.start()
314         # Simulate an unexpected disconnect, and wait for reconnect.
315         try:
316             client.on_closed()
317             self.assertTrue(forever_thread.is_alive())
318             self.assertEqual(2, websocket_client.call_count)
319         finally:
320             client.close()
321             forever_thread.join()
322
323     @staticmethod
324     def fake_client(conn_patch, filters=None, url=MOCK_WS_URL):
325         """Set up EventClient test infrastructure
326
327         Given a patch of `arvados.events.ws_client.connect`,
328         this returns a 3-tuple:
329
330         * `msg_event` is a `threading.Event` that is set as the test client
331           event callback. You can wait for this event to confirm that a
332           sent message has been acknowledged and processed.
333
334         * `mock_client` is a `mock.Mock` wrapper around `FakeWebsocketClient`.
335           Use this to assert `EventClient` calls the right methods. It tests
336           that `EventClient` acquires a lock before calling `send`.
337
338         * `client` is the `EventClient` that uses `mock_client` under the hood
339           that you exercise methods of.
340
341         Other arguments are passed to initialize `EventClient`.
342         """
343         msg_event = threading.Event()
344         fake_client = FakeWebsocketClient()
345         mock_client = fake_client.mock_wrapper()
346         conn_patch.return_value = mock_client
347         client = arvados.events.EventClient(url, filters, lambda _: msg_event.set())
348         fake_client._check_lock = client._subscribe_lock
349         return msg_event, mock_client, client
350
351     @mock.patch('arvados.events.ws_client.connect')
352     def test_subscribe_locking(self, ws_conn):
353         f = [['created_at', '>=', '2023-12-01T00:00:00.000Z']]
354         msg_event, wss_client, self.ws = self.fake_client(ws_conn)
355         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
356         msg_event.clear()
357         wss_client.send.reset_mock()
358         self.ws.subscribe(f)
359         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for subscribe callback")
360         wss_client.send.assert_called()
361         (msg,), _ = wss_client.send.call_args
362         self.assertEqual(
363             json.loads(msg),
364             {'method': 'subscribe', 'filters': f},
365         )
366
367     @mock.patch('arvados.events.ws_client.connect')
368     def test_unsubscribe_locking(self, ws_conn):
369         f = [['created_at', '>=', '2023-12-01T01:00:00.000Z']]
370         msg_event, wss_client, self.ws = self.fake_client(ws_conn, f)
371         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
372         msg_event.clear()
373         wss_client.send.reset_mock()
374         self.ws.unsubscribe(f)
375         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for unsubscribe callback")
376         wss_client.send.assert_called()
377         (msg,), _ = wss_client.send.call_args
378         self.assertEqual(
379             json.loads(msg),
380             {'method': 'unsubscribe', 'filters': f},
381         )
382
383     @mock.patch('arvados.events.ws_client.connect')
384     def test_resubscribe_locking(self, ws_conn):
385         f = [['created_at', '>=', '2023-12-01T02:00:00.000Z']]
386         msg_event, wss_client, self.ws = self.fake_client(ws_conn, f)
387         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
388         msg_event.clear()
389         wss_client.send.reset_mock()
390         wss_client.force_disconnect()
391         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for resubscribe callback")
392         wss_client.send.assert_called()
393         (msg,), _ = wss_client.send.call_args
394         self.assertEqual(
395             json.loads(msg),
396             {'method': 'subscribe', 'filters': f},
397         )
398
399
400 class PollClientTestCase(unittest.TestCase):
401     TEST_TIMEOUT = 10.0
402
403     class MockLogs(object):
404
405         def __init__(self):
406             self.logs = []
407             self.lock = threading.Lock()
408             self.api_called = threading.Event()
409
410         def add(self, log):
411             with self.lock:
412                 self.logs.append(log)
413
414         def return_list(self, num_retries=None):
415             self.api_called.set()
416             args, kwargs = self.list_func.call_args_list[-1]
417             filters = kwargs.get('filters', [])
418             if not any(True for f in filters if f[0] == 'id' and f[1] == '>'):
419                 # No 'id' filter was given -- this must be the probe
420                 # to determine the most recent id.
421                 return {'items': [{'id': 1}], 'items_available': 1}
422             with self.lock:
423                 retval = self.logs
424                 self.logs = []
425             return {'items': retval, 'items_available': len(retval)}
426
427     def setUp(self):
428         self.logs = self.MockLogs()
429         self.arv = mock.MagicMock(name='arvados.api()')
430         self.arv.logs().list().execute.side_effect = self.logs.return_list
431         # our MockLogs object's "execute" stub will need to inspect
432         # the call history to determine X in
433         # ....logs().list(filters=X).execute():
434         self.logs.list_func = self.arv.logs().list
435         self.status_ok = threading.Event()
436         self.event_received = threading.Event()
437         self.recv_events = []
438
439     def tearDown(self):
440         if hasattr(self, 'client'):
441             self.client.close(timeout=None)
442
443     def callback(self, event):
444         if event.get('status') == 200:
445             self.status_ok.set()
446         else:
447             self.recv_events.append(event)
448             self.event_received.set()
449
450     def build_client(self, filters=None, callback=None, last_log_id=None, poll_time=99):
451         if filters is None:
452             filters = []
453         if callback is None:
454             callback = self.callback
455         self.client = arvados.events.PollClient(
456             self.arv, filters, callback, poll_time, last_log_id)
457
458     def was_filter_used(self, target):
459         return any(target in call[-1].get('filters', [])
460                    for call in self.arv.logs().list.call_args_list)
461
462     def test_callback(self):
463         test_log = {'id': 12345, 'testkey': 'testtext'}
464         self.logs.add({'id': 123})
465         self.build_client(poll_time=.01)
466         self.client.start()
467         self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
468         self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
469         self.event_received.clear()
470         self.logs.add(test_log.copy())
471         self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
472         self.assertIn(test_log, self.recv_events)
473
474     def test_subscribe(self):
475         client_filter = ['kind', '=', 'arvados#test']
476         self.build_client()
477         self.client.unsubscribe([])
478         self.client.subscribe([client_filter[:]])
479         self.client.start()
480         self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
481         self.assertTrue(self.logs.api_called.wait(self.TEST_TIMEOUT))
482         self.assertTrue(self.was_filter_used(client_filter))
483
484     def test_unsubscribe(self):
485         should_filter = ['foo', '=', 'foo']
486         should_not_filter = ['foo', '=', 'bar']
487         self.build_client(poll_time=0.01)
488         self.client.unsubscribe([])
489         self.client.subscribe([should_not_filter[:]])
490         self.client.subscribe([should_filter[:]])
491         self.client.unsubscribe([should_not_filter[:]])
492         self.client.start()
493         self.logs.add({'id': 123})
494         self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
495         self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
496         self.assertTrue(self.was_filter_used(should_filter))
497         self.assertFalse(self.was_filter_used(should_not_filter))
498
499     def test_run_forever(self):
500         self.build_client()
501         self.client.start()
502         forever_thread = threading.Thread(target=self.client.run_forever)
503         forever_thread.start()
504         self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
505         self.assertTrue(forever_thread.is_alive())
506         self.client.close()
507         forever_thread.join()
508         del self.client