4823: Add tests for Collection.clone and Collection.merge
[arvados.git] / sdk / python / arvados / collection.py
index 24362cd088081b5f22038b0fb5a4a7a0084664a0..1f2974b3488ba77772c963e93c5702502a2e7988 100644 (file)
@@ -3,17 +3,20 @@ import logging
 import os
 import re
 import errno
+import time
 
 from collections import deque
 from stat import *
 
-from .arvfile import ArvadosFileBase, split, ArvadosFile, ArvadosFileWriter, ArvadosFileReader, BlockManager
+from .arvfile import ArvadosFileBase, split, ArvadosFile, ArvadosFileWriter, ArvadosFileReader, BlockManager, _synchronized, _must_be_writable, SYNC_READONLY, SYNC_EXPLICIT, SYNC_LIVE, NoopLock
 from keep import *
 from .stream import StreamReader, normalize_stream, locator_block_size
 from .ranges import Range, LocatorAndRange
+from .safeapi import SafeApi
 import config
 import errors
 import util
+import events
 
 _logger = logging.getLogger('arvados.collection')
 
@@ -152,7 +155,6 @@ class CollectionReader(CollectionBase):
                          for sline in self._manifest_text.split("\n")
                          if sline]
 
-    @staticmethod
     def _populate_first(orig_func):
         # Decorator for methods that read actual Collection data.
         @functools.wraps(orig_func)
@@ -640,15 +642,13 @@ class ResumableCollectionWriter(CollectionWriter):
                 "resumable writer can't accept unsourced data")
         return super(ResumableCollectionWriter, self).write(data)
 
+ADD = "add"
+DEL = "del"
 
 class SynchronizedCollectionBase(CollectionBase):
-    SYNC_READONLY = 1
-    SYNC_EXPLICIT = 2
-    SYNC_LIVE = 3
-
     def __init__(self, parent=None):
         self.parent = parent
-        self._items = None
+        self._items = {}
 
     def _my_api(self):
         raise NotImplementedError()
@@ -665,21 +665,13 @@ class SynchronizedCollectionBase(CollectionBase):
     def _populate(self):
         raise NotImplementedError()
 
-    def _sync_mode(self):
+    def sync_mode(self):
         raise NotImplementedError()
 
-    @staticmethod
-    def _populate_first(orig_func):
-        # Decorator for methods that read actual Collection data.
-        @functools.wraps(orig_func)
-        def wrapper(self, *args, **kwargs):
-            if self._items is None:
-                self._populate()
-            return orig_func(self, *args, **kwargs)
-        return wrapper
+    def notify(self, collection, event, name, item):
+        raise NotImplementedError()
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def find(self, path, create=False, create_collection=False):
         """Recursively search the specified file path.  May return either a Collection
         or ArvadosFile.
@@ -697,14 +689,14 @@ class SynchronizedCollectionBase(CollectionBase):
           component.
 
         """
-        if create and self._sync_mode() == SynchronizedCollectionBase.SYNC_READONLY:
+        if create and self.sync_mode() == SYNC_READONLY:
             raise IOError((errno.EROFS, "Collection is read only"))
 
         p = path.split("/")
         if p[0] == '.':
             del p[0]
 
-        if len(p) > 0:
+        if p and p[0]:
             item = self._items.get(p[0])
             if len(p) == 1:
                 # item must be a file
@@ -715,12 +707,14 @@ class SynchronizedCollectionBase(CollectionBase):
                     else:
                         item = ArvadosFile(self)
                     self._items[p[0]] = item
+                    self.notify(self, ADD, p[0], item)
                 return item
             else:
                 if item is None and create:
                     # create new collection
                     item = Subcollection(self)
                     self._items[p[0]] = item
+                    self.notify(self, ADD, p[0], item)
                 del p[0]
                 return item.find("/".join(p), create=create)
         else:
@@ -749,10 +743,11 @@ class SynchronizedCollectionBase(CollectionBase):
             raise ArgumentError("Bad mode '%s'" % mode)
         create = (mode != "r")
 
