Merge branch 'master' of git.curoverse.com:arvados into 13076-r-autogen-api
[arvados.git] / sdk / python / arvados / collection.py
index 1a427814cf4d5bc13ffbeca75f7c22c87134962c..8fb90c944396967e6863a38daee27ffe3cb8b9ec 100644 (file)
@@ -1,4 +1,12 @@
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
 from __future__ import absolute_import
+from future.utils import listitems, listvalues, viewkeys
+from builtins import str
+from past.builtins import basestring
+from builtins import object
 import functools
 import logging
 import os
@@ -26,6 +34,8 @@ from arvados.retry import retry_method
 _logger = logging.getLogger('arvados.collection')
 
 class CollectionBase(object):
+    """Abstract base class for Collection classes."""
+
     def __enter__(self):
         return self
 
@@ -83,6 +93,8 @@ class _WriterFile(_FileLikeObjectBase):
 
 
 class CollectionWriter(CollectionBase):
+    """Deprecated, use Collection instead."""
+
     def __init__(self, api_client=None, num_retries=0, replication=None):
         """Instantiate a CollectionWriter.
 
@@ -217,7 +229,11 @@ class CollectionWriter(CollectionBase):
         self.do_queued_work()
 
     def write(self, newdata):
-        if hasattr(newdata, '__iter__'):
+        if isinstance(newdata, bytes):
+            pass
+        elif isinstance(newdata, str):
+            newdata = newdata.encode()
+        elif hasattr(newdata, '__iter__'):
             for s in newdata:
                 self.write(s)
             return
@@ -257,7 +273,7 @@ class CollectionWriter(CollectionBase):
         return self._last_open
 
     def flush_data(self):
-        data_buffer = ''.join(self._data_buffer)
+        data_buffer = b''.join(self._data_buffer)
         if data_buffer:
             self._current_stream_locators.append(
                 self._my_keep().put(
@@ -347,11 +363,12 @@ class CollectionWriter(CollectionBase):
         sending manifest_text() to the API server's "create
         collection" endpoint.
         """
-        return self._my_keep().put(self.manifest_text(), copies=self.replication)
+        return self._my_keep().put(self.manifest_text().encode(),
+                                   copies=self.replication)
 
     def portable_data_hash(self):
-        stripped = self.stripped_manifest()
-        return hashlib.md5(stripped).hexdigest() + '+' + str(len(stripped))
+        stripped = self.stripped_manifest().encode()
+        return '{}+{}'.format(hashlib.md5(stripped).hexdigest(), len(stripped))
 
     def manifest_text(self):
         self.finish_current_stream()
@@ -373,8 +390,18 @@ class CollectionWriter(CollectionBase):
             ret += locators
         return ret
 
+    def save_new(self, name=None):
+        return self._api_client.collections().create(
+            ensure_unique_name=True,
+            body={
+                'name': name,
+                'manifest_text': self.manifest_text(),
+            }).execute(num_retries=self.num_retries)
+
 
 class ResumableCollectionWriter(CollectionWriter):
