15954: Merge branch 'master'
[arvados.git] / sdk / python / tests / arvados_testutil.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import str
8 from builtins import range
9 from builtins import object
10 import arvados
11 import contextlib
12 import errno
13 import hashlib
14 import http.client
15 import httplib2
16 import io
17 import mock
18 import os
19 import pycurl
20 import queue
21 import shutil
22 import sys
23 import tempfile
24 import unittest
25
26 if sys.version_info >= (3, 0):
27     from io import StringIO, BytesIO
28 else:
29     from cStringIO import StringIO
30     BytesIO = StringIO
31
32 # Use this hostname when you want to make sure the traffic will be
33 # instantly refused.  100::/64 is a dedicated black hole.
34 TEST_HOST = '100::'
35
36 skip_sleep = mock.patch('time.sleep', lambda n: None)  # clown'll eat me
37
38 def queue_with(items):
39     """Return a thread-safe iterator that yields the given items.
40
41     +items+ can be given as an array or an iterator. If an iterator is
42     given, it will be consumed to fill the queue before queue_with()
43     returns.
44     """
45     q = queue.Queue()
46     for val in items:
47         q.put(val)
48     return lambda *args, **kwargs: q.get(block=False)
49
50 # fake_httplib2_response and mock_responses
51 # mock calls to httplib2.Http.request()
52 def fake_httplib2_response(code, **headers):
53     headers.update(status=str(code),
54                    reason=http.client.responses.get(code, "Unknown Response"))
55     return httplib2.Response(headers)
56
57 def mock_responses(body, *codes, **headers):
58     if not isinstance(body, bytes) and hasattr(body, 'encode'):
59         body = body.encode()
60     return mock.patch('httplib2.Http.request', side_effect=queue_with((
61         (fake_httplib2_response(code, **headers), body) for code in codes)))
62
63 def mock_api_responses(api_client, body, codes, headers={}):
64     if not isinstance(body, bytes) and hasattr(body, 'encode'):
65         body = body.encode()
66     return mock.patch.object(api_client._http, 'request', side_effect=queue_with((
67         (fake_httplib2_response(code, **headers), body) for code in codes)))
68
69 def str_keep_locator(s):
70     return '{}+{}'.format(hashlib.md5(s if isinstance(s, bytes) else s.encode()).hexdigest(), len(s))
71
72 @contextlib.contextmanager
73 def redirected_streams(stdout=None, stderr=None):
74     if stdout == StringIO:
75         stdout = StringIO()
76     if stderr == StringIO:
77         stderr = StringIO()
78     orig_stdout, sys.stdout = sys.stdout, stdout or sys.stdout
79     orig_stderr, sys.stderr = sys.stderr, stderr or sys.stderr
80     try:
81         yield (stdout, stderr)
82     finally:
83         sys.stdout = orig_stdout
84         sys.stderr = orig_stderr
85
86
87 class VersionChecker(object):
88     def assertVersionOutput(self, out, err):
89         if sys.version_info >= (3, 0):
90             self.assertEqual(err.getvalue(), '')
91             v = out.getvalue()
92         else:
93             # Python 2 writes version info on stderr.
94             self.assertEqual(out.getvalue(), '')
95             v = err.getvalue()
96         self.assertRegex(v, r"[0-9]+\.[0-9]+\.[0-9]+(\.dev[0-9]+)?$\n")
97
98
99 class FakeCurl(object):
100     @classmethod
101     def make(cls, code, body=b'', headers={}):
102         if not isinstance(body, bytes) and hasattr(body, 'encode'):
103             body = body.encode()
104         return mock.Mock(spec=cls, wraps=cls(code, body, headers))
105
106     def __init__(self, code=200, body=b'', headers={}):
107         self._opt = {}
108         self._got_url = None
109         self._writer = None
110         self._headerfunction = None
111         self._resp_code = code
112         self._resp_body = body
113         self._resp_headers = headers
114
115     def getopt(self, opt):
116         return self._opt.get(str(opt), None)
117
118     def setopt(self, opt, val):
119         self._opt[str(opt)] = val
120         if opt == pycurl.WRITEFUNCTION:
121             self._writer = val
122         elif opt == pycurl.HEADERFUNCTION:
123             self._headerfunction = val
124
125     def perform(self):
126         if not isinstance(self._resp_code, int):
127             raise self._resp_code
128         if self.getopt(pycurl.URL) is None:
129             raise ValueError
130         if self._writer is None:
131             raise ValueError
132         if self._headerfunction:
133             self._headerfunction("HTTP/1.1 {} Status".format(self._resp_code))
134             for k, v in self._resp_headers.items():
135                 self._headerfunction(k + ': ' + str(v))
136         if type(self._resp_body) is not bool:
137             self._writer(self._resp_body)
138
139     def close(self):
140         pass
141
142     def reset(self):
143         """Prevent fake UAs from going back into the user agent pool."""
144         raise Exception
145
146     def getinfo(self, opt):
147         if opt == pycurl.RESPONSE_CODE:
148             return self._resp_code
149         raise Exception
150
151 def mock_keep_responses(body, *codes, **headers):
152     """Patch pycurl to return fake responses and raise exceptions.
153
154     body can be a string to return as the response body; an exception
155     to raise when perform() is called; or an iterable that returns a
156     sequence of such values.
157     """
158     cm = mock.MagicMock()
159     if isinstance(body, tuple):
160         codes = list(codes)
161         codes.insert(0, body)
162         responses = [
163             FakeCurl.make(code=code, body=b, headers=headers)
164             for b, code in codes
165         ]
166     else:
167         responses = [
168             FakeCurl.make(code=code, body=body, headers=headers)
169             for code in codes
170         ]
171     cm.side_effect = queue_with(responses)
172     cm.responses = responses
173     return mock.patch('pycurl.Curl', cm)
174
175
176 class MockStreamReader(object):
177     def __init__(self, name='.', *data):
178         self._name = name
179         self._data = b''.join([
180             b if isinstance(b, bytes) else b.encode()
181             for b in data])
182         self._data_locators = [str_keep_locator(d) for d in data]
183         self.num_retries = 0
184
185     def name(self):
186         return self._name
187
188     def readfrom(self, start, size, num_retries=None):
189         return self._data[start:start + size]
190
191 class ApiClientMock(object):
192     def api_client_mock(self):
193         return mock.MagicMock(name='api_client_mock')
194
195     def mock_keep_services(self, api_mock=None, status=200, count=12,
196                            service_type='disk',
197                            service_host=None,
198                            service_port=None,
199                            service_ssl_flag=False,
200                            additional_services=[],
201                            read_only=False):
202         if api_mock is None:
203             api_mock = self.api_client_mock()
204         body = {
205             'items_available': count,
206             'items': [{
207                 'uuid': 'zzzzz-bi6l4-{:015x}'.format(i),
208                 'owner_uuid': 'zzzzz-tpzed-000000000000000',
209                 'service_host': service_host or 'keep0x{:x}'.format(i),
210                 'service_port': service_port or 65535-i,
211                 'service_ssl_flag': service_ssl_flag,
212                 'service_type': service_type,
213                 'read_only': read_only,
214             } for i in range(0, count)] + additional_services
215         }
216         self._mock_api_call(api_mock.keep_services().accessible, status, body)
217         return api_mock
218
219     def _mock_api_call(self, mock_method, code, body):
220         mock_method = mock_method().execute
221         if code == 200:
222             mock_method.return_value = body
223         else:
224             mock_method.side_effect = arvados.errors.ApiError(
225                 fake_httplib2_response(code), b"{}")
226
227
228 class ArvadosBaseTestCase(unittest.TestCase):
229     # This class provides common utility functions for our tests.
230
231     def setUp(self):
232         self._tempdirs = []
233
234     def tearDown(self):
235         for workdir in self._tempdirs:
236             shutil.rmtree(workdir, ignore_errors=True)
237
238     def make_tmpdir(self):
239         self._tempdirs.append(tempfile.mkdtemp())
240         return self._tempdirs[-1]
241
242     def data_file(self, filename):
243         try:
244             basedir = os.path.dirname(__file__)
245         except NameError:
246             basedir = '.'
247         return open(os.path.join(basedir, 'data', filename))
248
249     def build_directory_tree(self, tree):
250         tree_root = self.make_tmpdir()
251         for leaf in tree:
252             path = os.path.join(tree_root, leaf)
253             try:
254                 os.makedirs(os.path.dirname(path))
255             except OSError as error:
256                 if error.errno != errno.EEXIST:
257                     raise
258             with open(path, 'w') as tmpfile:
259                 tmpfile.write(leaf)
260         return tree_root
261
262     def make_test_file(self, text=b"test"):
263         testfile = tempfile.NamedTemporaryFile()
264         testfile.write(text)
265         testfile.flush()
266         return testfile
267
268 if sys.version_info < (3, 0):
269     # There is no assert[Not]Regex that works in both Python 2 and 3,
270     # so we backport Python 3 style to Python 2.
271     def assertRegex(self, *args, **kwargs):
272         return self.assertRegexpMatches(*args, **kwargs)
273     def assertNotRegex(self, *args, **kwargs):
274         return self.assertNotRegexpMatches(*args, **kwargs)
275     unittest.TestCase.assertRegex = assertRegex
276     unittest.TestCase.assertNotRegex = assertNotRegex