11308: Raise exception on invalid/unsupported open() mode.
[arvados.git] / sdk / python / tests / keepstub.py
1 from __future__ import division
2 from future import standard_library
3 standard_library.install_aliases()
4 from builtins import str
5 import http.server
6 import hashlib
7 import os
8 import re
9 import socketserver
10 import sys
11 import time
12
13 _debug = os.environ.get('ARVADOS_DEBUG', None)
14
15
16 class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object):
17
18     allow_reuse_address = 1
19
20     def __init__(self, *args, **kwargs):
21         self.store = {}
22         self.delays = {
23             # before reading request headers
24             'request': 0,
25             # before reading request body
26             'request_body': 0,
27             # before setting response status and headers
28             'response': 0,
29             # before sending response body
30             'response_body': 0,
31             # before returning from handler (thus setting response EOF)
32             'response_close': 0,
33             # after writing over 1s worth of data at self.bandwidth
34             'mid_write': 0,
35             # after reading over 1s worth of data at self.bandwidth
36             'mid_read': 0,
37         }
38         self.bandwidth = None
39         super(Server, self).__init__(*args, **kwargs)
40
41     def setdelays(self, **kwargs):
42         """In future requests, induce delays at the given checkpoints."""
43         for (k, v) in kwargs.items():
44             self.delays.get(k)  # NameError if unknown key
45             self.delays[k] = v
46
47     def setbandwidth(self, bandwidth):
48         """For future requests, set the maximum bandwidth (number of bytes per
49         second) to operate at. If setbandwidth is never called, function at
50         maximum bandwidth possible"""
51         self.bandwidth = float(bandwidth)
52
53     def _sleep_at_least(self, seconds):
54         """Sleep for given time, even if signals are received."""
55         wake = time.time() + seconds
56         todo = seconds
57         while todo > 0:
58             time.sleep(todo)
59             todo = wake - time.time()
60
61     def _do_delay(self, k):
62         self._sleep_at_least(self.delays[k])
63
64
65 class Handler(http.server.BaseHTTPRequestHandler, object):
66
67     protocol_version = 'HTTP/1.1'
68
69     def wfile_bandwidth_write(self, data_to_write):
70         if self.server.bandwidth is None and self.server.delays['mid_write'] == 0:
71             self.wfile.write(data_to_write)
72         else:
73             BYTES_PER_WRITE = int(self.server.bandwidth/4) or 32768
74             outage_happened = False
75             num_bytes = len(data_to_write)
76             num_sent_bytes = 0
77             target_time = time.time()
78             while num_sent_bytes < num_bytes:
79                 if num_sent_bytes > self.server.bandwidth and not outage_happened:
80                     self.server._do_delay('mid_write')
81                     target_time += self.server.delays['mid_write']
82                     outage_happened = True
83                 num_write_bytes = min(BYTES_PER_WRITE,
84                     num_bytes - num_sent_bytes)
85                 self.wfile.write(data_to_write[
86                     num_sent_bytes:num_sent_bytes+num_write_bytes])
87                 num_sent_bytes += num_write_bytes
88                 if self.server.bandwidth is not None:
89                     target_time += num_write_bytes / self.server.bandwidth
90                     self.server._sleep_at_least(target_time - time.time())
91         return None
92
93     def rfile_bandwidth_read(self, bytes_to_read):
94         if self.server.bandwidth is None and self.server.delays['mid_read'] == 0:
95             return self.rfile.read(bytes_to_read)
96         else:
97             BYTES_PER_READ = int(self.server.bandwidth/4) or 32768
98             data = b''
99             outage_happened = False
100             bytes_read = 0
101             target_time = time.time()
102             while bytes_to_read > bytes_read:
103                 if bytes_read > self.server.bandwidth and not outage_happened:
104                     self.server._do_delay('mid_read')
105                     target_time += self.server.delays['mid_read']
106                     outage_happened = True
107                 next_bytes_to_read = min(BYTES_PER_READ,
108                     bytes_to_read - bytes_read)
109                 data += self.rfile.read(next_bytes_to_read)
110                 bytes_read += next_bytes_to_read
111                 if self.server.bandwidth is not None:
112                     target_time += next_bytes_to_read / self.server.bandwidth
113                     self.server._sleep_at_least(target_time - time.time())
114         return data
115
116     def finish(self, *args, **kwargs):
117         try:
118             return super(Handler, self).finish(*args, **kwargs)
119         except Exception as err:
120             if _debug:
121                 raise
122
123     def handle(self, *args, **kwargs):
124         try:
125             return super(Handler, self).handle(*args, **kwargs)
126         except:
127             if _debug:
128                 raise
129
130     def handle_one_request(self, *args, **kwargs):
131         self._sent_continue = False
132         self.server._do_delay('request')
133         return super(Handler, self).handle_one_request(*args, **kwargs)
134
135     def handle_expect_100(self):
136         self.server._do_delay('request_body')
137         self._sent_continue = True
138         return super(Handler, self).handle_expect_100()
139
140     def do_GET(self):
141         self.server._do_delay('response')
142         r = re.search(r'[0-9a-f]{32}', self.path)
143         if not r:
144             return self.send_response(422)
145         datahash = r.group(0)
146         if datahash not in self.server.store:
147             return self.send_response(404)
148         self.send_response(200)
149         self.send_header('Connection', 'close')
150         self.send_header('Content-type', 'application/octet-stream')
151         self.end_headers()
152         self.server._do_delay('response_body')
153         self.wfile_bandwidth_write(self.server.store[datahash])
154         self.server._do_delay('response_close')
155
156     def do_HEAD(self):
157         self.server._do_delay('response')
158         r = re.search(r'[0-9a-f]{32}', self.path)
159         if not r:
160             return self.send_response(422)
161         datahash = r.group(0)
162         if datahash not in self.server.store:
163             return self.send_response(404)
164         self.send_response(200)
165         self.send_header('Connection', 'close')
166         self.send_header('Content-type', 'application/octet-stream')
167         self.send_header('Content-length', str(len(self.server.store[datahash])))
168         self.end_headers()
169         self.server._do_delay('response_close')
170         self.close_connection = True
171
172     def do_PUT(self):
173         if not self._sent_continue and self.headers.get('expect') == '100-continue':
174             # The comments at https://bugs.python.org/issue1491
175             # implies that Python 2.7 BaseHTTPRequestHandler was
176             # patched to support 100 Continue, but reading the actual
177             # code that ships in Debian it clearly is not, so we need
178             # to send the response on the socket directly.
179             self.server._do_delay('request_body')
180             self.wfile.write("{} {} {}\r\n\r\n".format(
181                 self.protocol_version, 100, "Continue").encode())
182         data = self.rfile_bandwidth_read(
183             int(self.headers.get('content-length')))
184         datahash = hashlib.md5(data).hexdigest()
185         self.server.store[datahash] = data
186         resp = '{}+{}\n'.format(datahash, len(data)).encode()
187         self.server._do_delay('response')
188         self.send_response(200)
189         self.send_header('Connection', 'close')
190         self.send_header('Content-type', 'text/plain')
191         self.send_header('Content-length', len(resp))
192         self.end_headers()
193         self.server._do_delay('response_body')
194         self.wfile_bandwidth_write(resp)
195         self.server._do_delay('response_close')
196         self.close_connection = True
197
198     def log_request(self, *args, **kwargs):
199         if _debug:
200             super(Handler, self).log_request(*args, **kwargs)