10224: Choose a recent-event threshold without querying the entire event history.
[arvados.git] / sdk / python / arvados / keep.py
index cd39f83703f4b341e07201a67c0afb58a11574ae..db7835be3746f8f67eddd61d2aac505356e601f4 100644 (file)
@@ -9,8 +9,10 @@ import Queue
 import re
 import socket
 import ssl
 import re
 import socket
 import ssl
+import sys
 import threading
 import timer
 import threading
 import timer
+import urlparse
 
 import arvados
 import arvados.config as config
 
 import arvados
 import arvados.config as config
@@ -22,6 +24,17 @@ _logger = logging.getLogger('arvados.keep')
 global_client_object = None
 
 
 global_client_object = None
 
 
+# Monkey patch TCP constants when not available (apple). Values sourced from:
+# http://www.opensource.apple.com/source/xnu/xnu-2422.115.4/bsd/netinet/tcp.h
+if sys.platform == 'darwin':
+    if not hasattr(socket, 'TCP_KEEPALIVE'):
+        socket.TCP_KEEPALIVE = 0x010
+    if not hasattr(socket, 'TCP_KEEPINTVL'):
+        socket.TCP_KEEPINTVL = 0x101
+    if not hasattr(socket, 'TCP_KEEPCNT'):
+        socket.TCP_KEEPCNT = 0x102
+
+
 class KeepLocator(object):
     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
 class KeepLocator(object):
     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
@@ -212,7 +225,6 @@ class KeepBlockCache(object):
                 self._cache.insert(0, n)
                 return n, True
 
                 self._cache.insert(0, n)
                 return n, True
 
-
 class Counter(object):
     def __init__(self, v=0):
         self._lk = threading.Lock()
 class Counter(object):
     def __init__(self, v=0):
         self._lk = threading.Lock()
@@ -238,76 +250,6 @@ class KeepClient(object):
     DEFAULT_TIMEOUT = (2, 256, 32768)
     DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
 
     DEFAULT_TIMEOUT = (2, 256, 32768)
     DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
 
-    class ThreadLimiter(object):
-        """Limit the number of threads writing to Keep at once.
-
-        This ensures that only a number of writer threads that could
-        potentially achieve the desired replication level run at once.
-        Once the desired replication level is achieved, queued threads
-        are instructed not to run.
-
-        Should be used in a "with" block.
-        """
-        def __init__(self, want_copies, max_service_replicas):
-            self._started = 0
-            self._want_copies = want_copies
-            self._done = 0
-            self._response = None
-            self._start_lock = threading.Condition()
-            if (not max_service_replicas) or (max_service_replicas >= want_copies):
-                max_threads = 1
-            else:
-                max_threads = math.ceil(float(want_copies) / max_service_replicas)
-            _logger.debug("Limiter max threads is %d", max_threads)
-            self._todo_lock = threading.Semaphore(max_threads)
-            self._done_lock = threading.Lock()
-            self._local = threading.local()
-
-        def __enter__(self):
-            self._start_lock.acquire()
-            if getattr(self._local, 'sequence', None) is not None:
-                # If the calling thread has used set_sequence(N), then
-                # we wait here until N other threads have started.
-                while self._started < self._local.sequence:
-                    self._start_lock.wait()
-            self._todo_lock.acquire()
-            self._started += 1
-            self._start_lock.notifyAll()
-            self._start_lock.release()
-            return self
-
-        def __exit__(self, type, value, traceback):
-            self._todo_lock.release()
-
-        def set_sequence(self, sequence):
-            self._local.sequence = sequence
-
-        def shall_i_proceed(self):
-            """
-            Return true if the current thread should write to Keep.
-            Return false otherwise.
-            """
-            with self._done_lock:
-                return (self._done < self._want_copies)
-
-        def save_response(self, response_body, replicas_stored):
-            """
-            Records a response body (a locator, possibly signed) returned by
-            the Keep server, and the number of replicas it stored.
-            """
-            with self._done_lock:
-                self._done += replicas_stored
-                self._response = response_body
-
-        def response(self):
-            """Return the body from the response to a PUT request."""
-            with self._done_lock:
-                return self._response
-
-        def done(self):
-            """Return the total number of replicas successfully stored."""
-            with self._done_lock:
-                return self._done
 
     class KeepService(object):
         """Make requests to a single Keep service, and track results.
 
     class KeepService(object):
         """Make requests to a single Keep service, and track results.
@@ -370,14 +312,16 @@ class KeepClient(object):
             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
             s = socket.socket(family, socktype, protocol)
             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
             s = socket.socket(family, socktype, protocol)
             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
