17755: Merge branch 'main' into 17755-add-singularity-to-compute-image
[arvados.git] / sdk / python / arvados / keep.py
index 71e101cf4c5073d40e78f73c0bf46a9ff231f937..9dfe0436dec9bdf22eb71ad9bfe2e8a201ee3ab6 100644 (file)
@@ -4,6 +4,7 @@
 
 from __future__ import absolute_import
 from __future__ import division
 
 from __future__ import absolute_import
 from __future__ import division
+import copy
 from future import standard_library
 from future.utils import native_str
 standard_library.install_aliases()
 from future import standard_library
 from future.utils import native_str
 standard_library.install_aliases()
@@ -375,9 +376,11 @@ class KeepClient(object):
                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
                     if self.insecure:
                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
                     if self.insecure:
                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
+                    else:
+                        curl.setopt(pycurl.CAINFO, arvados.util.ca_certs_path())
                     if method == "HEAD":
                         curl.setopt(pycurl.NOBODY, True)
                     if method == "HEAD":
                         curl.setopt(pycurl.NOBODY, True)
-                    self._setcurltimeouts(curl, timeout)
+                    self._setcurltimeouts(curl, timeout, method=="HEAD")
 
                     try:
                         curl.perform()
 
                     try:
                         curl.perform()
@@ -421,6 +424,10 @@ class KeepClient(object):
                 _logger.info("HEAD %s: %s bytes",
                          self._result['status_code'],
                          self._result.get('content-length'))
                 _logger.info("HEAD %s: %s bytes",
                          self._result['status_code'],
                          self._result.get('content-length'))
+                if self._result['headers'].get('x-keep-locator'):
+                    # This is a response to a remote block copy request, return
+                    # the local copy block locator.
+                    return self._result['headers'].get('x-keep-locator')
                 return True
 
             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
                 return True
 
             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
@@ -440,7 +447,9 @@ class KeepClient(object):
                 return None
             return self._result['body']
 
                 return None
             return self._result['body']
 
-        def put(self, hash_s, body, timeout=None):
+        def put(self, hash_s, body, timeout=None, headers={}):
+            put_headers = copy.copy(self.put_headers)
+            put_headers.update(headers)
             url = self.root + hash_s
             _logger.debug("Request: PUT %s", url)
             curl = self._get_user_agent()
             url = self.root + hash_s
             _logger.debug("Request: PUT %s", url)
             curl = self._get_user_agent()
@@ -464,11 +473,13 @@ class KeepClient(object):
                     curl.setopt(pycurl.INFILESIZE, len(body))
                     curl.setopt(pycurl.READFUNCTION, body_reader.read)
                     curl.setopt(pycurl.HTTPHEADER, [
                     curl.setopt(pycurl.INFILESIZE, len(body))
                     curl.setopt(pycurl.READFUNCTION, body_reader.read)
                     curl.setopt(pycurl.HTTPHEADER, [
-                        '{}: {}'.format(k,v) for k,v in self.put_headers.items()])
+                        '{}: {}'.format(k,v) for k,v in put_headers.items()])
                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
                     if self.insecure:
                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
                     if self.insecure:
                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
+                    else:
+                        curl.setopt(pycurl.CAINFO, arvados.util.ca_certs_path())
                     self._setcurltimeouts(curl, timeout)
                     try:
                         curl.perform()
                     self._setcurltimeouts(curl, timeout)
                     try:
                         curl.perform()
@@ -512,7 +523,7 @@ class KeepClient(object):
                 self.upload_counter.add(len(body))
             return True
 
                 self.upload_counter.add(len(body))
             return True
 
-        def _setcurltimeouts(self, curl, timeouts):
+        def _setcurltimeouts(self, curl, timeouts, ignore_bandwidth=False):
             if not timeouts:
                 return
             elif isinstance(timeouts, tuple):
             if not timeouts:
                 return
             elif isinstance(timeouts, tuple):
@@ -525,8 +536,9 @@ class KeepClient(object):
                 conn_t, xfer_t = (timeouts, timeouts)
                 bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
             curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
                 conn_t, xfer_t = (timeouts, timeouts)
                 bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
             curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
-            curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
-            curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
+            if not ignore_bandwidth:
+                curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
+                curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
 
         def _headerfunction(self, header_line):
             if isinstance(header_line, bytes):
 
         def _headerfunction(self, header_line):
             if isinstance(header_line, bytes):
@@ -550,18 +562,30 @@ class KeepClient(object):
 
 
     class KeepWriterQueue(queue.Queue):
 
 
     class KeepWriterQueue(queue.Queue):