-        if create and self._sync_mode() == SynchronizedCollectionBase.SYNC_READONLY:
+        if create and self.sync_mode() == SYNC_READONLY:
             raise IOError((errno.EROFS, "Collection is read only"))
 
         f = self.find(path, create=create)
+
         if f is None:
             raise IOError((errno.ENOENT, "File not found"))
         if not isinstance(f, ArvadosFile):
@@ -762,12 +757,11 @@ class SynchronizedCollectionBase(CollectionBase):
             f.truncate(0)
 
         if mode == "r":
-            return ArvadosFileReader(f, path, mode)
+            return ArvadosFileReader(f, path, mode, num_retries=self.num_retries)
         else:
-            return ArvadosFileWriter(f, path, mode)
+            return ArvadosFileWriter(f, path, mode, num_retries=self.num_retries)
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def modified(self):
         """Test if the collection (or any subcollection or file) has been modified
         since it was created."""
@@ -776,67 +770,58 @@ class SynchronizedCollectionBase(CollectionBase):
                 return True
         return False
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def set_unmodified(self):
         """Recursively clear modified flag"""
         for k,v in self._items.items():
             v.set_unmodified()
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def __iter__(self):
         """Iterate over names of files and collections contained in this collection."""
         return self._items.keys()
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def iterkeys(self):
         """Iterate over names of files and collections directly contained in this collection."""
         return self._items.keys()
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def __getitem__(self, k):
         """Get a file or collection that is directly contained by this collection.  If
         you want to search a path, use `find()` instead.
         """
         return self._items[k]
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def __contains__(self, k):
         """If there is a file or collection a directly contained by this collection
         with name "k"."""
         return k in self._items
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def __len__(self):
         """Get the number of items directly contained in this collection"""
         return len(self._items)
 
     @_must_be_writable
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def __delitem__(self, p):
         """Delete an item by name which is directly contained by this collection."""
         del self._items[p]
+        self.notify(self, DEL, p, None)
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def keys(self):
         """Get a list of names of files and collections directly contained in this collection."""
         return self._items.keys()
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def values(self):
         """Get a list of files and collection objects directly contained in this collection."""
         return self._items.values()
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def items(self):
         """Get a list of (name, object) tuples directly contained in this collection."""
         return self._items.items()
@@ -846,8 +831,7 @@ class SynchronizedCollectionBase(CollectionBase):
         return self.find(path) != None
 
     @_must_be_writable
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def remove(self, path, rm_r=False):
         """Remove the file or subcollection (directory) at `path`.
         :rm_r:
@@ -863,9 +847,10 @@ class SynchronizedCollectionBase(CollectionBase):
             if item is None:
                 raise IOError((errno.ENOENT, "File not found"))
             if len(p) == 1:
-                if isinstance(SynchronizedCollection, self._items[p[0]]) and len(self._items[p[0]]) > 0 and not rm_r:
+                if isinstance(self._items[p[0]], SynchronizedCollectionBase) and len(self._items[p[0]]) > 0 and not rm_r:
                     raise IOError((errno.ENOTEMPTY, "Subcollection not empty"))
                 del self._items[p[0]]
+                self.notify(self, DEL, p[0], None)
             else:
                 del p[0]
                 item.remove("/".join(p))
@@ -873,14 +858,64 @@ class SynchronizedCollectionBase(CollectionBase):
             raise IOError((errno.ENOENT, "File not found"))
 
     def _cloneinto(self, target):
-        for k,v in self._items:
-            target._items[k] = v.clone(new_parent=target)
+        for k,v in self._items.items():
+            target._items[k] = v.clone(target)
 
     def clone(self):
         raise NotImplementedError()
 
-    @arvfile._synchronized
-    @_populate_first
+    @_must_be_writable
+    @_synchronized
+    def copy(self, source_path, target_path, source_collection=None, overwrite=False):
+        """Copy a file or subcollection to a new path in this collection.
+
+        :source_path:
+          Source file or subcollection
+
+        :target_path:
+          Destination file or path.  If the target path already exists and is a
+          subcollection, the item will be placed inside the subcollection.  If
+          the target path already exists and is a file, this will raise an error
+          unless you specify `overwrite=True`.
+
+        :source_collection:
+          Collection to copy `source_path` from (default `self`)
+
+        :overwrite:
+          Whether to overwrite target file if it already exists.
+        """
+        if source_collection is None:
+            source_collection = self
+
+        # Find the object to copy
+        sp = source_path.split("/")
+        source_obj = source_collection.find(source_path)
+        if source_obj is None:
+            raise IOError((errno.ENOENT, "File not found"))
+
+        # Find parent collection the target path
+        tp = target_path.split("/")
+        target_dir = self.find("/".join(tp[0:-1]), create=True, create_collection=True)
+
+        # Determine the name to use.
+        target_name = tp[-1] if tp[-1] else sp[-1]
+
+        if target_name in target_dir:
+            if isinstance(target_dir[target_name], SynchronizedCollectionBase):
+                target_dir = target_dir[target_name]
+                target_name = sp[-1]
+            elif not overwrite:
+                raise IOError((errno.EEXIST, "File already exists"))
+
+        # Actually make the copy.
+        dup = source_obj.clone(target_dir)
+        with target_dir.lock:
+            target_dir._items[target_name] = dup
+
+        self.notify(target_dir, ADD, target_name, dup)
+
+
+    @_synchronized
     def manifest_text(self, strip=False, normalize=False):
         """Get the manifest text for this collection, sub collections and files.
 