-            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
+            # Will throw invalid protocol error on mac. This test prevents that.
+            if hasattr(socket, 'TCP_KEEPIDLE'):
+                s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
             return s
 
             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
             return s
 
-        def get(self, locator, timeout=None):
+        def get(self, locator, method="GET", timeout=None):
             # locator is a KeepLocator object.
             url = self.root + str(locator)
             # locator is a KeepLocator object.
             url = self.root + str(locator)
-            _logger.debug("Request: GET %s", url)
+            _logger.debug("Request: %s %s", method, url)
             curl = self._get_user_agent()
             ok = None
             try:
             curl = self._get_user_agent()
             ok = None
             try:
@@ -391,7 +335,10 @@ class KeepClient(object):
                         '{}: {}'.format(k,v) for k,v in self.get_headers.iteritems()])
                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
                         '{}: {}'.format(k,v) for k,v in self.get_headers.iteritems()])
                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
+                    if method == "HEAD":
+                        curl.setopt(pycurl.NOBODY, True)
                     self._setcurltimeouts(curl, timeout)
                     self._setcurltimeouts(curl, timeout)
+
                     try:
                         curl.perform()
                     except Exception as e:
                     try:
                         curl.perform()
                     except Exception as e:
@@ -402,6 +349,7 @@ class KeepClient(object):
                         'headers': self._headers,
                         'error': False,
                     }
                         'headers': self._headers,
                         'error': False,
                     }
+
                 ok = retry.check_http_response_success(self._result['status_code'])
                 if not ok:
                     self._result['error'] = arvados.errors.HttpError(
                 ok = retry.check_http_response_success(self._result['status_code'])
                 if not ok:
                     self._result['error'] = arvados.errors.HttpError(
@@ -425,11 +373,18 @@ class KeepClient(object):
                 _logger.debug("Request fail: GET %s => %s: %s",
                               url, type(self._result['error']), str(self._result['error']))
                 return None
                 _logger.debug("Request fail: GET %s => %s: %s",
                               url, type(self._result['error']), str(self._result['error']))
                 return None
+            if method == "HEAD":
+                _logger.info("HEAD %s: %s bytes",
+                         self._result['status_code'],
+                         self._result.get('content-length'))
+                return True
+
             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
                          self._result['status_code'],
                          len(self._result['body']),
                          t.msecs,
                          (len(self._result['body'])/(1024.0*1024))/t.secs if t.secs > 0 else 0)
             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
                          self._result['status_code'],
                          len(self._result['body']),
                          t.msecs,
                          (len(self._result['body'])/(1024.0*1024))/t.secs if t.secs > 0 else 0)
+
             if self.download_counter:
                 self.download_counter.add(len(self._result['body']))
             resp_md5 = hashlib.md5(self._result['body']).hexdigest()
             if self.download_counter:
                 self.download_counter.add(len(self._result['body']))
             resp_md5 = hashlib.md5(self._result['body']).hexdigest()
@@ -540,68 +495,130 @@ class KeepClient(object):
             self._lastheadername = name
             self._headers[name] = value
             # Returning None implies all bytes were written
             self._lastheadername = name
             self._headers[name] = value
             # Returning None implies all bytes were written
-
-
+    
+
+    class KeepWriterQueue(Queue.Queue):
+        def __init__(self, copies):
+            Queue.Queue.__init__(self) # Old-style superclass
+            self.wanted_copies = copies
+            self.successful_copies = 0
+            self.response = None
+            self.successful_copies_lock = threading.Lock()
+            self.pending_tries = copies
+            self.pending_tries_notification = threading.Condition()
+        
+        def write_success(self, response, replicas_nr):
+            with self.successful_copies_lock:
+                self.successful_copies += replicas_nr
+                self.response = response
+        
+        def write_fail(self, ks, status_code):
+            with self.pending_tries_notification:
+                self.pending_tries += 1
+                self.pending_tries_notification.notify()
+        
+        def pending_copies(self):
+            with self.successful_copies_lock:
+                return self.wanted_copies - self.successful_copies
+    
+    
+    class KeepWriterThreadPool(object):
+        def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None):
+            self.total_task_nr = 0
+            self.wanted_copies = copies
+            if (not max_service_replicas) or (max_service_replicas >= copies):
+                num_threads = 1
+            else:
+                num_threads = int(math.ceil(float(copies) / max_service_replicas))
+            _logger.debug("Pool max threads is %d", num_threads)
+            self.workers = []
+            self.queue = KeepClient.KeepWriterQueue(copies)
+            # Create workers
+            for _ in range(num_threads):
+                w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
+                self.workers.append(w)
+        
+        def add_task(self, ks, service_root):
+            self.queue.put((ks, service_root))
+            self.total_task_nr += 1
+        
+        def done(self):
+            return self.queue.successful_copies
+        
+        def join(self):
+            # Start workers
+            for worker in self.workers:
+                worker.start()
+            # Wait for finished work
+            self.queue.join()
+            with self.queue.pending_tries_notification:
+                self.queue.pending_tries_notification.notify_all()
+            for worker in self.workers:
+                worker.join()
+        
+        def response(self):
+            return self.queue.response
+    
+    
     class KeepWriterThread(threading.Thread):
     class KeepWriterThread(threading.Thread):