-        def __init__(self, copies):
+        def __init__(self, copies, classes=[]):
             queue.Queue.__init__(self) # Old-style superclass
             self.wanted_copies = copies
             queue.Queue.__init__(self) # Old-style superclass
             self.wanted_copies = copies
+            self.wanted_storage_classes = classes
             self.successful_copies = 0
             self.successful_copies = 0
+            self.confirmed_storage_classes = {}
             self.response = None
             self.response = None
-            self.successful_copies_lock = threading.Lock()
-            self.pending_tries = copies
+            self.storage_classes_tracking = True
+            self.queue_data_lock = threading.RLock()
+            self.pending_tries = max(copies, len(classes))
             self.pending_tries_notification = threading.Condition()
 
             self.pending_tries_notification = threading.Condition()
 
-        def write_success(self, response, replicas_nr):
-            with self.successful_copies_lock:
+        def write_success(self, response, replicas_nr, classes_confirmed):
+            with self.queue_data_lock:
                 self.successful_copies += replicas_nr
                 self.successful_copies += replicas_nr
+                if classes_confirmed is None:
+                    self.storage_classes_tracking = False
+                elif self.storage_classes_tracking:
+                    for st_class, st_copies in classes_confirmed.items():
+                        try:
+                            self.confirmed_storage_classes[st_class] += st_copies
+                        except KeyError:
+                            self.confirmed_storage_classes[st_class] = st_copies
+                    self.pending_tries = max(self.wanted_copies - self.successful_copies, len(self.pending_classes()))
                 self.response = response
             with self.pending_tries_notification:
                 self.pending_tries_notification.notify_all()
                 self.response = response
             with self.pending_tries_notification:
                 self.pending_tries_notification.notify_all()
@@ -572,13 +596,31 @@ class KeepClient(object):
                 self.pending_tries_notification.notify()
 
         def pending_copies(self):
                 self.pending_tries_notification.notify()
 
         def pending_copies(self):
-            with self.successful_copies_lock:
+            with self.queue_data_lock:
                 return self.wanted_copies - self.successful_copies
 
                 return self.wanted_copies - self.successful_copies
 
+        def satisfied_classes(self):
+            with self.queue_data_lock:
+                if not self.storage_classes_tracking:
+                    # Notifies disabled storage classes expectation to
+                    # the outer loop.
+                    return None
+            return list(set(self.wanted_storage_classes) - set(self.pending_classes()))
+
+        def pending_classes(self):
+            with self.queue_data_lock:
+                if (not self.storage_classes_tracking) or (self.wanted_storage_classes is None):
+                    return []
+                unsatisfied_classes = copy.copy(self.wanted_storage_classes)
+                for st_class, st_copies in self.confirmed_storage_classes.items():
+                    if st_class in unsatisfied_classes and st_copies >= self.wanted_copies:
+                        unsatisfied_classes.remove(st_class)
+                return unsatisfied_classes
+
         def get_next_task(self):
             with self.pending_tries_notification:
                 while True:
         def get_next_task(self):
             with self.pending_tries_notification:
                 while True:
-                    if self.pending_copies() < 1:
+                    if self.pending_copies() < 1 and len(self.pending_classes()) == 0:
                         # This notify_all() is unnecessary --
                         # write_success() already called notify_all()
                         # when pending<1 became true, so it's not
                         # This notify_all() is unnecessary --
                         # write_success() already called notify_all()
                         # when pending<1 became true, so it's not
@@ -605,16 +647,15 @@ class KeepClient(object):
 
 
     class KeepWriterThreadPool(object):
 
 
     class KeepWriterThreadPool(object):
-        def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None):
+        def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None, classes=[]):
             self.total_task_nr = 0
             self.total_task_nr = 0
-            self.wanted_copies = copies
             if (not max_service_replicas) or (max_service_replicas >= copies):
                 num_threads = 1
             else:
                 num_threads = int(math.ceil(1.0*copies/max_service_replicas))
             _logger.debug("Pool max threads is %d", num_threads)
             self.workers = []
             if (not max_service_replicas) or (max_service_replicas >= copies):
                 num_threads = 1
             else:
                 num_threads = int(math.ceil(1.0*copies/max_service_replicas))
             _logger.debug("Pool max threads is %d", num_threads)
             self.workers = []
-            self.queue = KeepClient.KeepWriterQueue(copies)
+            self.queue = KeepClient.KeepWriterQueue(copies, classes)
             # Create workers
             for _ in range(num_threads):
                 w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
             # Create workers
             for _ in range(num_threads):
                 w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
