Merge branch 'main' into 15397-remove-obsolete-apis
[arvados.git] / sdk / python / tests / test_events.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import json
6 import logging
7 import queue
8 import sys
9 import threading
10 import time
11 import unittest
12
13 from unittest import mock
14
15 import websockets.exceptions as ws_exc
16
17 import arvados
18 from . import arvados_testutil as tutil
19 from . import run_test_server
20
21 class FakeWebsocketClient:
22     """Fake self-contained version of websockets.sync.client.ClientConnection
23
24     This provides enough of the API to test EventClient. It loosely mimics
25     the Arvados WebSocket API by acknowledging subscribe messages. You can use
26     `mock_wrapper` to test calls. You can set `_check_lock` to test that the
27     given lock is acquired before `send` is called.
28     """
29
30     def __init__(self):
31         self._check_lock = None
32         self._closed = threading.Event()
33         self._messages = queue.Queue()
34
35     def mock_wrapper(self):
36         wrapper = mock.Mock(wraps=self)
37         wrapper.__iter__ = lambda _: self.__iter__()
38         return wrapper
39
40     def __iter__(self):
41         while True:
42             msg = self._messages.get()
43             self._messages.task_done()
44             if isinstance(msg, Exception):
45                 raise msg
46             else:
47                 yield msg
48
49     def close(self, code=1000, reason=''):
50         if not self._closed.is_set():
51             self._closed.set()
52             self.force_disconnect()
53
54     def force_disconnect(self):
55         self._messages.put(ws_exc.ConnectionClosed(None, None))
56
57     def send(self, msg):
58         if self._check_lock is not None and self._check_lock.acquire(blocking=False):
59             self._check_lock.release()
60             raise AssertionError(f"called ws_client.send() without lock")
61         elif self._closed.is_set():
62             raise ws_exc.ConnectionClosed(None, None)
63         try:
64             msg = json.loads(msg)
65         except ValueError:
66             status = 400
67         else:
68             status = 200
69         self._messages.put(json.dumps({'status': status}))
70
71
72 class WebsocketTest(run_test_server.TestCaseWithServers):
73     MAIN_SERVER = {}
74
75     TIME_PAST = time.time()-3600
76     TIME_FUTURE = time.time()+3600
77     MOCK_WS_URL = 'wss://[{}]/'.format(tutil.TEST_HOST)
78
79     TEST_TIMEOUT = 10.0
80
81     def setUp(self):
82         self.ws = None
83
84     def tearDown(self):
85         try:
86             if self.ws:
87                 self.ws.close()
88         except Exception as e:
89             print("Error in teardown: ", e)
90         super(WebsocketTest, self).tearDown()
91         run_test_server.reset()
92
93     def _test_subscribe(self, poll_fallback, expect_type, start_time=None, expected=1):
94         run_test_server.authorize_with('active')
95         events = queue.Queue(100)
96
97         # Create ancestor before subscribing.
98         # When listening with start_time in the past, this should also be retrieved.
99         # However, when start_time is omitted in subscribe, this should not be fetched.
100         ancestor = arvados.api('v1').collections().create(body={}).execute()
101
102         filters = [['object_uuid', 'is_a', 'arvados#collection']]
103         if start_time:
104             filters.append(['created_at', '>=', start_time])
105
106         self.ws = arvados.events.subscribe(
107             arvados.api('v1'), filters,
108             events.put_nowait,
109             poll_fallback=poll_fallback,
110             last_log_id=(1 if start_time else None))
111         self.assertIsInstance(self.ws, expect_type)
112         self.assertEqual(200, events.get(True, 5)['status'])
113
114         if hasattr(self.ws, '_skip_old_events'):
115             # Avoid race by waiting for the first "find ID threshold"
116             # poll to finish.
117             deadline = time.time() + 10
118             while not self.ws._skip_old_events:
119                 self.assertLess(time.time(), deadline)
120                 time.sleep(0.1)
121         collection = arvados.api('v1').collections().create(body={}).execute()
122
123         want_uuids = []
124         if expected > 0:
125             want_uuids.append(collection['uuid'])
126         if expected > 1:
127             want_uuids.append(ancestor['uuid'])
128         log_object_uuids = []
129         while set(want_uuids) - set(log_object_uuids):
130             log_object_uuids.append(events.get(True, 5)['object_uuid'])
131
132         if expected < 2:
133             with self.assertRaises(queue.Empty):
134                 # assertEqual just serves to show us what unexpected
135                 # thing comes out of the queue when the assertRaises
136                 # fails; when the test passes, this assertEqual
137                 # doesn't get called.
138                 self.assertEqual(events.get(True, 2), None)
139
140     def test_subscribe_websocket(self):
141         self._test_subscribe(
142             poll_fallback=False, expect_type=arvados.events.EventClient, expected=1)
143
144     @mock.patch('arvados.events.EventClient.__init__')
145     def test_subscribe_poll(self, event_client_constr):
146         event_client_constr.side_effect = Exception('All is well')
147         self._test_subscribe(
148             poll_fallback=0.25, expect_type=arvados.events.PollClient, expected=1)
149
150     def test_subscribe_poll_retry(self):
151         api_mock = mock.MagicMock()
152         n = []
153         def on_ev(ev):
154             n.append(ev)
155
156         error_mock = mock.MagicMock()
157         error_mock.resp.status = 0
158         error_mock._get_reason.return_value = "testing"
159         api_mock.logs().list().execute.side_effect = (
160             arvados.errors.ApiError(error_mock, b""),
161             {"items": [{"id": 1}], "items_available": 1},
162             arvados.errors.ApiError(error_mock, b""),
163             {"items": [{"id": 1}], "items_available": 1},
164         )
165         pc = arvados.events.PollClient(api_mock, [], on_ev, 15, None)
166         pc.start()
167         while len(n) < 2:
168             time.sleep(.1)
169         pc.close()
170
171     def test_subscribe_websocket_with_start_time_past(self):
172         self._test_subscribe(
173             poll_fallback=False, expect_type=arvados.events.EventClient,
174             start_time=self.localiso(self.TIME_PAST),
175             expected=2)
176
177     @mock.patch('arvados.events.EventClient.__init__')
178     def test_subscribe_poll_with_start_time_past(self, event_client_constr):
179         event_client_constr.side_effect = Exception('All is well')
180         self._test_subscribe(
181             poll_fallback=0.25, expect_type=arvados.events.PollClient,
182             start_time=self.localiso(self.TIME_PAST),
183             expected=2)
184
185     def test_subscribe_websocket_with_start_time_future(self):
186         self._test_subscribe(
187             poll_fallback=False, expect_type=arvados.events.EventClient,
188             start_time=self.localiso(self.TIME_FUTURE),
189             expected=0)
190
191     @mock.patch('arvados.events.EventClient.__init__')
192     def test_subscribe_poll_with_start_time_future(self, event_client_constr):
193         event_client_constr.side_effect = Exception('All is well')
194         self._test_subscribe(
195             poll_fallback=0.25, expect_type=arvados.events.PollClient,
196             start_time=self.localiso(self.TIME_FUTURE),
197             expected=0)
198
199     def test_subscribe_websocket_with_start_time_past_utc(self):
200         self._test_subscribe(
201             poll_fallback=False, expect_type=arvados.events.EventClient,
202             start_time=self.utciso(self.TIME_PAST),
203             expected=2)
204
205     def test_subscribe_websocket_with_start_time_future_utc(self):
206         self._test_subscribe(
207             poll_fallback=False, expect_type=arvados.events.EventClient,
208             start_time=self.utciso(self.TIME_FUTURE),
209             expected=0)
210
211     def utciso(self, t):
212         return time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime(t))
213
214     def localiso(self, t):
215         return time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime(t)) + self.isotz(-time.timezone//60)
216
217     def isotz(self, offset):
218         """Convert minutes-east-of-UTC to RFC3339- and ISO-compatible time zone designator"""
219         return '{:+03d}:{:02d}'.format(offset//60, offset%60)
220
221     # Test websocket reconnection on (un)expected close
222     def _test_websocket_reconnect(self, close_unexpected):
223         run_test_server.authorize_with('active')
224         events = queue.Queue(100)
225
226         logstream = tutil.StringIO()
227         rootLogger = logging.getLogger()
228         streamHandler = logging.StreamHandler(logstream)
229         rootLogger.addHandler(streamHandler)
230
231         filters = [['object_uuid', 'is_a', 'arvados#collection']]
232         filters.append(['created_at', '>=', self.localiso(self.TIME_PAST)])
233         self.ws = arvados.events.subscribe(
234             arvados.api('v1'), filters,
235             events.put_nowait,
236             poll_fallback=False,
237             last_log_id=None)
238         self.assertIsInstance(self.ws, arvados.events.EventClient)
239         self.assertEqual(200, events.get(True, 5)['status'])
240
241         # create obj
242         collection = arvados.api('v1').collections().create(body={}).execute()
243
244         # expect an event
245         self.assertIn(collection['uuid'], events.get(True, 5)['object_uuid'])
246         with self.assertRaises(queue.Empty):
247             self.assertEqual(events.get(True, 2), None)
248
249         # close (im)properly
250         if close_unexpected:
251             self.ws._client.close()
252         else:
253             self.ws.close()
254
255         # create one more obj
256         collection2 = arvados.api('v1').collections().create(body={}).execute()
257
258         # (un)expect the object creation event
259         if close_unexpected:
260             log_object_uuids = []
261             for i in range(0, 2):
262                 event = events.get(True, 5)
263                 if event.get('object_uuid') != None:
264                     log_object_uuids.append(event['object_uuid'])
265             with self.assertRaises(queue.Empty):
266                 self.assertEqual(events.get(True, 2), None)
267             self.assertNotIn(collection['uuid'], log_object_uuids)
268             self.assertIn(collection2['uuid'], log_object_uuids)
269         else:
270             with self.assertRaises(queue.Empty):
271                 self.assertEqual(events.get(True, 2), None)
272
273         # verify log message to ensure that an (un)expected close
274         log_messages = logstream.getvalue()
275         closeLogFound = log_messages.find("Unexpected close. Reconnecting.")
276         retryLogFound = log_messages.find("Error during websocket reconnect. Will retry")
277         if close_unexpected:
278             self.assertNotEqual(closeLogFound, -1)
279         else:
280             self.assertEqual(closeLogFound, -1)
281         rootLogger.removeHandler(streamHandler)
282
283     def test_websocket_reconnect_on_unexpected_close(self):
284         self._test_websocket_reconnect(True)
285
286     def test_websocket_no_reconnect_on_close_by_user(self):
287         self._test_websocket_reconnect(False)
288
289     # Test websocket reconnection retry
290     @mock.patch('arvados.events.ws_client.connect')
291     def test_websocket_reconnect_retry(self, ws_conn):
292         logstream = tutil.StringIO()
293         rootLogger = logging.getLogger()
294         streamHandler = logging.StreamHandler(logstream)
295         rootLogger.addHandler(streamHandler)
296         try:
297             msg_event, wss_client, self.ws = self.fake_client(ws_conn)
298             self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
299             msg_event.clear()
300             ws_conn.side_effect = [Exception('EventClient.connect error'), wss_client]
301             wss_client.force_disconnect()
302             self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for reconnect callback")
303             # verify log messages to ensure retry happened
304             self.assertIn("Error 'EventClient.connect error' during websocket reconnect.", logstream.getvalue())
305             self.assertEqual(ws_conn.call_count, 3)
306         finally:
307             rootLogger.removeHandler(streamHandler)
308
309     @mock.patch('arvados.events.ws_client.connect')
310     def test_run_forever_survives_reconnects(self, websocket_client):
311         client = arvados.events.EventClient(
312             self.MOCK_WS_URL, [], lambda event: None, None)
313         forever_thread = threading.Thread(target=client.run_forever)
314         forever_thread.start()
315         # Simulate an unexpected disconnect, and wait for reconnect.
316         try:
317             client.on_closed()
318             self.assertTrue(forever_thread.is_alive())
319             self.assertEqual(2, websocket_client.call_count)
320         finally:
321             client.close()
322             forever_thread.join()
323
324     @staticmethod
325     def fake_client(conn_patch, filters=None, url=MOCK_WS_URL):
326         """Set up EventClient test infrastructure
327
328         Given a patch of `arvados.events.ws_client.connect`,
329         this returns a 3-tuple:
330
331         * `msg_event` is a `threading.Event` that is set as the test client
332           event callback. You can wait for this event to confirm that a
333           sent message has been acknowledged and processed.
334
335         * `mock_client` is a `mock.Mock` wrapper around `FakeWebsocketClient`.
336           Use this to assert `EventClient` calls the right methods. It tests
337           that `EventClient` acquires a lock before calling `send`.
338
339         * `client` is the `EventClient` that uses `mock_client` under the hood
340           that you exercise methods of.
341
342         Other arguments are passed to initialize `EventClient`.
343         """
344         msg_event = threading.Event()
345         fake_client = FakeWebsocketClient()
346         mock_client = fake_client.mock_wrapper()
347         conn_patch.return_value = mock_client
348         client = arvados.events.EventClient(url, filters, lambda _: msg_event.set())
349         fake_client._check_lock = client._subscribe_lock
350         return msg_event, mock_client, client
351
352     @mock.patch('arvados.events.ws_client.connect')
353     def test_subscribe_locking(self, ws_conn):
354         f = [['created_at', '>=', '2023-12-01T00:00:00.000Z']]
355         msg_event, wss_client, self.ws = self.fake_client(ws_conn)
356         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
357         msg_event.clear()
358         wss_client.send.reset_mock()
359         self.ws.subscribe(f)
360         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for subscribe callback")
361         wss_client.send.assert_called()
362         (msg,), _ = wss_client.send.call_args
363         self.assertEqual(
364             json.loads(msg),
365             {'method': 'subscribe', 'filters': f},
366         )
367
368     @mock.patch('arvados.events.ws_client.connect')
369     def test_unsubscribe_locking(self, ws_conn):
370         f = [['created_at', '>=', '2023-12-01T01:00:00.000Z']]
371         msg_event, wss_client, self.ws = self.fake_client(ws_conn, f)
372         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
373         msg_event.clear()
374         wss_client.send.reset_mock()
375         self.ws.unsubscribe(f)
376         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for unsubscribe callback")
377         wss_client.send.assert_called()
378         (msg,), _ = wss_client.send.call_args
379         self.assertEqual(
380             json.loads(msg),
381             {'method': 'unsubscribe', 'filters': f},
382         )
383
384     @mock.patch('arvados.events.ws_client.connect')
385     def test_resubscribe_locking(self, ws_conn):
386         f = [['created_at', '>=', '2023-12-01T02:00:00.000Z']]
387         msg_event, wss_client, self.ws = self.fake_client(ws_conn, f)
388         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
389         msg_event.clear()
390         wss_client.send.reset_mock()
391         wss_client.force_disconnect()
392         self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for resubscribe callback")
393         wss_client.send.assert_called()
394         (msg,), _ = wss_client.send.call_args
395         self.assertEqual(
396             json.loads(msg),
397             {'method': 'subscribe', 'filters': f},
398         )
399
400
401 class PollClientTestCase(unittest.TestCase):
402     TEST_TIMEOUT = 10.0
403
404     class MockLogs(object):
405
406         def __init__(self):
407             self.logs = []
408             self.lock = threading.Lock()
409             self.api_called = threading.Event()
410
411         def add(self, log):
412             with self.lock:
413                 self.logs.append(log)
414
415         def return_list(self, num_retries=None):
416             self.api_called.set()
417             args, kwargs = self.list_func.call_args_list[-1]
418             filters = kwargs.get('filters', [])
419             if not any(True for f in filters if f[0] == 'id' and f[1] == '>'):
420                 # No 'id' filter was given -- this must be the probe
421                 # to determine the most recent id.
422                 return {'items': [{'id': 1}], 'items_available': 1}
423             with self.lock:
424                 retval = self.logs
425                 self.logs = []
426             return {'items': retval, 'items_available': len(retval)}
427
428     def setUp(self):
429         self.logs = self.MockLogs()
430         self.arv = mock.MagicMock(name='arvados.api()')
431         self.arv.logs().list().execute.side_effect = self.logs.return_list
432         # our MockLogs object's "execute" stub will need to inspect
433         # the call history to determine X in
434         # ....logs().list(filters=X).execute():
435         self.logs.list_func = self.arv.logs().list
436         self.status_ok = threading.Event()
437         self.event_received = threading.Event()
438         self.recv_events = []
439
440     def tearDown(self):
441         if hasattr(self, 'client'):
442             self.client.close(timeout=None)
443
444     def callback(self, event):
445         if event.get('status') == 200:
446             self.status_ok.set()
447         else:
448             self.recv_events.append(event)
449             self.event_received.set()
450
451     def build_client(self, filters=None, callback=None, last_log_id=None, poll_time=99):
452         if filters is None:
453             filters = []
454         if callback is None:
455             callback = self.callback
456         self.client = arvados.events.PollClient(
457             self.arv, filters, callback, poll_time, last_log_id)
458
459     def was_filter_used(self, target):
460         return any(target in call[-1].get('filters', [])
461                    for call in self.arv.logs().list.call_args_list)
462
463     def test_callback(self):
464         test_log = {'id': 12345, 'testkey': 'testtext'}
465         self.logs.add({'id': 123})
466         self.build_client(poll_time=.01)
467         self.client.start()
468         self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
469         self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
470         self.event_received.clear()
471         self.logs.add(test_log.copy())
472         self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
473         self.assertIn(test_log, self.recv_events)
474
475     def test_subscribe(self):
476         client_filter = ['kind', '=', 'arvados#test']
477         self.build_client()
478         self.client.unsubscribe([])
479         self.client.subscribe([client_filter[:]])
480         self.client.start()
481         self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
482         self.assertTrue(self.logs.api_called.wait(self.TEST_TIMEOUT))
483         self.assertTrue(self.was_filter_used(client_filter))
484
485     def test_unsubscribe(self):
486         should_filter = ['foo', '=', 'foo']
487         should_not_filter = ['foo', '=', 'bar']
488         self.build_client(poll_time=0.01)
489         self.client.unsubscribe([])
490         self.client.subscribe([should_not_filter[:]])
491         self.client.subscribe([should_filter[:]])
492         self.client.unsubscribe([should_not_filter[:]])
493         self.client.start()
494         self.logs.add({'id': 123})
495         self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
496         self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
497         self.assertTrue(self.was_filter_used(should_filter))
498         self.assertFalse(self.was_filter_used(should_not_filter))
499
500     def test_run_forever(self):
501         self.build_client()
502         self.client.start()
503         forever_thread = threading.Thread(target=self.client.run_forever)
504         forever_thread.start()
505         self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
506         self.assertTrue(forever_thread.is_alive())
507         self.client.close()
508         forever_thread.join()
509         del self.client