11308: Merge branch 'master' into 11308-python3
[arvados.git] / sdk / python / tests / arvados_testutil.py
1 from future import standard_library
2 standard_library.install_aliases()
3 from builtins import str
4 from builtins import range
5 from builtins import object
6 import arvados
7 import contextlib
8 import errno
9 import hashlib
10 import http.client
11 import httplib2
12 import io
13 import mock
14 import os
15 import pycurl
16 import queue
17 import shutil
18 import sys
19 import tempfile
20 import unittest
21
22 if sys.version_info >= (3, 0):
23     from io import StringIO
24 else:
25     from cStringIO import StringIO
26
27 # Use this hostname when you want to make sure the traffic will be
28 # instantly refused.  100::/64 is a dedicated black hole.
29 TEST_HOST = '100::'
30
31 skip_sleep = mock.patch('time.sleep', lambda n: None)  # clown'll eat me
32
33 def queue_with(items):
34     """Return a thread-safe iterator that yields the given items.
35
36     +items+ can be given as an array or an iterator. If an iterator is
37     given, it will be consumed to fill the queue before queue_with()
38     returns.
39     """
40     q = queue.Queue()
41     for val in items:
42         q.put(val)
43     return lambda *args, **kwargs: q.get(block=False)
44
45 # fake_httplib2_response and mock_responses
46 # mock calls to httplib2.Http.request()
47 def fake_httplib2_response(code, **headers):
48     headers.update(status=str(code),
49                    reason=http.client.responses.get(code, "Unknown Response"))
50     return httplib2.Response(headers)
51
52 def mock_responses(body, *codes, **headers):
53     if not isinstance(body, bytes) and hasattr(body, 'encode'):
54         body = body.encode()
55     return mock.patch('httplib2.Http.request', side_effect=queue_with((
56         (fake_httplib2_response(code, **headers), body) for code in codes)))
57
58 def mock_api_responses(api_client, body, codes, headers={}):
59     if not isinstance(body, bytes) and hasattr(body, 'encode'):
60         body = body.encode()
61     return mock.patch.object(api_client._http, 'request', side_effect=queue_with((
62         (fake_httplib2_response(code, **headers), body) for code in codes)))
63
64 def str_keep_locator(s):
65     return '{}+{}'.format(hashlib.md5(s if isinstance(s, bytes) else s.encode()).hexdigest(), len(s))
66
67 @contextlib.contextmanager
68 def redirected_streams(stdout=None, stderr=None):
69     if stdout == StringIO:
70         stdout = StringIO()
71     if stderr == StringIO:
72         stderr = StringIO()
73     orig_stdout, sys.stdout = sys.stdout, stdout or sys.stdout
74     orig_stderr, sys.stderr = sys.stderr, stderr or sys.stderr
75     try:
76         yield (stdout, stderr)
77     finally:
78         sys.stdout = orig_stdout
79         sys.stderr = orig_stderr
80
81
82 class VersionChecker(object):
83     def assertVersionOutput(self, out, err):
84         if sys.version_info >= (3, 0):
85             self.assertEqual(err.getvalue(), '')
86             v = out.getvalue()
87         else:
88             # Python 2 writes version info on stderr.
89             self.assertEqual(out.getvalue(), '')
90             v = err.getvalue()
91         self.assertRegex(v, r"[0-9]+\.[0-9]+\.[0-9]+$\n")
92
93
94 class FakeCurl(object):
95     @classmethod
96     def make(cls, code, body=b'', headers={}):
97         if not isinstance(body, bytes) and hasattr(body, 'encode'):
98             body = body.encode()
99         return mock.Mock(spec=cls, wraps=cls(code, body, headers))
100
101     def __init__(self, code=200, body=b'', headers={}):
102         self._opt = {}
103         self._got_url = None
104         self._writer = None
105         self._headerfunction = None
106         self._resp_code = code
107         self._resp_body = body
108         self._resp_headers = headers
109
110     def getopt(self, opt):
111         return self._opt.get(str(opt), None)
112
113     def setopt(self, opt, val):
114         self._opt[str(opt)] = val
115         if opt == pycurl.WRITEFUNCTION:
116             self._writer = val
117         elif opt == pycurl.HEADERFUNCTION:
118             self._headerfunction = val
119
120     def perform(self):
121         if not isinstance(self._resp_code, int):
122             raise self._resp_code
123         if self.getopt(pycurl.URL) is None:
124             raise ValueError
125         if self._writer is None:
126             raise ValueError
127         if self._headerfunction:
128             self._headerfunction("HTTP/1.1 {} Status".format(self._resp_code))
129             for k, v in self._resp_headers.items():
130                 self._headerfunction(k + ': ' + str(v))
131         if type(self._resp_body) is not bool:
132             self._writer(self._resp_body)
133
134     def close(self):
135         pass
136
137     def reset(self):
138         """Prevent fake UAs from going back into the user agent pool."""
139         raise Exception
140
141     def getinfo(self, opt):
142         if opt == pycurl.RESPONSE_CODE:
143             return self._resp_code
144         raise Exception
145
146 def mock_keep_responses(body, *codes, **headers):
147     """Patch pycurl to return fake responses and raise exceptions.
148
149     body can be a string to return as the response body; an exception
150     to raise when perform() is called; or an iterable that returns a
151     sequence of such values.
152     """
153     cm = mock.MagicMock()
154     if isinstance(body, tuple):
155         codes = list(codes)
156         codes.insert(0, body)
157         responses = [
158             FakeCurl.make(code=code, body=b, headers=headers)
159             for b, code in codes
160         ]
161     else:
162         responses = [
163             FakeCurl.make(code=code, body=body, headers=headers)
164             for code in codes
165         ]
166     cm.side_effect = queue_with(responses)
167     cm.responses = responses
168     return mock.patch('pycurl.Curl', cm)
169
170
171 class MockStreamReader(object):
172     def __init__(self, name='.', *data):
173         self._name = name
174         self._data = b''.join([
175             b if isinstance(b, bytes) else b.encode()
176             for b in data])
177         self._data_locators = [str_keep_locator(d) for d in data]
178         self.num_retries = 0
179
180     def name(self):
181         return self._name
182
183     def readfrom(self, start, size, num_retries=None):
184         return self._data[start:start + size]
185
186 class ApiClientMock(object):
187     def api_client_mock(self):
188         return mock.MagicMock(name='api_client_mock')
189
190     def mock_keep_services(self, api_mock=None, status=200, count=12,
191                            service_type='disk',
192                            service_host=None,
193                            service_port=None,
194                            service_ssl_flag=False,
195                            additional_services=[],
196                            read_only=False):
197         if api_mock is None:
198             api_mock = self.api_client_mock()
199         body = {
200             'items_available': count,
201             'items': [{
202                 'uuid': 'zzzzz-bi6l4-{:015x}'.format(i),
203                 'owner_uuid': 'zzzzz-tpzed-000000000000000',
204                 'service_host': service_host or 'keep0x{:x}'.format(i),
205                 'service_port': service_port or 65535-i,
206                 'service_ssl_flag': service_ssl_flag,
207                 'service_type': service_type,
208                 'read_only': read_only,
209             } for i in range(0, count)] + additional_services
210         }
211         self._mock_api_call(api_mock.keep_services().accessible, status, body)
212         return api_mock
213
214     def _mock_api_call(self, mock_method, code, body):
215         mock_method = mock_method().execute
216         if code == 200:
217             mock_method.return_value = body
218         else:
219             mock_method.side_effect = arvados.errors.ApiError(
220                 fake_httplib2_response(code), b"{}")
221
222
223 class ArvadosBaseTestCase(unittest.TestCase):
224     # This class provides common utility functions for our tests.
225
226     def setUp(self):
227         self._tempdirs = []
228
229     def tearDown(self):
230         for workdir in self._tempdirs:
231             shutil.rmtree(workdir, ignore_errors=True)
232
233     def make_tmpdir(self):
234         self._tempdirs.append(tempfile.mkdtemp())
235         return self._tempdirs[-1]
236
237     def data_file(self, filename):
238         try:
239             basedir = os.path.dirname(__file__)
240         except NameError:
241             basedir = '.'
242         return open(os.path.join(basedir, 'data', filename))
243
244     def build_directory_tree(self, tree):
245         tree_root = self.make_tmpdir()
246         for leaf in tree:
247             path = os.path.join(tree_root, leaf)
248             try:
249                 os.makedirs(os.path.dirname(path))
250             except OSError as error:
251                 if error.errno != errno.EEXIST:
252                     raise
253             with open(path, 'w') as tmpfile:
254                 tmpfile.write(leaf)
255         return tree_root
256
257     def make_test_file(self, text=b"test"):
258         testfile = tempfile.NamedTemporaryFile()
259         testfile.write(text)
260         testfile.flush()
261         return testfile
262
263 if sys.version_info < (3, 0):
264     # There is no assert[Not]Regex that works in both Python 2 and 3,
265     # so we backport Python 3 style to Python 2.
266     def assertRegex(self, *args, **kwargs):
267         return self.assertRegexpMatches(*args, **kwargs)
268     def assertNotRegex(self, *args, **kwargs):
269         return self.assertNotRegexpMatches(*args, **kwargs)
270     unittest.TestCase.assertRegex = assertRegex
271     unittest.TestCase.assertNotRegex = assertNotRegex