Merge branch '19316-oj-safe-load'
[arvados.git] / sdk / python / tests / arvados_testutil.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import str
8 from builtins import range
9 from builtins import object
10 import arvados
11 import contextlib
12 import errno
13 import hashlib
14 import http.client
15 import httplib2
16 import io
17 import mock
18 import os
19 import pycurl
20 import queue
21 import shutil
22 import sys
23 import tempfile
24 import unittest
25
26 if sys.version_info >= (3, 0):
27     from io import StringIO, BytesIO
28 else:
29     from cStringIO import StringIO
30     BytesIO = StringIO
31
32 # Use this hostname when you want to make sure the traffic will be
33 # instantly refused.  100::/64 is a dedicated black hole.
34 TEST_HOST = '100::'
35
36 skip_sleep = mock.patch('time.sleep', lambda n: None)  # clown'll eat me
37
38 def queue_with(items):
39     """Return a thread-safe iterator that yields the given items.
40
41     +items+ can be given as an array or an iterator. If an iterator is
42     given, it will be consumed to fill the queue before queue_with()
43     returns.
44     """
45     q = queue.Queue()
46     for val in items:
47         q.put(val)
48     return lambda *args, **kwargs: q.get(block=False)
49
50 # fake_httplib2_response and mock_responses
51 # mock calls to httplib2.Http.request()
52 def fake_httplib2_response(code, **headers):
53     headers.update(status=str(code),
54                    reason=http.client.responses.get(code, "Unknown Response"))
55     return httplib2.Response(headers)
56
57 def mock_responses(body, *codes, **headers):
58     if not isinstance(body, bytes) and hasattr(body, 'encode'):
59         body = body.encode()
60     return mock.patch('httplib2.Http.request', side_effect=queue_with((
61         (fake_httplib2_response(code, **headers), body) for code in codes)))
62
63 def mock_api_responses(api_client, body, codes, headers={}):
64     if not isinstance(body, bytes) and hasattr(body, 'encode'):
65         body = body.encode()
66     return mock.patch.object(api_client._http, 'request', side_effect=queue_with((
67         (fake_httplib2_response(code, **headers), body) for code in codes)))
68
69 def str_keep_locator(s):
70     return '{}+{}'.format(hashlib.md5(s if isinstance(s, bytes) else s.encode()).hexdigest(), len(s))
71
72 @contextlib.contextmanager
73 def redirected_streams(stdout=None, stderr=None):
74     if stdout == StringIO:
75         stdout = StringIO()
76     if stderr == StringIO:
77         stderr = StringIO()
78     orig_stdout, sys.stdout = sys.stdout, stdout or sys.stdout
79     orig_stderr, sys.stderr = sys.stderr, stderr or sys.stderr
80     try:
81         yield (stdout, stderr)
82     finally:
83         sys.stdout = orig_stdout
84         sys.stderr = orig_stderr
85
86
87 class VersionChecker(object):
88     def assertVersionOutput(self, out, err):
89         if sys.version_info >= (3, 0):
90             self.assertEqual(err.getvalue(), '')
91             v = out.getvalue()
92         else:
93             # Python 2 writes version info on stderr.
94             self.assertEqual(out.getvalue(), '')
95             v = err.getvalue()
96         self.assertRegex(v, r"[0-9]+\.[0-9]+\.[0-9]+(\.dev[0-9]+)?$\n")
97
98
99 class FakeCurl(object):
100     @classmethod
101     def make(cls, code, body=b'', headers={}):
102         if not isinstance(body, bytes) and hasattr(body, 'encode'):
103             body = body.encode()
104         return mock.Mock(spec=cls, wraps=cls(code, body, headers))
105
106     def __init__(self, code=200, body=b'', headers={}):
107         self._opt = {}
108         self._got_url = None
109         self._writer = None
110         self._headerfunction = None
111         self._resp_code = code
112         self._resp_body = body
113         self._resp_headers = headers
114
115     def getopt(self, opt):
116         return self._opt.get(str(opt), None)
117
118     def setopt(self, opt, val):
119         self._opt[str(opt)] = val
120         if opt == pycurl.WRITEFUNCTION:
121             self._writer = val
122         elif opt == pycurl.HEADERFUNCTION:
123             self._headerfunction = val
124
125     def perform(self):
126         if not isinstance(self._resp_code, int):
127             raise self._resp_code
128         if self.getopt(pycurl.URL) is None:
129             raise ValueError
130         if self._writer is None:
131             raise ValueError
132         if self._headerfunction:
133             self._headerfunction("HTTP/1.1 {} Status".format(self._resp_code))
134             for k, v in self._resp_headers.items():
135                 self._headerfunction(k + ': ' + str(v))
136         if type(self._resp_body) is not bool:
137             self._writer(self._resp_body)
138
139     def close(self):
140         pass
141
142     def reset(self):
143         """Prevent fake UAs from going back into the user agent pool."""
144         raise Exception
145
146     def getinfo(self, opt):
147         if opt == pycurl.RESPONSE_CODE:
148             return self._resp_code
149         raise Exception
150
151 def mock_keep_responses(body, *codes, **headers):
152     """Patch pycurl to return fake responses and raise exceptions.
153
154     body can be a string to return as the response body; an exception
155     to raise when perform() is called; or an iterable that returns a
156     sequence of such values.
157     """
158     cm = mock.MagicMock()
159     if isinstance(body, tuple):
160         codes = list(codes)
161         codes.insert(0, body)
162         responses = [
163             FakeCurl.make(code=code, body=b, headers=headers)
164             for b, code in codes
165         ]
166     else:
167         responses = [
168             FakeCurl.make(code=code, body=body, headers=headers)
169             for code in codes
170         ]
171     cm.side_effect = queue_with(responses)
172     cm.responses = responses
173     return mock.patch('pycurl.Curl', cm)
174
175
176 class MockStreamReader(object):
177     def __init__(self, name='.', *data):
178         self._name = name
179         self._data = b''.join([
180             b if isinstance(b, bytes) else b.encode()
181             for b in data])
182         self._data_locators = [str_keep_locator(d) for d in data]
183         self.num_retries = 0
184
185     def name(self):
186         return self._name
187
188     def readfrom(self, start, size, num_retries=None):
189         return self._data[start:start + size]
190
191 class ApiClientMock(object):
192     def api_client_mock(self):
193         api_mock = mock.MagicMock(name='api_client_mock')
194         api_mock.config.return_value = {
195             'StorageClasses': {
196                 'default': {'Default': True}
197             }
198         }
199         return api_mock
200
201     def mock_keep_services(self, api_mock=None, status=200, count=12,
202                            service_type='disk',
203                            service_host=None,
204                            service_port=None,
205                            service_ssl_flag=False,
206                            additional_services=[],
207                            read_only=False):
208         if api_mock is None:
209             api_mock = self.api_client_mock()
210         body = {
211             'items_available': count,
212             'items': [{
213                 'uuid': 'zzzzz-bi6l4-{:015x}'.format(i),
214                 'owner_uuid': 'zzzzz-tpzed-000000000000000',
215                 'service_host': service_host or 'keep0x{:x}'.format(i),
216                 'service_port': service_port or 65535-i,
217                 'service_ssl_flag': service_ssl_flag,
218                 'service_type': service_type,
219                 'read_only': read_only,
220             } for i in range(0, count)] + additional_services
221         }
222         self._mock_api_call(api_mock.keep_services().accessible, status, body)
223         return api_mock
224
225     def _mock_api_call(self, mock_method, code, body):
226         mock_method = mock_method().execute
227         if code == 200:
228             mock_method.return_value = body
229         else:
230             mock_method.side_effect = arvados.errors.ApiError(
231                 fake_httplib2_response(code), b"{}")
232
233
234 class ArvadosBaseTestCase(unittest.TestCase):
235     # This class provides common utility functions for our tests.
236
237     def setUp(self):
238         self._tempdirs = []
239
240     def tearDown(self):
241         for workdir in self._tempdirs:
242             shutil.rmtree(workdir, ignore_errors=True)
243
244     def make_tmpdir(self):
245         self._tempdirs.append(tempfile.mkdtemp())
246         return self._tempdirs[-1]
247
248     def data_file(self, filename):
249         try:
250             basedir = os.path.dirname(__file__)
251         except NameError:
252             basedir = '.'
253         return open(os.path.join(basedir, 'data', filename))
254
255     def build_directory_tree(self, tree):
256         tree_root = self.make_tmpdir()
257         for leaf in tree:
258             path = os.path.join(tree_root, leaf)
259             try:
260                 os.makedirs(os.path.dirname(path))
261             except OSError as error:
262                 if error.errno != errno.EEXIST:
263                     raise
264             with open(path, 'w') as tmpfile:
265                 tmpfile.write(leaf)
266         return tree_root
267
268     def make_test_file(self, text=b"test"):
269         testfile = tempfile.NamedTemporaryFile()
270         testfile.write(text)
271         testfile.flush()
272         return testfile
273
274 if sys.version_info < (3, 0):
275     # There is no assert[Not]Regex that works in both Python 2 and 3,
276     # so we backport Python 3 style to Python 2.
277     def assertRegex(self, *args, **kwargs):
278         return self.assertRegexpMatches(*args, **kwargs)
279     def assertNotRegex(self, *args, **kwargs):
280         return self.assertNotRegexpMatches(*args, **kwargs)
281     unittest.TestCase.assertRegex = assertRegex
282     unittest.TestCase.assertNotRegex = assertNotRegex