11507: Cleanup
[arvados.git] / sdk / python / arvados / keep.py
index 776c9b2a7d7159f6fbd8c99b20db55b2047aefb0..5b4770c4d0dca8824c268448296d0658c8ba04d8 100644 (file)
@@ -1,4 +1,5 @@
 import cStringIO
+import collections
 import datetime
 import hashlib
 import logging
@@ -296,21 +297,33 @@ class KeepClient(object):
 
         def _get_user_agent(self):
             try:
-                return self._user_agent_pool.get(False)
+                return self._user_agent_pool.get(block=False)
             except Queue.Empty:
                 return pycurl.Curl()
 
         def _put_user_agent(self, ua):
             try:
                 ua.reset()
-                self._user_agent_pool.put(ua, False)
+                self._user_agent_pool.put(ua, block=False)
             except:
                 ua.close()
 
-        @staticmethod
-        def _socket_open(family, socktype, protocol, address=None):
+        def _socket_open(self, *args, **kwargs):
+            if len(args) + len(kwargs) == 2:
+                return self._socket_open_pycurl_7_21_5(*args, **kwargs)
+            else:
+                return self._socket_open_pycurl_7_19_3(*args, **kwargs)
+
+        def _socket_open_pycurl_7_19_3(self, family, socktype, protocol, address=None):
+            return self._socket_open_pycurl_7_21_5(
+                purpose=None,
+                address=collections.namedtuple(
+                    'Address', ['family', 'socktype', 'protocol', 'addr'],
+                )(family, socktype, protocol, address))
+
+        def _socket_open_pycurl_7_21_5(self, purpose, address):
             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
-            s = socket.socket(family, socktype, protocol)
+            s = socket.socket(address.family, address.socktype, address.protocol)
             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
             # Will throw invalid protocol error on mac. This test prevents that.
             if hasattr(socket, 'TCP_KEEPIDLE'):
@@ -527,13 +540,24 @@ class KeepClient(object):
             with self.pending_tries_notification:
                 while True:
                     if self.pending_copies() < 1:
+                        # This notify_all() is unnecessary --
+                        # write_success() already called notify_all()
+                        # when pending<1 became true, so it's not
+                        # possible for any other thread to be in
+                        # wait() now -- but it's cheap insurance
+                        # against deadlock so we do it anyway:
+                        self.pending_tries_notification.notify_all()
                         # Drain the queue and then raise Queue.Empty
                         while True:
                             self.get_nowait()
                             self.task_done()
                     elif self.pending_tries > 0:
+                        service, service_root = self.get_nowait()
+                        if service.finished():
+                            self.task_done()
+                            continue
                         self.pending_tries -= 1
-                        return self.get_nowait()
+                        return service, service_root
                     elif self.empty():
                         self.pending_tries_notification.notify_all()
                         raise Queue.Empty
@@ -576,6 +600,8 @@ class KeepClient(object):
     
     
     class KeepWriterThread(threading.Thread):
+        TaskFailed = RuntimeError()
+
         def __init__(self, queue, data, data_hash, timeout=None):
             super(KeepClient.KeepWriterThread, self).__init__()
             self.timeout = timeout
@@ -593,7 +619,8 @@ class KeepClient(object):
                 try:
                     locator, copies = self.do_task(service, service_root)
                 except Exception as e:
-                    _logger.exception("Exception in KeepWriterThread")
+                    if e is not self.TaskFailed:
+                        _logger.exception("Exception in KeepWriterThread")
                     self.queue.write_fail(service)
                 else:
                     self.queue.write_success(locator, copies)
@@ -601,8 +628,6 @@ class KeepClient(object):
                     self.queue.task_done()
 
         def do_task(self, service, service_root):
-            if service.finished():
-                return
             success = bool(service.put(self.data_hash,
                                         self.data,
                                         timeout=self.timeout))
@@ -614,7 +639,7 @@ class KeepClient(object):
                                   self.data_hash,
                                   result['status_code'],
                                   result['body'])
-                raise RuntimeError()
+                raise self.TaskFailed
 
             _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
                           str(threading.current_thread()),