Merge branch '19316-oj-safe-load'
[arvados.git] / sdk / python / tests / keepstub.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from __future__ import division
6 from future import standard_library
7 standard_library.install_aliases()
8 from builtins import str
9 import http.server
10 import hashlib
11 import os
12 import re
13 import socket
14 import socketserver
15 import sys
16 import threading
17 import time
18
19 from . import arvados_testutil as tutil
20
21 _debug = os.environ.get('ARVADOS_DEBUG', None)
22
23
24 class StubKeepServers(tutil.ApiClientMock):
25
26     def setUp(self):
27         super(StubKeepServers, self).setUp()
28         sock = socket.socket()
29         sock.bind(('0.0.0.0', 0))
30         self.port = sock.getsockname()[1]
31         sock.close()
32         self.server = Server(('0.0.0.0', self.port), Handler)
33         self.thread = threading.Thread(target=self.server.serve_forever)
34         self.thread.daemon = True # Exit thread if main proc exits
35         self.thread.start()
36         self.api_client = self.mock_keep_services(
37             count=1,
38             service_host='localhost',
39             service_port=self.port,
40         )
41
42     def tearDown(self):
43         self.server.shutdown()
44         super(StubKeepServers, self).tearDown()
45
46
47 class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object):
48
49     allow_reuse_address = 1
50
51     def __init__(self, *args, **kwargs):
52         self.store = {}
53         self.delays = {
54             # before reading request headers
55             'request': 0,
56             # before reading request body
57             'request_body': 0,
58             # before setting response status and headers
59             'response': 0,
60             # before sending response body
61             'response_body': 0,
62             # before returning from handler (thus setting response EOF)
63             'response_close': 0,
64             # after writing over 1s worth of data at self.bandwidth
65             'mid_write': 0,
66             # after reading over 1s worth of data at self.bandwidth
67             'mid_read': 0,
68         }
69         self.bandwidth = None
70         super(Server, self).__init__(*args, **kwargs)
71
72     def setdelays(self, **kwargs):
73         """In future requests, induce delays at the given checkpoints."""
74         for (k, v) in kwargs.items():
75             self.delays.get(k)  # NameError if unknown key
76             self.delays[k] = v
77
78     def setbandwidth(self, bandwidth):
79         """For future requests, set the maximum bandwidth (number of bytes per
80         second) to operate at. If setbandwidth is never called, function at
81         maximum bandwidth possible"""
82         self.bandwidth = float(bandwidth)
83
84     def _sleep_at_least(self, seconds):
85         """Sleep for given time, even if signals are received."""
86         wake = time.time() + seconds
87         todo = seconds
88         while todo > 0:
89             time.sleep(todo)
90             todo = wake - time.time()
91
92     def _do_delay(self, k):
93         self._sleep_at_least(self.delays[k])
94
95
96 class Handler(http.server.BaseHTTPRequestHandler, object):
97
98     protocol_version = 'HTTP/1.1'
99
100     def wfile_bandwidth_write(self, data_to_write):
101         if self.server.bandwidth is None and self.server.delays['mid_write'] == 0:
102             self.wfile.write(data_to_write)
103         else:
104             BYTES_PER_WRITE = int(self.server.bandwidth/4) or 32768
105             outage_happened = False
106             num_bytes = len(data_to_write)
107             num_sent_bytes = 0
108             target_time = time.time()
109             while num_sent_bytes < num_bytes:
110                 if num_sent_bytes > self.server.bandwidth and not outage_happened:
111                     self.server._do_delay('mid_write')
112                     target_time += self.server.delays['mid_write']
113                     outage_happened = True
114                 num_write_bytes = min(BYTES_PER_WRITE,
115                     num_bytes - num_sent_bytes)
116                 self.wfile.write(data_to_write[
117                     num_sent_bytes:num_sent_bytes+num_write_bytes])
118                 num_sent_bytes += num_write_bytes
119                 if self.server.bandwidth is not None:
120                     target_time += num_write_bytes / self.server.bandwidth
121                     self.server._sleep_at_least(target_time - time.time())
122         return None
123
124     def rfile_bandwidth_read(self, bytes_to_read):
125         if self.server.bandwidth is None and self.server.delays['mid_read'] == 0:
126             return self.rfile.read(bytes_to_read)
127         else:
128             BYTES_PER_READ = int(self.server.bandwidth/4) or 32768
129             data = b''
130             outage_happened = False
131             bytes_read = 0
132             target_time = time.time()
133             while bytes_to_read > bytes_read:
134                 if bytes_read > self.server.bandwidth and not outage_happened:
135                     self.server._do_delay('mid_read')
136                     target_time += self.server.delays['mid_read']
137                     outage_happened = True
138                 next_bytes_to_read = min(BYTES_PER_READ,
139                     bytes_to_read - bytes_read)
140                 data += self.rfile.read(next_bytes_to_read)
141                 bytes_read += next_bytes_to_read
142                 if self.server.bandwidth is not None:
143                     target_time += next_bytes_to_read / self.server.bandwidth
144                     self.server._sleep_at_least(target_time - time.time())
145         return data
146
147     def finish(self, *args, **kwargs):
148         try:
149             return super(Handler, self).finish(*args, **kwargs)
150         except Exception as err:
151             if _debug:
152                 raise
153
154     def handle(self, *args, **kwargs):
155         try:
156             return super(Handler, self).handle(*args, **kwargs)
157         except:
158             if _debug:
159                 raise
160
161     def handle_one_request(self, *args, **kwargs):
162         self._sent_continue = False
163         self.server._do_delay('request')
164         return super(Handler, self).handle_one_request(*args, **kwargs)
165
166     def handle_expect_100(self):
167         self.server._do_delay('request_body')
168         self._sent_continue = True
169         return super(Handler, self).handle_expect_100()
170
171     def do_GET(self):
172         self.server._do_delay('response')
173         r = re.search(r'[0-9a-f]{32}', self.path)
174         if not r:
175             return self.send_response(422)
176         datahash = r.group(0)
177         if datahash not in self.server.store:
178             return self.send_response(404)
179         self.send_response(200)
180         self.send_header('Connection', 'close')
181         self.send_header('Content-type', 'application/octet-stream')
182         self.end_headers()
183         self.server._do_delay('response_body')
184         self.wfile_bandwidth_write(self.server.store[datahash])
185         self.server._do_delay('response_close')
186
187     def do_HEAD(self):
188         self.server._do_delay('response')
189         r = re.search(r'[0-9a-f]{32}', self.path)
190         if not r:
191             return self.send_response(422)
192         datahash = r.group(0)
193         if datahash not in self.server.store:
194             return self.send_response(404)
195         self.send_response(200)
196         self.send_header('Connection', 'close')
197         self.send_header('Content-type', 'application/octet-stream')
198         self.send_header('Content-length', str(len(self.server.store[datahash])))
199         self.end_headers()
200         self.server._do_delay('response_close')
201         self.close_connection = True
202
203     def do_PUT(self):
204         if not self._sent_continue and self.headers.get('expect') == '100-continue':
205             # The comments at https://bugs.python.org/issue1491
206             # implies that Python 2.7 BaseHTTPRequestHandler was
207             # patched to support 100 Continue, but reading the actual
208             # code that ships in Debian it clearly is not, so we need
209             # to send the response on the socket directly.
210             self.server._do_delay('request_body')
211             self.wfile.write("{} {} {}\r\n\r\n".format(
212                 self.protocol_version, 100, "Continue").encode())
213         data = self.rfile_bandwidth_read(
214             int(self.headers.get('content-length')))
215         datahash = hashlib.md5(data).hexdigest()
216         self.server.store[datahash] = data
217         resp = '{}+{}\n'.format(datahash, len(data)).encode()
218         self.server._do_delay('response')
219         self.send_response(200)
220         self.send_header('Connection', 'close')
221         self.send_header('Content-type', 'text/plain')
222         self.send_header('Content-length', len(resp))
223         self.end_headers()
224         self.server._do_delay('response_body')
225         self.wfile_bandwidth_write(resp)
226         self.server._do_delay('response_close')
227         self.close_connection = True
228
229     def log_request(self, *args, **kwargs):
230         if _debug:
231             super(Handler, self).log_request(*args, **kwargs)