-        """
-        Write a blob of data to the given Keep server. On success, call
-        save_response() of the given ThreadLimiter to save the returned
-        locator.
-        """
-        def __init__(self, keep_service, **kwargs):
+        def __init__(self, queue, data, data_hash, timeout=None):
             super(KeepClient.KeepWriterThread, self).__init__()
             super(KeepClient.KeepWriterThread, self).__init__()
-            self.service = keep_service
-            self.args = kwargs
-            self._success = False
-
-        def success(self):
-            return self._success
-
+            self.timeout = timeout
+            self.queue = queue
+            self.data = data
+            self.data_hash = data_hash
+        
         def run(self):
         def run(self):
-            limiter = self.args['thread_limiter']
-            sequence = self.args['thread_sequence']
-            if sequence is not None:
-                limiter.set_sequence(sequence)
-            with limiter:
-                if not limiter.shall_i_proceed():
-                    # My turn arrived, but the job has been done without
-                    # me.
-                    return
-                self.run_with_limiter(limiter)
-
-        def run_with_limiter(self, limiter):
-            if self.service.finished():
-                return
-            _logger.debug("KeepWriterThread %s proceeding %s+%i %s",
-                          str(threading.current_thread()),
-                          self.args['data_hash'],
-                          len(self.args['data']),
-                          self.args['service_root'])
-            self._success = bool(self.service.put(
-                self.args['data_hash'],
-                self.args['data'],
-                timeout=self.args.get('timeout', None)))
-            result = self.service.last_result()
-            if self._success:
-                _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
-                              str(threading.current_thread()),
-                              self.args['data_hash'],
-                              len(self.args['data']),
-                              self.args['service_root'])
-                # Tick the 'done' counter for the number of replica
-                # reported stored by the server, for the case that
-                # we're talking to a proxy or other backend that
-                # stores to multiple copies for us.
-                try:
-                    replicas_stored = int(result['headers']['x-keep-replicas-stored'])
-                except (KeyError, ValueError):
-                    replicas_stored = 1
-                limiter.save_response(result['body'].strip(), replicas_stored)
-            elif result.get('status_code', None):
-                _logger.debug("Request fail: PUT %s => %s %s",
-                              self.args['data_hash'],
-                              result['status_code'],
-                              result['body'])
+            while not self.queue.empty():
+                if self.queue.pending_copies() > 0:
+                    # Avoid overreplication, wait for some needed re-attempt
+                    with self.queue.pending_tries_notification:
+                        if self.queue.pending_tries <= 0:
+                            self.queue.pending_tries_notification.wait()
+                            continue # try again when awake
+                        self.queue.pending_tries -= 1
+
+                    # Get to work
+                    try:
+                        service, service_root = self.queue.get_nowait()
+                    except Queue.Empty:
+                        continue
+                    if service.finished():
+                        self.queue.task_done()
+                        continue
+                    success = bool(service.put(self.data_hash,
+                                                self.data,
+                                                timeout=self.timeout))
+                    result = service.last_result()
+                    if success:
+                        _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
+                                      str(threading.current_thread()),
+                                      self.data_hash,
+                                      len(self.data),
+                                      service_root)
+                        try:
+                            replicas_stored = int(result['headers']['x-keep-replicas-stored'])
+                        except (KeyError, ValueError):
+                            replicas_stored = 1
+                        
+                        self.queue.write_success(result['body'].strip(), replicas_stored)
+                    else:
+                        if result.get('status_code', None):
+                            _logger.debug("Request fail: PUT %s => %s %s",
+                                          self.data_hash,
+                                          result['status_code'],
+                                          result['body'])
+                        self.queue.write_fail(service, result.get('status_code', None)) # Schedule a re-attempt with next service
+                    # Mark as done so the queue can be join()ed
+                    self.queue.task_done()
+                else:
+                    # Remove the task from the queue anyways
+                    try:
+                        self.queue.get_nowait()
+                        # Mark as done so the queue can be join()ed
+                        self.queue.task_done()
+                    except Queue.Empty:
+                        continue
 
 
     def __init__(self, api_client=None, proxy=None,
 
 
     def __init__(self, api_client=None, proxy=None,
@@ -619,8 +636,9 @@ class KeepClient(object):
         :proxy:
           If specified, this KeepClient will send requests to this Keep
           proxy.  Otherwise, KeepClient will fall back to the setting of the
         :proxy:
           If specified, this KeepClient will send requests to this Keep
           proxy.  Otherwise, KeepClient will fall back to the setting of the
-          ARVADOS_KEEP_PROXY configuration setting.  If you want to ensure
-          KeepClient does not use a proxy, pass in an empty string.
+          ARVADOS_KEEP_SERVICES or ARVADOS_KEEP_PROXY configuration settings.
+          If you want to KeepClient does not use a proxy, pass in an empty
+          string.
 
         :timeout:
           The initial timeout (in seconds) for HTTP requests to Keep
 
         :timeout:
           The initial timeout (in seconds) for HTTP requests to Keep
@@ -663,7 +681,10 @@ class KeepClient(object):
         """
         self.lock = threading.Lock()
         if proxy is None:
         """
         self.lock = threading.Lock()
         if proxy is None:
-            proxy = config.get('ARVADOS_KEEP_PROXY')
+            if config.get('ARVADOS_KEEP_SERVICES'):
+                proxy = config.get('ARVADOS_KEEP_SERVICES')
+            else:
+                proxy = config.get('ARVADOS_KEEP_PROXY')
         if api_token is None:
             if api_client is None:
                 api_token = config.get('ARVADOS_API_TOKEN')
         if api_token is None:
             if api_client is None:
                 api_token = config.get('ARVADOS_API_TOKEN')
@@ -694,15 +715,21 @@ class KeepClient(object):
             self.num_retries = num_retries
             self.max_replicas_per_service = None
             if proxy:
             self.num_retries = num_retries
             self.max_replicas_per_service = None
             if proxy:
-                if not proxy.endswith('/'):
-                    proxy += '/'
+                proxy_uris = proxy.split()
+                for i in range(len(proxy_uris)):
+                    if not proxy_uris[i].endswith('/'):
+                        proxy_uris[i] += '/'
+                    # URL validation
+                    url = urlparse.urlparse(proxy_uris[i])
+                    if not (url.scheme and url.netloc):
+                        raise arvados.errors.ArgumentError("Invalid proxy URI: {}".format(proxy_uris[i]))
                 self.api_token = api_token
                 self._gateway_services = {}
                 self._keep_services = [{
                 self.api_token = api_token
                 self._gateway_services = {}
                 self._keep_services = [{
-                    'uuid': 'proxy',
+                    'uuid': "00000-bi6l4-%015d" % idx,
                     'service_type': 'proxy',
                     'service_type': 'proxy',
-                    '_service_root': proxy,
-                    }]
+                    '_service_root': uri,
+                    } for idx, uri in enumerate(proxy_uris)]
                 self._writable_services = self._keep_services
                 self.using_proxy = True
                 self._static_services_list = True
                 self._writable_services = self._keep_services
                 self.using_proxy = True
                 self._static_services_list = True
@@ -870,8 +897,15 @@ class KeepClient(object):
         else:
             return None
 
         else:
             return None
 
+    @retry.retry_method
+    def head(self, loc_s, num_retries=None):
+        return self._get_or_head(loc_s, method="HEAD", num_retries=num_retries)
+
     @retry.retry_method
     def get(self, loc_s, num_retries=None):
     @retry.retry_method
     def get(self, loc_s, num_retries=None):
+        return self._get_or_head(loc_s, method="GET", num_retries=num_retries)
+
+    def _get_or_head(self, loc_s, method="GET", num_retries=None):
         """Get data from Keep.
 
         This method fetches one or more blocks of data from Keep.  It
         """Get data from Keep.
 
         This method fetches one or more blocks of data from Keep.  It
@@ -897,11 +931,12 @@ class KeepClient(object):
         self.get_counter.add(1)
 
         locator = KeepLocator(loc_s)
         self.get_counter.add(1)
 
         locator = KeepLocator(loc_s)
-        slot, first = self.block_cache.reserve_cache(locator.md5sum)
-        if not first:
-            self.hits_counter.add(1)
-            v = slot.get()
-            return v
+        if method == "GET":
+            slot, first = self.block_cache.reserve_cache(locator.md5sum)
+            if not first:
+                self.hits_counter.add(1)
+                v = slot.get()
+                return v
 
         self.misses_counter.add(1)
 
 
         self.misses_counter.add(1)
 
@@ -951,16 +986,20 @@ class KeepClient(object):
                                for root in sorted_roots
                                if roots_map[root].usable()]
             for keep_service in services_to_try:
                                for root in sorted_roots
                                if roots_map[root].usable()]
             for keep_service in services_to_try:
