4520: manifest_text() is utf-8 encoded by default so it can be safely put() to
[arvados.git] / sdk / python / arvados / collection.py
index 59dc49fbbc8fbcb21107f8cfca61c40bf6176a7a..42ee16f21c1d8e4c10222dab6ec3521550c0a1fa 100644 (file)
@@ -1,3 +1,4 @@
+import functools
 import logging
 import os
 import re
 import logging
 import os
 import re
@@ -5,8 +6,9 @@ import re
 from collections import deque
 from stat import *
 
 from collections import deque
 from stat import *
 
+from .arvfile import ArvadosFileBase
 from keep import *
 from keep import *
-from stream import *
+from .stream import StreamReader, split
 import config
 import errors
 import util
 import config
 import errors
 import util
@@ -123,6 +125,7 @@ class CollectionReader(CollectionBase):
         else:
             raise errors.ArgumentError(
                 "Argument to CollectionReader must be a manifest or a collection UUID")
         else:
             raise errors.ArgumentError(
                 "Argument to CollectionReader must be a manifest or a collection UUID")
+        self._api_response = None
         self._streams = None
 
     def _populate_from_api_server(self):
         self._streams = None
 
     def _populate_from_api_server(self):
@@ -137,10 +140,10 @@ class CollectionReader(CollectionBase):
             if self._api_client is None:
                 self._api_client = arvados.api('v1')
                 self._keep_client = None  # Make a new one with the new api.
             if self._api_client is None:
                 self._api_client = arvados.api('v1')
                 self._keep_client = None  # Make a new one with the new api.
-            c = self._api_client.collections().get(
+            self._api_response = self._api_client.collections().get(
                 uuid=self._manifest_locator).execute(
                 num_retries=self.num_retries)
                 uuid=self._manifest_locator).execute(
                 num_retries=self.num_retries)
-            self._manifest_text = c['manifest_text']
+            self._manifest_text = self._api_response['manifest_text']
             return None
         except Exception as e:
             return e
             return None
         except Exception as e:
             return e
@@ -157,8 +160,6 @@ class CollectionReader(CollectionBase):
             return e
 
     def _populate(self):
             return e
 
     def _populate(self):
-        if self._streams is not None:
-            return
         error_via_api = None
         error_via_keep = None
         should_try_keep = ((self._manifest_text is None) and
         error_via_api = None
         error_via_keep = None
         should_try_keep = ((self._manifest_text is None) and
@@ -189,17 +190,32 @@ class CollectionReader(CollectionBase):
                          for sline in self._manifest_text.split("\n")
                          if sline]
 
                          for sline in self._manifest_text.split("\n")
                          if sline]
 
-    def normalize(self):
-        self._populate()
+    def _populate_first(orig_func):
+        # Decorator for methods that read actual Collection data.
+        @functools.wraps(orig_func)
+        def wrapper(self, *args, **kwargs):
+            if self._streams is None:
+                self._populate()
+            return orig_func(self, *args, **kwargs)
+        return wrapper
+
+    @_populate_first
+    def api_response(self):
+        """api_response() -> dict or None
+
+        Returns information about this Collection fetched from the API server.
+        If the Collection exists in Keep but not the API server, currently
+        returns None.  Future versions may provide a synthetic response.
+        """
+        return self._api_response
 
 
+    @_populate_first
+    def normalize(self):
         # Rearrange streams
         streams = {}
         for s in self.all_streams():
             for f in s.all_files():
         # Rearrange streams
         streams = {}
         for s in self.all_streams():
             for f in s.all_files():
-                filestream = s.name() + "/" + f.name()
-                r = filestream.rindex("/")
-                streamname = filestream[:r]
-                filename = filestream[r+1:]
+                streamname, filename = split(s.name() + "/" + f.name())
                 if streamname not in streams:
                     streams[streamname] = {}
                 if filename not in streams[streamname]:
                 if streamname not in streams:
                     streams[streamname] = {}
                 if filename not in streams[streamname]:
@@ -215,8 +231,33 @@ class CollectionReader(CollectionBase):
             [StreamReader(stream, keep=self._my_keep()).manifest_text()
              for stream in self._streams])
 
             [StreamReader(stream, keep=self._my_keep()).manifest_text()
              for stream in self._streams])
 
