18870: Need to declare NODES as array
[arvados.git] / sdk / python / tests / test_events.py
index 7e8c84ec11279495d55fd47770378847886ae76e..f5192160f3e5fad01d080b0eb16bf834bdfd1ed6 100644 (file)
@@ -1,21 +1,33 @@
-import arvados
-import io
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import print_function
+from __future__ import absolute_import
+from __future__ import division
+from future import standard_library
+standard_library.install_aliases()
+from builtins import range
+from builtins import object
 import logging
 import mock
-import Queue
-import run_test_server
+import queue
+import sys
 import threading
 import time
 import unittest
 
-import arvados_testutil
+import arvados
+from . import arvados_testutil as tutil
+from . import run_test_server
+
 
 class WebsocketTest(run_test_server.TestCaseWithServers):
     MAIN_SERVER = {}
 
     TIME_PAST = time.time()-3600
     TIME_FUTURE = time.time()+3600
-    MOCK_WS_URL = 'wss://[{}]/'.format(arvados_testutil.TEST_HOST)
+    MOCK_WS_URL = 'wss://[{}]/'.format(tutil.TEST_HOST)
 
     TEST_TIMEOUT = 10.0
 
@@ -33,7 +45,7 @@ class WebsocketTest(run_test_server.TestCaseWithServers):
 
     def _test_subscribe(self, poll_fallback, expect_type, start_time=None, expected=1):
         run_test_server.authorize_with('active')
-        events = Queue.Queue(100)
+        events = queue.Queue(100)
 
         # Create ancestor before subscribing.
         # When listening with start_time in the past, this should also be retrieved.
@@ -51,6 +63,14 @@ class WebsocketTest(run_test_server.TestCaseWithServers):
             last_log_id=(1 if start_time else None))
         self.assertIsInstance(self.ws, expect_type)
         self.assertEqual(200, events.get(True, 5)['status'])
+
+        if hasattr(self.ws, '_skip_old_events'):
+            # Avoid race by waiting for the first "find ID threshold"
+            # poll to finish.
+            deadline = time.time() + 10
+            while not self.ws._skip_old_events:
+                self.assertLess(time.time(), deadline)
+                time.sleep(0.1)
         human = arvados.api('v1').humans().create(body={}).execute()
 
         want_uuids = []
@@ -63,7 +83,7 @@ class WebsocketTest(run_test_server.TestCaseWithServers):
             log_object_uuids.append(events.get(True, 5)['object_uuid'])
 
         if expected < 2:
-            with self.assertRaises(Queue.Empty):
+            with self.assertRaises(queue.Empty):
                 # assertEqual just serves to show us what unexpected
                 # thing comes out of the queue when the assertRaises
                 # fails; when the test passes, this assertEqual
@@ -89,10 +109,12 @@ class WebsocketTest(run_test_server.TestCaseWithServers):
         error_mock = mock.MagicMock()
         error_mock.resp.status = 0
         error_mock._get_reason.return_value = "testing"
-        api_mock.logs().list().execute.side_effect = (arvados.errors.ApiError(error_mock, ""),
-                                                      {"items": [{"id": 1}], "items_available": 1},
-                                                      arvados.errors.ApiError(error_mock, ""),
-                                                      {"items": [{"id": 1}], "items_available": 1})
+        api_mock.logs().list().execute.side_effect = (
+            arvados.errors.ApiError(error_mock, b""),
+            {"items": [{"id": 1}], "items_available": 1},
+            arvados.errors.ApiError(error_mock, b""),
+            {"items": [{"id": 1}], "items_available": 1},
+        )
         pc = arvados.events.PollClient(api_mock, [], on_ev, 15, None)
         pc.start()
         while len(n) < 2:
@@ -143,18 +165,18 @@ class WebsocketTest(run_test_server.TestCaseWithServers):
         return time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime(t))
 
     def localiso(self, t):