+    """Deprecated, use Collection instead."""
+
     STATE_PROPS = ['_current_stream_files', '_current_stream_length',
                    '_current_stream_locators', '_current_stream_name',
                    '_current_file_name', '_current_file_pos', '_close_file',
@@ -419,7 +446,7 @@ class ResumableCollectionWriter(CollectionWriter):
         return writer
 
     def check_dependencies(self):
-        for path, orig_stat in self._dependencies.items():
+        for path, orig_stat in listitems(self._dependencies):
             if not S_ISREG(orig_stat[ST_MODE]):
                 raise errors.StaleWriterStateError("{} not file".format(path))
             try:
@@ -613,7 +640,12 @@ class RichCollectionBase(CollectionBase):
         :path:
           path to a file in the collection
         :mode:
-          one of "r", "r+", "w", "w+", "a", "a+"
+          a string consisting of "r", "w", or "a", optionally followed
+          by "b" or "t", optionally followed by "+".
+          :"b":
+            binary mode: write() accepts bytes, read() returns bytes.
+          :"t":
+            text mode (default): write() accepts strings, read() returns strings.
           :"r":
             opens for reading
           :"r+":
@@ -625,33 +657,28 @@ class RichCollectionBase(CollectionBase):
             the end of the file.  Writing does not affect the file pointer for
             reading.
         """
-        mode = mode.replace("b", "")
-        if len(mode) == 0 or mode[0] not in ("r", "w", "a"):
-            raise errors.ArgumentError("Bad mode '%s'" % mode)
-        create = (mode != "r")
 
-        if create and not self.writable():
-            raise IOError(errno.EROFS, "Collection is read only")
+        if not re.search(r'^[rwa][bt]?\+?$', mode):
+            raise errors.ArgumentError("Invalid mode {!r}".format(mode))
 
-        if create:
-            arvfile = self.find_or_create(path, FILE)
-        else:
+        if mode[0] == 'r' and '+' not in mode:
+            fclass = ArvadosFileReader
             arvfile = self.find(path)
+        elif not self.writable():
+            raise IOError(errno.EROFS, "Collection is read only")
+        else:
+            fclass = ArvadosFileWriter
+            arvfile = self.find_or_create(path, FILE)
 
         if arvfile is None:
             raise IOError(errno.ENOENT, "File not found", path)
         if not isinstance(arvfile, ArvadosFile):
             raise IOError(errno.EISDIR, "Is a directory", path)
 
-        if mode[0] == "w":
+        if mode[0] == 'w':
             arvfile.truncate(0)
 
-        name = os.path.basename(path)
-
-        if mode == "r":
-            return ArvadosFileReader(arvfile, num_retries=self.num_retries)
-        else:
-            return ArvadosFileWriter(arvfile, mode, num_retries=self.num_retries)
+        return fclass(arvfile, mode=mode, num_retries=self.num_retries)
 
     def modified(self):
         """Determine if the collection has been modified since last commited."""
@@ -673,7 +700,7 @@ class RichCollectionBase(CollectionBase):
         if value == self._committed:
             return
         if value:
-            for k,v in self._items.items():
+            for k,v in listitems(self._items):
                 v.set_committed(True)
             self._committed = True
         else:
@@ -684,7 +711,7 @@ class RichCollectionBase(CollectionBase):
     @synchronized
     def __iter__(self):
         """Iterate over names of files and collections contained in this collection."""
-        return iter(self._items.keys())
+        return iter(viewkeys(self._items))
 
     @synchronized
     def __getitem__(self, k):
@@ -721,12 +748,12 @@ class RichCollectionBase(CollectionBase):
     @synchronized
     def values(self):
         """Get a list of files and collection objects directly contained in this collection."""
-        return self._items.values()
+        return listvalues(self._items)
 
     @synchronized
     def items(self):
         """Get a list of (name, object) tuples directly contained in this collection."""
-        return self._items.items()
+        return listitems(self._items)
 
     def exists(self, path):
         """Test if there is a file or collection at `path`."""
@@ -759,7 +786,7 @@ class RichCollectionBase(CollectionBase):
             item.remove(pathcomponents[1])
 
     def _clonefrom(self, source):
-        for k,v in source.items():
+        for k,v in listitems(source):
             self._items[k] = v.clone(self, k)
 
     def clone(self):
@@ -1075,8 +1102,8 @@ class RichCollectionBase(CollectionBase):
             # then return API server's PDH response.
             return self._portable_data_hash
         else:
-            stripped = self.portable_manifest_text()
-            return hashlib.md5(stripped).hexdigest() + '+' + str(len(stripped))
+            stripped = self.portable_manifest_text().encode()
+            return '{}+{}'.format(hashlib.md5(stripped).hexdigest(), len(stripped))
 
     @synchronized
     def subscribe(self, callback):
@@ -1117,7 +1144,7 @@ class RichCollectionBase(CollectionBase):
     @synchronized
     def flush(self):
         """Flush bufferblocks to Keep."""
-        for e in self.values():
+        for e in listvalues(self):
             e.flush()
 
 
@@ -1172,8 +1199,9 @@ class Collection(RichCollectionBase):
         """Collection constructor.
 
         :manifest_locator_or_text:
-          One of Arvados collection UUID, block locator of
-          a manifest, raw manifest text, or None (to create an empty collection).
+          An Arvados collection UUID, portable data hash, raw manifest
+          text, or (if creating an empty collection) None.
+
         :parent:
           the parent Collection, may be None.
 
@@ -1312,65 +1340,25 @@ class Collection(RichCollectionBase):
         # it.  If instantiation fails, we'll fall back to the except
         # clause, just like any other Collection lookup
         # failure. Return an exception, or None if successful.
-        try:
-            self._remember_api_response(self._my_api().collections().get(
-                uuid=self._manifest_locator).execute(
-                    num_retries=self.num_retries))
-            self._manifest_text = self._api_response['manifest_text']
-            self._portable_data_hash = self._api_response['portable_data_hash']
-            # If not overriden via kwargs, we should try to load the
-            # replication_desired from the API server
-            if self.replication_desired is None:
-                self.replication_desired = self._api_response.get('replication_desired', None)
-            return None
-        except Exception as e:
-            return e
-
-    def _populate_from_keep(self):
-        # Retrieve a manifest directly from Keep. This has a chance of
-        # working if [a] the locator includes a permission signature
-        # or [b] the Keep services are operating in world-readable
-        # mode. Return an exception, or None if successful.
-        try:
-            self._manifest_text = self._my_keep().get(
-                self._manifest_locator, num_retries=self.num_retries)
-        except Exception as e:
-            return e
+        self._remember_api_response(self._my_api().collections().get(
+            uuid=self._manifest_locator).execute(
+                num_retries=self.num_retries))
+        self._manifest_text = self._api_response['manifest_text']
+        self._portable_data_hash = self._api_response['portable_data_hash']
+        # If not overriden via kwargs, we should try to load the
+        # replication_desired from the API server
+        if self.replication_desired is None:
+            self.replication_desired = self._api_response.get('replication_desired', None)
 
     def _populate(self):
-        if self._manifest_locator is None and self._manifest_text is None:
-            return
-        error_via_api = None
-        error_via_keep = None
-        should_try_keep = ((self._manifest_text is None) and
-                           arvados.util.keep_locator_pattern.match(
-                               self._manifest_locator))
-        if ((self._manifest_text is None) and
-            arvados.util.signed_locator_pattern.match(self._manifest_locator)):
-            error_via_keep = self._populate_from_keep()
         if self._manifest_text is None:
-            error_via_api = self._populate_from_api_server()
-            if error_via_api is not None and not should_try_keep:
-                raise error_via_api
-        if ((self._manifest_text is None) and
-            not error_via_keep and
-            should_try_keep):
-            # Looks like a keep locator, and we didn't already try keep above
-            error_via_keep = self._populate_from_keep()
-        if self._manifest_text is None:
-            # Nothing worked!
-            raise errors.NotFoundError(
-                ("Failed to retrieve collection '{}' " +
-                 "from either API server ({}) or Keep ({})."
-                 ).format(
-                    self._manifest_locator,
-                    error_via_api,
-                    error_via_keep))
-        # populate
+            if self._manifest_locator is None:
+                return
+            else:
+                self._populate_from_api_server()
         self._baseline_manifest = self._manifest_text
         self._import_manifest(self._manifest_text)
 
-
     def _has_collection_uuid(self):
         return self._manifest_locator is not None and re.match(arvados.util.collection_uuid_pattern, self._manifest_locator)
 
@@ -1549,6 +1537,10 @@ class Collection(RichCollectionBase):
 
         return text
 
+    _token_re = re.compile(r'(\S+)(\s+|$)')
+    _block_re = re.compile(r'[0-9a-f]{32}\+(\d+)(\+\S+)*')
+    _segment_re = re.compile(r'(\d+):(\d+):(\S+)')
+
     @synchronized
     def _import_manifest(self, manifest_text):
         """Import a manifest into a `Collection`.
@@ -1567,7 +1559,7 @@ class Collection(RichCollectionBase):
         stream_name = None
         state = STREAM_NAME
 
-        for token_and_separator in re.finditer(r'(\S+)(\s+|$)', manifest_text):
+        for token_and_separator in self._token_re.finditer(manifest_text):
             tok = token_and_separator.group(1)
             sep = token_and_separator.group(2)
 
@@ -1582,19 +1574,19 @@ class Collection(RichCollectionBase):
                 continue
 
             if state == BLOCKS:
-                block_locator = re.match(r'[0-9a-f]{32}\+(\d+)(\+\S+)*', tok)
+                block_locator = self._block_re.match(tok)
                 if block_locator:
-                    blocksize = long(block_locator.group(1))
+                    blocksize = int(block_locator.group(1))
                     blocks.append(Range(tok, streamoffset, blocksize, 0))
                     streamoffset += blocksize
                 else:
                     state = SEGMENTS
 
             if state == SEGMENTS:
-                file_segment = re.search(r'^(\d+):(\d+):(\S+)', tok)
+                file_segment = self._segment_re.match(tok)
                 if file_segment:
-                    pos = long(file_segment.group(1))
-                    size = long(file_segment.group(2))
+                    pos = int(file_segment.group(1))
+                    size = int(file_segment.group(2))
                     name = file_segment.group(3).replace('\\040', ' ')
                     filepath = os.path.join(stream_name, name)
                     afile = self.find_or_create(filepath, FILE)
@@ -1671,9 +1663,8 @@ class Subcollection(RichCollectionBase):
 class CollectionReader(Collection):
     """A read-only collection object.
 
-    Initialize from an api collection record locator, a portable data hash of a
-    manifest, or raw manifest text.  See `Collection` constructor for detailed
-    options.
+    Initialize from a collection UUID or portable data hash, or raw
+    manifest text.  See `Collection` constructor for detailed options.
 
     """
     def __init__(self, manifest_locator_or_text, *args, **kwargs):