Merge branch 'master' into 4823-python-sdk-writable-collection-api
[arvados.git] / sdk / python / arvados / keep.py
index 22bf327e79dfd1942e5f452c1de944707818d0cd..cc7dbd4161d1586e0c530c2d90a97c5c099670d0 100644 (file)
@@ -1,6 +1,4 @@
 import gflags
-import httplib
-import httplib2
 import logging
 import os
 import pprint
@@ -21,9 +19,7 @@ import timer
 import datetime
 import ssl
 import socket
-
-_logger = logging.getLogger('arvados.keep')
-global_client_object = None
+import requests
 
 import arvados
 import arvados.config as config
@@ -31,6 +27,9 @@ import arvados.errors
 import arvados.retry as retry
 import arvados.util
 
+_logger = logging.getLogger('arvados.keep')
+global_client_object = None
+
 class KeepLocator(object):
     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
@@ -59,6 +58,9 @@ class KeepLocator(object):
                              self.permission_hint()] + self.hints
             if s is not None)
 
+    def stripped(self):
+        return "%s+%i" % (self.md5sum, self.size)
+
     def _make_hex_prop(name, length):
         # Build and return a new property with the given name that
         # must be a hex string of the given length.
@@ -165,15 +167,14 @@ class KeepBlockCache(object):
             self.ready.set()
 
         def size(self):
-            if self.content == None:
+            if self.content is None:
                 return 0
             else:
                 return len(self.content)
 
     def cap_cache(self):
         '''Cap the cache size to self.cache_max'''
-        self._cache_lock.acquire()
-        try:
+        with self._cache_lock:
             # Select all slots except those where ready.is_set() and content is
             # None (that means there was an error reading the block).
             self._cache = [c for c in self._cache if not (c.ready.is_set() and c.content is None)]
@@ -184,32 +185,45 @@ class KeepBlockCache(object):
                         del self._cache[i]
                         break
                 sm = sum([slot.size() for slot in self._cache])
-        finally:
-            self._cache_lock.release()
+
+    def _get(self, locator):
+        # Test if the locator is already in the cache
+        for i in xrange(0, len(self._cache)):
+            if self._cache[i].locator == locator:
+                n = self._cache[i]
+                if i != 0:
+                    # move it to the front
+                    del self._cache[i]
+                    self._cache.insert(0, n)
+                return n
+        return None
+
+    def get(self, locator):
+        with self._cache_lock:
+            return self._get(locator)
 
     def reserve_cache(self, locator):
         '''Reserve a cache slot for the specified locator,
         or return the existing slot.'''
-        self._cache_lock.acquire()
-        try:
-            # Test if the locator is already in the cache
-            for i in xrange(0, len(self._cache)):
-                if self._cache[i].locator == locator:
-                    n = self._cache[i]
-                    if i != 0:
-                        # move it to the front
-                        del self._cache[i]
-                        self._cache.insert(0, n)
-                    return n, False
-
-            # Add a new cache slot for the locator
-            n = KeepBlockCache.CacheSlot(locator)
-            self._cache.insert(0, n)
-            return n, True
-        finally:
-            self._cache_lock.release()
+        with self._cache_lock:
+            n = self._get(locator)
+            if n:
+                return n, False
+            else:
+                # Add a new cache slot for the locator
+                n = KeepBlockCache.CacheSlot(locator)
+                self._cache.insert(0, n)
+                return n, True
 
 class KeepClient(object):