@@ -903,6 +938,23 @@ class SynchronizedCollectionBase(CollectionBase):
             else:
                 return self._manifest_text
 
+    @_must_be_writable
+    @_synchronized
+    def merge(self, other):
+        for k in other.keys():
+            if k in self:
+                if isinstance(self[k], Subcollection) and isinstance(other[k], Subcollection):
+                    self[k].merge(other[k])
+                else:
+                    if self[k] != other[k]:
+                        name = "%s~conflict-%s~" % (k, time.strftime("%Y-%m-%d_%H:%M:%S",
+                                                                     time.gmtime()))
+                        self._items[name] = other[k].clone(self)
+                        self.notify(self, name, ADD, self[name])
+            else:
+                self._items[k] = other[k].clone(self)
+                self.notify(self, k, ADD, self[k])
+
     def portable_data_hash(self):
         """Get the portable data hash for this collection's manifest."""
         stripped = self.manifest_text(strip=True)
@@ -919,9 +971,9 @@ class Collection(SynchronizedCollectionBase):
                  config=None,
                  api_client=None,
                  keep_client=None,
-                 num_retries=0,
+                 num_retries=None,
                  block_manager=None,
-                 sync=Collection.SYNC_READONLY):
+                 sync=SYNC_READONLY):
         """:manifest_locator_or_text:
           One of Arvados collection UUID, block locator of
           a manifest, raw manifest text, or None (to create an empty collection).
@@ -945,15 +997,13 @@ class Collection(SynchronizedCollectionBase):
             Collection is read only.  No synchronization.  This mode will
             also forego locking, which gives better performance.
           :SYNC_EXPLICIT:
-            Synchronize on explicit request via `merge()` or `save()`
+            Synchronize on explicit request via `update()` or `save()`
           :SYNC_LIVE:
             Synchronize with server in response to background websocket events,
             on block write, or on file close.
 
         """
-
-        self.parent = parent
-        self._items = None
+        super(Collection, self).__init__(parent)
         self._api_client = api_client
         self._keep_client = keep_client
         self._block_manager = block_manager
@@ -964,6 +1014,8 @@ class Collection(SynchronizedCollectionBase):
         self._api_response = None
         self._sync = sync
         self.lock = threading.RLock()
+        self.callbacks = []
+        self.events = None
 
         if manifest_locator_or_text:
             if re.match(util.keep_locator_pattern, manifest_locator_or_text):
@@ -976,20 +1028,42 @@ class Collection(SynchronizedCollectionBase):
                 raise errors.ArgumentError(
                     "Argument to CollectionReader must be a manifest or a collection UUID")
 
