X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/84260dab5182907cae91849acd652c138c2d5095..e67d0f5d43c56f78694ea4a5f93acec5c93cd0fb:/sdk/python/tests/keepstub.py diff --git a/sdk/python/tests/keepstub.py b/sdk/python/tests/keepstub.py index f074f8d6cf..6be8d8b646 100644 --- a/sdk/python/tests/keepstub.py +++ b/sdk/python/tests/keepstub.py @@ -1,11 +1,50 @@ -import BaseHTTPServer +# Copyright (C) The Arvados Authors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import division +from future import standard_library +standard_library.install_aliases() +from builtins import str +import http.server import hashlib import os import re -import SocketServer +import socket +import socketserver +import sys +import threading import time -class Server(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer, object): +from . import arvados_testutil as tutil + +_debug = os.environ.get('ARVADOS_DEBUG', None) + + +class StubKeepServers(tutil.ApiClientMock): + + def setUp(self): + super(StubKeepServers, self).setUp() + sock = socket.socket() + sock.bind(('0.0.0.0', 0)) + self.port = sock.getsockname()[1] + sock.close() + self.server = Server(('0.0.0.0', self.port), Handler) + self.thread = threading.Thread(target=self.server.serve_forever) + self.thread.daemon = True # Exit thread if main proc exits + self.thread.start() + self.api_client = self.mock_keep_services( + count=1, + service_host='localhost', + service_port=self.port, + ) + + def tearDown(self): + self.server.shutdown() + super(StubKeepServers, self).tearDown() + + +class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object): allow_reuse_address = 1 @@ -32,8 +71,8 @@ class Server(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer, object): def setdelays(self, **kwargs): """In future requests, induce delays at the given checkpoints.""" - for (k, v) in kwargs.iteritems(): - self.delays.get(k) # NameError if unknown key + for (k, v) in kwargs.items(): + self.delays.get(k) # NameError if unknown key self.delays[k] = v def setbandwidth(self, bandwidth): @@ -54,12 +93,15 @@ class Server(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer, object): self._sleep_at_least(self.delays[k]) -class Handler(BaseHTTPServer.BaseHTTPRequestHandler, object): +class Handler(http.server.BaseHTTPRequestHandler, object): + + protocol_version = 'HTTP/1.1' + def wfile_bandwidth_write(self, data_to_write): - if self.server.bandwidth == None and self.server.delays['mid_write'] == 0: + if self.server.bandwidth is None and self.server.delays['mid_write'] == 0: self.wfile.write(data_to_write) else: - BYTES_PER_WRITE = int(self.server.bandwidth/4.0) or 32768 + BYTES_PER_WRITE = int(self.server.bandwidth/4) or 32768 outage_happened = False num_bytes = len(data_to_write) num_sent_bytes = 0 @@ -67,7 +109,7 @@ class Handler(BaseHTTPServer.BaseHTTPRequestHandler, object): while num_sent_bytes < num_bytes: if num_sent_bytes > self.server.bandwidth and not outage_happened: self.server._do_delay('mid_write') - target_time += self.delays['mid_write'] + target_time += self.server.delays['mid_write'] outage_happened = True num_write_bytes = min(BYTES_PER_WRITE, num_bytes - num_sent_bytes) @@ -80,18 +122,18 @@ class Handler(BaseHTTPServer.BaseHTTPRequestHandler, object): return None def rfile_bandwidth_read(self, bytes_to_read): - if self.server.bandwidth == None and self.server.delays['mid_read'] == 0: + if self.server.bandwidth is None and self.server.delays['mid_read'] == 0: return self.rfile.read(bytes_to_read) else: - BYTES_PER_READ = int(self.server.bandwidth/4.0) or 32768 - data = '' + BYTES_PER_READ = int(self.server.bandwidth/4) or 32768 + data = b'' outage_happened = False bytes_read = 0 target_time = time.time() while bytes_to_read > bytes_read: if bytes_read > self.server.bandwidth and not outage_happened: self.server._do_delay('mid_read') - target_time += self.delays['mid_read'] + target_time += self.server.delays['mid_read'] outage_happened = True next_bytes_to_read = min(BYTES_PER_READ, bytes_to_read - bytes_read) @@ -102,9 +144,29 @@ class Handler(BaseHTTPServer.BaseHTTPRequestHandler, object): self.server._sleep_at_least(target_time - time.time()) return data + def finish(self, *args, **kwargs): + try: + return super(Handler, self).finish(*args, **kwargs) + except Exception as err: + if _debug: + raise + def handle(self, *args, **kwargs): + try: + return super(Handler, self).handle(*args, **kwargs) + except: + if _debug: + raise + + def handle_one_request(self, *args, **kwargs): + self._sent_continue = False self.server._do_delay('request') - return super(Handler, self).handle(*args, **kwargs) + return super(Handler, self).handle_one_request(*args, **kwargs) + + def handle_expect_100(self): + self.server._do_delay('request_body') + self._sent_continue = True + return super(Handler, self).handle_expect_100() def do_GET(self): self.server._do_delay('response') @@ -115,45 +177,55 @@ class Handler(BaseHTTPServer.BaseHTTPRequestHandler, object): if datahash not in self.server.store: return self.send_response(404) self.send_response(200) + self.send_header('Connection', 'close') self.send_header('Content-type', 'application/octet-stream') self.end_headers() self.server._do_delay('response_body') self.wfile_bandwidth_write(self.server.store[datahash]) self.server._do_delay('response_close') + def do_HEAD(self): + self.server._do_delay('response') + r = re.search(r'[0-9a-f]{32}', self.path) + if not r: + return self.send_response(422) + datahash = r.group(0) + if datahash not in self.server.store: + return self.send_response(404) + self.send_response(200) + self.send_header('Connection', 'close') + self.send_header('Content-type', 'application/octet-stream') + self.send_header('Content-length', str(len(self.server.store[datahash]))) + self.end_headers() + self.server._do_delay('response_close') + self.close_connection = True + def do_PUT(self): - self.server._do_delay('request_body') - # The comments at https://bugs.python.org/issue1491 implies that Python - # 2.7 BaseHTTPRequestHandler was patched to support 100 Continue, but - # reading the actual code that ships in Debian it clearly is not, so we - # need to send the response on the socket directly. - self.wfile_bandwidth_write("%s %d %s\r\n\r\n" % - (self.protocol_version, 100, "Continue")) - data = self.rfile_bandwidth_read(int(self.headers.getheader('content-length'))) + if not self._sent_continue and self.headers.get('expect') == '100-continue': + # The comments at https://bugs.python.org/issue1491 + # implies that Python 2.7 BaseHTTPRequestHandler was + # patched to support 100 Continue, but reading the actual + # code that ships in Debian it clearly is not, so we need + # to send the response on the socket directly. + self.server._do_delay('request_body') + self.wfile.write("{} {} {}\r\n\r\n".format( + self.protocol_version, 100, "Continue").encode()) + data = self.rfile_bandwidth_read( + int(self.headers.get('content-length'))) datahash = hashlib.md5(data).hexdigest() self.server.store[datahash] = data + resp = '{}+{}\n'.format(datahash, len(data)).encode() self.server._do_delay('response') self.send_response(200) + self.send_header('Connection', 'close') self.send_header('Content-type', 'text/plain') + self.send_header('Content-length', len(resp)) self.end_headers() self.server._do_delay('response_body') - self.wfile_bandwidth_write(datahash + '+' + str(len(data))) + self.wfile_bandwidth_write(resp) self.server._do_delay('response_close') + self.close_connection = True def log_request(self, *args, **kwargs): - if os.environ.get('ARVADOS_DEBUG', None): + if _debug: super(Handler, self).log_request(*args, **kwargs) - - def finish(self, *args, **kwargs): - """Ignore exceptions, notably "Broken pipe" when client times out.""" - try: - return super(Handler, self).finish(*args, **kwargs) - except: - pass - - def handle_one_request(self, *args, **kwargs): - """Ignore exceptions, notably "Broken pipe" when client times out.""" - try: - return super(Handler, self).handle_one_request(*args, **kwargs) - except: - pass