@@ -625,7 +666,7 @@ class KeepClient(object):
             self.total_task_nr += 1
 
         def done(self):
             self.total_task_nr += 1
 
         def done(self):
-            return self.queue.successful_copies
+            return self.queue.successful_copies, self.queue.satisfied_classes()
 
         def join(self):
             # Start workers
 
         def join(self):
             # Start workers
@@ -639,7 +680,7 @@ class KeepClient(object):
 
 
     class KeepWriterThread(threading.Thread):
 
 
     class KeepWriterThread(threading.Thread):
-        TaskFailed = RuntimeError()
+        class TaskFailed(RuntimeError): pass
 
         def __init__(self, queue, data, data_hash, timeout=None):
             super(KeepClient.KeepWriterThread, self).__init__()
 
         def __init__(self, queue, data, data_hash, timeout=None):
             super(KeepClient.KeepWriterThread, self).__init__()
@@ -656,20 +697,26 @@ class KeepClient(object):
                 except queue.Empty:
                     return
                 try:
                 except queue.Empty:
                     return
                 try:
-                    locator, copies = self.do_task(service, service_root)
+                    locator, copies, classes = self.do_task(service, service_root)
                 except Exception as e:
                 except Exception as e:
-                    if e is not self.TaskFailed:
+                    if not isinstance(e, self.TaskFailed):
                         _logger.exception("Exception in KeepWriterThread")
                     self.queue.write_fail(service)
                 else:
                         _logger.exception("Exception in KeepWriterThread")
                     self.queue.write_fail(service)
                 else:
-                    self.queue.write_success(locator, copies)
+                    self.queue.write_success(locator, copies, classes)
                 finally:
                     self.queue.task_done()
 
         def do_task(self, service, service_root):
                 finally:
                     self.queue.task_done()
 
         def do_task(self, service, service_root):
+            classes = self.queue.pending_classes()
+            headers = {}
+            if len(classes) > 0:
+                classes.sort()
+                headers['X-Keep-Storage-Classes'] = ', '.join(classes)
             success = bool(service.put(self.data_hash,
                                         self.data,
             success = bool(service.put(self.data_hash,
                                         self.data,
-                                        timeout=self.timeout))
+                                        timeout=self.timeout,
+                                        headers=headers))
             result = service.last_result()
 
             if not success:
             result = service.last_result()
 
             if not success:
@@ -678,7 +725,7 @@ class KeepClient(object):
                                   self.data_hash,
                                   result['status_code'],
                                   result['body'])
                                   self.data_hash,
                                   result['status_code'],
                                   result['body'])
-                raise self.TaskFailed
+                raise self.TaskFailed()
 
             _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
                           str(threading.current_thread()),
 
             _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
                           str(threading.current_thread()),
@@ -690,7 +737,18 @@ class KeepClient(object):
             except (KeyError, ValueError):
                 replicas_stored = 1
 
             except (KeyError, ValueError):
                 replicas_stored = 1
 
-            return result['body'].strip(), replicas_stored
+            classes_confirmed = {}
+            try:
+                scch = result['headers']['x-keep-storage-classes-confirmed']
+                for confirmation in scch.replace(' ', '').split(','):
+                    if '=' in confirmation:
+                        stored_class, stored_copies = confirmation.split('=')[:2]
+                        classes_confirmed[stored_class] = int(stored_copies)
+            except (KeyError, ValueError):
+                # Storage classes confirmed header missing or corrupt
+                classes_confirmed = None
+
+            return result['body'].strip(), replicas_stored, classes_confirmed
 
 
     def __init__(self, api_client=None, proxy=None,
 
 
     def __init__(self, api_client=None, proxy=None,
@@ -783,9 +841,12 @@ class KeepClient(object):
         self.get_counter = Counter()
         self.hits_counter = Counter()
         self.misses_counter = Counter()
         self.get_counter = Counter()
         self.hits_counter = Counter()
         self.misses_counter = Counter()
+        self._storage_classes_unsupported_warning = False
+        self._default_classes = []
 
         if local_store:
             self.local_store = local_store
 
         if local_store:
             self.local_store = local_store
+            self.head = self.local_store_head
             self.get = self.local_store_get
             self.put = self.local_store_put
         else:
             self.get = self.local_store_get
             self.put = self.local_store_put
         else:
@@ -822,6 +883,12 @@ class KeepClient(object):
                 self._writable_services = None
                 self.using_proxy = None
                 self._static_services_list = False
                 self._writable_services = None
                 self.using_proxy = None
                 self._static_services_list = False
+                try:
+                    self._default_classes = [
+                        k for k, v in self.api_client.config()['StorageClasses'].items() if v['Default']]
+                except KeyError:
+                    # We're talking to an old cluster
+                    pass
 
     def current_timeout(self, attempt_number):
         """Return the appropriate timeout to use for this client.
 
     def current_timeout(self, attempt_number):
         """Return the appropriate timeout to use for this client.
@@ -975,6 +1042,11 @@ class KeepClient(object):
         else:
             return None
 
         else:
             return None
 
+    def refresh_signature(self, loc):
+        """Ask Keep to get the remote block and return its local signature"""
+        now = datetime.datetime.utcnow().isoformat("T") + 'Z'
+        return self.head(loc, headers={'X-Keep-Signature': 'local, {}'.format(now)})
+
     @retry.retry_method
     def head(self, loc_s, **kwargs):
         return self._get_or_head(loc_s, method="HEAD", **kwargs)
     @retry.retry_method
     def head(self, loc_s, **kwargs):
         return self._get_or_head(loc_s, method="HEAD", **kwargs)
@@ -983,7 +1055,7 @@ class KeepClient(object):
     def get(self, loc_s, **kwargs):
         return self._get_or_head(loc_s, method="GET", **kwargs)
 
     def get(self, loc_s, **kwargs):
         return self._get_or_head(loc_s, method="GET", **kwargs)
 
-    def _get_or_head(self, loc_s, method="GET", num_retries=None, request_id=None):
+    def _get_or_head(self, loc_s, method="GET", num_retries=None, request_id=None, headers=None):
         """Get data from Keep.
 
         This method fetches one or more blocks of data from Keep.  It
         """Get data from Keep.
 
         This method fetches one or more blocks of data from Keep.  It
@@ -1024,11 +1096,11 @@ class KeepClient(object):
 
             self.misses_counter.add(1)
 
 
             self.misses_counter.add(1)
 
-            headers = {
-                'X-Request-Id': (request_id or
-                                 (hasattr(self, 'api_client') and self.api_client.request_id) or
-                                 arvados.util.new_request_id()),
-            }
+            if headers is None:
+                headers = {}
+            headers['X-Request-Id'] = (request_id or
+                                        (hasattr(self, 'api_client') and self.api_client.request_id) or
+                                        arvados.util.new_request_id())
 
             # If the locator has hints specifying a prefix (indicating a
             # remote keepproxy) or the UUID of a local gateway service,
 
             # If the locator has hints specifying a prefix (indicating a
             # remote keepproxy) or the UUID of a local gateway service,
@@ -1085,10 +1157,7 @@ class KeepClient(object):
 
             # Always cache the result, then return it if we succeeded.
             if loop.success():
 
             # Always cache the result, then return it if we succeeded.
             if loop.success():
-                if method == "HEAD":
-                    return True
-                else:
-                    return blob
+                return blob
         finally:
             if slot is not None:
                 slot.set(blob)
         finally:
             if slot is not None:
                 slot.set(blob)
@@ -1109,10 +1178,10 @@ class KeepClient(object):
                 "{} not found".format(loc_s), service_errors)
         else:
             raise arvados.errors.KeepReadError(
                 "{} not found".format(loc_s), service_errors)
         else:
             raise arvados.errors.KeepReadError(
-                "failed to read {}".format(loc_s), service_errors, label="service")
+                "failed to read {} after {}".format(loc_s, loop.attempts_str()), service_errors, label="service")
 
     @retry.retry_method
 
     @retry.retry_method
-    def put(self, data, copies=2, num_retries=None, request_id=None):
+    def put(self, data, copies=2, num_retries=None, request_id=None, classes=None):
         """Save data in Keep.
 
         This method will get a list of Keep services from the API server, and
         """Save data in Keep.
 
         This method will get a list of Keep services from the API server, and
@@ -1129,8 +1198,12 @@ class KeepClient(object):
           *each* Keep server if it returns temporary failures, with
           exponential backoff.  The default value is set when the
           KeepClient is initialized.
           *each* Keep server if it returns temporary failures, with
           exponential backoff.  The default value is set when the
           KeepClient is initialized.
+        * classes: An optional list of storage class names where copies should
+          be written.
         """
 
         """
 
+        classes = classes or self._default_classes
+
         if not isinstance(data, bytes):
             data = data.encode()
 
         if not isinstance(data, bytes):
             data = data.encode()
 
@@ -1151,7 +1224,8 @@ class KeepClient(object):
         roots_map = {}
         loop = retry.RetryLoop(num_retries, self._check_loop_result,
                                backoff_start=2)
         roots_map = {}
         loop = retry.RetryLoop(num_retries, self._check_loop_result,
                                backoff_start=2)
-        done = 0
+        done_copies = 0
+        done_classes = []
         for tries_left in loop:
             try:
                 sorted_roots = self.map_new_services(
         for tries_left in loop:
             try:
                 sorted_roots = self.map_new_services(
@@ -1163,19 +1237,39 @@ class KeepClient(object):
                 loop.save_result(error)
                 continue
 
                 loop.save_result(error)
                 continue
 
+            pending_classes = []
+            if done_classes is not None:
+                pending_classes = list(set(classes) - set(done_classes))
             writer_pool = KeepClient.KeepWriterThreadPool(data=data,
                                                         data_hash=data_hash,
             writer_pool = KeepClient.KeepWriterThreadPool(data=data,
                                                         data_hash=data_hash,
-                                                        copies=copies - done,
+                                                        copies=copies - done_copies,
                                                         max_service_replicas=self.max_replicas_per_service,
                                                         max_service_replicas=self.max_replicas_per_service,
-                                                        timeout=self.current_timeout(num_retries - tries_left))
+                                                        timeout=self.current_timeout(num_retries - tries_left),
+                                                        classes=pending_classes)
             for service_root, ks in [(root, roots_map[root])
                                      for root in sorted_roots]:
                 if ks.finished():
                     continue
                 writer_pool.add_task(ks, service_root)
             writer_pool.join()
             for service_root, ks in [(root, roots_map[root])
                                      for root in sorted_roots]:
                 if ks.finished():
                     continue
                 writer_pool.add_task(ks, service_root)
             writer_pool.join()
-            done += writer_pool.done()
-            loop.save_result((done >= copies, writer_pool.total_task_nr))
+            pool_copies, pool_classes = writer_pool.done()
+            done_copies += pool_copies
+            if (done_classes is not None) and (pool_classes is not None):
+                done_classes += pool_classes
+                loop.save_result(
+                    (done_copies >= copies and set(done_classes) == set(classes),
+                    writer_pool.total_task_nr))
+            else:
+                # Old keepstore contacted without storage classes support:
+                # success is determined only by successful copies.
+                #
+                # Disable storage classes tracking from this point forward.
+                if not self._storage_classes_unsupported_warning:
+                    self._storage_classes_unsupported_warning = True
+                    _logger.warning("X-Keep-Storage-Classes header not supported by the cluster")
+                done_classes = None
+                loop.save_result(
+                    (done_copies >= copies, writer_pool.total_task_nr))
 
         if loop.success():
             return writer_pool.response()
 
         if loop.success():
             return writer_pool.response()
@@ -1188,10 +1282,10 @@ class KeepClient(object):
                               for key in sorted_roots
                               if roots_map[key].last_result()['error'])
             raise arvados.errors.KeepWriteError(
                               for key in sorted_roots
                               if roots_map[key].last_result()['error'])
             raise arvados.errors.KeepWriteError(
-                "failed to write {} (wanted {} copies but wrote {})".format(
-                    data_hash, copies, writer_pool.done()), service_errors, label="service")
+                "failed to write {} after {} (wanted {} copies but wrote {})".format(
+                    data_hash, loop.attempts_str(), (copies, classes), writer_pool.done()), service_errors, label="service")
 
 
-    def local_store_put(self, data, copies=1, num_retries=None):
+    def local_store_put(self, data, copies=1, num_retries=None, classes=[]):
         """A stub for put().
 
         This method is used in place of the real put() method when
         """A stub for put().
 
         This method is used in place of the real put() method when
@@ -1223,5 +1317,17 @@ class KeepClient(object):
         with open(os.path.join(self.local_store, locator.md5sum), 'rb') as f:
             return f.read()
 
         with open(os.path.join(self.local_store, locator.md5sum), 'rb') as f:
             return f.read()
 
+    def local_store_head(self, loc_s, num_retries=None):
+        """Companion to local_store_put()."""
+        try:
+            locator = KeepLocator(loc_s)
+        except ValueError:
+            raise arvados.errors.NotFoundError(
+                "Invalid data locator: '%s'" % loc_s)
+        if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
+            return True
+        if os.path.exists(os.path.join(self.local_store, locator.md5sum)):
+            return True
+
     def is_cached(self, locator):
         return self.block_cache.reserve_cache(expect_hash)
     def is_cached(self, locator):
         return self.block_cache.reserve_cache(expect_hash)