+            self._populate()
+
+            if self._sync == SYNC_LIVE:
+                if not self._manifest_locator or not re.match(util.collection_uuid_pattern, self._manifest_locator):
+                    raise errors.ArgumentError("Cannot SYNC_LIVE unless a collection uuid is specified")
+                self.events = events.subscribe(arvados.api(), [["object_uuid", "=", self._manifest_locator]], self.on_message)
+
+    @staticmethod
+    def create(name, owner_uuid=None, sync=SYNC_EXPLICIT):
+        c = Collection(sync=sync)
+        c.save_as(name, owner_uuid=owner_uuid, ensure_unique_name=True)
+        return c
+
     def _root_lock(self):
         return self.lock
 
     def sync_mode(self):
         return self._sync
 
-    @arvfile._synchronized
+    def on_message(self):
+        self.update()
+
+    @_synchronized
+    def update(self):
+        n = self._my_api().collections().get(uuid=self._manifest_locator, select=["manifest_text"]).execute()
+        other = import_collection(n["manifest_text"])
+        self.merge(other)
+
+    @_synchronized
     def _my_api(self):
         if self._api_client is None:
-            self._api_client = arvados.api.SafeApi(self._config)
+            self._api_client = arvados.SafeApi(self._config)
             self._keep_client = self._api_client.keep
         return self._api_client
 
-    @arvfile._synchronized
+    @_synchronized
     def _my_keep(self):
         if self._keep_client is None:
             if self._api_client is None:
@@ -998,7 +1072,7 @@ class Collection(SynchronizedCollectionBase):
                 self._keep_client = KeepClient(api=self._api_client)
         return self._keep_client
 
-    @arvfile._synchronized
+    @_synchronized
     def _my_block_manager(self):
         if self._block_manager is None:
             self._block_manager = BlockManager(self._my_keep())
@@ -1033,14 +1107,13 @@ class Collection(SynchronizedCollectionBase):
             return e
 
     def _populate(self):
-        self._items = {}
         if self._manifest_locator is None and self._manifest_text is None:
             return
         error_via_api = None
         error_via_keep = None
         should_try_keep = ((self._manifest_text is None) and
                            util.keep_locator_pattern.match(
-                self._manifest_locator))
+                               self._manifest_locator))
         if ((self._manifest_text is None) and
             util.signed_locator_pattern.match(self._manifest_locator)):
             error_via_keep = self._populate_from_keep()
@@ -1075,22 +1148,23 @@ class Collection(SynchronizedCollectionBase):
 
     def __exit__(self, exc_type, exc_value, traceback):
         """Support scoped auto-commit in a with: block"""
-        self.save(allow_no_locator=True)
+        if self._sync != SYNC_READONLY:
+            self.save(allow_no_locator=True)
         if self._block_manager is not None:
             self._block_manager.stop_threads()
 
-    @arvfile._synchronized
-    @_populate_first
-    def clone(self, new_parent=None, new_sync=Collection.SYNC_READONLY, new_config=self.config):
+    @_synchronized
+    def clone(self, new_parent=None, new_sync=SYNC_READONLY, new_config=None):
+        if new_config is None:
+            new_config = self._config
         c = Collection(parent=new_parent, config=new_config, sync=new_sync)
-        if new_sync == Collection.SYNC_READONLY:
+        if new_sync == SYNC_READONLY:
             c.lock = NoopLock()
         c._items = {}
         self._cloneinto(c)
         return c
 
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def api_response(self):
         """
         api_response() -> dict or None
@@ -1102,8 +1176,7 @@ class Collection(SynchronizedCollectionBase):
         return self._api_response
 
     @_must_be_writable
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def save(self, allow_no_locator=False):
         """Commit pending buffer blocks to Keep, write the manifest to Keep, and
         update the collection record to Keep.
@@ -1127,8 +1200,7 @@ class Collection(SynchronizedCollectionBase):
             self.set_unmodified()
 
     @_must_be_writable
