8104: OPENSOCKETFUNCTION accepts calls from pycurl 7.21.
[arvados.git] / sdk / python / arvados / keep.py
index 776c9b2a7d7159f6fbd8c99b20db55b2047aefb0..218f9b1355f330c24984ad4dcb6943995f0589a2 100644 (file)
@@ -1,4 +1,5 @@
 import cStringIO
+import collections
 import datetime
 import hashlib
 import logging
@@ -277,6 +278,7 @@ class KeepClient(object):
             self._result = {'error': None}
             self._usable = True
             self._session = None
+            self._socket = None
             self.get_headers = {'Accept': 'application/octet-stream'}
             self.get_headers.update(headers)
             self.put_headers = headers
@@ -296,26 +298,39 @@ class KeepClient(object):
 
         def _get_user_agent(self):
             try:
-                return self._user_agent_pool.get(False)
+                return self._user_agent_pool.get(block=False)
             except Queue.Empty:
                 return pycurl.Curl()
 
         def _put_user_agent(self, ua):
             try:
                 ua.reset()
-                self._user_agent_pool.put(ua, False)
+                self._user_agent_pool.put(ua, block=False)
             except:
                 ua.close()
 
-        @staticmethod
-        def _socket_open(family, socktype, protocol, address=None):
+        def _socket_open(self, *args, **kwargs):
+            if len(args) + len(kwargs) == 2:
+                return self._socket_open_pycurl_7_21_5(*args, **kwargs)
+            else:
+                return self._socket_open_pycurl_7_19_3(*args, **kwargs)
+
+        def _socket_open_pycurl_7_19_3(self, family, socktype, protocol, address=None):
+            return self._socket_open_pycurl_7_21_5(
+                purpose=None,
+                address=collections.namedtuple(
+                    'Address', ['family', 'socktype', 'protocol', 'addr'],
+                )(family, socktype, protocol, address))
+
+        def _socket_open_pycurl_7_21_5(self, purpose, address):
             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
-            s = socket.socket(family, socktype, protocol)
+            s = socket.socket(address.family, address.socktype, address.protocol)
             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
             # Will throw invalid protocol error on mac. This test prevents that.
             if hasattr(socket, 'TCP_KEEPIDLE'):
                 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
+            self._socket = s
             return s
 
         def get(self, locator, method="GET", timeout=None):
@@ -329,7 +344,8 @@ class KeepClient(object):
                     self._headers = {}
                     response_body = cStringIO.StringIO()
                     curl.setopt(pycurl.NOSIGNAL, 1)
-                    curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
+                    curl.setopt(pycurl.OPENSOCKETFUNCTION,
+                                lambda *args, **kwargs: self._socket_open(*args, **kwargs))
                     curl.setopt(pycurl.URL, url.encode('utf-8'))
                     curl.setopt(pycurl.HTTPHEADER, [
                         '{}: {}'.format(k,v) for k,v in self.get_headers.iteritems()])
@@ -343,6 +359,10 @@ class KeepClient(object):
                         curl.perform()
                     except Exception as e:
                         raise arvados.errors.HttpError(0, str(e))
+                    finally:
+                        if self._socket:
+                            self._socket.close()
+                            self._socket = None
                     self._result = {
                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
                         'body': response_body.getvalue(),
@@ -407,7 +427,8 @@ class KeepClient(object):
                     body_reader = cStringIO.StringIO(body)
                     response_body = cStringIO.StringIO()
                     curl.setopt(pycurl.NOSIGNAL, 1)
-                    curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
+                    curl.setopt(pycurl.OPENSOCKETFUNCTION,
+                                lambda *args, **kwargs: self._socket_open(*args, **kwargs))
                     curl.setopt(pycurl.URL, url.encode('utf-8'))
                     # Using UPLOAD tells cURL to wait for a "go ahead" from the
                     # Keep server (in the form of a HTTP/1.1 "100 Continue"
@@ -427,6 +448,10 @@ class KeepClient(object):
                         curl.perform()
                     except Exception as e:
                         raise arvados.errors.HttpError(0, str(e))
+                    finally:
+                        if self._socket:
+                            self._socket.close()
+                            self._socket = None
                     self._result = {
                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
                         'body': response_body.getvalue(),
@@ -527,13 +552,24 @@ class KeepClient(object):
             with self.pending_tries_notification:
                 while True:
                     if self.pending_copies() < 1:
+                        # This notify_all() is unnecessary --
+                        # write_success() already called notify_all()
+                        # when pending<1 became true, so it's not
+                        # possible for any other thread to be in
+                        # wait() now -- but it's cheap insurance
+                        # against deadlock so we do it anyway:
+                        self.pending_tries_notification.notify_all()
                         # Drain the queue and then raise Queue.Empty
                         while True:
                             self.get_nowait()
                             self.task_done()
                     elif self.pending_tries > 0:
+                        service, service_root = self.get_nowait()
+                        if service.finished():
+                            self.task_done()
+                            continue
                         self.pending_tries -= 1
-                        return self.get_nowait()
+                        return service, service_root
                     elif self.empty():
                         self.pending_tries_notification.notify_all()
                         raise Queue.Empty
@@ -576,6 +612,8 @@ class KeepClient(object):
     
     
     class KeepWriterThread(threading.Thread):
+        TaskFailed = RuntimeError()
+
         def __init__(self, queue, data, data_hash, timeout=None):
             super(KeepClient.KeepWriterThread, self).__init__()
             self.timeout = timeout
@@ -593,7 +631,8 @@ class KeepClient(object):
                 try:
                     locator, copies = self.do_task(service, service_root)
                 except Exception as e:
-                    _logger.exception("Exception in KeepWriterThread")
+                    if e is not self.TaskFailed:
+                        _logger.exception("Exception in KeepWriterThread")
                     self.queue.write_fail(service)
                 else:
                     self.queue.write_success(locator, copies)
@@ -601,8 +640,6 @@ class KeepClient(object):
                     self.queue.task_done()
 
         def do_task(self, service, service_root):
-            if service.finished():
-                return
             success = bool(service.put(self.data_hash,
                                         self.data,
                                         timeout=self.timeout))
@@ -614,7 +651,7 @@ class KeepClient(object):
                                   self.data_hash,
                                   result['status_code'],
                                   result['body'])
-                raise RuntimeError()
+                raise self.TaskFailed
 
             _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
                           str(threading.current_thread()),