-import arvados
-import io
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import json
import logging
import mock
-import Queue
-import run_test_server
+import queue
+import sys
import threading
import time
import unittest
-import arvados_testutil
+import websockets.exceptions as ws_exc
+
+import arvados
+from . import arvados_testutil as tutil
+from . import run_test_server
+
+class FakeWebsocketClient:
+ """Fake self-contained version of websockets.sync.client.ClientConnection
+
+ This provides enough of the API to test EventClient. It loosely mimics
+ the Arvados WebSocket API by acknowledging subscribe messages. You can use
+ `mock_wrapper` to test calls. You can set `_check_lock` to test that the
+ given lock is acquired before `send` is called.
+ """
+
+ def __init__(self):
+ self._check_lock = None
+ self._closed = threading.Event()
+ self._messages = queue.Queue()
+
+ def mock_wrapper(self):
+ wrapper = mock.Mock(wraps=self)
+ wrapper.__iter__ = lambda _: self.__iter__()
+ return wrapper
+
+ def __iter__(self):
+ while True:
+ msg = self._messages.get()
+ self._messages.task_done()
+ if isinstance(msg, Exception):
+ raise msg
+ else:
+ yield msg
+
+ def close(self, code=1000, reason=''):
+ if not self._closed.is_set():
+ self._closed.set()
+ self.force_disconnect()
+
+ def force_disconnect(self):
+ self._messages.put(ws_exc.ConnectionClosed(None, None))
+
+ def send(self, msg):
+ if self._check_lock is not None and self._check_lock.acquire(blocking=False):
+ self._check_lock.release()
+ raise AssertionError(f"called ws_client.send() without lock")
+ elif self._closed.is_set():
+ raise ws_exc.ConnectionClosed(None, None)
+ try:
+ msg = json.loads(msg)
+ except ValueError:
+ status = 400
+ else:
+ status = 200
+ self._messages.put(json.dumps({'status': status}))
+
class WebsocketTest(run_test_server.TestCaseWithServers):
MAIN_SERVER = {}
TIME_PAST = time.time()-3600
TIME_FUTURE = time.time()+3600
- MOCK_WS_URL = 'wss://[{}]/'.format(arvados_testutil.TEST_HOST)
+ MOCK_WS_URL = 'wss://[{}]/'.format(tutil.TEST_HOST)
TEST_TIMEOUT = 10.0
def _test_subscribe(self, poll_fallback, expect_type, start_time=None, expected=1):
run_test_server.authorize_with('active')
- events = Queue.Queue(100)
+ events = queue.Queue(100)
# Create ancestor before subscribing.
# When listening with start_time in the past, this should also be retrieved.
last_log_id=(1 if start_time else None))
self.assertIsInstance(self.ws, expect_type)
self.assertEqual(200, events.get(True, 5)['status'])
+
+ if hasattr(self.ws, '_skip_old_events'):
+ # Avoid race by waiting for the first "find ID threshold"
+ # poll to finish.
+ deadline = time.time() + 10
+ while not self.ws._skip_old_events:
+ self.assertLess(time.time(), deadline)
+ time.sleep(0.1)
human = arvados.api('v1').humans().create(body={}).execute()
want_uuids = []
log_object_uuids.append(events.get(True, 5)['object_uuid'])
if expected < 2:
- with self.assertRaises(Queue.Empty):
+ with self.assertRaises(queue.Empty):
# assertEqual just serves to show us what unexpected
# thing comes out of the queue when the assertRaises
# fails; when the test passes, this assertEqual
error_mock = mock.MagicMock()
error_mock.resp.status = 0
error_mock._get_reason.return_value = "testing"
- api_mock.logs().list().execute.side_effect = (arvados.errors.ApiError(error_mock, ""),
- {"items": [{"id": 1}], "items_available": 1},
- arvados.errors.ApiError(error_mock, ""),
- {"items": [{"id": 1}], "items_available": 1})
+ api_mock.logs().list().execute.side_effect = (
+ arvados.errors.ApiError(error_mock, b""),
+ {"items": [{"id": 1}], "items_available": 1},
+ arvados.errors.ApiError(error_mock, b""),
+ {"items": [{"id": 1}], "items_available": 1},
+ )
pc = arvados.events.PollClient(api_mock, [], on_ev, 15, None)
pc.start()
while len(n) < 2:
return time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime(t))
def localiso(self, t):
- return time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime(t)) + self.isotz(-time.timezone/60)
+ return time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime(t)) + self.isotz(-time.timezone//60)
def isotz(self, offset):
"""Convert minutes-east-of-UTC to RFC3339- and ISO-compatible time zone designator"""
- return '{:+03d}:{:02d}'.format(offset/60, offset%60)
+ return '{:+03d}:{:02d}'.format(offset//60, offset%60)
- # Test websocket reconnection on (un)execpted close
+ # Test websocket reconnection on (un)expected close
def _test_websocket_reconnect(self, close_unexpected):
run_test_server.authorize_with('active')
- events = Queue.Queue(100)
+ events = queue.Queue(100)
- logstream = io.BytesIO()
+ logstream = tutil.StringIO()
rootLogger = logging.getLogger()
streamHandler = logging.StreamHandler(logstream)
rootLogger.addHandler(streamHandler)
# expect an event
self.assertIn(human['uuid'], events.get(True, 5)['object_uuid'])
- with self.assertRaises(Queue.Empty):
+ with self.assertRaises(queue.Empty):
self.assertEqual(events.get(True, 2), None)
# close (im)properly
if close_unexpected:
- self.ws.ec.close_connection()
+ self.ws._client.close()
else:
self.ws.close()
event = events.get(True, 5)
if event.get('object_uuid') != None:
log_object_uuids.append(event['object_uuid'])
- with self.assertRaises(Queue.Empty):
+ with self.assertRaises(queue.Empty):
self.assertEqual(events.get(True, 2), None)
self.assertNotIn(human['uuid'], log_object_uuids)
self.assertIn(human2['uuid'], log_object_uuids)
else:
- with self.assertRaises(Queue.Empty):
+ with self.assertRaises(queue.Empty):
self.assertEqual(events.get(True, 2), None)
# verify log message to ensure that an (un)expected close
self._test_websocket_reconnect(False)
# Test websocket reconnection retry
- @mock.patch('arvados.events._EventClient.connect')
- def test_websocket_reconnect_retry(self, event_client_connect):
- event_client_connect.side_effect = [None, Exception('EventClient.connect error'), None]
-
- logstream = io.BytesIO()
+ @mock.patch('arvados.events.ws_client.connect')
+ def test_websocket_reconnect_retry(self, ws_conn):
+ logstream = tutil.StringIO()
rootLogger = logging.getLogger()
streamHandler = logging.StreamHandler(logstream)
rootLogger.addHandler(streamHandler)
-
- run_test_server.authorize_with('active')
- events = Queue.Queue(100)
-
- filters = [['object_uuid', 'is_a', 'arvados#human']]
- self.ws = arvados.events.subscribe(
- arvados.api('v1'), filters,
- events.put_nowait,
- poll_fallback=False,
- last_log_id=None)
- self.assertIsInstance(self.ws, arvados.events.EventClient)
-
- # simulate improper close
- self.ws.on_closed()
-
- # verify log messages to ensure retry happened
- log_messages = logstream.getvalue()
- found = log_messages.find("Error 'EventClient.connect error' during websocket reconnect.")
- self.assertNotEqual(found, -1)
- rootLogger.removeHandler(streamHandler)
-
- @mock.patch('arvados.events._EventClient')
- def test_subscribe_method(self, websocket_client):
- filters = [['object_uuid', 'is_a', 'arvados#human']]
- client = arvados.events.EventClient(
- self.MOCK_WS_URL, [], lambda event: None, None)
- client.subscribe(filters[:], 99)
- websocket_client().subscribe.assert_called_with(filters, 99)
-
- @mock.patch('arvados.events._EventClient')
- def test_unsubscribe(self, websocket_client):
- filters = [['object_uuid', 'is_a', 'arvados#human']]
- client = arvados.events.EventClient(
- self.MOCK_WS_URL, filters[:], lambda event: None, None)
- client.unsubscribe(filters[:])
- websocket_client().unsubscribe.assert_called_with(filters)
-
- @mock.patch('arvados.events._EventClient')
+ try:
+ msg_event, wss_client, self.ws = self.fake_client(ws_conn)
+ self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
+ msg_event.clear()
+ ws_conn.side_effect = [Exception('EventClient.connect error'), wss_client]
+ wss_client.force_disconnect()
+ self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for reconnect callback")
+ # verify log messages to ensure retry happened
+ self.assertIn("Error 'EventClient.connect error' during websocket reconnect.", logstream.getvalue())
+ self.assertEqual(ws_conn.call_count, 3)
+ finally:
+ rootLogger.removeHandler(streamHandler)
+
+ @mock.patch('arvados.events.ws_client.connect')
def test_run_forever_survives_reconnects(self, websocket_client):
- connected = threading.Event()
- websocket_client().connect.side_effect = connected.set
client = arvados.events.EventClient(
self.MOCK_WS_URL, [], lambda event: None, None)
forever_thread = threading.Thread(target=client.run_forever)
forever_thread.start()
# Simulate an unexpected disconnect, and wait for reconnect.
- close_thread = threading.Thread(target=client.on_closed)
- close_thread.start()
- self.assertTrue(connected.wait(timeout=self.TEST_TIMEOUT))
- close_thread.join()
- run_forever_alive = forever_thread.is_alive()
- client.close()
- forever_thread.join()
- self.assertTrue(run_forever_alive)
- self.assertEqual(2, websocket_client().connect.call_count)
+ try:
+ client.on_closed()
+ self.assertTrue(forever_thread.is_alive())
+ self.assertEqual(2, websocket_client.call_count)
+ finally:
+ client.close()
+ forever_thread.join()
+
+ @staticmethod
+ def fake_client(conn_patch, filters=None, url=MOCK_WS_URL):
+ """Set up EventClient test infrastructure
+
+ Given a patch of `arvados.events.ws_client.connect`,
+ this returns a 3-tuple:
+
+ * `msg_event` is a `threading.Event` that is set as the test client
+ event callback. You can wait for this event to confirm that a
+ sent message has been acknowledged and processed.
+
+ * `mock_client` is a `mock.Mock` wrapper around `FakeWebsocketClient`.
+ Use this to assert `EventClient` calls the right methods. It tests
+ that `EventClient` acquires a lock before calling `send`.
+
+ * `client` is the `EventClient` that uses `mock_client` under the hood
+ that you exercise methods of.
+
+ Other arguments are passed to initialize `EventClient`.
+ """
+ msg_event = threading.Event()
+ fake_client = FakeWebsocketClient()
+ mock_client = fake_client.mock_wrapper()
+ conn_patch.return_value = mock_client
+ client = arvados.events.EventClient(url, filters, lambda _: msg_event.set())
+ fake_client._check_lock = client._subscribe_lock
+ return msg_event, mock_client, client
+
+ @mock.patch('arvados.events.ws_client.connect')
+ def test_subscribe_locking(self, ws_conn):
+ f = [['created_at', '>=', '2023-12-01T00:00:00.000Z']]
+ msg_event, wss_client, self.ws = self.fake_client(ws_conn)
+ self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
+ msg_event.clear()
+ wss_client.send.reset_mock()
+ self.ws.subscribe(f)
+ self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for subscribe callback")
+ wss_client.send.assert_called()
+ (msg,), _ = wss_client.send.call_args
+ self.assertEqual(
+ json.loads(msg),
+ {'method': 'subscribe', 'filters': f},
+ )
+
+ @mock.patch('arvados.events.ws_client.connect')
+ def test_unsubscribe_locking(self, ws_conn):
+ f = [['created_at', '>=', '2023-12-01T01:00:00.000Z']]
+ msg_event, wss_client, self.ws = self.fake_client(ws_conn, f)
+ self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
+ msg_event.clear()
+ wss_client.send.reset_mock()
+ self.ws.unsubscribe(f)
+ self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for unsubscribe callback")
+ wss_client.send.assert_called()
+ (msg,), _ = wss_client.send.call_args
+ self.assertEqual(
+ json.loads(msg),
+ {'method': 'unsubscribe', 'filters': f},
+ )
+
+ @mock.patch('arvados.events.ws_client.connect')
+ def test_resubscribe_locking(self, ws_conn):
+ f = [['created_at', '>=', '2023-12-01T02:00:00.000Z']]
+ msg_event, wss_client, self.ws = self.fake_client(ws_conn, f)
+ self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for setup callback")
+ msg_event.clear()
+ wss_client.send.reset_mock()
+ wss_client.force_disconnect()
+ self.assertTrue(msg_event.wait(timeout=1), "timed out waiting for resubscribe callback")
+ wss_client.send.assert_called()
+ (msg,), _ = wss_client.send.call_args
+ self.assertEqual(
+ json.loads(msg),
+ {'method': 'subscribe', 'filters': f},
+ )
class PollClientTestCase(unittest.TestCase):
def __init__(self):
self.logs = []
self.lock = threading.Lock()
+ self.api_called = threading.Event()
def add(self, log):
with self.lock:
self.logs.append(log)
def return_list(self, num_retries=None):
+ self.api_called.set()
+ args, kwargs = self.list_func.call_args_list[-1]
+ filters = kwargs.get('filters', [])
+ if not any(True for f in filters if f[0] == 'id' and f[1] == '>'):
+ # No 'id' filter was given -- this must be the probe
+ # to determine the most recent id.
+ return {'items': [{'id': 1}], 'items_available': 1}
with self.lock:
retval = self.logs
self.logs = []
self.logs = self.MockLogs()
self.arv = mock.MagicMock(name='arvados.api()')
self.arv.logs().list().execute.side_effect = self.logs.return_list
- self.callback_called = threading.Event()
+ # our MockLogs object's "execute" stub will need to inspect
+ # the call history to determine X in
+ # ....logs().list(filters=X).execute():
+ self.logs.list_func = self.arv.logs().list
+ self.status_ok = threading.Event()
+ self.event_received = threading.Event()
self.recv_events = []
def tearDown(self):
self.client.close(timeout=None)
def callback(self, event):
- self.recv_events.append(event)
- self.callback_called.set()
+ if event.get('status') == 200:
+ self.status_ok.set()
+ else:
+ self.recv_events.append(event)
+ self.event_received.set()
def build_client(self, filters=None, callback=None, last_log_id=None, poll_time=99):
if filters is None:
self.logs.add({'id': 123})
self.build_client(poll_time=.01)
self.client.start()
- self.assertTrue(self.callback_called.wait(self.TEST_TIMEOUT))
- self.callback_called.clear()
+ self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
+ self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
+ self.event_received.clear()
self.logs.add(test_log.copy())
- self.assertTrue(self.callback_called.wait(self.TEST_TIMEOUT))
- self.client.close(timeout=None)
+ self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
self.assertIn(test_log, self.recv_events)
def test_subscribe(self):
client_filter = ['kind', '=', 'arvados#test']
self.build_client()
+ self.client.unsubscribe([])
self.client.subscribe([client_filter[:]])
self.client.start()
- self.assertTrue(self.callback_called.wait(self.TEST_TIMEOUT))
- self.client.close(timeout=None)
+ self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
+ self.assertTrue(self.logs.api_called.wait(self.TEST_TIMEOUT))
self.assertTrue(self.was_filter_used(client_filter))
def test_unsubscribe(self):
- client_filter = ['kind', '=', 'arvados#test']
- self.build_client()
- self.client.subscribe([client_filter[:]])
- self.client.unsubscribe([client_filter[:]])
+ should_filter = ['foo', '=', 'foo']
+ should_not_filter = ['foo', '=', 'bar']
+ self.build_client(poll_time=0.01)
+ self.client.unsubscribe([])
+ self.client.subscribe([should_not_filter[:]])
+ self.client.subscribe([should_filter[:]])
+ self.client.unsubscribe([should_not_filter[:]])
self.client.start()
- self.client.close(timeout=None)
- self.assertFalse(self.was_filter_used(client_filter))
+ self.logs.add({'id': 123})
+ self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
+ self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
+ self.assertTrue(self.was_filter_used(should_filter))
+ self.assertFalse(self.was_filter_used(should_not_filter))
def test_run_forever(self):
self.build_client()
self.client.start()
forever_thread = threading.Thread(target=self.client.run_forever)
forever_thread.start()
- self.assertTrue(self.callback_called.wait(self.TEST_TIMEOUT))
+ self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
self.assertTrue(forever_thread.is_alive())
self.client.close()
forever_thread.join()
+ del self.client