-                blob = keep_service.get(locator, timeout=self.current_timeout(num_retries-tries_left))
+                blob = keep_service.get(locator, method=method, timeout=self.current_timeout(num_retries-tries_left))
                 if blob is not None:
                     break
             loop.save_result((blob, len(services_to_try)))
 
         # Always cache the result, then return it if we succeeded.
                 if blob is not None:
                     break
             loop.save_result((blob, len(services_to_try)))
 
         # Always cache the result, then return it if we succeeded.
-        slot.set(blob)
-        self.block_cache.cap_cache()
+        if method == "GET":
+            slot.set(blob)
+            self.block_cache.cap_cache()
         if loop.success():
         if loop.success():
-            return blob
+            if method == "HEAD":
+                return True
+            else:
+                return blob
 
         # Q: Including 403 is necessary for the Keep tests to continue
         # passing, but maybe they should expect KeepReadError instead?
 
         # Q: Including 403 is necessary for the Keep tests to continue
         # passing, but maybe they should expect KeepReadError instead?
@@ -1014,7 +1053,7 @@ class KeepClient(object):
 
         headers = {}
         # Tell the proxy how many copies we want it to store
 
         headers = {}
         # Tell the proxy how many copies we want it to store
-        headers['X-Keep-Desired-Replication'] = str(copies)
+        headers['X-Keep-Desired-Replicas'] = str(copies)
         roots_map = {}
         loop = retry.RetryLoop(num_retries, self._check_loop_result,
                                backoff_start=2)
         roots_map = {}
         loop = retry.RetryLoop(num_retries, self._check_loop_result,
                                backoff_start=2)
