18870: Need to declare NODES as array
[arvados.git] / sdk / python / tests / keepstub.py
index 965bf299b86d9bb431e82dec94c46317922a1222..6be8d8b6465b0720bf594341536ac4b0cfb3c212 100644 (file)
@@ -1,15 +1,49 @@
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
 from __future__ import division
 from future import standard_library
 standard_library.install_aliases()
 from builtins import str
-from past.utils import old_div
 import http.server
 import hashlib
 import os
 import re
+import socket
 import socketserver
+import sys
+import threading
 import time
 
+from . import arvados_testutil as tutil
+
+_debug = os.environ.get('ARVADOS_DEBUG', None)
+
+
+class StubKeepServers(tutil.ApiClientMock):
+
+    def setUp(self):
+        super(StubKeepServers, self).setUp()
+        sock = socket.socket()
+        sock.bind(('0.0.0.0', 0))
+        self.port = sock.getsockname()[1]
+        sock.close()
+        self.server = Server(('0.0.0.0', self.port), Handler)
+        self.thread = threading.Thread(target=self.server.serve_forever)
+        self.thread.daemon = True # Exit thread if main proc exits
+        self.thread.start()
+        self.api_client = self.mock_keep_services(
+            count=1,
+            service_host='localhost',
+            service_port=self.port,
+        )
+
+    def tearDown(self):
+        self.server.shutdown()
+        super(StubKeepServers, self).tearDown()
+
+
 class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object):
 
     allow_reuse_address = 1
@@ -38,7 +72,7 @@ class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object):
     def setdelays(self, **kwargs):
         """In future requests, induce delays at the given checkpoints."""
         for (k, v) in kwargs.items():
-            self.delays.get(k) # NameError if unknown key
+            self.delays.get(k)  # NameError if unknown key
             self.delays[k] = v
 
     def setbandwidth(self, bandwidth):
@@ -60,11 +94,14 @@ class Server(socketserver.ThreadingMixIn, http.server.HTTPServer, object):
 
 
 class Handler(http.server.BaseHTTPRequestHandler, object):
+
+    protocol_version = 'HTTP/1.1'
+
     def wfile_bandwidth_write(self, data_to_write):
-        if self.server.bandwidth == None and self.server.delays['mid_write'] == 0:
+        if self.server.bandwidth is None and self.server.delays['mid_write'] == 0:
             self.wfile.write(data_to_write)
         else:
-            BYTES_PER_WRITE = int(old_div(self.server.bandwidth,4.0)) or 32768
+            BYTES_PER_WRITE = int(self.server.bandwidth/4) or 32768
             outage_happened = False
             num_bytes = len(data_to_write)
             num_sent_bytes = 0
@@ -72,7 +109,7 @@ class Handler(http.server.BaseHTTPRequestHandler, object):
             while num_sent_bytes < num_bytes:
                 if num_sent_bytes > self.server.bandwidth and not outage_happened:
                     self.server._do_delay('mid_write')
-                    target_time += self.delays['mid_write']
+                    target_time += self.server.delays['mid_write']
                     outage_happened = True
                 num_write_bytes = min(BYTES_PER_WRITE,
                     num_bytes - num_sent_bytes)
@@ -80,36 +117,56 @@ class Handler(http.server.BaseHTTPRequestHandler, object):
                     num_sent_bytes:num_sent_bytes+num_write_bytes])
                 num_sent_bytes += num_write_bytes
                 if self.server.bandwidth is not None:
-                    target_time += old_div(num_write_bytes, self.server.bandwidth)
+                    target_time += num_write_bytes / self.server.bandwidth
                     self.server._sleep_at_least(target_time - time.time())
         return None
 
     def rfile_bandwidth_read(self, bytes_to_read):
-        if self.server.bandwidth == None and self.server.delays['mid_read'] == 0:
+        if self.server.bandwidth is None and self.server.delays['mid_read'] == 0:
             return self.rfile.read(bytes_to_read)
         else:
-            BYTES_PER_READ = int(old_div(self.server.bandwidth,4.0)) or 32768
-            data = ''
+            BYTES_PER_READ = int(self.server.bandwidth/4) or 32768
+            data = b''
             outage_happened = False
             bytes_read = 0
             target_time = time.time()
             while bytes_to_read > bytes_read:
                 if bytes_read > self.server.bandwidth and not outage_happened:
                     self.server._do_delay('mid_read')
