2853: Merge branch 'master' into 2853-rendezvous
[arvados.git] / sdk / python / arvados / keep.py
index 22bf327e79dfd1942e5f452c1de944707818d0cd..23d8c20db546c8e1f22abe2661a270c8ea614fb2 100644 (file)
@@ -1,6 +1,4 @@
 import gflags
 import gflags
-import httplib
-import httplib2
 import logging
 import os
 import pprint
 import logging
 import os
 import pprint
@@ -21,9 +19,7 @@ import timer
 import datetime
 import ssl
 import socket
 import datetime
 import ssl
 import socket
-
-_logger = logging.getLogger('arvados.keep')
-global_client_object = None
+import requests
 
 import arvados
 import arvados.config as config
 
 import arvados
 import arvados.config as config
@@ -31,6 +27,9 @@ import arvados.errors
 import arvados.retry as retry
 import arvados.util
 
 import arvados.retry as retry
 import arvados.util
 
+_logger = logging.getLogger('arvados.keep')
+global_client_object = None
+
 class KeepLocator(object):
     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
 class KeepLocator(object):
     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
@@ -165,7 +164,7 @@ class KeepBlockCache(object):
             self.ready.set()
 
         def size(self):
             self.ready.set()
 
         def size(self):
-            if self.content == None:
+            if self.content is None:
                 return 0
             else:
                 return len(self.content)
                 return 0
             else:
                 return len(self.content)
@@ -210,6 +209,14 @@ class KeepBlockCache(object):
             self._cache_lock.release()
 
 class KeepClient(object):
             self._cache_lock.release()
 
 class KeepClient(object):
+
+    # Default Keep server connection timeout:  2 seconds
+    # Default Keep server read timeout:      300 seconds
+    # Default Keep proxy connection timeout:  20 seconds
+    # Default Keep proxy read timeout:       300 seconds
+    DEFAULT_TIMEOUT = (2, 300)
+    DEFAULT_PROXY_TIMEOUT = (20, 300)
+
     class ThreadLimiter(object):
         """
         Limit the number of threads running at a given time to
     class ThreadLimiter(object):
         """
         Limit the number of threads running at a given time to
@@ -269,7 +276,7 @@ class KeepClient(object):
 
     class KeepService(object):
         # Make requests to a single Keep service, and track results.
 
     class KeepService(object):
         # Make requests to a single Keep service, and track results.
-        HTTP_ERRORS = (httplib2.HttpLib2Error, httplib.HTTPException,
+        HTTP_ERRORS = (requests.exceptions.RequestException,
                        socket.error, ssl.SSLError)
 
         def __init__(self, root, **headers):
                        socket.error, ssl.SSLError)
 
         def __init__(self, root, **headers):
@@ -288,19 +295,19 @@ class KeepClient(object):
 
         def last_status(self):
             try:
 
         def last_status(self):
             try:
-                return int(self.last_result[0].status)
-            except (AttributeError, IndexError, ValueError):
+                return self.last_result.status_code
+            except AttributeError:
                 return None
 
                 return None
 
-        def get(self, http, locator):
-            # http is an httplib2.Http object.
+        def get(self, locator, timeout=None):
             # locator is a KeepLocator object.
             url = self.root + str(locator)
             _logger.debug("Request: GET %s", url)
             try:
                 with timer.Timer() as t:
             # locator is a KeepLocator object.
             url = self.root + str(locator)
             _logger.debug("Request: GET %s", url)
             try:
                 with timer.Timer() as t:
-                    result = http.request(url.encode('utf-8'), 'GET',
-                                          headers=self.get_headers)
+                    result = requests.get(url.encode('utf-8'),
+                                          headers=self.get_headers,
+                                          timeout=timeout)
             except self.HTTP_ERRORS as e:
                 _logger.debug("Request fail: GET %s => %s: %s",
                               url, type(e), str(e))
             except self.HTTP_ERRORS as e:
                 _logger.debug("Request fail: GET %s => %s: %s",
                               url, type(e), str(e))
