1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
5 from __future__ import division
6 from future import standard_library
7 standard_library.install_aliases()
8 from builtins import str
19 from . import arvados_testutil as tutil
21 _debug = os.environ.get('ARVADOS_DEBUG', None)
24 class StubKeepServers(tutil.ApiClientMock):
27 super(StubKeepServers, self).setUp()
28 sock = socket.socket()
29 sock.bind(('0.0.0.0', 0))
30 self.port = sock.getsockname()[1]
32 self.server = Server(('0.0.0.0', self.port), Handler)
33 self.thread = threading.Thread(target=self.server.serve_forever)
34 self.thread.daemon = True # Exit thread if main proc exits
36 self.api_client = self.mock_keep_services(
38 service_host='localhost',
39 service_port=self.port,
43 self.server.shutdown()
44 super(StubKeepServers, self).tearDown()
47 class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object):
49 allow_reuse_address = 1
51 def __init__(self, *args, **kwargs):
54 # before reading request headers
56 # before reading request body
58 # before setting response status and headers
60 # before sending response body
62 # before returning from handler (thus setting response EOF)
64 # after writing over 1s worth of data at self.bandwidth
66 # after reading over 1s worth of data at self.bandwidth
70 super(Server, self).__init__(*args, **kwargs)
72 def setdelays(self, **kwargs):
73 """In future requests, induce delays at the given checkpoints."""
74 for (k, v) in kwargs.items():
75 self.delays.get(k) # NameError if unknown key
78 def setbandwidth(self, bandwidth):
79 """For future requests, set the maximum bandwidth (number of bytes per
80 second) to operate at. If setbandwidth is never called, function at
81 maximum bandwidth possible"""
82 self.bandwidth = float(bandwidth)
84 def _sleep_at_least(self, seconds):
85 """Sleep for given time, even if signals are received."""
86 wake = time.time() + seconds
90 todo = wake - time.time()
92 def _do_delay(self, k):
93 self._sleep_at_least(self.delays[k])
96 class Handler(http.server.BaseHTTPRequestHandler, object):
98 protocol_version = 'HTTP/1.1'
100 def wfile_bandwidth_write(self, data_to_write):
101 if self.server.bandwidth is None and self.server.delays['mid_write'] == 0:
102 self.wfile.write(data_to_write)
104 BYTES_PER_WRITE = int(self.server.bandwidth/4) or 32768
105 outage_happened = False
106 num_bytes = len(data_to_write)
108 target_time = time.time()
109 while num_sent_bytes < num_bytes:
110 if num_sent_bytes > self.server.bandwidth and not outage_happened:
111 self.server._do_delay('mid_write')
112 target_time += self.server.delays['mid_write']
113 outage_happened = True
114 num_write_bytes = min(BYTES_PER_WRITE,
115 num_bytes - num_sent_bytes)
116 self.wfile.write(data_to_write[
117 num_sent_bytes:num_sent_bytes+num_write_bytes])
118 num_sent_bytes += num_write_bytes
119 if self.server.bandwidth is not None:
120 target_time += num_write_bytes / self.server.bandwidth
121 self.server._sleep_at_least(target_time - time.time())
124 def rfile_bandwidth_read(self, bytes_to_read):
125 if self.server.bandwidth is None and self.server.delays['mid_read'] == 0:
126 return self.rfile.read(bytes_to_read)
128 BYTES_PER_READ = int(self.server.bandwidth/4) or 32768
130 outage_happened = False
132 target_time = time.time()
133 while bytes_to_read > bytes_read:
134 if bytes_read > self.server.bandwidth and not outage_happened:
135 self.server._do_delay('mid_read')
136 target_time += self.server.delays['mid_read']
137 outage_happened = True
138 next_bytes_to_read = min(BYTES_PER_READ,
139 bytes_to_read - bytes_read)
140 data += self.rfile.read(next_bytes_to_read)
141 bytes_read += next_bytes_to_read
142 if self.server.bandwidth is not None:
143 target_time += next_bytes_to_read / self.server.bandwidth
144 self.server._sleep_at_least(target_time - time.time())
147 def finish(self, *args, **kwargs):
149 return super(Handler, self).finish(*args, **kwargs)
150 except Exception as err:
154 def handle(self, *args, **kwargs):
156 return super(Handler, self).handle(*args, **kwargs)
161 def handle_one_request(self, *args, **kwargs):
162 self._sent_continue = False
163 self.server._do_delay('request')
164 return super(Handler, self).handle_one_request(*args, **kwargs)
166 def handle_expect_100(self):
167 self.server._do_delay('request_body')
168 self._sent_continue = True
169 return super(Handler, self).handle_expect_100()
172 self.server._do_delay('response')
173 r = re.search(r'[0-9a-f]{32}', self.path)
175 return self.send_response(422)
176 datahash = r.group(0)
177 if datahash not in self.server.store:
178 return self.send_response(404)
179 self.send_response(200)
180 self.send_header('Connection', 'close')
181 self.send_header('Content-type', 'application/octet-stream')
183 self.server._do_delay('response_body')
184 self.wfile_bandwidth_write(self.server.store[datahash])
185 self.server._do_delay('response_close')
188 self.server._do_delay('response')
189 r = re.search(r'[0-9a-f]{32}', self.path)
191 return self.send_response(422)
192 datahash = r.group(0)
193 if datahash not in self.server.store:
194 return self.send_response(404)
195 self.send_response(200)
196 self.send_header('Connection', 'close')
197 self.send_header('Content-type', 'application/octet-stream')
198 self.send_header('Content-length', str(len(self.server.store[datahash])))
200 self.server._do_delay('response_close')
201 self.close_connection = True
204 if not self._sent_continue and self.headers.get('expect') == '100-continue':
205 # The comments at https://bugs.python.org/issue1491
206 # implies that Python 2.7 BaseHTTPRequestHandler was
207 # patched to support 100 Continue, but reading the actual
208 # code that ships in Debian it clearly is not, so we need
209 # to send the response on the socket directly.
210 self.server._do_delay('request_body')
211 self.wfile.write("{} {} {}\r\n\r\n".format(
212 self.protocol_version, 100, "Continue").encode())
213 data = self.rfile_bandwidth_read(
214 int(self.headers.get('content-length')))
215 datahash = hashlib.md5(data).hexdigest()
216 self.server.store[datahash] = data
217 resp = '{}+{}\n'.format(datahash, len(data)).encode()
218 self.server._do_delay('response')
219 self.send_response(200)
220 self.send_header('Connection', 'close')
221 self.send_header('Content-type', 'text/plain')
222 self.send_header('Content-length', len(resp))
224 self.server._do_delay('response_body')
225 self.wfile_bandwidth_write(resp)
226 self.server._do_delay('response_close')
227 self.close_connection = True
229 def log_request(self, *args, **kwargs):
231 super(Handler, self).log_request(*args, **kwargs)