+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
from __future__ import division
from future import standard_library
standard_library.install_aliases()
import hashlib
import os
import re
+import socket
import socketserver
import sys
+import threading
import time
+from . import arvados_testutil as tutil
+
+_debug = os.environ.get('ARVADOS_DEBUG', None)
+
+
+class StubKeepServers(tutil.ApiClientMock):
+
+ def setUp(self):
+ super(StubKeepServers, self).setUp()
+ sock = socket.socket()
+ sock.bind(('0.0.0.0', 0))
+ self.port = sock.getsockname()[1]
+ sock.close()
+ self.server = Server(('0.0.0.0', self.port), Handler)
+ self.thread = threading.Thread(target=self.server.serve_forever)
+ self.thread.daemon = True # Exit thread if main proc exits
+ self.thread.start()
+ self.api_client = self.mock_keep_services(
+ count=1,
+ service_host='localhost',
+ service_port=self.port,
+ )
+
+ def tearDown(self):
+ self.server.shutdown()
+ super(StubKeepServers, self).tearDown()
+
+
class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object):
allow_reuse_address = 1
def setdelays(self, **kwargs):
"""In future requests, induce delays at the given checkpoints."""
for (k, v) in kwargs.items():
- self.delays.get(k) # NameError if unknown key
+ self.delays.get(k) # NameError if unknown key
self.delays[k] = v
def setbandwidth(self, bandwidth):
class Handler(http.server.BaseHTTPRequestHandler, object):
+
+ protocol_version = 'HTTP/1.1'
+
def wfile_bandwidth_write(self, data_to_write):
- if self.server.bandwidth == None and self.server.delays['mid_write'] == 0:
+ if self.server.bandwidth is None and self.server.delays['mid_write'] == 0:
self.wfile.write(data_to_write)
else:
BYTES_PER_WRITE = int(self.server.bandwidth/4) or 32768
while num_sent_bytes < num_bytes:
if num_sent_bytes > self.server.bandwidth and not outage_happened:
self.server._do_delay('mid_write')
- target_time += self.delays['mid_write']
+ target_time += self.server.delays['mid_write']
outage_happened = True
num_write_bytes = min(BYTES_PER_WRITE,
num_bytes - num_sent_bytes)
return None
def rfile_bandwidth_read(self, bytes_to_read):
- if self.server.bandwidth == None and self.server.delays['mid_read'] == 0:
+ if self.server.bandwidth is None and self.server.delays['mid_read'] == 0:
return self.rfile.read(bytes_to_read)
else:
BYTES_PER_READ = int(self.server.bandwidth/4) or 32768
while bytes_to_read > bytes_read:
if bytes_read > self.server.bandwidth and not outage_happened:
self.server._do_delay('mid_read')
- target_time += self.delays['mid_read']
+ target_time += self.server.delays['mid_read']
outage_happened = True
next_bytes_to_read = min(BYTES_PER_READ,
bytes_to_read - bytes_read)
self.server._sleep_at_least(target_time - time.time())
return data
+ def finish(self, *args, **kwargs):
+ try:
+ return super(Handler, self).finish(*args, **kwargs)
+ except Exception as err:
+ if _debug:
+ raise
+
def handle(self, *args, **kwargs):
+ try:
+ return super(Handler, self).handle(*args, **kwargs)
+ except:
+ if _debug:
+ raise
+
+ def handle_one_request(self, *args, **kwargs):
+ self._sent_continue = False
self.server._do_delay('request')
- return super(Handler, self).handle(*args, **kwargs)
+ return super(Handler, self).handle_one_request(*args, **kwargs)
+
+ def handle_expect_100(self):
+ self.server._do_delay('request_body')
+ self._sent_continue = True
+ return super(Handler, self).handle_expect_100()
def do_GET(self):
self.server._do_delay('response')
if datahash not in self.server.store:
return self.send_response(404)
self.send_response(200)
+ self.send_header('Connection', 'close')
self.send_header('Content-type', 'application/octet-stream')
self.end_headers()
self.server._do_delay('response_body')
if datahash not in self.server.store:
return self.send_response(404)
self.send_response(200)
+ self.send_header('Connection', 'close')
self.send_header('Content-type', 'application/octet-stream')
self.send_header('Content-length', str(len(self.server.store[datahash])))
self.end_headers()
self.server._do_delay('response_close')
-
- def handle_expect_100(self):
- self.server._do_delay('request_body')
+ self.close_connection = True
def do_PUT(self):
- if sys.version_info < (3, 0):
+ if not self._sent_continue and self.headers.get('expect') == '100-continue':
# The comments at https://bugs.python.org/issue1491
# implies that Python 2.7 BaseHTTPRequestHandler was
# patched to support 100 Continue, but reading the actual
# to send the response on the socket directly.
self.server._do_delay('request_body')
self.wfile.write("{} {} {}\r\n\r\n".format(
- self.protocol_version, 100, "Continue"))
+ self.protocol_version, 100, "Continue").encode())
data = self.rfile_bandwidth_read(
int(self.headers.get('content-length')))
datahash = hashlib.md5(data).hexdigest()
self.server.store[datahash] = data
+ resp = '{}+{}\n'.format(datahash, len(data)).encode()
self.server._do_delay('response')
self.send_response(200)
+ self.send_header('Connection', 'close')
self.send_header('Content-type', 'text/plain')
+ self.send_header('Content-length', len(resp))
self.end_headers()
self.server._do_delay('response_body')
- self.wfile_bandwidth_write(datahash + '+' + str(len(data)))
+ self.wfile_bandwidth_write(resp)
self.server._do_delay('response_close')
+ self.close_connection = True
def log_request(self, *args, **kwargs):
- if os.environ.get('ARVADOS_DEBUG', None):
+ if _debug:
super(Handler, self).log_request(*args, **kwargs)
-
- def finish(self, *args, **kwargs):
- """Ignore exceptions, notably "Broken pipe" when client times out."""
- try:
- return super(Handler, self).finish(*args, **kwargs)
- except:
- pass
-
- def handle_one_request(self, *args, **kwargs):
- """Ignore exceptions, notably "Broken pipe" when client times out."""
- try:
- return super(Handler, self).handle_one_request(*args, **kwargs)
- except:
- pass