Merge branch '19982-spot-instance' refs #19982
[arvados.git] / sdk / python / tests / keepstub.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import http.server
6 import hashlib
7 import os
8 import re
9 import socket
10 import socketserver
11 import sys
12 import threading
13 import time
14
15 from . import arvados_testutil as tutil
16
17 _debug = os.environ.get('ARVADOS_DEBUG', None)
18
19 class StubKeepServers(tutil.ApiClientMock):
20     def setUp(self):
21         super(StubKeepServers, self).setUp()
22         sock = socket.socket()
23         sock.bind(('0.0.0.0', 0))
24         self.port = sock.getsockname()[1]
25         sock.close()
26         self.server = Server(('0.0.0.0', self.port), Handler)
27         self.thread = threading.Thread(target=self.server.serve_forever)
28         self.thread.daemon = True # Exit thread if main proc exits
29         self.thread.start()
30         self.api_client = self.mock_keep_services(
31             count=1,
32             service_host='localhost',
33             service_port=self.port,
34         )
35
36     def tearDown(self):
37         self.server.shutdown()
38         super(StubKeepServers, self).tearDown()
39
40
41 class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object):
42
43     allow_reuse_address = 1
44
45     def __init__(self, *args, **kwargs):
46         self.store = {}
47         self.delays = {
48             # before reading request headers
49             'request': 0,
50             # before reading request body
51             'request_body': 0,
52             # before setting response status and headers
53             'response': 0,
54             # before sending response body
55             'response_body': 0,
56             # before returning from handler (thus setting response EOF)
57             'response_close': 0,
58             # after writing over 1s worth of data at self.bandwidth
59             'mid_write': 0,
60             # after reading over 1s worth of data at self.bandwidth
61             'mid_read': 0,
62         }
63         self.bandwidth = None
64         super(Server, self).__init__(*args, **kwargs)
65
66     def setdelays(self, **kwargs):
67         """In future requests, induce delays at the given checkpoints."""
68         for (k, v) in kwargs.items():
69             self.delays.get(k)  # NameError if unknown key
70             self.delays[k] = v
71
72     def setbandwidth(self, bandwidth):
73         """For future requests, set the maximum bandwidth (number of bytes per
74         second) to operate at. If setbandwidth is never called, function at
75         maximum bandwidth possible"""
76         self.bandwidth = float(bandwidth)
77
78     def _sleep_at_least(self, seconds):
79         """Sleep for given time, even if signals are received."""
80         wake = time.time() + seconds
81         todo = seconds
82         while todo > 0:
83             time.sleep(todo)
84             todo = wake - time.time()
85
86     def _do_delay(self, k):
87         self._sleep_at_least(self.delays[k])
88
89
90 class Handler(http.server.BaseHTTPRequestHandler, object):
91
92     protocol_version = 'HTTP/1.1'
93
94     def wfile_bandwidth_write(self, data_to_write):
95         if self.server.bandwidth is None and self.server.delays['mid_write'] == 0:
96             self.wfile.write(data_to_write)
97         else:
98             BYTES_PER_WRITE = int(self.server.bandwidth/4) or 32768
99             outage_happened = False
100             num_bytes = len(data_to_write)
101             num_sent_bytes = 0
102             target_time = time.time()
103             while num_sent_bytes < num_bytes:
104                 if num_sent_bytes > self.server.bandwidth and not outage_happened:
105                     self.server._do_delay('mid_write')
106                     target_time += self.server.delays['mid_write']
107                     outage_happened = True
108                 num_write_bytes = min(BYTES_PER_WRITE,
109                     num_bytes - num_sent_bytes)
110                 self.wfile.write(data_to_write[
111                     num_sent_bytes:num_sent_bytes+num_write_bytes])
112                 num_sent_bytes += num_write_bytes
113                 if self.server.bandwidth is not None:
114                     target_time += num_write_bytes / self.server.bandwidth
115                     self.server._sleep_at_least(target_time - time.time())
116         return None
117
118     def rfile_bandwidth_read(self, bytes_to_read):
119         if self.server.bandwidth is None and self.server.delays['mid_read'] == 0:
120             return self.rfile.read(bytes_to_read)
121         else:
122             BYTES_PER_READ = int(self.server.bandwidth/4) or 32768
123             data = b''
124             outage_happened = False
125             bytes_read = 0
126             target_time = time.time()
127             while bytes_to_read > bytes_read:
128                 if bytes_read > self.server.bandwidth and not outage_happened:
129                     self.server._do_delay('mid_read')
130                     target_time += self.server.delays['mid_read']
131                     outage_happened = True
132                 next_bytes_to_read = min(BYTES_PER_READ,
133                     bytes_to_read - bytes_read)
134                 data += self.rfile.read(next_bytes_to_read)
135                 bytes_read += next_bytes_to_read
136                 if self.server.bandwidth is not None:
137                     target_time += next_bytes_to_read / self.server.bandwidth
138                     self.server._sleep_at_least(target_time - time.time())
139         return data
140
141     def finish(self, *args, **kwargs):
142         try:
143             return super(Handler, self).finish(*args, **kwargs)
144         except Exception as err:
145             if _debug:
146                 raise
147
148     def handle(self, *args, **kwargs):
149         try:
150             return super(Handler, self).handle(*args, **kwargs)
151         except:
152             if _debug:
153                 raise
154
155     def handle_one_request(self, *args, **kwargs):
156         self._sent_continue = False
157         self.server._do_delay('request')
158         return super(Handler, self).handle_one_request(*args, **kwargs)
159
160     def handle_expect_100(self):
161         self.server._do_delay('request_body')
162         self._sent_continue = True
163         return super(Handler, self).handle_expect_100()
164
165     def do_GET(self):
166         self.server._do_delay('response')
167         r = re.search(r'[0-9a-f]{32}', self.path)
168         if not r:
169             return self.send_response(422)
170         datahash = r.group(0)
171         if datahash not in self.server.store:
172             return self.send_response(404)
173         self.send_response(200)
174         self.send_header('Connection', 'close')
175         self.send_header('Content-type', 'application/octet-stream')
176         self.end_headers()
177         self.server._do_delay('response_body')
178         self.wfile_bandwidth_write(self.server.store[datahash])
179         self.server._do_delay('response_close')
180
181     def do_HEAD(self):
182         self.server._do_delay('response')
183         r = re.search(r'[0-9a-f]{32}', self.path)
184         if not r:
185             return self.send_response(422)
186         datahash = r.group(0)
187         if datahash not in self.server.store:
188             return self.send_response(404)
189         self.send_response(200)
190         self.send_header('Connection', 'close')
191         self.send_header('Content-type', 'application/octet-stream')
192         self.send_header('Content-length', str(len(self.server.store[datahash])))
193         self.end_headers()
194         self.server._do_delay('response_close')
195         self.close_connection = True
196
197     def do_PUT(self):
198         if not self._sent_continue and self.headers.get('expect') == '100-continue':
199             # The comments at https://bugs.python.org/issue1491
200             # implies that Python 2.7 BaseHTTPRequestHandler was
201             # patched to support 100 Continue, but reading the actual
202             # code that ships in Debian it clearly is not, so we need
203             # to send the response on the socket directly.
204             self.server._do_delay('request_body')
205             self.wfile.write("{} {} {}\r\n\r\n".format(
206                 self.protocol_version, 100, "Continue").encode())
207         data = self.rfile_bandwidth_read(
208             int(self.headers.get('content-length')))
209         datahash = hashlib.md5(data).hexdigest()
210         self.server.store[datahash] = data
211         resp = '{}+{}\n'.format(datahash, len(data)).encode()
212         self.server._do_delay('response')
213         self.send_response(200)
214         self.send_header('Connection', 'close')
215         self.send_header('Content-type', 'text/plain')
216         self.send_header('Content-length', len(resp))
217         self.end_headers()
218         self.server._do_delay('response_body')
219         self.wfile_bandwidth_write(resp)
220         self.server._do_delay('response_close')
221         self.close_connection = True
222
223     def log_request(self, *args, **kwargs):
224         if _debug:
225             super(Handler, self).log_request(*args, **kwargs)