3198: VWD syncs collection before loading from API server.
[arvados.git] / crunch_scripts / crunchutil / vwd.py
index 3d54c9c2b32c1fc05c2ce536a5ece807df59d49b..0ae1c4620995014f61d17379bca756d2415f6e4c 100644 (file)
@@ -1,7 +1,8 @@
 import arvados
 import os
-import robust_put
 import stat
+import arvados.commands.run
+import logging
 
 # Implements "Virtual Working Directory"
 # Provides a way of emulating a shared writable directory in Keep based
@@ -32,23 +33,71 @@ def checkout(source_collection, target_dir, keepmount=None):
         for f in files:
             os.symlink(os.path.join(root, f), os.path.join(target_dir, rel, f))
 
-# Delete all symlinks and check in any remaining normal files.
-# If merge == True, merge the manifest with source_collection and return a
-# CollectionReader for the combined collection.
-def checkin(source_collection, target_dir, merge=True):
-    # delete symlinks, commit directory, merge manifests and return combined
-    # collection.
+def checkin(target_dir):
+    """Write files in `target_dir` to Keep.
+
+    Regular files or symlinks to files outside the keep mount are written to
+    Keep as normal files (Keep does not support symlinks).
+
+    Symlinks to files in the keep mount will result in files in the new
+    collection which reference existing Keep blocks, no data copying necessary.
+
+    Returns a new Collection object, with data flushed but the collection record
+    not saved to the API.
+
+    """
+
+    outputcollection = arvados.collection.Collection(num_retries=5)
+
+    if target_dir[-1:] != '/':
+        target_dir += '/'
+
+    collections = {}
+
+    logger = logging.getLogger("arvados")
+
+    last_error = None
     for root, dirs, files in os.walk(target_dir):
         for f in files:
-            s = os.lstat(os.path.join(root, f))
-            if stat.S_ISLNK(s.st_mode):
-                os.unlink(os.path.join(root, f))
-
-    uuid = robust_put.upload(target_dir)
-    if merge:
-        cr1 = arvados.CollectionReader(source_collection)
-        cr2 = arvados.CollectionReader(uuid)
-        combined = arvados.CollectionReader(cr1.manifest_text() + cr2.manifest_text())
-        return combined
-    else:
-        return arvados.CollectionReader(uuid)
+            try:
+                s = os.lstat(os.path.join(root, f))
+
+                writeIt = False
+
+                if stat.S_ISREG(s.st_mode):
+                    writeIt = True
+                elif stat.S_ISLNK(s.st_mode):
+                    # 1. check if it is a link into a collection
+                    real = os.path.split(os.path.realpath(os.path.join(root, f)))
+                    (pdh, branch) = arvados.commands.run.is_in_collection(real[0], real[1])
+                    if pdh is not None:
+                        # 2. load collection
+                        if pdh not in collections:
+                            # 2.1 make sure it is flushed (see #5787 note 11)
+                            fd = os.open(real[0], os.O_RDONLY)
+                            os.fsync(fd)
+                            os.close(fd)
+
+                            # 2.2 get collection from API server
+                            collections[pdh] = arvados.collection.CollectionReader(pdh,
+                                                                                   api_client=outputcollection._my_api(),
+                                                                                   keep_client=outputcollection._my_keep(),
+                                                                                   num_retries=5)
+                        # 3. copy arvfile to new collection
+                        outputcollection.copy(branch, os.path.join(root[len(target_dir):], f), source_collection=collections[pdh])
+                    else:
+                        writeIt = True
+
+                if writeIt:
+                    reldir = root[len(target_dir):]
+                    with outputcollection.open(os.path.join(reldir, f), "wb") as writer:
+                        with open(os.path.join(root, f), "rb") as reader:
+                            dat = reader.read(64*1024)
+                            while dat:
+                                writer.write(dat)
+                                dat = reader.read(64*1024)
+            except (IOError, OSError) as e:
+                logger.error(e)
+                last_error = e
+
+    return (outputcollection, last_error)