10111: Merge branch 'master' into 10111-cr-provenance-graph
[arvados.git] / sdk / python / arvados / keep.py
index 5416379ad03bac620d669f10a53769a980523dbb..5b4770c4d0dca8824c268448296d0658c8ba04d8 100644 (file)
@@ -1,4 +1,5 @@
 import cStringIO
+import collections
 import datetime
 import hashlib
 import logging
@@ -12,6 +13,7 @@ import ssl
 import sys
 import threading
 import timer
+import urlparse
 
 import arvados
 import arvados.config as config
@@ -295,21 +297,33 @@ class KeepClient(object):
 
         def _get_user_agent(self):
             try:
-                return self._user_agent_pool.get(False)
+                return self._user_agent_pool.get(block=False)
             except Queue.Empty:
                 return pycurl.Curl()
 
         def _put_user_agent(self, ua):
             try:
                 ua.reset()
-                self._user_agent_pool.put(ua, False)
+                self._user_agent_pool.put(ua, block=False)
             except:
                 ua.close()
 
-        @staticmethod
-        def _socket_open(family, socktype, protocol, address=None):
+        def _socket_open(self, *args, **kwargs):
+            if len(args) + len(kwargs) == 2:
+                return self._socket_open_pycurl_7_21_5(*args, **kwargs)
+            else:
+                return self._socket_open_pycurl_7_19_3(*args, **kwargs)
+
+        def _socket_open_pycurl_7_19_3(self, family, socktype, protocol, address=None):
+            return self._socket_open_pycurl_7_21_5(
+                purpose=None,
+                address=collections.namedtuple(
+                    'Address', ['family', 'socktype', 'protocol', 'addr'],
+                )(family, socktype, protocol, address))
+
+        def _socket_open_pycurl_7_21_5(self, purpose, address):
             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
-            s = socket.socket(family, socktype, protocol)
+            s = socket.socket(address.family, address.socktype, address.protocol)
             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
             # Will throw invalid protocol error on mac. This test prevents that.
             if hasattr(socket, 'TCP_KEEPIDLE'):
@@ -510,8 +524,10 @@ class KeepClient(object):
             with self.successful_copies_lock:
                 self.successful_copies += replicas_nr
                 self.response = response
+            with self.pending_tries_notification:
+                self.pending_tries_notification.notify_all()
         
-        def write_fail(self, ks, status_code):
+        def write_fail(self, ks):
             with self.pending_tries_notification:
                 self.pending_tries += 1
                 self.pending_tries_notification.notify()
@@ -519,8 +535,36 @@ class KeepClient(object):
         def pending_copies(self):
             with self.successful_copies_lock:
                 return self.wanted_copies - self.successful_copies
-    
-    
+
+        def get_next_task(self):
+            with self.pending_tries_notification:
+                while True:
+                    if self.pending_copies() < 1:
+                        # This notify_all() is unnecessary --
+                        # write_success() already called notify_all()
+                        # when pending<1 became true, so it's not
+                        # possible for any other thread to be in
+                        # wait() now -- but it's cheap insurance
+                        # against deadlock so we do it anyway:
+                        self.pending_tries_notification.notify_all()
+                        # Drain the queue and then raise Queue.Empty
+                        while True:
+                            self.get_nowait()
+                            self.task_done()
+                    elif self.pending_tries > 0:
+                        service, service_root = self.get_nowait()
+                        if service.finished():
+                            self.task_done()
+                            continue
+                        self.pending_tries -= 1
+                        return service, service_root
+                    elif self.empty():
+                        self.pending_tries_notification.notify_all()
+                        raise Queue.Empty
+                    else:
+                        self.pending_tries_notification.wait()
+
+
     class KeepWriterThreadPool(object):
         def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None):
             self.total_task_nr = 0
@@ -550,74 +594,64 @@ class KeepClient(object):
                 worker.start()
             # Wait for finished work
             self.queue.join()
-            with self.queue.pending_tries_notification:
-                self.queue.pending_tries_notification.notify_all()
-            for worker in self.workers:
-                worker.join()
         
         def response(self):
             return self.queue.response
     
     
     class KeepWriterThread(threading.Thread):
+        TaskFailed = RuntimeError()
+
         def __init__(self, queue, data, data_hash, timeout=None):
             super(KeepClient.KeepWriterThread, self).__init__()
             self.timeout = timeout
             self.queue = queue
             self.data = data
             self.data_hash = data_hash
-        
+            self.daemon = True
+
         def run(self):