@@ -1028,30 +1067,22 @@ class KeepClient(object):
                 loop.save_result(error)
                 continue
 
                 loop.save_result(error)
                 continue
 
-            thread_limiter = KeepClient.ThreadLimiter(
-                copies - done, self.max_replicas_per_service)
-            threads = []
+            writer_pool = KeepClient.KeepWriterThreadPool(data=data, 
+                                                        data_hash=data_hash,
+                                                        copies=copies - done,
+                                                        max_service_replicas=self.max_replicas_per_service,
+                                                        timeout=self.current_timeout(num_retries - tries_left))
             for service_root, ks in [(root, roots_map[root])
                                      for root in sorted_roots]:
                 if ks.finished():
                     continue
             for service_root, ks in [(root, roots_map[root])
                                      for root in sorted_roots]:
                 if ks.finished():
                     continue
-                t = KeepClient.KeepWriterThread(
-                    ks,
-                    data=data,
-                    data_hash=data_hash,
-                    service_root=service_root,
-                    thread_limiter=thread_limiter,
-                    timeout=self.current_timeout(num_retries-tries_left),
-                    thread_sequence=len(threads))
-                t.start()
-                threads.append(t)
-            for t in threads:
-                t.join()
-            done += thread_limiter.done()
-            loop.save_result((done >= copies, len(threads)))
+                writer_pool.add_task(ks, service_root)
+            writer_pool.join()
+            done += writer_pool.done()
+            loop.save_result((done >= copies, writer_pool.total_task_nr))
 
         if loop.success():
 
         if loop.success():
-            return thread_limiter.response()
+            return writer_pool.response()
         if not roots_map:
             raise arvados.errors.KeepWriteError(
                 "failed to write {}: no Keep services available ({})".format(
         if not roots_map:
             raise arvados.errors.KeepWriteError(
                 "failed to write {}: no Keep services available ({})".format(
@@ -1062,7 +1093,7 @@ class KeepClient(object):
                               if roots_map[key].last_result()['error'])
             raise arvados.errors.KeepWriteError(
                 "failed to write {} (wanted {} copies but wrote {})".format(
                               if roots_map[key].last_result()['error'])
             raise arvados.errors.KeepWriteError(
                 "failed to write {} (wanted {} copies but wrote {})".format(
-                    data_hash, copies, thread_limiter.done()), service_errors, label="service")
+                    data_hash, copies, writer_pool.done()), service_errors, label="service")
 
     def local_store_put(self, data, copies=1, num_retries=None):
         """A stub for put().
 
     def local_store_put(self, data, copies=1, num_retries=None):
         """A stub for put().