11308: Fix string/bytes confusion for Python 3
[arvados.git] / sdk / python / tests / arvados_testutil.py
1 #!/usr/bin/env python
2
3 from future import standard_library
4 standard_library.install_aliases()
5 from builtins import str
6 from builtins import range
7 from builtins import object
8 import arvados
9 import contextlib
10 import errno
11 import hashlib
12 import http.client
13 import httplib2
14 import io
15 import mock
16 import os
17 import pycurl
18 import queue
19 import shutil
20 import sys
21 import tempfile
22 import unittest
23
24 if sys.version_info >= (3, 0):
25     from io import StringIO
26 else:
27     from cStringIO import StringIO
28
29 # Use this hostname when you want to make sure the traffic will be
30 # instantly refused.  100::/64 is a dedicated black hole.
31 TEST_HOST = '100::'
32
33 skip_sleep = mock.patch('time.sleep', lambda n: None)  # clown'll eat me
34
35 def queue_with(items):
36     """Return a thread-safe iterator that yields the given items.
37
38     +items+ can be given as an array or an iterator. If an iterator is
39     given, it will be consumed to fill the queue before queue_with()
40     returns.
41     """
42     q = queue.Queue()
43     for val in items:
44         q.put(val)
45     return lambda *args, **kwargs: q.get(block=False)
46
47 # fake_httplib2_response and mock_responses
48 # mock calls to httplib2.Http.request()
49 def fake_httplib2_response(code, **headers):
50     headers.update(status=str(code),
51                    reason=http.client.responses.get(code, "Unknown Response"))
52     return httplib2.Response(headers)
53
54 def mock_responses(body, *codes, **headers):
55     if not isinstance(body, bytes) and hasattr(body, 'encode'):
56         body = body.encode()
57     return mock.patch('httplib2.Http.request', side_effect=queue_with((
58         (fake_httplib2_response(code, **headers), body) for code in codes)))
59
60 def mock_api_responses(api_client, body, codes, headers={}):
61     if not isinstance(body, bytes) and hasattr(body, 'encode'):
62         body = body.encode()
63     return mock.patch.object(api_client._http, 'request', side_effect=queue_with((
64         (fake_httplib2_response(code, **headers), body) for code in codes)))
65
66 def str_keep_locator(s):
67     return '{}+{}'.format(hashlib.md5(s if isinstance(s, bytes) else s.encode()).hexdigest(), len(s))
68
69 @contextlib.contextmanager
70 def redirected_streams(stdout=None, stderr=None):
71     if stdout == StringIO:
72         stdout = StringIO()
73     if stderr == StringIO:
74         stderr = StringIO()
75     orig_stdout, sys.stdout = sys.stdout, stdout or sys.stdout
76     orig_stderr, sys.stderr = sys.stderr, stderr or sys.stderr
77     try:
78         yield (stdout, stderr)
79     finally:
80         sys.stdout = orig_stdout
81         sys.stderr = orig_stderr
82
83
84 class VersionChecker(object):
85     def assertVersionOutput(self, out, err):
86         if sys.version_info >= (3, 0):
87             self.assertEqual(err.getvalue(), '')
88             v = out.getvalue()
89         else:
90             # Python 2 writes version info on stderr.
91             self.assertEqual(out.getvalue(), '')
92             v = err.getvalue()
93         self.assertRegexpMatches(v, "[0-9]+\.[0-9]+\.[0-9]+$\n")
94
95
96 class FakeCurl(object):
97     @classmethod
98     def make(cls, code, body=b'', headers={}):
99         if not isinstance(body, bytes) and hasattr(body, 'encode'):
100             body = body.encode()
101         return mock.Mock(spec=cls, wraps=cls(code, body, headers))
102
103     def __init__(self, code=200, body=b'', headers={}):
104         self._opt = {}
105         self._got_url = None
106         self._writer = None
107         self._headerfunction = None
108         self._resp_code = code
109         self._resp_body = body
110         self._resp_headers = headers
111
112     def getopt(self, opt):
113         return self._opt.get(str(opt), None)
114
115     def setopt(self, opt, val):
116         self._opt[str(opt)] = val
117         if opt == pycurl.WRITEFUNCTION:
118             self._writer = val
119         elif opt == pycurl.HEADERFUNCTION:
120             self._headerfunction = val
121
122     def perform(self):
123         if not isinstance(self._resp_code, int):
124             raise self._resp_code
125         if self.getopt(pycurl.URL) is None:
126             raise ValueError
127         if self._writer is None:
128             raise ValueError
129         if self._headerfunction:
130             self._headerfunction("HTTP/1.1 {} Status".format(self._resp_code))
131             for k, v in self._resp_headers.items():
132                 self._headerfunction(k + ': ' + str(v))
133         if type(self._resp_body) is not bool:
134             self._writer(self._resp_body)
135
136     def close(self):
137         pass
138
139     def reset(self):
140         """Prevent fake UAs from going back into the user agent pool."""
141         raise Exception
142
143     def getinfo(self, opt):
144         if opt == pycurl.RESPONSE_CODE:
145             return self._resp_code
146         raise Exception
147
148 def mock_keep_responses(body, *codes, **headers):
149     """Patch pycurl to return fake responses and raise exceptions.
150
151     body can be a string to return as the response body; an exception
152     to raise when perform() is called; or an iterable that returns a
153     sequence of such values.
154     """
155     cm = mock.MagicMock()
156     if isinstance(body, tuple):
157         codes = list(codes)
158         codes.insert(0, body)
159         responses = [
160             FakeCurl.make(code=code, body=b, headers=headers)
161             for b, code in codes
162         ]
163     else:
164         responses = [
165             FakeCurl.make(code=code, body=body, headers=headers)
166             for code in codes
167         ]
168     cm.side_effect = queue_with(responses)
169     cm.responses = responses
170     return mock.patch('pycurl.Curl', cm)
171
172
173 class MockStreamReader(object):
174     def __init__(self, name='.', *data):
175         self._name = name
176         self._data = b''.join([
177             b if isinstance(b, bytes) else b.encode()
178             for b in data])
179         self._data_locators = [str_keep_locator(d) for d in data]
180         self.num_retries = 0
181
182     def name(self):
183         return self._name
184
185     def readfrom(self, start, size, num_retries=None):
186         return self._data[start:start + size]
187
188 class ApiClientMock(object):
189     def api_client_mock(self):
190         return mock.MagicMock(name='api_client_mock')
191
192     def mock_keep_services(self, api_mock=None, status=200, count=12,
193                            service_type='disk',
194                            service_host=None,
195                            service_port=None,
196                            service_ssl_flag=False,
197                            additional_services=[],
198                            read_only=False):
199         if api_mock is None:
200             api_mock = self.api_client_mock()
201         body = {
202             'items_available': count,
203             'items': [{
204                 'uuid': 'zzzzz-bi6l4-{:015x}'.format(i),
205                 'owner_uuid': 'zzzzz-tpzed-000000000000000',
206                 'service_host': service_host or 'keep0x{:x}'.format(i),
207                 'service_port': service_port or 65535-i,
208                 'service_ssl_flag': service_ssl_flag,
209                 'service_type': service_type,
210                 'read_only': read_only,
211             } for i in range(0, count)] + additional_services
212         }
213         self._mock_api_call(api_mock.keep_services().accessible, status, body)
214         return api_mock
215
216     def _mock_api_call(self, mock_method, code, body):
217         mock_method = mock_method().execute
218         if code == 200:
219             mock_method.return_value = body
220         else:
221             mock_method.side_effect = arvados.errors.ApiError(
222                 fake_httplib2_response(code), b"{}")
223
224
225 class ArvadosBaseTestCase(unittest.TestCase):
226     # This class provides common utility functions for our tests.
227
228     def setUp(self):
229         self._tempdirs = []
230
231     def tearDown(self):
232         for workdir in self._tempdirs:
233             shutil.rmtree(workdir, ignore_errors=True)
234
235     def make_tmpdir(self):
236         self._tempdirs.append(tempfile.mkdtemp())
237         return self._tempdirs[-1]
238
239     def data_file(self, filename):
240         try:
241             basedir = os.path.dirname(__file__)
242         except NameError:
243             basedir = '.'
244         return open(os.path.join(basedir, 'data', filename))
245
246     def build_directory_tree(self, tree):
247         tree_root = self.make_tmpdir()
248         for leaf in tree:
249             path = os.path.join(tree_root, leaf)
250             try:
251                 os.makedirs(os.path.dirname(path))
252             except OSError as error:
253                 if error.errno != errno.EEXIST:
254                     raise
255             with open(path, 'w') as tmpfile:
256                 tmpfile.write(leaf)
257         return tree_root
258
259     def make_test_file(self, text=b"test"):
260         testfile = tempfile.NamedTemporaryFile()
261         testfile.write(text)
262         testfile.flush()
263         return testfile