-            while not self.queue.empty():
-                if self.queue.pending_copies() > 0:
-                    # Avoid overreplication, wait for some needed re-attempt
-                    with self.queue.pending_tries_notification:
-                        if self.queue.pending_tries <= 0:
-                            self.queue.pending_tries_notification.wait()
-                            continue # try again when awake
-                        self.queue.pending_tries -= 1
-
-                    # Get to work
-                    try:
-                        service, service_root = self.queue.get_nowait()
-                    except Queue.Empty:
-                        continue
-                    if service.finished():
-                        self.queue.task_done()
-                        continue
-                    success = bool(service.put(self.data_hash,
-                                                self.data,
-                                                timeout=self.timeout))
-                    result = service.last_result()
-                    if success:
-                        _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
-                                      str(threading.current_thread()),
-                                      self.data_hash,
-                                      len(self.data),
-                                      service_root)
-                        try:
-                            replicas_stored = int(result['headers']['x-keep-replicas-stored'])
-                        except (KeyError, ValueError):
-                            replicas_stored = 1
-                        
-                        self.queue.write_success(result['body'].strip(), replicas_stored)
-                    else:
-                        if result.get('status_code', None):
-                            _logger.debug("Request fail: PUT %s => %s %s",
-                                          self.data_hash,
-                                          result['status_code'],
-                                          result['body'])
-                        self.queue.write_fail(service, result.get('status_code', None)) # Schedule a re-attempt with next service
-                    # Mark as done so the queue can be join()ed
-                    self.queue.task_done()
+            while True:
+                try:
+                    service, service_root = self.queue.get_next_task()
+                except Queue.Empty:
+                    return
+                try:
+                    locator, copies = self.do_task(service, service_root)
+                except Exception as e:
+                    if e is not self.TaskFailed:
+                        _logger.exception("Exception in KeepWriterThread")
+                    self.queue.write_fail(service)
                 else:
-                    # Remove the task from the queue anyways
-                    try:
-                        self.queue.get_nowait()
-                        # Mark as done so the queue can be join()ed
-                        self.queue.task_done()
-                    except Queue.Empty:
-                        continue
+                    self.queue.write_success(locator, copies)
+                finally:
+                    self.queue.task_done()
+
+        def do_task(self, service, service_root):
+            success = bool(service.put(self.data_hash,
+                                        self.data,
+                                        timeout=self.timeout))
+            result = service.last_result()
+
+            if not success:
+                if result.get('status_code', None):
+                    _logger.debug("Request fail: PUT %s => %s %s",
+                                  self.data_hash,
+                                  result['status_code'],
+                                  result['body'])
+                raise self.TaskFailed
+
+            _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
+                          str(threading.current_thread()),
+                          self.data_hash,
+                          len(self.data),
+                          service_root)
+            try:
+                replicas_stored = int(result['headers']['x-keep-replicas-stored'])
+            except (KeyError, ValueError):
+                replicas_stored = 1
+
+            return result['body'].strip(), replicas_stored
 
 
     def __init__(self, api_client=None, proxy=None,
@@ -635,8 +669,9 @@ class KeepClient(object):
         :proxy:
           If specified, this KeepClient will send requests to this Keep
           proxy.  Otherwise, KeepClient will fall back to the setting of the
-          ARVADOS_KEEP_PROXY configuration setting.  If you want to ensure
-          KeepClient does not use a proxy, pass in an empty string.
+          ARVADOS_KEEP_SERVICES or ARVADOS_KEEP_PROXY configuration settings.
+          If you want to KeepClient does not use a proxy, pass in an empty
+          string.
 
         :timeout:
           The initial timeout (in seconds) for HTTP requests to Keep
@@ -678,11 +713,11 @@ class KeepClient(object):
           put() are called.  Default 0.
         """
         self.lock = threading.Lock()
-        if config.get('ARVADOS_KEEP_SERVICES'):
-            # ARVADOS_KEEP_SERVICES overrides proxy settings
-            proxy = config.get('ARVADOS_KEEP_SERVICES')
-        elif proxy is None:
-            proxy = config.get('ARVADOS_KEEP_PROXY')
+        if proxy is None:
+            if config.get('ARVADOS_KEEP_SERVICES'):
+                proxy = config.get('ARVADOS_KEEP_SERVICES')
+            else:
+                proxy = config.get('ARVADOS_KEEP_PROXY')
         if api_token is None:
             if api_client is None:
                 api_token = config.get('ARVADOS_API_TOKEN')
@@ -713,17 +748,21 @@ class KeepClient(object):
             self.num_retries = num_retries
             self.max_replicas_per_service = None
             if proxy:
-                proxy_uris = proxy.split(' ')
+                proxy_uris = proxy.split()
                 for i in range(len(proxy_uris)):
                     if not proxy_uris[i].endswith('/'):
                         proxy_uris[i] += '/'
+                    # URL validation
+                    url = urlparse.urlparse(proxy_uris[i])
+                    if not (url.scheme and url.netloc):
+                        raise arvados.errors.ArgumentError("Invalid proxy URI: {}".format(proxy_uris[i]))
                 self.api_token = api_token
                 self._gateway_services = {}
                 self._keep_services = [{
-                    'uuid': "00000-bi6l4-%015d" % proxy_uris.index(uri),
+                    'uuid': "00000-bi6l4-%015d" % idx,
                     'service_type': 'proxy',
                     '_service_root': uri,
-                    } for uri in proxy_uris]
+                    } for idx, uri in enumerate(proxy_uris)]
                 self._writable_services = self._keep_services
                 self.using_proxy = True
                 self._static_services_list = True