-    @arvfile._synchronized
-    @_populate_first
+    @_synchronized
     def save_as(self, name, owner_uuid=None, ensure_unique_name=False):
         """Save a new collection record.
 
@@ -1152,9 +1224,29 @@ class Collection(SynchronizedCollectionBase):
         if owner_uuid:
             body["owner_uuid"] = owner_uuid
         self._api_response = self._my_api().collections().create(ensure_unique_name=ensure_unique_name, body=body).execute(num_retries=self.num_retries)
+
+        if self.events:
+            self.events.unsubscribe(filters=[["object_uuid", "=", self._manifest_locator]])
+
         self._manifest_locator = self._api_response["uuid"]
+
+        if self.events:
+            self.events.subscribe(filters=[["object_uuid", "=", self._manifest_locator]])
+
         self.set_unmodified()
 
+    @_synchronized
+    def subscribe(self, callback):
+        self.callbacks.append(callback)
+
+    @_synchronized
+    def unsubscribe(self, callback):
+        self.callbacks.remove(callback)
+
+    @_synchronized
+    def notify(self, collection, event, name, item):
+        for c in self.callbacks:
+            c(collection, event, name, item)
 
 class Subcollection(SynchronizedCollectionBase):
     """This is a subdirectory within a collection that doesn't have its own API
@@ -1164,7 +1256,7 @@ class Subcollection(SynchronizedCollectionBase):
         super(Subcollection, self).__init__(parent)
         self.lock = parent._root_lock()
 
-    def _root_lock():
+    def _root_lock(self):
         return self.parent._root_lock()
 
     def sync_mode(self):
@@ -1182,15 +1274,22 @@ class Subcollection(SynchronizedCollectionBase):
     def _populate(self):
         self.parent._populate()
 
-    @arvfile._synchronized
-    @_populate_first
+    def notify(self, collection, event, name, item):
+        self.parent.notify(collection, event, name, item)
+
+    @_synchronized
     def clone(self, new_parent):
-        c = Subcollection(parent=new_parent)
+        c = Subcollection(new_parent)
         c._items = {}
         self._cloneinto(c)
         return c
 
-def import_manifest(manifest_text, into_collection=None, api_client=None, keep=None, num_retries=None):
+def import_manifest(manifest_text,
+                    into_collection=None,
+                    api_client=None,
+                    keep=None,
+                    num_retries=None,
+                    sync=SYNC_READONLY):
     """Import a manifest into a `Collection`.
 
     :manifest_text:
@@ -1206,15 +1305,21 @@ def import_manifest(manifest_text, into_collection=None, api_client=None, keep=N
     :keep:
       The keep client object that will be used when creating a new `Collection` object.
 
-    num_retries
+    :num_retries:
       the default number of api client and keep retries on error.
+
+    :sync:
+      Collection sync mode (only if into_collection is None)
     """
     if into_collection is not None:
         if len(into_collection) > 0:
             raise ArgumentError("Can only import manifest into an empty collection")
         c = into_collection
     else:
-        c = Collection(api_client=api_client, keep_client=keep, num_retries=num_retries)
+        c = Collection(api_client=api_client, keep_client=keep, num_retries=num_retries, sync=sync)
+
+    save_sync = c.sync_mode()
+    c._sync = None
 
     STREAM_NAME = 0
     BLOCKS = 1
@@ -1262,6 +1367,7 @@ def import_manifest(manifest_text, into_collection=None, api_client=None, keep=N
             state = STREAM_NAME
 
     c.set_unmodified()
+    c._sync = save_sync
     return c
 
 def export_manifest(item, stream_name=".", portable_locators=False):
@@ -1284,7 +1390,7 @@ def export_manifest(item, stream_name=".", portable_locators=False):
         for k in [s for s in sorted_keys if isinstance(item[s], ArvadosFile)]:
             v = item[k]
             st = []
-            for s in v.segments:
+            for s in v.segments():
                 loc = s.locator
                 if loc.startswith("bufferblock"):
                     loc = v.parent._my_block_manager()._bufferblocks[loc].locator()