-        return time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime(t)) + self.isotz(-time.timezone/60)
+        return time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime(t)) + self.isotz(-time.timezone//60)
 
     def isotz(self, offset):
         """Convert minutes-east-of-UTC to RFC3339- and ISO-compatible time zone designator"""
-        return '{:+03d}:{:02d}'.format(offset/60, offset%60)
+        return '{:+03d}:{:02d}'.format(offset//60, offset%60)
 
-    # Test websocket reconnection on (un)execpted close
+    # Test websocket reconnection on (un)expected close
     def _test_websocket_reconnect(self, close_unexpected):
         run_test_server.authorize_with('active')
-        events = Queue.Queue(100)
+        events = queue.Queue(100)
 
-        logstream = io.BytesIO()
+        logstream = tutil.StringIO()
         rootLogger = logging.getLogger()
         streamHandler = logging.StreamHandler(logstream)
         rootLogger.addHandler(streamHandler)
@@ -174,7 +196,7 @@ class WebsocketTest(run_test_server.TestCaseWithServers):
 
         # expect an event
         self.assertIn(human['uuid'], events.get(True, 5)['object_uuid'])
-        with self.assertRaises(Queue.Empty):
+        with self.assertRaises(queue.Empty):
             self.assertEqual(events.get(True, 2), None)
 
         # close (im)properly
@@ -193,12 +215,12 @@ class WebsocketTest(run_test_server.TestCaseWithServers):
                 event = events.get(True, 5)
                 if event.get('object_uuid') != None:
                     log_object_uuids.append(event['object_uuid'])
-            with self.assertRaises(Queue.Empty):
+            with self.assertRaises(queue.Empty):
                 self.assertEqual(events.get(True, 2), None)
             self.assertNotIn(human['uuid'], log_object_uuids)
             self.assertIn(human2['uuid'], log_object_uuids)
         else:
-            with self.assertRaises(Queue.Empty):
+            with self.assertRaises(queue.Empty):
                 self.assertEqual(events.get(True, 2), None)
 
         # verify log message to ensure that an (un)expected close
@@ -222,13 +244,13 @@ class WebsocketTest(run_test_server.TestCaseWithServers):
     def test_websocket_reconnect_retry(self, event_client_connect):
         event_client_connect.side_effect = [None, Exception('EventClient.connect error'), None]
 
-        logstream = io.BytesIO()
+        logstream = tutil.StringIO()
         rootLogger = logging.getLogger()
         streamHandler = logging.StreamHandler(logstream)
         rootLogger.addHandler(streamHandler)
 
         run_test_server.authorize_with('active')
-        events = Queue.Queue(100)
+        events = queue.Queue(100)
 
         filters = [['object_uuid', 'is_a', 'arvados#human']]
         self.ws = arvados.events.subscribe(
@@ -291,12 +313,20 @@ class PollClientTestCase(unittest.TestCase):
         def __init__(self):
             self.logs = []
             self.lock = threading.Lock()
+            self.api_called = threading.Event()
 
         def add(self, log):
             with self.lock:
                 self.logs.append(log)
 
         def return_list(self, num_retries=None):
+            self.api_called.set()
+            args, kwargs = self.list_func.call_args_list[-1]
+            filters = kwargs.get('filters', [])
+            if not any(True for f in filters if f[0] == 'id' and f[1] == '>'):
+                # No 'id' filter was given -- this must be the probe
+                # to determine the most recent id.
+                return {'items': [{'id': 1}], 'items_available': 1}
             with self.lock:
                 retval = self.logs
                 self.logs = []
@@ -306,7 +336,12 @@ class PollClientTestCase(unittest.TestCase):
         self.logs = self.MockLogs()
         self.arv = mock.MagicMock(name='arvados.api()')
         self.arv.logs().list().execute.side_effect = self.logs.return_list
-        self.callback_called = threading.Event()
+        # our MockLogs object's "execute" stub will need to inspect
+        # the call history to determine X in
+        # ....logs().list(filters=X).execute():
+        self.logs.list_func = self.arv.logs().list
+        self.status_ok = threading.Event()
+        self.event_received = threading.Event()
         self.recv_events = []
 
     def tearDown(self):
@@ -314,8 +349,11 @@ class PollClientTestCase(unittest.TestCase):
             self.client.close(timeout=None)
 
     def callback(self, event):
-        self.recv_events.append(event)
-        self.callback_called.set()
+        if event.get('status') == 200:
+            self.status_ok.set()
+        else:
+            self.recv_events.append(event)
+            self.event_received.set()
 
     def build_client(self, filters=None, callback=None, last_log_id=None, poll_time=99):
         if filters is None:
@@ -334,37 +372,45 @@ class PollClientTestCase(unittest.TestCase):
         self.logs.add({'id': 123})
         self.build_client(poll_time=.01)
         self.client.start()
-        self.assertTrue(self.callback_called.wait(self.TEST_TIMEOUT))
-        self.callback_called.clear()
+        self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
+        self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
+        self.event_received.clear()
         self.logs.add(test_log.copy())
-        self.assertTrue(self.callback_called.wait(self.TEST_TIMEOUT))
-        self.client.close(timeout=None)
+        self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
         self.assertIn(test_log, self.recv_events)
 
     def test_subscribe(self):
         client_filter = ['kind', '=', 'arvados#test']
         self.build_client()
+        self.client.unsubscribe([])
         self.client.subscribe([client_filter[:]])
         self.client.start()
-        self.assertTrue(self.callback_called.wait(self.TEST_TIMEOUT))
-        self.client.close(timeout=None)
+        self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
+        self.assertTrue(self.logs.api_called.wait(self.TEST_TIMEOUT))
         self.assertTrue(self.was_filter_used(client_filter))
 
     def test_unsubscribe(self):
-        client_filter = ['kind', '=', 'arvados#test']
-        self.build_client()
-        self.client.subscribe([client_filter[:]])
-        self.client.unsubscribe([client_filter[:]])
+        should_filter = ['foo', '=', 'foo']
+        should_not_filter = ['foo', '=', 'bar']
+        self.build_client(poll_time=0.01)
+        self.client.unsubscribe([])
+        self.client.subscribe([should_not_filter[:]])
+        self.client.subscribe([should_filter[:]])
+        self.client.unsubscribe([should_not_filter[:]])
         self.client.start()
-        self.client.close(timeout=None)
-        self.assertFalse(self.was_filter_used(client_filter))
+        self.logs.add({'id': 123})
+        self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
+        self.assertTrue(self.event_received.wait(self.TEST_TIMEOUT))
+        self.assertTrue(self.was_filter_used(should_filter))
+        self.assertFalse(self.was_filter_used(should_not_filter))
 
     def test_run_forever(self):
         self.build_client()
         self.client.start()
         forever_thread = threading.Thread(target=self.client.run_forever)
         forever_thread.start()
-        self.assertTrue(self.callback_called.wait(self.TEST_TIMEOUT))
+        self.assertTrue(self.status_ok.wait(self.TEST_TIMEOUT))
         self.assertTrue(forever_thread.is_alive())
         self.client.close()
         forever_thread.join()
+        del self.client