Merge branch '21659-gh-workflow-tests' into main. Closes #21659
[arvados.git] / sdk / python / tests / test_api.py
index 20c4f346a9363f501e6ed305eb3a48aa7612c9e1..0f85e5520c821dcaa7bf6690e7702cb857e3ac54 100644 (file)
@@ -7,13 +7,16 @@ from builtins import str
 from builtins import range
 import arvados
 import collections
+import contextlib
 import httplib2
 import itertools
 import json
+import logging
 import mimetypes
 import os
 import socket
 import string
+import sys
 import unittest
 import urllib.parse as urlparse
 
@@ -27,11 +30,10 @@ from arvados.api import (
     normalize_api_kwargs,
     api_kwargs_from_config,
     OrderedJsonModel,
-    RETRY_DELAY_INITIAL,
-    RETRY_DELAY_BACKOFF,
-    RETRY_COUNT,
+    _googleapiclient_log_lock,
 )
-from .arvados_testutil import fake_httplib2_response, queue_with
+from .arvados_testutil import fake_httplib2_response, mock_api_responses, queue_with
+import httplib2.error
 
 if not mimetypes.inited:
     mimetypes.init()
@@ -39,6 +41,7 @@ if not mimetypes.inited:
 class ArvadosApiTest(run_test_server.TestCaseWithServers):
     MAIN_SERVER = {}
     ERROR_HEADERS = {'Content-Type': mimetypes.types_map['.json']}
+    RETRIED_4XX = frozenset([408, 409, 423])
 
     def api_error_response(self, code, *errors):
         return (fake_httplib2_response(code, **self.ERROR_HEADERS),
@@ -150,6 +153,57 @@ class ArvadosApiTest(run_test_server.TestCaseWithServers):
         self.assertEqual(api._http.timeout, 1234,
             "Requested timeout value was 1234")
 
+    def test_4xx_retried(self):
+        client = arvados.api('v1')
+        for code in self.RETRIED_4XX:
+            name = f'retried #{code}'
+            with self.subTest(name), mock.patch('time.sleep'):
+                expected = {'username': name}
+                with mock_api_responses(
+                        client,
+                        json.dumps(expected),
+                        [code, code, 200],
+                        self.ERROR_HEADERS,
+                        'orig_http_request',
+                ):
+                    actual = client.users().current().execute()
+                self.assertEqual(actual, expected)
+
+    def test_4xx_not_retried(self):
+        client = arvados.api('v1', num_retries=3)
+        # Note that googleapiclient does retry 403 *if* the response JSON
+        # includes flags that say the request was denied by rate limiting.
+        # An empty JSON response like we use here should not be retried.
+        for code in [400, 401, 403, 404, 422]:
+            with self.subTest(f'error {code}'), mock.patch('time.sleep'):
+                with mock_api_responses(
+                        client,
+                        b'{}',
+                        [code, 200],
+                        self.ERROR_HEADERS,
+                        'orig_http_request',
+                ), self.assertRaises(arvados.errors.ApiError) as exc_check:
+                    client.users().current().execute()
+                response = exc_check.exception.args[0]
+                self.assertEqual(response.status, code)
+                self.assertEqual(response.get('status'), str(code))
+
+    def test_4xx_raised_after_retry_exhaustion(self):
+        client = arvados.api('v1', num_retries=1)
+        for code in self.RETRIED_4XX:
+            with self.subTest(f'failed {code}'), mock.patch('time.sleep'):
+                with mock_api_responses(
+                        client,
+                        b'{}',
+                        [code, code, code, 200],
+                        self.ERROR_HEADERS,
+                        'orig_http_request',
+                ), self.assertRaises(arvados.errors.ApiError) as exc_check:
+                    client.users().current().execute()
+                response = exc_check.exception.args[0]
+                self.assertEqual(response.status, code)
+                self.assertEqual(response.get('status'), str(code))
+
     def test_ordered_json_model(self):
         mock_responses = {
             'arvados.humans.get': (
@@ -340,8 +394,119 @@ class ArvadosApiTest(run_test_server.TestCaseWithServers):
                 args[arg_index] = arg_value
                 api_client(*args, insecure=True)
 
-
-class RetryREST(unittest.TestCase):
+    def test_initial_retry_logs(self):
+        try:
+            _googleapiclient_log_lock.release()
+        except RuntimeError:
+            # Lock was never acquired - that's the state we want anyway
+            pass
+        real_logger = logging.getLogger('googleapiclient.http')
+        mock_logger = mock.Mock(wraps=real_logger)
+        mock_logger.handlers = logging.getLogger('googleapiclient').handlers
+        mock_logger.level = logging.NOTSET
+        with mock.patch('logging.getLogger', return_value=mock_logger), \
+             mock.patch('time.sleep'), \
+             self.assertLogs(real_logger, 'INFO') as actual_logs:
+            try:
+                api_client('v1', 'https://test.invalid/', 'NoToken', num_retries=1)
+            except httplib2.error.ServerNotFoundError:
+                pass
+        mock_logger.addFilter.assert_called()
+        mock_logger.addHandler.assert_called()
+        mock_logger.setLevel.assert_called()
+        mock_logger.removeHandler.assert_called()
+        mock_logger.removeFilter.assert_called()
+        self.assertRegex(actual_logs.output[0], r'^INFO:googleapiclient\.http:Sleeping \d')
+
+    def test_configured_logger_untouched(self):
+        real_logger = logging.getLogger('googleapiclient.http')
+        mock_logger = mock.Mock(wraps=real_logger)
+        mock_logger.handlers = logging.getLogger().handlers
+        with mock.patch('logging.getLogger', return_value=mock_logger), \
+             mock.patch('time.sleep'):
+            try:
+                api_client('v1', 'https://test.invalid/', 'NoToken', num_retries=1)
+            except httplib2.error.ServerNotFoundError:
+                pass
+        mock_logger.addFilter.assert_not_called()
+        mock_logger.addHandler.assert_not_called()
+        mock_logger.setLevel.assert_not_called()
+        mock_logger.removeHandler.assert_not_called()
+        mock_logger.removeFilter.assert_not_called()
+
+
+class ConstructNumRetriesTestCase(unittest.TestCase):
+    @staticmethod
+    def _fake_retry_request(http, num_retries, req_type, sleep, rand, uri, method, *args, **kwargs):
+        return http.request(uri, method, *args, **kwargs)
+
+    @contextlib.contextmanager
+    def patch_retry(self):
+        # We have this dedicated context manager that goes through `sys.modules`
+        # instead of just using `mock.patch` because of the unfortunate
+        # `arvados.api` name collision.
+        orig_func = sys.modules['arvados.api']._orig_retry_request
+        expect_name = 'googleapiclient.http._retry_request'
+        self.assertEqual(
+            '{0.__module__}.{0.__name__}'.format(orig_func), expect_name,
+            f"test setup problem: {expect_name} not at arvados.api._orig_retry_request",
+        )
+        retry_mock = mock.Mock(wraps=self._fake_retry_request)
+        sys.modules['arvados.api']._orig_retry_request = retry_mock
+        try:
+            yield retry_mock
+        finally:
+            sys.modules['arvados.api']._orig_retry_request = orig_func
+
+    def _iter_num_retries(self, retry_mock):
+        for call in retry_mock.call_args_list:
+            try:
+                yield call.args[1]
+            except IndexError:
+                yield call.kwargs['num_retries']
+
+    def test_default_num_retries(self):
+        with self.patch_retry() as retry_mock:
+            client = arvados.api('v1')
+        actual = set(self._iter_num_retries(retry_mock))
+        self.assertEqual(len(actual), 1)
+        self.assertTrue(actual.pop() > 6, "num_retries lower than expected")
+
+    def _test_calls(self, init_arg, call_args, expected):
+        with self.patch_retry() as retry_mock:
+            client = arvados.api('v1', num_retries=init_arg)
+            for num_retries in call_args:
+                client.users().current().execute(num_retries=num_retries)
+        actual = self._iter_num_retries(retry_mock)
+        # The constructor makes two requests with its num_retries argument:
+        # one for the discovery document, and one for the config.
+        self.assertEqual(next(actual, None), init_arg)
+        self.assertEqual(next(actual, None), init_arg)
+        self.assertEqual(list(actual), expected)
+
+    def test_discovery_num_retries(self):
+        for num_retries in [0, 5, 55]:
+            with self.subTest(f"num_retries={num_retries}"):
+                self._test_calls(num_retries, [], [])
+
+    def test_num_retries_called_le_init(self):
+        for n in [6, 10]:
+            with self.subTest(f"init_arg={n}"):
+                call_args = [n - 4, n - 2, n]
+                expected = [n] * 3
+                self._test_calls(n, call_args, expected)
+
+    def test_num_retries_called_ge_init(self):
+        for n in [0, 10]:
+            with self.subTest(f"init_arg={n}"):
+                call_args = [n, n + 4, n + 8]
+                self._test_calls(n, call_args, call_args)
+
+    def test_num_retries_called_mixed(self):
+        self._test_calls(5, [2, 6, 4, 8], [5, 6, 5, 8])
+
+
+class PreCloseSocketTestCase(unittest.TestCase):
     def setUp(self):
         self.api = arvados.api('v1')
         self.assertTrue(hasattr(self.api._http, 'orig_http_request'),
@@ -353,59 +518,6 @@ class RetryREST(unittest.TestCase):
         # All requests succeed by default. Tests override as needed.
         self.api._http.orig_http_request.return_value = self.request_success
 
-    @mock.patch('time.sleep')
-    def test_socket_error_retry_get(self, sleep):
-        self.api._http.orig_http_request.side_effect = (
-            socket.error('mock error'),
-            self.request_success,
-        )
-        self.assertEqual(self.api.users().current().execute(),
-                         self.mock_response)
-        self.assertGreater(self.api._http.orig_http_request.call_count, 1,
-                           "client got the right response without retrying")
-        self.assertEqual(sleep.call_args_list,
-                         [mock.call(RETRY_DELAY_INITIAL)])
-
-    @mock.patch('time.sleep')
-    def test_same_automatic_request_id_on_retry(self, sleep):
-        self.api._http.orig_http_request.side_effect = (
-            socket.error('mock error'),
-            self.request_success,
-        )
-        self.api.users().current().execute()
-        calls = self.api._http.orig_http_request.call_args_list
-        self.assertEqual(len(calls), 2)
-        self.assertEqual(
-            calls[0][1]['headers']['X-Request-Id'],
-            calls[1][1]['headers']['X-Request-Id'])
-        self.assertRegex(calls[0][1]['headers']['X-Request-Id'], r'^req-[a-z0-9]{20}$')
-
-    @mock.patch('time.sleep')
-    def test_provided_request_id_on_retry(self, sleep):
-        self.api.request_id='fake-request-id'
-        self.api._http.orig_http_request.side_effect = (
-            socket.error('mock error'),
-            self.request_success,
-        )
-        self.api.users().current().execute()
-        calls = self.api._http.orig_http_request.call_args_list
-        self.assertEqual(len(calls), 2)
-        for call in calls:
-            self.assertEqual(call[1]['headers']['X-Request-Id'], 'fake-request-id')
-
-    @mock.patch('time.sleep')
-    def test_socket_error_retry_delay(self, sleep):
-        self.api._http.orig_http_request.side_effect = socket.error('mock')
-        self.api._http._retry_count = 3
-        with self.assertRaises(socket.error):
-            self.api.users().current().execute()
-        self.assertEqual(self.api._http.orig_http_request.call_count, 4)
-        self.assertEqual(sleep.call_args_list, [
-            mock.call(RETRY_DELAY_INITIAL),
-            mock.call(RETRY_DELAY_INITIAL * RETRY_DELAY_BACKOFF),
-            mock.call(RETRY_DELAY_INITIAL * RETRY_DELAY_BACKOFF**2),
-        ])
-
     @mock.patch('time.time', side_effect=[i*2**20 for i in range(99)])
     def test_close_old_connections_non_retryable(self, sleep):
         self._test_connection_close(expect=1)
@@ -429,18 +541,6 @@ class RetryREST(unittest.TestCase):
         for c in mock_conns.values():
             self.assertEqual(c.close.call_count, expect)
 
-    @mock.patch('time.sleep')
-    def test_socket_error_no_retry_post(self, sleep):
-        self.api._http.orig_http_request.side_effect = (
-            socket.error('mock error'),
-            self.request_success,
-        )
-        with self.assertRaises(socket.error):
-            self.api.users().create(body={}).execute()
-        self.assertEqual(self.api._http.orig_http_request.call_count, 1,
-                         "client should try non-retryable method exactly once")
-        self.assertEqual(sleep.call_args_list, [])
-
 
 if __name__ == '__main__':
     unittest.main()