+
+    # Default Keep server connection timeout:  2 seconds
+    # Default Keep server read timeout:      300 seconds
+    # Default Keep proxy connection timeout:  20 seconds
+    # Default Keep proxy read timeout:       300 seconds
+    DEFAULT_TIMEOUT = (2, 300)
+    DEFAULT_PROXY_TIMEOUT = (20, 300)
+
     class ThreadLimiter(object):
         """
         Limit the number of threads running at a given time to
@@ -269,13 +283,14 @@ class KeepClient(object):
 
     class KeepService(object):
         # Make requests to a single Keep service, and track results.
-        HTTP_ERRORS = (httplib2.HttpLib2Error, httplib.HTTPException,
+        HTTP_ERRORS = (requests.exceptions.RequestException,
                        socket.error, ssl.SSLError)
 
-        def __init__(self, root, **headers):
+        def __init__(self, root, session, **headers):
             self.root = root
             self.last_result = None
             self.success_flag = None
+            self.session = session
             self.get_headers = {'Accept': 'application/octet-stream'}
             self.get_headers.update(headers)
             self.put_headers = headers
@@ -288,19 +303,19 @@ class KeepClient(object):
 
         def last_status(self):
             try:
-                return int(self.last_result[0].status)
-            except (AttributeError, IndexError, ValueError):
+                return self.last_result.status_code
+            except AttributeError:
                 return None
 
-        def get(self, http, locator):
-            # http is an httplib2.Http object.
+        def get(self, locator, timeout=None):
             # locator is a KeepLocator object.
             url = self.root + str(locator)
             _logger.debug("Request: GET %s", url)
             try:
                 with timer.Timer() as t:
-                    result = http.request(url.encode('utf-8'), 'GET',
-                                          headers=self.get_headers)
+                    result = self.session.get(url.encode('utf-8'),
+                                          headers=self.get_headers,
+                                          timeout=timeout)
             except self.HTTP_ERRORS as e:
                 _logger.debug("Request fail: GET %s => %s: %s",
                               url, type(e), str(e))
@@ -308,23 +323,26 @@ class KeepClient(object):
             else:
                 self.last_result = result
                 self.success_flag = retry.check_http_response_success(result)
-                content = result[1]
+                content = result.content
                 _logger.info("%s response: %s bytes in %s msec (%.3f MiB/sec)",
                              self.last_status(), len(content), t.msecs,
-                             (len(content)/(1024.0*1024))/t.secs)
+                             (len(content)/(1024.0*1024))/t.secs if t.secs > 0 else 0)
                 if self.success_flag:
                     resp_md5 = hashlib.md5(content).hexdigest()
                     if resp_md5 == locator.md5sum:
                         return content
-                    _logger.warning("Checksum fail: md5(%s) = %s", url, md5)
+                    _logger.warning("Checksum fail: md5(%s) = %s",
+                                    url, resp_md5)
             return None
 
-        def put(self, http, hash_s, body):
+        def put(self, hash_s, body, timeout=None):
             url = self.root + hash_s
             _logger.debug("Request: PUT %s", url)
             try:
-                result = http.request(url.encode('utf-8'), 'PUT',
-                                      headers=self.put_headers, body=body)
+                result = self.session.put(url.encode('utf-8'),
+                                      data=body,
+                                      headers=self.put_headers,
+                                      timeout=timeout)
             except self.HTTP_ERRORS as e:
                 _logger.debug("Request fail: PUT %s => %s: %s",
                               url, type(e), str(e))
@@ -361,37 +379,42 @@ class KeepClient(object):
         def run_with_limiter(self, limiter):
             if self.service.finished():
                 return
-            _logger.debug("KeepWriterThread %s proceeding %s %s",
+            _logger.debug("KeepWriterThread %s proceeding %s+%i %s",
                           str(threading.current_thread()),
                           self.args['data_hash'],
+                          len(self.args['data']),
                           self.args['service_root'])
-            h = httplib2.Http(timeout=self.args.get('timeout', None))
             self._success = bool(self.service.put(
-                    h, self.args['data_hash'], self.args['data']))
+                self.args['data_hash'],
+                self.args['data'],
+                timeout=self.args.get('timeout', None)))
             status = self.service.last_status()
             if self._success:
-                resp, body = self.service.last_result
-                _logger.debug("KeepWriterThread %s succeeded %s %s",
+                result = self.service.last_result
+                _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
                               str(threading.current_thread()),
                               self.args['data_hash'],
+                              len(self.args['data']),
                               self.args['service_root'])
                 # Tick the 'done' counter for the number of replica
                 # reported stored by the server, for the case that
                 # we're talking to a proxy or other backend that
                 # stores to multiple copies for us.
                 try:
-                    replicas_stored = int(resp['x-keep-replicas-stored'])
+                    replicas_stored = int(result.headers['x-keep-replicas-stored'])
                 except (KeyError, ValueError):
                     replicas_stored = 1
-                limiter.save_response(body.strip(), replicas_stored)
+                limiter.save_response(result.text.strip(), replicas_stored)
             elif status is not None:
                 _logger.debug("Request fail: PUT %s => %s %s",
                               self.args['data_hash'], status,
-                              self.service.last_result[1])
+                              self.service.last_result.text)
 
 
-    def __init__(self, api_client=None, proxy=None, timeout=300,
-                 api_token=None, local_store=None, block_cache=None):
+    def __init__(self, api_client=None, proxy=None,
+                 timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
+                 api_token=None, local_store=None, block_cache=None,
+                 num_retries=0, session=None):
         """Initialize a new KeepClient.
 
         Arguments:
@@ -402,8 +425,14 @@ class KeepClient(object):
           Keep proxy.  Otherwise, KeepClient will fall back to the setting
           of the ARVADOS_KEEP_PROXY configuration setting.  If you want to
           ensure KeepClient does not use a proxy, pass in an empty string.
-        * timeout: The timeout for all HTTP requests, in seconds.  Default
-          300.
+        * timeout: The timeout (in seconds) for HTTP requests to Keep
+          non-proxy servers.  A tuple of two floats is interpreted as
+          (connection_timeout, read_timeout): see
+          http://docs.python-requests.org/en/latest/user/advanced/#timeouts.
+          Default: (2, 300).
+        * proxy_timeout: The timeout (in seconds) for HTTP requests to
+          Keep proxies. A tuple of two floats is interpreted as
+          (connection_timeout, read_timeout). Default: (20, 300).
         * api_token: If you're not using an API client, but only talking
           directly to a Keep proxy, this parameter specifies an API token
           to authenticate Keep requests.  It is an error to specify both
@@ -415,12 +444,18 @@ class KeepClient(object):
           environment variable.  If you want to ensure KeepClient does not
           use local storage, pass in an empty string.  This is primarily
           intended to mock a server for testing.
+        * num_retries: The default number of times to retry failed requests.
+          This will be used as the default num_retries value when get() and
+          put() are called.  Default 0.
         """
         self.lock = threading.Lock()
         if proxy is None:
             proxy = config.get('ARVADOS_KEEP_PROXY')
         if api_token is None:
-            api_token = config.get('ARVADOS_API_TOKEN')
+            if api_client is None:
+                api_token = config.get('ARVADOS_API_TOKEN')
+            else:
+                api_token = api_client.api_token
         elif api_client is not None:
             raise ValueError(
                 "can't build KeepClient with both API client and token")
@@ -428,21 +463,26 @@ class KeepClient(object):
             local_store = os.environ.get('KEEP_LOCAL_STORE')
 
         self.block_cache = block_cache if block_cache else KeepBlockCache()
+        self.timeout = timeout
+        self.proxy_timeout = proxy_timeout
 
         if local_store:
             self.local_store = local_store
             self.get = self.local_store_get
             self.put = self.local_store_put
         else:
-            self.timeout = timeout
-
+            self.num_retries = num_retries
+            self.session = session if session is not None else requests.Session()
             if proxy:
                 if not proxy.endswith('/'):
                     proxy += '/'
                 self.api_token = api_token
-                self.service_roots = [proxy]
+                self._keep_services = [{
+                    'uuid': 'proxy',
+                    '_service_root': proxy,
+                    }]
                 self.using_proxy = True
-                self.static_service_roots = True
+                self._static_services_list = True
             else:
                 # It's important to avoid instantiating an API client
                 # unless we actually need one, for testing's sake.
@@ -450,13 +490,22 @@ class KeepClient(object):
                     api_client = arvados.api('v1')
                 self.api_client = api_client
                 self.api_token = api_client.api_token
-                self.service_roots = None
+                self._keep_services = None
                 self.using_proxy = None
-                self.static_service_roots = False
+                self._static_services_list = False
 
-    def build_service_roots(self, force_rebuild=False):
-        if (self.static_service_roots or
-              (self.service_roots and not force_rebuild)):
+    def current_timeout(self):
+        """Return the appropriate timeout to use for this client: the proxy
+        timeout setting if the backend service is currently a proxy,
+        the regular timeout setting otherwise.
+        """
+        # TODO(twp): the timeout should be a property of a
+        # KeepService, not a KeepClient. See #4488.
+        return self.proxy_timeout if self.using_proxy else self.timeout
+
+    def build_services_list(self, force_rebuild=False):
+        if (self._static_services_list or
+              (self._keep_services and not force_rebuild)):
             return
         with self.lock:
             try:
@@ -464,68 +513,47 @@ class KeepClient(object):
             except Exception:  # API server predates Keep services.
                 keep_services = self.api_client.keep_disks().list()
 
-            keep_services = keep_services.execute().get('items')
-            if not keep_services:
+            self._keep_services = keep_services.execute().get('items')
+            if not self._keep_services:
                 raise arvados.errors.NoKeepServersError()
 
-            self.using_proxy = (keep_services[0].get('service_type') ==
-                                'proxy')
-
-            roots = (("http%s://%s:%d/" %
-                      ('s' if f['service_ssl_flag'] else '',
-                       f['service_host'],
-                       f['service_port']))
-                     for f in keep_services)
-            self.service_roots = sorted(set(roots))
-            _logger.debug(str(self.service_roots))
-
-    def shuffled_service_roots(self, hash, force_rebuild=False):
-        self.build_service_roots(force_rebuild)
-
-        # Build an ordering with which to query the Keep servers based on the
-        # contents of the hash.
-        # "hash" is a hex-encoded number at least 8 digits
-        # (32 bits) long
-
-        # seed used to calculate the next keep server from 'pool'
-        # to be added to 'pseq'
-        seed = hash
-
-        # Keep servers still to be added to the ordering
-        pool = self.service_roots[:]
-
-        # output probe sequence
-        pseq = []
-
-        # iterate while there are servers left to be assigned
-        while len(pool) > 0:
-            if len(seed) < 8:
-                # ran out of digits in the seed
-                if len(pseq) < len(hash) / 4:
-                    # the number of servers added to the probe sequence is less
-                    # than the number of 4-digit slices in 'hash' so refill the
-                    # seed with the last 4 digits and then append the contents
-                    # of 'hash'.
-                    seed = hash[-4:] + hash
-                else:
-                    # refill the seed with the contents of 'hash'
-                    seed += hash
-
-            # Take the next 8 digits (32 bytes) and interpret as an integer,
-            # then modulus with the size of the remaining pool to get the next
-            # selected server.
-            probe = int(seed[0:8], 16) % len(pool)
-
-            # Append the selected server to the probe sequence and remove it
-            # from the pool.
-            pseq += [pool[probe]]
-            pool = pool[:probe] + pool[probe+1:]
-
-            # Remove the digits just used from the seed
-            seed = seed[8:]
-        _logger.debug(str(pseq))
-        return pseq
+            self.using_proxy = any(ks.get('service_type') == 'proxy'
+                                   for ks in self._keep_services)
+
+            # Precompute the base URI for each service.
+            for r in self._keep_services:
+                r['_service_root'] = "{}://[{}]:{:d}/".format(
+                    'https' if r['service_ssl_flag'] else 'http',
+                    r['service_host'],
+                    r['service_port'])
+            _logger.debug(str(self._keep_services))
 
+    def _service_weight(self, data_hash, service_uuid):
+        """Compute the weight of a Keep service endpoint for a data
+        block with a known hash.
+
+        The weight is md5(h + u) where u is the last 15 characters of
+        the service endpoint's UUID.
+        """
+        return hashlib.md5(data_hash + service_uuid[-15:]).hexdigest()
+
+    def weighted_service_roots(self, data_hash, force_rebuild=False):
+        """Return an array of Keep service endpoints, in the order in
+        which they should be probed when reading or writing data with
+        the given hash.
+        """
+        self.build_services_list(force_rebuild)
+
+        # Sort the available services by weight (heaviest first) for
+        # this data_hash, and return their service_roots (base URIs)
+        # in that order.
+        sorted_roots = [
+            svc['_service_root'] for svc in sorted(
+                self._keep_services,
+                reverse=True,
+                key=lambda svc: self._service_weight(data_hash, svc['uuid']))]
+        _logger.debug(data_hash + ': ' + str(sorted_roots))
+        return sorted_roots
 
     def map_new_services(self, roots_map, md5_s, force_rebuild, **headers):
         # roots_map is a dictionary, mapping Keep service root strings
@@ -533,10 +561,10 @@ class KeepClient(object):
         # new ones to roots_map.  Return the current list of local
         # root strings.
         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
-        local_roots = self.shuffled_service_roots(md5_s, force_rebuild)
+        local_roots = self.weighted_service_roots(md5_s, force_rebuild)
         for root in local_roots:
             if root not in roots_map:
-                roots_map[root] = self.KeepService(root, **headers)
+                roots_map[root] = self.KeepService(root, self.session, **headers)
         return local_roots
 
     @staticmethod
@@ -557,7 +585,8 @@ class KeepClient(object):
         else:
             return None
 
-    def get(self, loc_s, num_retries=0):
+    @retry.retry_method
+    def get(self, loc_s, num_retries=None, cache_only=False):
         """Get data from Keep.
 
         This method fetches one or more blocks of data from Keep.  It
