4363: Merge branch 'master' into 4363-less-filename-munging
[arvados.git] / sdk / python / arvados / keep.py
index 1fd8fa5d2c640b1291cd99c643bf8d90ac5bd287..f4c85969c3904cbba5ecb2e6dfd0b47218555a88 100644 (file)
@@ -21,15 +21,15 @@ import ssl
 import socket
 import requests
 
-_logger = logging.getLogger('arvados.keep')
-global_client_object = None
-
 import arvados
 import arvados.config as config
 import arvados.errors
 import arvados.retry as retry
 import arvados.util
 
+_logger = logging.getLogger('arvados.keep')
+global_client_object = None
+
 class KeepLocator(object):
     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
@@ -466,9 +466,12 @@ class KeepClient(object):
                 if not proxy.endswith('/'):
                     proxy += '/'
                 self.api_token = api_token
-                self.service_roots = [proxy]
+                self._keep_services = [{
+                    'uuid': 'proxy',
+                    '_service_root': proxy,
+                    }]
                 self.using_proxy = True
-                self.static_service_roots = True
+                self._static_services_list = True
             else:
                 # It's important to avoid instantiating an API client
                 # unless we actually need one, for testing's sake.
@@ -476,9 +479,9 @@ class KeepClient(object):
                     api_client = arvados.api('v1')
                 self.api_client = api_client
                 self.api_token = api_client.api_token
-                self.service_roots = None
+                self._keep_services = None
                 self.using_proxy = None
-                self.static_service_roots = False
+                self._static_services_list = False
 
     def current_timeout(self):
         """Return the appropriate timeout to use for this client: the proxy
@@ -489,9 +492,9 @@ class KeepClient(object):
         # KeepService, not a KeepClient. See #4488.
         return self.proxy_timeout if self.using_proxy else self.timeout
 
-    def build_service_roots(self, force_rebuild=False):
-        if (self.static_service_roots or
-              (self.service_roots and not force_rebuild)):
+    def build_services_list(self, force_rebuild=False):
+        if (self._static_services_list or
+              (self._keep_services and not force_rebuild)):
             return
         with self.lock:
             try:
@@ -499,68 +502,47 @@ class KeepClient(object):
             except Exception:  # API server predates Keep services.
                 keep_services = self.api_client.keep_disks().list()
 
-            keep_services = keep_services.execute().get('items')
-            if not keep_services:
+            self._keep_services = keep_services.execute().get('items')
+            if not self._keep_services:
                 raise arvados.errors.NoKeepServersError()
 
             self.using_proxy = any(ks.get('service_type') == 'proxy'
-                                   for ks in keep_services)
-
-            roots = ("{}://[{}]:{:d}/".format(
-                        'https' if ks['service_ssl_flag'] else 'http',
-                         ks['service_host'],
-                         ks['service_port'])
-                     for ks in keep_services)
-            self.service_roots = sorted(set(roots))
-            _logger.debug(str(self.service_roots))
-
-    def shuffled_service_roots(self, hash, force_rebuild=False):
-        self.build_service_roots(force_rebuild)
-
-        # Build an ordering with which to query the Keep servers based on the
-        # contents of the hash.
-        # "hash" is a hex-encoded number at least 8 digits
-        # (32 bits) long
-
-        # seed used to calculate the next keep server from 'pool'
-        # to be added to 'pseq'
-        seed = hash
-
-        # Keep servers still to be added to the ordering
-        pool = self.service_roots[:]
-
-        # output probe sequence
-        pseq = []
-
-        # iterate while there are servers left to be assigned
-        while len(pool) > 0:
-            if len(seed) < 8:
-                # ran out of digits in the seed
-                if len(pseq) < len(hash) / 4:
-                    # the number of servers added to the probe sequence is less
-                    # than the number of 4-digit slices in 'hash' so refill the
-                    # seed with the last 4 digits and then append the contents
-                    # of 'hash'.
-                    seed = hash[-4:] + hash
-                else:
-                    # refill the seed with the contents of 'hash'
-                    seed += hash
-
-            # Take the next 8 digits (32 bytes) and interpret as an integer,
-            # then modulus with the size of the remaining pool to get the next
-            # selected server.
-            probe = int(seed[0:8], 16) % len(pool)
-
-            # Append the selected server to the probe sequence and remove it
-            # from the pool.
-            pseq += [pool[probe]]
-            pool = pool[:probe] + pool[probe+1:]
-
-            # Remove the digits just used from the seed
-            seed = seed[8:]
-        _logger.debug(str(pseq))
-        return pseq
+                                   for ks in self._keep_services)
+
+            # Precompute the base URI for each service.
+            for r in self._keep_services:
+                r['_service_root'] = "{}://[{}]:{:d}/".format(
+                    'https' if r['service_ssl_flag'] else 'http',
+                    r['service_host'],
+                    r['service_port'])
+            _logger.debug(str(self._keep_services))
+
+    def _service_weight(self, data_hash, service_uuid):
+        """Compute the weight of a Keep service endpoint for a data
+        block with a known hash.
+
+        The weight is md5(h + u) where u is the last 15 characters of
+        the service endpoint's UUID.
+        """
+        return hashlib.md5(data_hash + service_uuid[-15:]).hexdigest()
 
+    def weighted_service_roots(self, data_hash, force_rebuild=False):
+        """Return an array of Keep service endpoints, in the order in
+        which they should be probed when reading or writing data with
+        the given hash.
+        """
+        self.build_services_list(force_rebuild)
+
+        # Sort the available services by weight (heaviest first) for
+        # this data_hash, and return their service_roots (base URIs)
+        # in that order.
+        sorted_roots = [
+            svc['_service_root'] for svc in sorted(
+                self._keep_services,
+                reverse=True,
+                key=lambda svc: self._service_weight(data_hash, svc['uuid']))]
+        _logger.debug(data_hash + ': ' + str(sorted_roots))
+        return sorted_roots
 
     def map_new_services(self, roots_map, md5_s, force_rebuild, **headers):
         # roots_map is a dictionary, mapping Keep service root strings
@@ -568,7 +550,7 @@ class KeepClient(object):
         # new ones to roots_map.  Return the current list of local
         # root strings.
         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
-        local_roots = self.shuffled_service_roots(md5_s, force_rebuild)
+        local_roots = self.weighted_service_roots(md5_s, force_rebuild)
         for root in local_roots:
             if root not in roots_map:
                 roots_map[root] = self.KeepService(root, **headers)