-                    target_time += self.delays['mid_read']
+                    target_time += self.server.delays['mid_read']
                     outage_happened = True
                 next_bytes_to_read = min(BYTES_PER_READ,
                     bytes_to_read - bytes_read)
                 data += self.rfile.read(next_bytes_to_read)
                 bytes_read += next_bytes_to_read
                 if self.server.bandwidth is not None:
-                    target_time += old_div(next_bytes_to_read, self.server.bandwidth)
+                    target_time += next_bytes_to_read / self.server.bandwidth
                     self.server._sleep_at_least(target_time - time.time())
         return data
 
+    def finish(self, *args, **kwargs):
+        try:
+            return super(Handler, self).finish(*args, **kwargs)
+        except Exception as err:
+            if _debug:
+                raise
+
     def handle(self, *args, **kwargs):
+        try:
+            return super(Handler, self).handle(*args, **kwargs)
+        except:
+            if _debug:
+                raise
+
+    def handle_one_request(self, *args, **kwargs):
+        self._sent_continue = False
         self.server._do_delay('request')
-        return super(Handler, self).handle(*args, **kwargs)
+        return super(Handler, self).handle_one_request(*args, **kwargs)
+
+    def handle_expect_100(self):
+        self.server._do_delay('request_body')
+        self._sent_continue = True
+        return super(Handler, self).handle_expect_100()
 
     def do_GET(self):
         self.server._do_delay('response')
@@ -120,6 +177,7 @@ class Handler(http.server.BaseHTTPRequestHandler, object):
         if datahash not in self.server.store:
             return self.send_response(404)
         self.send_response(200)
+        self.send_header('Connection', 'close')
         self.send_header('Content-type', 'application/octet-stream')
         self.end_headers()
         self.server._do_delay('response_body')
@@ -135,44 +193,39 @@ class Handler(http.server.BaseHTTPRequestHandler, object):
         if datahash not in self.server.store:
             return self.send_response(404)
         self.send_response(200)
+        self.send_header('Connection', 'close')
         self.send_header('Content-type', 'application/octet-stream')
         self.send_header('Content-length', str(len(self.server.store[datahash])))
         self.end_headers()
         self.server._do_delay('response_close')
+        self.close_connection = True
 
     def do_PUT(self):
-        self.server._do_delay('request_body')
-        # The comments at https://bugs.python.org/issue1491 implies that Python
-        # 2.7 BaseHTTPRequestHandler was patched to support 100 Continue, but
-        # reading the actual code that ships in Debian it clearly is not, so we
-        # need to send the response on the socket directly.
-        self.wfile_bandwidth_write("%s %d %s\r\n\r\n" %
-                         (self.protocol_version, 100, "Continue"))
-        data = self.rfile_bandwidth_read(int(self.headers.getheader('content-length')))
+        if not self._sent_continue and self.headers.get('expect') == '100-continue':
+            # The comments at https://bugs.python.org/issue1491
+            # implies that Python 2.7 BaseHTTPRequestHandler was
+            # patched to support 100 Continue, but reading the actual
+            # code that ships in Debian it clearly is not, so we need
+            # to send the response on the socket directly.
+            self.server._do_delay('request_body')
+            self.wfile.write("{} {} {}\r\n\r\n".format(
+                self.protocol_version, 100, "Continue").encode())
+        data = self.rfile_bandwidth_read(
+            int(self.headers.get('content-length')))
         datahash = hashlib.md5(data).hexdigest()
         self.server.store[datahash] = data
+        resp = '{}+{}\n'.format(datahash, len(data)).encode()
         self.server._do_delay('response')
         self.send_response(200)
+        self.send_header('Connection', 'close')
         self.send_header('Content-type', 'text/plain')
+        self.send_header('Content-length', len(resp))
         self.end_headers()
         self.server._do_delay('response_body')
-        self.wfile_bandwidth_write(datahash + '+' + str(len(data)))
+        self.wfile_bandwidth_write(resp)
         self.server._do_delay('response_close')
+        self.close_connection = True
 
     def log_request(self, *args, **kwargs):
-        if os.environ.get('ARVADOS_DEBUG', None):
+        if _debug:
             super(Handler, self).log_request(*args, **kwargs)
-
-    def finish(self, *args, **kwargs):
-        """Ignore exceptions, notably "Broken pipe" when client times out."""
-        try:
-            return super(Handler, self).finish(*args, **kwargs)
-        except:
-            pass
-
-    def handle_one_request(self, *args, **kwargs):
-        """Ignore exceptions, notably "Broken pipe" when client times out."""
-        try:
-            return super(Handler, self).handle_one_request(*args, **kwargs)
-        except:
-            pass