@@ -574,13 +603,23 @@ class KeepClient(object):
           *each* Keep server if it returns temporary failures, with
           exponential backoff.  Note that, in each loop, the method may try
           to fetch data from every available Keep service, along with any
-          that are named in location hints in the locator.  Default 0.
+          that are named in location hints in the locator.  The default value
+          is set when the KeepClient is initialized.
+        * cache_only: If true, return the block data only if already present in
+          cache, otherwise return None.
         """
         if ',' in loc_s:
             return ''.join(self.get(x) for x in loc_s.split(','))
         locator = KeepLocator(loc_s)
         expect_hash = locator.md5sum
 
+        if cache_only:
+            slot = self.block_cache.get(expect_hash)
+            if slot.ready.is_set():
+                return slot.get()
+            else:
+                return None
+
         slot, first = self.block_cache.reserve_cache(expect_hash)
         if not first:
             v = slot.get()
@@ -594,7 +633,7 @@ class KeepClient(object):
         hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
                       for hint in locator.hints if hint.startswith('K@')]
         # Map root URLs their KeepService objects.
-        roots_map = {root: self.KeepService(root) for root in hint_roots}
+        roots_map = {root: self.KeepService(root, self.session) for root in hint_roots}
         blob = None
         loop = retry.RetryLoop(num_retries, self._check_loop_result,
                                backoff_start=2)
@@ -612,9 +651,8 @@ class KeepClient(object):
             services_to_try = [roots_map[root]
                                for root in (local_roots + hint_roots)
                                if roots_map[root].usable()]
-            http = httplib2.Http(timeout=self.timeout)
             for keep_service in services_to_try:
-                blob = keep_service.get(http, locator)
+                blob = keep_service.get(locator, timeout=self.current_timeout())
                 if blob is not None:
                     break
             loop.save_result((blob, len(services_to_try)))
@@ -625,19 +663,30 @@ class KeepClient(object):
         if loop.success():
             return blob
 
-        # No servers fulfilled the request.  Count how many responded
-        # "not found;" if the ratio is high enough (currently 75%), report
-        # Not Found; otherwise a generic error.
+        try:
+            all_roots = local_roots + hint_roots
+        except NameError:
+            # We never successfully fetched local_roots.
+            all_roots = hint_roots
         # Q: Including 403 is necessary for the Keep tests to continue
         # passing, but maybe they should expect KeepReadError instead?
-        not_founds = sum(1 for ks in roots_map.values()
-                         if ks.last_status() in set([403, 404, 410]))
-        if roots_map and ((float(not_founds) / len(roots_map)) >= .75):
-            raise arvados.errors.NotFoundError(loc_s)
+        not_founds = sum(1 for key in all_roots
+                         if roots_map[key].last_status() in {403, 404, 410})
+        service_errors = ((key, roots_map[key].last_result)
+                          for key in all_roots)
+        if not roots_map:
+            raise arvados.errors.KeepReadError(
+                "failed to read {}: no Keep services available ({})".format(
+                    loc_s, loop.last_result()))
+        elif not_founds == len(all_roots):
+            raise arvados.errors.NotFoundError(
+                "{} not found".format(loc_s), service_errors)
         else:
-            raise arvados.errors.KeepReadError(loc_s)
+            raise arvados.errors.KeepReadError(
+                "failed to read {}".format(loc_s), service_errors)
 
-    def put(self, data, copies=2, num_retries=0):
+    @retry.retry_method
+    def put(self, data, copies=2, num_retries=None):
         """Save data in Keep.
 
         This method will get a list of Keep services from the API server, and