+    @_populate_first
+    def open(self, streampath, filename=None):
+        """open(streampath[, filename]) -> file-like object
+
+        Pass in the path of a file to read from the Collection, either as a
+        single string or as two separate stream name and file name arguments.
+        This method returns a file-like object to read that file.
+        """
+        if filename is None:
+            streampath, filename = split(streampath)
+        keep_client = self._my_keep()
+        for stream_s in self._streams:
+            stream = StreamReader(stream_s, keep_client,
+                                  num_retries=self.num_retries)
+            if stream.name() == streampath:
+                break
+        else:
+            raise ValueError("stream '{}' not found in Collection".
+                             format(streampath))
+        try:
+            return stream.files()[filename]
+        except KeyError:
+            raise ValueError("file '{}' not found in Collection stream '{}'".
+                             format(filename, streampath))
+
+    @_populate_first
     def all_streams(self):
     def all_streams(self):
-        self._populate()
         return [StreamReader(s, self._my_keep(), num_retries=self.num_retries)
                 for s in self._streams]
 
         return [StreamReader(s, self._my_keep(), num_retries=self.num_retries)
                 for s in self._streams]
 
@@ -225,6 +266,7 @@ class CollectionReader(CollectionBase):
             for f in s.all_files():
                 yield f
 
             for f in s.all_files():
                 yield f
 
+    @_populate_first
     def manifest_text(self, strip=False, normalize=False):
         if normalize:
             cr = CollectionReader(self.manifest_text())
     def manifest_text(self, strip=False, normalize=False):
         if normalize:
             cr = CollectionReader(self.manifest_text())
@@ -233,14 +275,36 @@ class CollectionReader(CollectionBase):
         elif strip:
             return self.stripped_manifest()
         else:
         elif strip:
             return self.stripped_manifest()
         else:
-            self._populate()
             return self._manifest_text
 
 
             return self._manifest_text
 
 
+class _WriterFile(ArvadosFileBase):
+    def __init__(self, coll_writer, name):
+        super(_WriterFile, self).__init__(name, 'wb')
+        self.dest = coll_writer
+
+    def close(self):
+        super(_WriterFile, self).close()
+        self.dest.finish_current_file()
+
+    @ArvadosFileBase._before_close
+    def write(self, data):
+        self.dest.write(data)
+
+    @ArvadosFileBase._before_close
+    def writelines(self, seq):
+        for data in seq:
+            self.write(data)
+
+    @ArvadosFileBase._before_close
+    def flush(self):
+        self.dest.flush_data()
+
+
 class CollectionWriter(CollectionBase):
     KEEP_BLOCK_SIZE = 2**26
 
 class CollectionWriter(CollectionBase):
     KEEP_BLOCK_SIZE = 2**26
 
-    def __init__(self, api_client=None, num_retries=0):
+    def __init__(self, api_client=None, num_retries=0, replication=None):
         """Instantiate a CollectionWriter.
 
         CollectionWriter lets you build a new Arvados Collection from scratch.
         """Instantiate a CollectionWriter.
 
         CollectionWriter lets you build a new Arvados Collection from scratch.
@@ -256,9 +320,13 @@ class CollectionWriter(CollectionBase):
           service requests.  Default 0.  You may change this value
           after instantiation, but note those changes may not
           propagate to related objects like the Keep client.
           service requests.  Default 0.  You may change this value
           after instantiation, but note those changes may not
           propagate to related objects like the Keep client.
+        * replication: The number of copies of each block to store.
+          If this argument is None or not supplied, replication is
+          the server-provided default if available, otherwise 2.
         """
         self._api_client = api_client
         self.num_retries = num_retries
         """
         self._api_client = api_client
         self.num_retries = num_retries
+        self.replication = (2 if replication is None else replication)
         self._keep_client = None
         self._data_buffer = []
         self._data_buffer_len = 0
         self._keep_client = None
         self._data_buffer = []
         self._data_buffer_len = 0
@@ -273,6 +341,7 @@ class CollectionWriter(CollectionBase):
         self._queued_file = None
         self._queued_dirents = deque()
         self._queued_trees = deque()
         self._queued_file = None
         self._queued_dirents = deque()
         self._queued_trees = deque()
+        self._last_open = None
 
     def __exit__(self, exc_type, exc_value, traceback):
         if exc_type is None:
 
     def __exit__(self, exc_type, exc_value, traceback):
         if exc_type is None:
@@ -379,11 +448,42 @@ class CollectionWriter(CollectionBase):
         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
             self.flush_data()
 
         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
             self.flush_data()
 