@@ -308,7 +315,7 @@ class KeepClient(object):
             else:
                 self.last_result = result
                 self.success_flag = retry.check_http_response_success(result)
             else:
                 self.last_result = result
                 self.success_flag = retry.check_http_response_success(result)
-                content = result[1]
+                content = result.content
                 _logger.info("%s response: %s bytes in %s msec (%.3f MiB/sec)",
                              self.last_status(), len(content), t.msecs,
                              (len(content)/(1024.0*1024))/t.secs)
                 _logger.info("%s response: %s bytes in %s msec (%.3f MiB/sec)",
                              self.last_status(), len(content), t.msecs,
                              (len(content)/(1024.0*1024))/t.secs)
@@ -316,15 +323,18 @@ class KeepClient(object):
                     resp_md5 = hashlib.md5(content).hexdigest()
                     if resp_md5 == locator.md5sum:
                         return content
                     resp_md5 = hashlib.md5(content).hexdigest()
                     if resp_md5 == locator.md5sum:
                         return content
-                    _logger.warning("Checksum fail: md5(%s) = %s", url, md5)
+                    _logger.warning("Checksum fail: md5(%s) = %s",
+                                    url, resp_md5)
             return None
 
             return None
 
-        def put(self, http, hash_s, body):
+        def put(self, hash_s, body, timeout=None):
             url = self.root + hash_s
             _logger.debug("Request: PUT %s", url)
             try:
             url = self.root + hash_s
             _logger.debug("Request: PUT %s", url)
             try:
-                result = http.request(url.encode('utf-8'), 'PUT',
-                                      headers=self.put_headers, body=body)
+                result = requests.put(url.encode('utf-8'),
+                                      data=body,
+                                      headers=self.put_headers,
+                                      timeout=timeout)
             except self.HTTP_ERRORS as e:
                 _logger.debug("Request fail: PUT %s => %s: %s",
                               url, type(e), str(e))
             except self.HTTP_ERRORS as e:
                 _logger.debug("Request fail: PUT %s => %s: %s",
                               url, type(e), str(e))
@@ -365,12 +375,13 @@ class KeepClient(object):
                           str(threading.current_thread()),
                           self.args['data_hash'],
                           self.args['service_root'])
                           str(threading.current_thread()),
                           self.args['data_hash'],
                           self.args['service_root'])
-            h = httplib2.Http(timeout=self.args.get('timeout', None))
             self._success = bool(self.service.put(
             self._success = bool(self.service.put(
-                    h, self.args['data_hash'], self.args['data']))
+                self.args['data_hash'],
+                self.args['data'],
+                timeout=self.args.get('timeout', None)))
             status = self.service.last_status()
             if self._success:
             status = self.service.last_status()
             if self._success:
-                resp, body = self.service.last_result
+                result = self.service.last_result
                 _logger.debug("KeepWriterThread %s succeeded %s %s",
                               str(threading.current_thread()),
                               self.args['data_hash'],
                 _logger.debug("KeepWriterThread %s succeeded %s %s",
                               str(threading.current_thread()),
                               self.args['data_hash'],
@@ -380,18 +391,20 @@ class KeepClient(object):
                 # we're talking to a proxy or other backend that
                 # stores to multiple copies for us.
                 try:
                 # we're talking to a proxy or other backend that
                 # stores to multiple copies for us.
                 try:
-                    replicas_stored = int(resp['x-keep-replicas-stored'])
+                    replicas_stored = int(result.headers['x-keep-replicas-stored'])
                 except (KeyError, ValueError):
                     replicas_stored = 1
                 except (KeyError, ValueError):
                     replicas_stored = 1
-                limiter.save_response(body.strip(), replicas_stored)
+                limiter.save_response(result.text.strip(), replicas_stored)
             elif status is not None:
                 _logger.debug("Request fail: PUT %s => %s %s",
                               self.args['data_hash'], status,
             elif status is not None:
                 _logger.debug("Request fail: PUT %s => %s %s",
                               self.args['data_hash'], status,
-                              self.service.last_result[1])
+                              self.service.last_result.text)
 
 
 
 
-    def __init__(self, api_client=None, proxy=None, timeout=300,
-                 api_token=None, local_store=None, block_cache=None):
+    def __init__(self, api_client=None, proxy=None,
+                 timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
+                 api_token=None, local_store=None, block_cache=None,
+                 num_retries=0):
         """Initialize a new KeepClient.
 
         Arguments:
         """Initialize a new KeepClient.
 
         Arguments:
@@ -402,8 +415,14 @@ class KeepClient(object):
           Keep proxy.  Otherwise, KeepClient will fall back to the setting
           of the ARVADOS_KEEP_PROXY configuration setting.  If you want to
           ensure KeepClient does not use a proxy, pass in an empty string.
           Keep proxy.  Otherwise, KeepClient will fall back to the setting
           of the ARVADOS_KEEP_PROXY configuration setting.  If you want to
           ensure KeepClient does not use a proxy, pass in an empty string.
-        * timeout: The timeout for all HTTP requests, in seconds.  Default
-          300.
+        * timeout: The timeout (in seconds) for HTTP requests to Keep
+          non-proxy servers.  A tuple of two floats is interpreted as
+          (connection_timeout, read_timeout): see
+          http://docs.python-requests.org/en/latest/user/advanced/#timeouts.
+          Default: (2, 300).
+        * proxy_timeout: The timeout (in seconds) for HTTP requests to
+          Keep proxies. A tuple of two floats is interpreted as
+          (connection_timeout, read_timeout). Default: (20, 300).
         * api_token: If you're not using an API client, but only talking
           directly to a Keep proxy, this parameter specifies an API token
           to authenticate Keep requests.  It is an error to specify both
         * api_token: If you're not using an API client, but only talking
           directly to a Keep proxy, this parameter specifies an API token
           to authenticate Keep requests.  It is an error to specify both
@@ -415,12 +434,18 @@ class KeepClient(object):
           environment variable.  If you want to ensure KeepClient does not
           use local storage, pass in an empty string.  This is primarily
           intended to mock a server for testing.
           environment variable.  If you want to ensure KeepClient does not
           use local storage, pass in an empty string.  This is primarily
           intended to mock a server for testing.
+        * num_retries: The default number of times to retry failed requests.
+          This will be used as the default num_retries value when get() and
+          put() are called.  Default 0.
         """
         self.lock = threading.Lock()
         if proxy is None:
             proxy = config.get('ARVADOS_KEEP_PROXY')
         if api_token is None:
         """
         self.lock = threading.Lock()
         if proxy is None:
             proxy = config.get('ARVADOS_KEEP_PROXY')
         if api_token is None:
-            api_token = config.get('ARVADOS_API_TOKEN')
+            if api_client is None:
+                api_token = config.get('ARVADOS_API_TOKEN')
+            else:
+                api_token = api_client.api_token
         elif api_client is not None:
             raise ValueError(
                 "can't build KeepClient with both API client and token")
         elif api_client is not None:
             raise ValueError(
                 "can't build KeepClient with both API client and token")
@@ -428,21 +453,25 @@ class KeepClient(object):
             local_store = os.environ.get('KEEP_LOCAL_STORE')
 
         self.block_cache = block_cache if block_cache else KeepBlockCache()
             local_store = os.environ.get('KEEP_LOCAL_STORE')
 
         self.block_cache = block_cache if block_cache else KeepBlockCache()
+        self.timeout = timeout
+        self.proxy_timeout = proxy_timeout
 
         if local_store:
             self.local_store = local_store
             self.get = self.local_store_get
             self.put = self.local_store_put
         else:
 
         if local_store:
             self.local_store = local_store
             self.get = self.local_store_get
             self.put = self.local_store_put
         else:
-            self.timeout = timeout
-
+            self.num_retries = num_retries
             if proxy:
                 if not proxy.endswith('/'):
                     proxy += '/'
                 self.api_token = api_token
             if proxy:
                 if not proxy.endswith('/'):
                     proxy += '/'
                 self.api_token = api_token
-                self.service_roots = [proxy]
+                self._keep_services = [{
+                    'uuid': 'proxy',
+                    '_service_root': proxy,
+                    }]
                 self.using_proxy = True
                 self.using_proxy = True
-                self.static_service_roots = True
+                self._static_services_list = True
             else:
                 # It's important to avoid instantiating an API client
                 # unless we actually need one, for testing's sake.
             else:
                 # It's important to avoid instantiating an API client
                 # unless we actually need one, for testing's sake.
@@ -450,13 +479,22 @@ class KeepClient(object):
                     api_client = arvados.api('v1')
                 self.api_client = api_client
                 self.api_token = api_client.api_token
                     api_client = arvados.api('v1')
                 self.api_client = api_client
                 self.api_token = api_client.api_token
-                self.service_roots = None
+                self._keep_services = None
                 self.using_proxy = None
                 self.using_proxy = None
-                self.static_service_roots = False
+                self._static_services_list = False
+
+    def current_timeout(self):
+        """Return the appropriate timeout to use for this client: the proxy
+        timeout setting if the backend service is currently a proxy,
+        the regular timeout setting otherwise.
+        """
+        # TODO(twp): the timeout should be a property of a
+        # KeepService, not a KeepClient. See #4488.
+        return self.proxy_timeout if self.using_proxy else self.timeout
 
 
-    def build_service_roots(self, force_rebuild=False):
-        if (self.static_service_roots or
-              (self.service_roots and not force_rebuild)):
+    def build_services_list(self, force_rebuild=False):
+        if (self._static_services_list or
+              (self._keep_services and not force_rebuild)):
             return
         with self.lock:
             try:
             return
         with self.lock:
             try:
@@ -464,68 +502,47 @@ class KeepClient(object):
             except Exception:  # API server predates Keep services.
                 keep_services = self.api_client.keep_disks().list()
 
             except Exception:  # API server predates Keep services.
                 keep_services = self.api_client.keep_disks().list()
 
-            keep_services = keep_services.execute().get('items')
-            if not keep_services:
+            self._keep_services = keep_services.execute().get('items')
+            if not self._keep_services:
                 raise arvados.errors.NoKeepServersError()
 
                 raise arvados.errors.NoKeepServersError()
 
-            self.using_proxy = (keep_services[0].get('service_type') ==
-                                'proxy')
-
-            roots = (("http%s://%s:%d/" %
-                      ('s' if f['service_ssl_flag'] else '',
-                       f['service_host'],
-                       f['service_port']))
-                     for f in keep_services)
-            self.service_roots = sorted(set(roots))
-            _logger.debug(str(self.service_roots))
-
-    def shuffled_service_roots(self, hash, force_rebuild=False):
-        self.build_service_roots(force_rebuild)
-
-        # Build an ordering with which to query the Keep servers based on the
-        # contents of the hash.
-        # "hash" is a hex-encoded number at least 8 digits
-        # (32 bits) long
-
-        # seed used to calculate the next keep server from 'pool'
-        # to be added to 'pseq'
-        seed = hash
-
-        # Keep servers still to be added to the ordering
-        pool = self.service_roots[:]
-
-        # output probe sequence
-        pseq = []
-
-        # iterate while there are servers left to be assigned
-        while len(pool) > 0:
-            if len(seed) < 8:
-                # ran out of digits in the seed
-                if len(pseq) < len(hash) / 4:
-                    # the number of servers added to the probe sequence is less
-                    # than the number of 4-digit slices in 'hash' so refill the
-                    # seed with the last 4 digits and then append the contents
-                    # of 'hash'.
-                    seed = hash[-4:] + hash
-                else:
-                    # refill the seed with the contents of 'hash'
-                    seed += hash
-
-            # Take the next 8 digits (32 bytes) and interpret as an integer,
-            # then modulus with the size of the remaining pool to get the next
-            # selected server.
-            probe = int(seed[0:8], 16) % len(pool)
-
-            # Append the selected server to the probe sequence and remove it
-            # from the pool.
-            pseq += [pool[probe]]
-            pool = pool[:probe] + pool[probe+1:]
-
-            # Remove the digits just used from the seed
-            seed = seed[8:]
-        _logger.debug(str(pseq))
-        return pseq
+            self.using_proxy = any(ks.get('service_type') == 'proxy'
+                                   for ks in self._keep_services)
+
+            # Precompute the base URI for each service.
+            for r in self._keep_services:
+                r['_service_root'] = "{}://[{}]:{:d}/".format(
+                    'https' if r['service_ssl_flag'] else 'http',
+                    r['service_host'],
+                    r['service_port'])
+            _logger.debug(str(self._keep_services))
 
 
+    def _service_weight(self, hash, service_uuid):
+        """Compute the weight of a Keep service endpoint for a data
+        block with a known hash.
+
+        The weight is md5(h + u) where u is the last 15 characters of
+        the service endpoint's UUID.
+        """
+        return hashlib.md5(hash + service_uuid[-15:]).hexdigest()
+
+    def weighted_service_roots(self, hash, force_rebuild=False):
+        """Return an array of Keep service endpoints, in the order in
+        which they should be probed when reading or writing data with
+        the given hash.
+        """
+        self.build_services_list(force_rebuild)
+
+        # Sort the available services by weight (heaviest first) for
+        # this hash, and return their service_roots (base URIs) in
+        # that order.
+        sorted_roots = [
+            svc['_service_root'] for svc in sorted(
+                self._keep_services,
+                reverse=True,
+                key=lambda svc: self._service_weight(hash, svc['uuid']))]
+        _logger.debug(hash + ': ' + str(sorted_roots))
+        return sorted_roots
 
     def map_new_services(self, roots_map, md5_s, force_rebuild, **headers):
         # roots_map is a dictionary, mapping Keep service root strings
 
     def map_new_services(self, roots_map, md5_s, force_rebuild, **headers):
         # roots_map is a dictionary, mapping Keep service root strings
@@ -533,7 +550,7 @@ class KeepClient(object):
         # new ones to roots_map.  Return the current list of local
         # root strings.
         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
         # new ones to roots_map.  Return the current list of local
         # root strings.
         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
-        local_roots = self.shuffled_service_roots(md5_s, force_rebuild)
+        local_roots = self.weighted_service_roots(md5_s, force_rebuild)
         for root in local_roots:
             if root not in roots_map:
                 roots_map[root] = self.KeepService(root, **headers)
         for root in local_roots:
             if root not in roots_map:
                 roots_map[root] = self.KeepService(root, **headers)
@@ -557,7 +574,8 @@ class KeepClient(object):
         else:
             return None
 
         else:
             return None
 
-    def get(self, loc_s, num_retries=0):
+    @retry.retry_method
+    def get(self, loc_s, num_retries=None):
         """Get data from Keep.
 
         This method fetches one or more blocks of data from Keep.  It
         """Get data from Keep.
 
         This method fetches one or more blocks of data from Keep.  It
@@ -574,7 +592,8 @@ class KeepClient(object):
           *each* Keep server if it returns temporary failures, with
           exponential backoff.  Note that, in each loop, the method may try
           to fetch data from every available Keep service, along with any
           *each* Keep server if it returns temporary failures, with
           exponential backoff.  Note that, in each loop, the method may try
           to fetch data from every available Keep service, along with any
-          that are named in location hints in the locator.  Default 0.
+          that are named in location hints in the locator.  The default value
+          is set when the KeepClient is initialized.
         """
         if ',' in loc_s:
             return ''.join(self.get(x) for x in loc_s.split(','))
         """
         if ',' in loc_s:
             return ''.join(self.get(x) for x in loc_s.split(','))
@@ -612,9 +631,8 @@ class KeepClient(object):
             services_to_try = [roots_map[root]
                                for root in (local_roots + hint_roots)
                                if roots_map[root].usable()]
             services_to_try = [roots_map[root]
                                for root in (local_roots + hint_roots)
                                if roots_map[root].usable()]
-            http = httplib2.Http(timeout=self.timeout)
             for keep_service in services_to_try:
             for keep_service in services_to_try:
-                blob = keep_service.get(http, locator)
+                blob = keep_service.get(locator, timeout=self.current_timeout())
                 if blob is not None:
                     break
             loop.save_result((blob, len(services_to_try)))
                 if blob is not None:
                     break
             loop.save_result((blob, len(services_to_try)))
@@ -637,7 +655,8 @@ class KeepClient(object):
         else:
             raise arvados.errors.KeepReadError(loc_s)
 
         else:
             raise arvados.errors.KeepReadError(loc_s)
 
-    def put(self, data, copies=2, num_retries=0):
+    @retry.retry_method
+    def put(self, data, copies=2, num_retries=None):
         """Save data in Keep.
 
         This method will get a list of Keep services from the API server, and
         """Save data in Keep.
 
         This method will get a list of Keep services from the API server, and
@@ -652,7 +671,8 @@ class KeepClient(object):
           Default 2.
         * num_retries: The number of times to retry PUT requests to
           *each* Keep server if it returns temporary failures, with
           Default 2.
         * num_retries: The number of times to retry PUT requests to
           *each* Keep server if it returns temporary failures, with
-          exponential backoff.  Default 0.
+          exponential backoff.  The default value is set when the
+          KeepClient is initialized.
         """
         data_hash = hashlib.md5(data).hexdigest()
         if copies < 1:
         """
         data_hash = hashlib.md5(data).hexdigest()
         if copies < 1:
@@ -685,7 +705,7 @@ class KeepClient(object):
                     data_hash=data_hash,
                     service_root=service_root,
                     thread_limiter=thread_limiter,
                     data_hash=data_hash,
                     service_root=service_root,
                     thread_limiter=thread_limiter,
-                    timeout=self.timeout)
+                    timeout=self.current_timeout())
                 t.start()
                 threads.append(t)
             for t in threads:
                 t.start()
                 threads.append(t)
             for t in threads:
@@ -698,7 +718,10 @@ class KeepClient(object):
             "Write fail for %s: wanted %d but wrote %d" %
             (data_hash, copies, thread_limiter.done()))
 
             "Write fail for %s: wanted %d but wrote %d" %
             (data_hash, copies, thread_limiter.done()))
 
-    def local_store_put(self, data):
+    # Local storage methods need no-op num_retries arguments to keep
+    # integration tests happy.  With better isolation they could
+    # probably be removed again.
+    def local_store_put(self, data, num_retries=0):
         md5 = hashlib.md5(data).hexdigest()
         locator = '%s+%d' % (md5, len(data))
         with open(os.path.join(self.local_store, md5 + '.tmp'), 'w') as f:
         md5 = hashlib.md5(data).hexdigest()
         locator = '%s+%d' % (md5, len(data))
         with open(os.path.join(self.local_store, md5 + '.tmp'), 'w') as f:
@@ -707,7 +730,7 @@ class KeepClient(object):
                   os.path.join(self.local_store, md5))
         return locator
 
                   os.path.join(self.local_store, md5))
         return locator
 
-    def local_store_get(self, loc_s):
+    def local_store_get(self, loc_s, num_retries=0):
         try:
             locator = KeepLocator(loc_s)
         except ValueError:
         try:
             locator = KeepLocator(loc_s)
         except ValueError: