Merge branch '8784-dir-listings'
[arvados.git] / sdk / python / arvados / collection.py
index 0d88084340dbe227d81bef0a2537618ed8fd66c2..77312e4d4917a276f00b46e90e13ab13ba0d5ac4 100644 (file)
@@ -1,4 +1,5 @@
 from __future__ import absolute_import
+from future.utils import listitems, listvalues, viewkeys
 from builtins import str
 from past.builtins import basestring
 from builtins import object
@@ -220,7 +221,11 @@ class CollectionWriter(CollectionBase):
         self.do_queued_work()
 
     def write(self, newdata):
-        if hasattr(newdata, '__iter__'):
+        if isinstance(newdata, bytes):
+            pass
+        elif isinstance(newdata, str):
+            newdata = newdata.encode()
+        elif hasattr(newdata, '__iter__'):
             for s in newdata:
                 self.write(s)
             return
@@ -260,7 +265,7 @@ class CollectionWriter(CollectionBase):
         return self._last_open
 
     def flush_data(self):
-        data_buffer = ''.join(self._data_buffer)
+        data_buffer = b''.join(self._data_buffer)
         if data_buffer:
             self._current_stream_locators.append(
                 self._my_keep().put(
@@ -350,11 +355,12 @@ class CollectionWriter(CollectionBase):
         sending manifest_text() to the API server's "create
         collection" endpoint.
         """
-        return self._my_keep().put(self.manifest_text(), copies=self.replication)
+        return self._my_keep().put(self.manifest_text().encode(),
+                                   copies=self.replication)
 
     def portable_data_hash(self):
-        stripped = self.stripped_manifest()
-        return hashlib.md5(stripped).hexdigest() + '+' + str(len(stripped))
+        stripped = self.stripped_manifest().encode()
+        return '{}+{}'.format(hashlib.md5(stripped).hexdigest(), len(stripped))
 
     def manifest_text(self):
         self.finish_current_stream()
@@ -422,7 +428,7 @@ class ResumableCollectionWriter(CollectionWriter):
         return writer
 
     def check_dependencies(self):
-        for path, orig_stat in list(self._dependencies.items()):
+        for path, orig_stat in listitems(self._dependencies):
             if not S_ISREG(orig_stat[ST_MODE]):
                 raise errors.StaleWriterStateError("{} not file".format(path))
             try:
@@ -616,7 +622,12 @@ class RichCollectionBase(CollectionBase):
         :path:
           path to a file in the collection
         :mode:
-          one of "r", "r+", "w", "w+", "a", "a+"
+          a string consisting of "r", "w", or "a", optionally followed
+          by "b" or "t", optionally followed by "+".
+          :"b":
+            binary mode: write() accepts bytes, read() returns bytes.
+          :"t":
+            text mode (default): write() accepts strings, read() returns strings.
           :"r":
             opens for reading
           :"r+":
@@ -628,33 +639,28 @@ class RichCollectionBase(CollectionBase):
             the end of the file.  Writing does not affect the file pointer for
             reading.
         """
-        mode = mode.replace("b", "")
-        if len(mode) == 0 or mode[0] not in ("r", "w", "a"):
-            raise errors.ArgumentError("Bad mode '%s'" % mode)
-        create = (mode != "r")
 
-        if create and not self.writable():
-            raise IOError(errno.EROFS, "Collection is read only")
+        if not re.search(r'^[rwa][bt]?\+?$', mode):
+            raise errors.ArgumentError("Invalid mode {!r}".format(mode))
 
-        if create:
-            arvfile = self.find_or_create(path, FILE)
-        else:
+        if mode[0] == 'r' and '+' not in mode:
+            fclass = ArvadosFileReader
             arvfile = self.find(path)
+        elif not self.writable():
+            raise IOError(errno.EROFS, "Collection is read only")
+        else:
+            fclass = ArvadosFileWriter
+            arvfile = self.find_or_create(path, FILE)
 
         if arvfile is None:
             raise IOError(errno.ENOENT, "File not found", path)
         if not isinstance(arvfile, ArvadosFile):
             raise IOError(errno.EISDIR, "Is a directory", path)
 
-        if mode[0] == "w":
+        if mode[0] == 'w':
             arvfile.truncate(0)
 
-        name = os.path.basename(path)
-
-        if mode == "r":
-            return ArvadosFileReader(arvfile, num_retries=self.num_retries)
-        else:
-            return ArvadosFileWriter(arvfile, mode, num_retries=self.num_retries)
+        return fclass(arvfile, mode=mode, num_retries=self.num_retries)
 
     def modified(self):
         """Determine if the collection has been modified since last commited."""
@@ -676,7 +682,7 @@ class RichCollectionBase(CollectionBase):
         if value == self._committed:
             return
         if value:
-            for k,v in list(self._items.items()):
+            for k,v in listitems(self._items):
                 v.set_committed(True)
             self._committed = True
         else:
@@ -687,7 +693,7 @@ class RichCollectionBase(CollectionBase):
     @synchronized
     def __iter__(self):
         """Iterate over names of files and collections contained in this collection."""
-        return iter(list(self._items.keys()))
+        return iter(viewkeys(self._items))
 
     @synchronized
     def __getitem__(self, k):
@@ -719,17 +725,17 @@ class RichCollectionBase(CollectionBase):
     @synchronized
     def keys(self):
         """Get a list of names of files and collections directly contained in this collection."""
-        return list(self._items.keys())
+        return self._items.keys()
 
     @synchronized
     def values(self):
         """Get a list of files and collection objects directly contained in this collection."""
-        return list(self._items.values())
+        return listvalues(self._items)
 
     @synchronized
     def items(self):
         """Get a list of (name, object) tuples directly contained in this collection."""
-        return list(self._items.items())
+        return listitems(self._items)
 
     def exists(self, path):
         """Test if there is a file or collection at `path`."""
@@ -762,7 +768,7 @@ class RichCollectionBase(CollectionBase):
             item.remove(pathcomponents[1])
 
     def _clonefrom(self, source):
-        for k,v in list(source.items()):
+        for k,v in listitems(source):
             self._items[k] = v.clone(self, k)
 
     def clone(self):
@@ -1078,8 +1084,8 @@ class RichCollectionBase(CollectionBase):
             # then return API server's PDH response.
             return self._portable_data_hash
         else:
-            stripped = self.portable_manifest_text()
-            return hashlib.md5(stripped).hexdigest() + '+' + str(len(stripped))
+            stripped = self.portable_manifest_text().encode()
+            return '{}+{}'.format(hashlib.md5(stripped).hexdigest(), len(stripped))
 
     @synchronized
     def subscribe(self, callback):
@@ -1120,7 +1126,7 @@ class RichCollectionBase(CollectionBase):
     @synchronized
     def flush(self):
         """Flush bufferblocks to Keep."""
-        for e in list(self.values()):
+        for e in listvalues(self):
             e.flush()
 
 
@@ -1336,7 +1342,7 @@ class Collection(RichCollectionBase):
         # mode. Return an exception, or None if successful.
         try:
             self._manifest_text = self._my_keep().get(
-                self._manifest_locator, num_retries=self.num_retries)
+                self._manifest_locator, num_retries=self.num_retries).decode()
         except Exception as e:
             return e