+    def open(self, streampath, filename=None):
+        """open(streampath[, filename]) -> file-like object
+
+        Pass in the path of a file to write to the Collection, either as a
+        single string or as two separate stream name and file name arguments.
+        This method returns a file-like object you can write to add it to the
+        Collection.
+
+        You may only have one file object from the Collection open at a time,
+        so be sure to close the object when you're done.  Using the object in
+        a with statement makes that easy::
+
+          with cwriter.open('./doc/page1.txt') as outfile:
+              outfile.write(page1_data)
+          with cwriter.open('./doc/page2.txt') as outfile:
+              outfile.write(page2_data)
+        """
+        if filename is None:
+            streampath, filename = split(streampath)
+        if self._last_open and not self._last_open.closed:
+            raise errors.AssertionError(
+                "can't open '{}' when '{}' is still open".format(
+                    filename, self._last_open.name))
+        if streampath != self.current_stream_name():
+            self.start_new_stream(streampath)
+        self.set_current_file_name(filename)
+        self._last_open = _WriterFile(self, filename)
+        return self._last_open
+
     def flush_data(self):
         data_buffer = ''.join(self._data_buffer)
         if data_buffer:
             self._current_stream_locators.append(
     def flush_data(self):
         data_buffer = ''.join(self._data_buffer)
         if data_buffer:
             self._current_stream_locators.append(
-                self._my_keep().put(data_buffer[0:self.KEEP_BLOCK_SIZE]))
+                self._my_keep().put(
+                    data_buffer[0:self.KEEP_BLOCK_SIZE],
+                    copies=self.replication))
             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
             self._data_buffer_len = len(self._data_buffer[0])
 
             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
             self._data_buffer_len = len(self._data_buffer[0])
 
@@ -396,6 +496,10 @@ class CollectionWriter(CollectionBase):
             raise errors.AssertionError(
                 "Manifest filenames cannot contain whitespace: %s" %
                 newfilename)
             raise errors.AssertionError(
                 "Manifest filenames cannot contain whitespace: %s" %
                 newfilename)
+        elif re.search(r'\x00', newfilename):
+            raise errors.AssertionError(
+                "Manifest filenames cannot contain NUL characters: %s" %
+                newfilename)
         self._current_file_name = newfilename
 
     def current_file_name(self):
         self._current_file_name = newfilename
 
     def current_file_name(self):
@@ -454,8 +558,16 @@ class CollectionWriter(CollectionBase):
         self._current_file_name = None
 
     def finish(self):
         self._current_file_name = None
 
     def finish(self):
-        # Store the manifest in Keep and return its locator.
-        return self._my_keep().put(self.manifest_text())
+        """Store the manifest in Keep and return its locator.
+
+        This is useful for storing manifest fragments (task outputs)
+        temporarily in Keep during a Crunch job.
+
+        In other cases you should make a collection instead, by
+        sending manifest_text() to the API server's "create
+        collection" endpoint.
+        """
+        return self._my_keep().put(self.manifest_text(), copies=self.replication)
 
     def portable_data_hash(self):
         stripped = self.stripped_manifest()
 
     def portable_data_hash(self):
         stripped = self.stripped_manifest()
@@ -473,7 +585,7 @@ class CollectionWriter(CollectionBase):
             manifest += ' ' + ' '.join("%d:%d:%s" % (sfile[0], sfile[1], sfile[2].replace(' ', '\\040')) for sfile in stream[2])
             manifest += "\n"
 
             manifest += ' ' + ' '.join("%d:%d:%s" % (sfile[0], sfile[1], sfile[2].replace(' ', '\\040')) for sfile in stream[2])
             manifest += "\n"
 
-        return manifest
+        return manifest.encode("utf-8")
 
     def data_locators(self):
         ret = []
 
     def data_locators(self):
         ret = []
@@ -489,10 +601,9 @@ class ResumableCollectionWriter(CollectionWriter):
                    '_data_buffer', '_dependencies', '_finished_streams',
                    '_queued_dirents', '_queued_trees']
 
                    '_data_buffer', '_dependencies', '_finished_streams',
                    '_queued_dirents', '_queued_trees']
 
-    def __init__(self, api_client=None, num_retries=0):
+    def __init__(self, api_client=None, **kwargs):
         self._dependencies = {}
         self._dependencies = {}
-        super(ResumableCollectionWriter, self).__init__(
-            api_client, num_retries=num_retries)
+        super(ResumableCollectionWriter, self).__init__(api_client, **kwargs)
 
     @classmethod
     def from_state(cls, state, *init_args, **init_kwargs):
 
     @classmethod
     def from_state(cls, state, *init_args, **init_kwargs):