2800: Remove global state from KeepClient.
authorBrett Smith <brett@curoverse.com>
Tue, 19 Aug 2014 14:17:57 +0000 (10:17 -0400)
committerBrett Smith <brett@curoverse.com>
Wed, 20 Aug 2014 18:18:10 +0000 (14:18 -0400)
This commit makes it possible to build and use a KeepClient that isn't
influenced by changes in outside state.  Changing the KeepClient based
on global state has been pushed up to the simple Keep class.

This commit changes the tests as little as possible to demonstrate
backward compatibility.  Refactoring the tests to build KeepClients
directly will come in the future.

sdk/python/arvados/api.py
sdk/python/arvados/keep.py
sdk/python/tests/test_keep_client.py

index e7348a1356077e1ea8a670c59a071787b8414e4e..1eb8f5161330b8588b1f52eb7e7fafdaba7c2000 100644 (file)
@@ -154,6 +154,7 @@ def api(version=None, cache=True, host=None, token=None, insecure=False, **kwarg
     kwargs['http'] = credentials.authorize(kwargs['http'])
 
     svc = apiclient.discovery.build('arvados', version, **kwargs)
+    svc.api_token = token
     kwargs['http'].cache = None
     if cache:
         conncache[connprofile] = svc
index 36678291d877c91559835524f410dd27ca711186..561d34cd4090cb66c9d1bdf20222cbb925e1a38f 100644 (file)
@@ -24,8 +24,8 @@ import ssl
 _logger = logging.getLogger('arvados.keep')
 global_client_object = None
 
-from api import *
-import config
+import arvados
+import arvados.config as config
 import arvados.errors
 import arvados.util
 
@@ -106,12 +106,31 @@ class KeepLocator(object):
         return self.perm_expiry <= as_of_dt
 
 
-class Keep:
-    @staticmethod
-    def global_client_object():
+class Keep(object):
+    """Simple interface to a global KeepClient object.
+
+    THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
+    own API client.  The global KeepClient will build an API client from the
+    current Arvados configuration, which may not match the one you built.
+    """
+    _last_key = None
+
+    @classmethod
+    def global_client_object(cls):
         global global_client_object
-        if global_client_object == None:
+        # Previously, KeepClient would change its behavior at runtime based
+        # on these configuration settings.  We simulate that behavior here
+        # by checking the values and returning a new KeepClient if any of
+        # them have changed.
+        key = (config.get('ARVADOS_API_HOST'),
+               config.get('ARVADOS_API_TOKEN'),
+               config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
+               config.get('ARVADOS_KEEP_PROXY'),
+               config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
+               os.environ.get('KEEP_LOCAL_STORE'))
+        if (global_client_object is None) or (cls._last_key != key):
             global_client_object = KeepClient()
+            cls._last_key = key
         return global_client_object
 
     @staticmethod
@@ -122,8 +141,8 @@ class Keep:
     def put(data, **kwargs):
         return Keep.global_client_object().put(data, **kwargs)
 
-class KeepClient(object):
 
+class KeepClient(object):
     class ThreadLimiter(object):
         """
         Limit the number of threads running at a given time to
@@ -180,14 +199,16 @@ class KeepClient(object):
             with self._done_lock:
                 return self._done
 
+
     class KeepWriterThread(threading.Thread):
         """
         Write a blob of data to the given Keep server. On success, call
         save_response() of the given ThreadLimiter to save the returned
         locator.
         """
-        def __init__(self, **kwargs):
+        def __init__(self, api_token, **kwargs):
             super(KeepClient.KeepWriterThread, self).__init__()
+            self._api_token = api_token
             self.args = kwargs
             self._success = False
 
@@ -209,8 +230,7 @@ class KeepClient(object):
                           self.args['service_root'])
             h = httplib2.Http(timeout=self.args.get('timeout', None))
             url = self.args['service_root'] + self.args['data_hash']
-            api_token = config.get('ARVADOS_API_TOKEN')
-            headers = {'Authorization': "OAuth2 %s" % api_token}
+            headers = {'Authorization': "OAuth2 %s" % (self._api_token,)}
 
             if self.args['using_proxy']:
                 # We're using a proxy, so tell the proxy how many copies we
@@ -259,51 +279,91 @@ class KeepClient(object):
                 _logger.debug("Request fail: PUT %s => %s: %s",
                                 url, type(e), str(e))
 
-    def __init__(self, **kwargs):
+
+    def __init__(self, api_client=None, proxy=None, timeout=60,
+                 api_token=None, local_store=None):
+        """Initialize a new KeepClient.
+
+        Arguments:
+        * api_client: The API client to use to find Keep services.  If not
+          provided, KeepClient will build one from available Arvados
+          configuration.
+        * proxy: If specified, this KeepClient will send requests to this
+          Keep proxy.  Otherwise, KeepClient will fall back to the setting
+          of the ARVADOS_KEEP_PROXY configuration setting.  If you want to
+          ensure KeepClient does not use a proxy, pass in an empty string.
+        * timeout: The timeout for all HTTP requests, in seconds.  Default
+          60.
+        * api_token: If you're not using an API client, but only talking
+          directly to a Keep proxy, this parameter specifies an API token
+          to authenticate Keep requests.  It is an error to specify both
+          api_client and api_token.  If you specify neither, KeepClient
+          will use one available from the Arvados configuration.
+        * local_store: If specified, this KeepClient will bypass Keep
+          services, and save data to the named directory.  If unspecified,
+          KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
+          environment variable.  If you want to ensure KeepClient does not
+          use local storage, pass in an empty string.  This is primarily
+          intended to mock a server for testing.
+        """
         self.lock = threading.Lock()
-        self.service_roots = None
-        self._cache_lock = threading.Lock()
-        self._cache = []
-        # default 256 megabyte cache
-        self.cache_max = 256 * 1024 * 1024
-        self.using_proxy = False
-        self.timeout = kwargs.get('timeout', 60)
+        if proxy is None:
+            proxy = config.get('ARVADOS_KEEP_PROXY')
+        if api_token is None:
+            api_token = config.get('ARVADOS_API_TOKEN')
+        elif api_client is not None:
+            raise ValueError(
+                "can't build KeepClient with both API client and token")
+        if local_store is None:
+            local_store = os.environ.get('KEEP_LOCAL_STORE')
+
+        if local_store:
+            self.local_store = local_store
+            self.get = self.local_store_get
+            self.put = self.local_store_put
+        else:
+            self.timeout = timeout
+            self.cache_max = 256 * 1024 * 1024  # Cache is 256MiB
+            self._cache = []
+            self._cache_lock = threading.Lock()
+            if proxy:
+                if not proxy.endswith('/'):
+                    proxy += '/'
+                self.api_token = api_token
+                self.service_roots = [proxy]
+                self.using_proxy = True
+            else:
+                # It's important to avoid instantiating an API client
+                # unless we actually need one, for testing's sake.
+                if api_client is None:
+                    api_client = arvados.api('v1')
+                self.api_client = api_client
+                self.api_token = api_client.api_token
+                self.service_roots = None
+                self.using_proxy = None
 
     def shuffled_service_roots(self, hash):
-        if self.service_roots == None:
-            self.lock.acquire()
+        if self.service_roots is None:
+            with self.lock:
+                try:
+                    keep_services = self.api_client.keep_services().accessible()
+                except Exception:  # API server predates Keep services.
+                    keep_services = self.api_client.keep_disks().list()
 
-            # Override normal keep disk lookup with an explict proxy
-            # configuration.
-            keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
-            if keep_proxy_env != None and len(keep_proxy_env) > 0:
+                keep_services = keep_services.execute().get('items')
+                if not keep_services:
+                    raise arvados.errors.NoKeepServersError()
 
-                if keep_proxy_env[-1:] != '/':
-                    keep_proxy_env += "/"
-                self.service_roots = [keep_proxy_env]
-                self.using_proxy = True
-            else:
-                try:
-                    try:
-                        keep_services = arvados.api().keep_services().accessible().execute()['items']
-                    except Exception:
-                        keep_services = arvados.api().keep_disks().list().execute()['items']
-
-                    if len(keep_services) == 0:
-                        raise arvados.errors.NoKeepServersError()
-
-                    if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
-                        self.using_proxy = True
-
-                    roots = (("http%s://%s:%d/" %
-                              ('s' if f['service_ssl_flag'] else '',
-                               f['service_host'],
-                               f['service_port']))
-                             for f in keep_services)
-                    self.service_roots = sorted(set(roots))
-                    _logger.debug(str(self.service_roots))
-                finally:
-                    self.lock.release()
+                self.using_proxy = (keep_services[0].get('service_type') ==
+                                    'proxy')
+
+                roots = (("http%s://%s:%d/" %
+                          ('s' if f['service_ssl_flag'] else '',
+                           f['service_host'],
+                           f['service_port']))
+                         for f in keep_services)
+                self.service_roots = sorted(set(roots))
+                _logger.debug(str(self.service_roots))
 
         # Build an ordering with which to query the Keep servers based on the
         # contents of the hash.
@@ -403,12 +463,11 @@ class KeepClient(object):
         finally:
             self._cache_lock.release()
 
-    def get(self, locator):
-        if re.search(r',', locator):
-            return ''.join(self.get(x) for x in locator.split(','))
-        if 'KEEP_LOCAL_STORE' in os.environ:
-            return KeepClient.local_store_get(locator)
-        expect_hash = re.sub(r'\+.*', '', locator)
+    def get(self, loc_s):
+        if ',' in loc_s:
+            return ''.join(self.get(x) for x in loc_s.split(','))
+        locator = KeepLocator(loc_s)
+        expect_hash = locator.md5sum
 
         slot, first = self.reserve_cache(expect_hash)
 
@@ -418,9 +477,8 @@ class KeepClient(object):
 
         try:
             for service_root in self.shuffled_service_roots(expect_hash):
-                url = service_root + locator
-                api_token = config.get('ARVADOS_API_TOKEN')
-                headers = {'Authorization': "OAuth2 %s" % api_token,
+                url = service_root + loc_s
+                headers = {'Authorization': "OAuth2 %s" % (self.api_token,),
                            'Accept': 'application/octet-stream'}
                 blob = self.get_url(url, headers, expect_hash)
                 if blob:
@@ -428,9 +486,10 @@ class KeepClient(object):
                     self.cap_cache()
                     return blob
 
-            for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
-                instance = location_hint.group(1)
-                url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
+            for hint in locator.hints:
+                if not hint.startswith('K@'):
+                    continue
+                url = 'http://keep.' + hint[2:] + '.arvadosapi.com/' + loc_s
                 blob = self.get_url(url, {}, expect_hash)
                 if blob:
                     slot.set(blob)
@@ -456,9 +515,7 @@ class KeepClient(object):
                          len(content), t.msecs,
                          (len(content)/(1024*1024))/t.secs)
             if re.match(r'^2\d\d$', resp['status']):
-                m = hashlib.new('md5')
-                m.update(content)
-                md5 = m.hexdigest()
+                md5 = hashlib.md5(content).hexdigest()
                 if md5 == expect_hash:
                     return content
                 _logger.warning("Checksum fail: md5(%s) = %s", url, md5)
@@ -467,20 +524,17 @@ class KeepClient(object):
                          url, type(e), str(e))
         return None
 
-    def put(self, data, **kwargs):
-        if 'KEEP_LOCAL_STORE' in os.environ:
-            return KeepClient.local_store_put(data)
-        m = hashlib.new('md5')
-        m.update(data)
-        data_hash = m.hexdigest()
+    def put(self, data, copies=2):
+        data_hash = hashlib.md5(data).hexdigest()
         have_copies = 0
-        want_copies = kwargs.get('copies', 2)
+        want_copies = copies
         if not (want_copies > 0):
             return data_hash
         threads = []
         thread_limiter = KeepClient.ThreadLimiter(want_copies)
         for service_root in self.shuffled_service_roots(data_hash):
             t = KeepClient.KeepWriterThread(
+                self.api_token,
                 data=data,
                 data_hash=data_hash,
                 service_root=service_root,
@@ -502,7 +556,8 @@ class KeepClient(object):
                                     t.args['service_root'],
                                     t.args['data_hash'])
                     retry_with_args = t.args.copy()
-                    t_retry = KeepClient.KeepWriterThread(**retry_with_args)
+                    t_retry = KeepClient.KeepWriterThread(self.api_token,
+                                                          **retry_with_args)
                     t_retry.start()
                     threads_retry += [t_retry]
             for t in threads_retry:
@@ -519,26 +574,22 @@ class KeepClient(object):
     def sign_for_old_server(data_hash, data):
         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
 
-
-    @staticmethod
-    def local_store_put(data):
-        m = hashlib.new('md5')
-        m.update(data)
-        md5 = m.hexdigest()
+    def local_store_put(self, data):
+        md5 = hashlib.md5(data).hexdigest()
         locator = '%s+%d' % (md5, len(data))
-        with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
+        with open(os.path.join(self.local_store, md5 + '.tmp'), 'w') as f:
             f.write(data)
-        os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
-                  os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
+        os.rename(os.path.join(self.local_store, md5 + '.tmp'),
+                  os.path.join(self.local_store, md5))
         return locator
 
-    @staticmethod
-    def local_store_get(locator):
-        r = re.search('^([0-9a-f]{32,})', locator)
-        if not r:
+    def local_store_get(self, loc_s):
+        try:
+            locator = KeepLocator(loc_s)
+        except ValueError:
             raise arvados.errors.NotFoundError(
-                "Invalid data locator: '%s'" % locator)
-        if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
+                "Invalid data locator: '%s'" % loc_s)
+        if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
             return ''
-        with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
+        with open(os.path.join(self.local_store, locator.md5sum), 'r') as f:
             return f.read()
index 8706e2193b7c3f249db3778c03b7ee598269b364..9d3cecd9614ede083a770204f63736518a2b618d 100644 (file)
@@ -126,7 +126,7 @@ class KeepPermissionTestCase(unittest.TestCase):
 
         # Unauthenticated GET for a signed locator => NotFound
         # Unauthenticated GET for an unsigned locator => NotFound
-        del arvados.config.settings()["ARVADOS_API_TOKEN"]
+        arvados.keep.global_client_object.api_token = ''
         self.assertRaises(arvados.errors.NotFoundError,
                           arvados.Keep.get,
                           bar_locator)
@@ -192,7 +192,7 @@ class KeepOptionalPermission(unittest.TestCase):
             r'^acbd18db4cc2f85cedef654fccc4a4d8\+3\+A[a-f0-9]+@[a-f0-9]+$',
             'invalid locator from Keep.put("foo"): ' + signed_locator)
 
-        del arvados.config.settings()["ARVADOS_API_TOKEN"]
+        arvados.keep.global_client_object.api_token = ''
         self.assertEqual(arvados.Keep.get(signed_locator),
                          'foo',
                          'wrong content from Keep.get(md5("foo"))')
@@ -207,7 +207,7 @@ class KeepOptionalPermission(unittest.TestCase):
             r'^acbd18db4cc2f85cedef654fccc4a4d8\+3\+A[a-f0-9]+@[a-f0-9]+$',
             'invalid locator from Keep.put("foo"): ' + signed_locator)
 
-        del arvados.config.settings()["ARVADOS_API_TOKEN"]
+        arvados.keep.global_client_object.api_token = ''
         self.assertEqual(arvados.Keep.get("acbd18db4cc2f85cedef654fccc4a4d8"),
                          'foo',
                          'wrong content from Keep.get(md5("foo"))')