@@ -652,7 +701,8 @@ class KeepClient(object):
           Default 2.
         * num_retries: The number of times to retry PUT requests to
           *each* Keep server if it returns temporary failures, with
-          exponential backoff.  Default 0.
+          exponential backoff.  The default value is set when the
+          KeepClient is initialized.
         """
         data_hash = hashlib.md5(data).hexdigest()
         if copies < 1:
@@ -685,7 +735,7 @@ class KeepClient(object):
                     data_hash=data_hash,
                     service_root=service_root,
                     thread_limiter=thread_limiter,
-                    timeout=self.timeout)
+                    timeout=self.current_timeout())
                 t.start()
                 threads.append(t)
             for t in threads:
@@ -694,11 +744,30 @@ class KeepClient(object):
 
         if loop.success():
             return thread_limiter.response()
-        raise arvados.errors.KeepWriteError(
-            "Write fail for %s: wanted %d but wrote %d" %
-            (data_hash, copies, thread_limiter.done()))
+        if not roots_map:
+            raise arvados.errors.KeepWriteError(
+                "failed to write {}: no Keep services available ({})".format(
+                    data_hash, loop.last_result()))
+        else:
+            service_errors = ((key, roots_map[key].last_result)
+                              for key in local_roots
+                              if not roots_map[key].success_flag)
+            raise arvados.errors.KeepWriteError(
+                "failed to write {} (wanted {} copies but wrote {})".format(
+                    data_hash, copies, thread_limiter.done()), service_errors)
+
+    def local_store_put(self, data, copies=1, num_retries=None):
+        """A stub for put().
 
-    def local_store_put(self, data):
+        This method is used in place of the real put() method when
+        using local storage (see constructor's local_store argument).
+
+        copies and num_retries arguments are ignored: they are here
+        only for the sake of offering the same call signature as
+        put().
+
+        Data stored this way can be retrieved via local_store_get().
+        """
         md5 = hashlib.md5(data).hexdigest()
         locator = '%s+%d' % (md5, len(data))
         with open(os.path.join(self.local_store, md5 + '.tmp'), 'w') as f:
@@ -707,7 +776,8 @@ class KeepClient(object):
                   os.path.join(self.local_store, md5))
         return locator
 
-    def local_store_get(self, loc_s):
+    def local_store_get(self, loc_s, num_retries=None):
+        """Companion to local_store_put()."""
         try:
             locator = KeepLocator(loc_s)
         except ValueError:
@@ -717,3 +787,6 @@ class KeepClient(object):
             return ''
         with open(os.path.join(self.local_store, locator.md5sum), 'r') as f:
             return f.read()
+
+    def is_cached(self, locator):
+        return self.block_cache.reserve_cache(expect_hash)