Merge branch '21773-keep-service-discovery'
[arvados.git] / sdk / python / tests / arvados_testutil.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import arvados
6 import contextlib
7 import errno
8 import hashlib
9 import http.client
10 import httplib2
11 import io
12 import os
13 import pycurl
14 import queue
15 import shutil
16 import sys
17 import tempfile
18 import unittest
19
20 from io import StringIO, BytesIO
21 from unittest import mock
22
23 # Use this hostname when you want to make sure the traffic will be
24 # instantly refused.  100::/64 is a dedicated black hole.
25 TEST_HOST = '100::'
26
27 skip_sleep = mock.patch('time.sleep', lambda n: None)  # clown'll eat me
28
29 def queue_with(items):
30     """Return a thread-safe iterator that yields the given items.
31
32     +items+ can be given as an array or an iterator. If an iterator is
33     given, it will be consumed to fill the queue before queue_with()
34     returns.
35     """
36     q = queue.Queue()
37     for val in items:
38         q.put(val)
39     return lambda *args, **kwargs: q.get(block=False)
40
41 # fake_httplib2_response and mock_responses
42 # mock calls to httplib2.Http.request()
43 def fake_httplib2_response(code, **headers):
44     headers.update(status=str(code),
45                    reason=http.client.responses.get(code, "Unknown Response"))
46     return httplib2.Response(headers)
47
48 def mock_responses(body, *codes, **headers):
49     if not isinstance(body, bytes) and hasattr(body, 'encode'):
50         body = body.encode()
51     return mock.patch('httplib2.Http.request', side_effect=queue_with((
52         (fake_httplib2_response(code, **headers), body) for code in codes)))
53
54 def mock_api_responses(api_client, body, codes, headers={}, method='request'):
55     if not isinstance(body, bytes) and hasattr(body, 'encode'):
56         body = body.encode()
57     return mock.patch.object(api_client._http, method, side_effect=queue_with((
58         (fake_httplib2_response(code, **headers), body) for code in codes)))
59
60 def str_keep_locator(s):
61     return '{}+{}'.format(hashlib.md5(s if isinstance(s, bytes) else s.encode()).hexdigest(), len(s))
62
63 @contextlib.contextmanager
64 def redirected_streams(stdout=None, stderr=None):
65     if stdout == StringIO:
66         stdout = StringIO()
67     if stderr == StringIO:
68         stderr = StringIO()
69     orig_stdout, sys.stdout = sys.stdout, stdout or sys.stdout
70     orig_stderr, sys.stderr = sys.stderr, stderr or sys.stderr
71     try:
72         yield (stdout, stderr)
73     finally:
74         sys.stdout = orig_stdout
75         sys.stderr = orig_stderr
76
77
78 class VersionChecker(object):
79     def assertVersionOutput(self, out, err):
80         self.assertEqual(err.getvalue(), '')
81         v = out.getvalue()
82         self.assertRegex(v, r"[0-9]+\.[0-9]+\.[0-9]+(\.dev[0-9]+)?$\n")
83
84
85 class FakeCurl(object):
86     @classmethod
87     def make(cls, code, body=b'', headers={}):
88         if not isinstance(body, bytes) and hasattr(body, 'encode'):
89             body = body.encode()
90         return mock.Mock(spec=cls, wraps=cls(code, body, headers))
91
92     def __init__(self, code=200, body=b'', headers={}):
93         self._opt = {}
94         self._got_url = None
95         self._writer = None
96         self._headerfunction = None
97         self._resp_code = code
98         self._resp_body = body
99         self._resp_headers = headers
100
101     def getopt(self, opt):
102         return self._opt.get(str(opt), None)
103
104     def setopt(self, opt, val):
105         self._opt[str(opt)] = val
106         if opt == pycurl.WRITEFUNCTION:
107             self._writer = val
108         elif opt == pycurl.HEADERFUNCTION:
109             self._headerfunction = val
110
111     def perform(self):
112         if not isinstance(self._resp_code, int):
113             raise self._resp_code
114         if self.getopt(pycurl.URL) is None:
115             raise ValueError
116         if self._writer is None:
117             raise ValueError
118         if self._headerfunction:
119             self._headerfunction("HTTP/1.1 {} Status".format(self._resp_code))
120             for k, v in self._resp_headers.items():
121                 self._headerfunction(k + ': ' + str(v))
122         if type(self._resp_body) is not bool:
123             self._writer(self._resp_body)
124
125     def close(self):
126         pass
127
128     def reset(self):
129         """Prevent fake UAs from going back into the user agent pool."""
130         raise Exception
131
132     def getinfo(self, opt):
133         if opt == pycurl.RESPONSE_CODE:
134             return self._resp_code
135         raise Exception
136
137
138 def mock_keep_responses(body, *codes, **headers):
139     """Patch pycurl to return fake responses and raise exceptions.
140
141     body can be a string to return as the response body; an exception
142     to raise when perform() is called; or an iterable that returns a
143     sequence of such values.
144     """
145     cm = mock.MagicMock()
146     if isinstance(body, tuple):
147         codes = list(codes)
148         codes.insert(0, body)
149         responses = [
150             FakeCurl.make(code=code, body=b, headers=headers)
151             for b, code in codes
152         ]
153     else:
154         responses = [
155             FakeCurl.make(code=code, body=body, headers=headers)
156             for code in codes
157         ]
158     cm.side_effect = queue_with(responses)
159     cm.responses = responses
160     return mock.patch('pycurl.Curl', cm)
161
162
163 class ApiClientMock(object):
164     def api_client_mock(self):
165         api_mock = mock.MagicMock(name='api_client_mock')
166         api_mock.config.return_value = {
167             'StorageClasses': {
168                 'default': {'Default': True}
169             }
170         }
171         return api_mock
172
173     def mock_keep_services(self, api_mock=None, status=200, count=12,
174                            service_type='disk',
175                            service_host=None,
176                            service_port=None,
177                            service_ssl_flag=False,
178                            additional_services=[],
179                            read_only=False):
180         if api_mock is None:
181             api_mock = self.api_client_mock()
182         body = {
183             'items_available': count,
184             'items': [{
185                 'uuid': 'zzzzz-bi6l4-{:015x}'.format(i),
186                 'owner_uuid': 'zzzzz-tpzed-000000000000000',
187                 'service_host': service_host or 'keep0x{:x}'.format(i),
188                 'service_port': service_port or 65535-i,
189                 'service_ssl_flag': service_ssl_flag,
190                 'service_type': service_type,
191                 'read_only': read_only,
192             } for i in range(0, count)] + additional_services
193         }
194         self._mock_api_call(api_mock.keep_services().accessible, status, body)
195         return api_mock
196
197     def _mock_api_call(self, mock_method, code, body):
198         mock_method = mock_method().execute
199         if code == 200:
200             mock_method.return_value = body
201         else:
202             mock_method.side_effect = arvados.errors.ApiError(
203                 fake_httplib2_response(code), b"{}")
204
205
206 class ArvadosBaseTestCase(unittest.TestCase):
207     # This class provides common utility functions for our tests.
208
209     def setUp(self):
210         self._tempdirs = []
211
212     def tearDown(self):
213         for workdir in self._tempdirs:
214             shutil.rmtree(workdir, ignore_errors=True)
215
216     def make_tmpdir(self):
217         self._tempdirs.append(tempfile.mkdtemp())
218         return self._tempdirs[-1]
219
220     def data_file(self, filename):
221         try:
222             basedir = os.path.dirname(__file__)
223         except NameError:
224             basedir = '.'
225         return open(os.path.join(basedir, 'data', filename))
226
227     def build_directory_tree(self, tree):
228         tree_root = self.make_tmpdir()
229         for leaf in tree:
230             path = os.path.join(tree_root, leaf)
231             try:
232                 os.makedirs(os.path.dirname(path))
233             except OSError as error:
234                 if error.errno != errno.EEXIST:
235                     raise
236             with open(path, 'w') as tmpfile:
237                 tmpfile.write(leaf)
238         return tree_root
239
240     def make_test_file(self, text=b"test"):
241         testfile = tempfile.NamedTemporaryFile()
242         testfile.write(text)
243         testfile.flush()
244         return testfile
245
246 def binary_compare(a, b):
247     if len(a) != len(b):
248         return False
249     for i in range(0, len(a)):
250         if a[i] != b[i]:
251             return False
252     return True
253
254 class DiskCacheBase:
255     def make_block_cache(self, disk_cache):
256         self.disk_cache_dir = tempfile.mkdtemp() if disk_cache else None
257         block_cache = arvados.keep.KeepBlockCache(disk_cache=disk_cache,
258                                                   disk_cache_dir=self.disk_cache_dir)
259         return block_cache
260
261     def tearDown(self):
262         if self.disk_cache_dir:
263             shutil.rmtree(self.disk_cache_dir)