Merge branch '2411-check-copyright'
[arvados.git] / sdk / python / arvados / commands / put.py
index d616f3087ed08af5c616483936dd8a92eb82c0fb..548f4b0948ae715393eb657a1693364e8b500639 100644 (file)
@@ -1,8 +1,11 @@
-#!/usr/bin/env python
-
-# TODO:
-# --md5sum - display md5 of each file as read from disk
-
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import division
+from future.utils import listitems, listvalues
+from builtins import str
+from builtins import object
 import argparse
 import arvados
 import arvados.collection
@@ -40,7 +43,9 @@ upload_opts.add_argument('--version', action='version',
                          help='Print version and exit.')
 upload_opts.add_argument('paths', metavar='path', type=str, nargs='*',
                          help="""
-Local file or directory. Default: read from standard input.
+Local file or directory. If path is a directory reference with a trailing
+slash, then just upload the directory's contents; otherwise upload the
+directory itself. Default: read from standard input.
 """)
 
 _group = upload_opts.add_mutually_exclusive_group()
@@ -188,12 +193,11 @@ Do not continue interrupted uploads from cached state.
 _group = run_opts.add_mutually_exclusive_group()
 _group.add_argument('--follow-links', action='store_true', default=True,
                     dest='follow_links', help="""
-Traverse directory symlinks (default).
-Multiple symlinks pointing to the same directory will only be followed once.
+Follow file and directory symlinks (default).
 """)
 _group.add_argument('--no-follow-links', action='store_false', dest='follow_links',
                     help="""
-Do not traverse directory symlinks.
+Do not follow file and directory symlinks.
 """)
 
 _group = run_opts.add_mutually_exclusive_group()
@@ -216,7 +220,7 @@ def parse_arguments(arguments):
     if len(args.paths) == 0:
         args.paths = ['-']
 
-    args.paths = map(lambda x: "-" if x == "/dev/stdin" else x, args.paths)
+    args.paths = ["-" if x == "/dev/stdin" else x for x in args.paths]
 
     if len(args.paths) != 1 or os.path.isdir(args.paths[0]):
         if args.filename:
@@ -247,6 +251,10 @@ def parse_arguments(arguments):
     return args
 
 
+class PathDoesNotExistError(Exception):
+    pass
+
+
 class CollectionUpdateError(Exception):
     pass
 
@@ -289,13 +297,13 @@ class ResumeCache(object):
     @classmethod
     def make_path(cls, args):
         md5 = hashlib.md5()
-        md5.update(arvados.config.get('ARVADOS_API_HOST', '!nohost'))
+        md5.update(arvados.config.get('ARVADOS_API_HOST', '!nohost').encode())
         realpaths = sorted(os.path.realpath(path) for path in args.paths)
-        md5.update('\0'.join(realpaths))
+        md5.update(b'\0'.join([p.encode() for p in realpaths]))
         if any(os.path.isdir(path) for path in realpaths):
-            md5.update("-1")
+            md5.update(b'-1')
         elif args.filename:
-            md5.update(args.filename)
+            md5.update(args.filename.encode())
         return os.path.join(
             arv_cmd.make_home_conf_dir(cls.CACHE_DIR, 0o700, 'raise'),
             md5.hexdigest())
@@ -407,7 +415,6 @@ class ArvPutUploadJob(object):
         self.dry_run = dry_run
         self._checkpoint_before_quit = True
         self.follow_links = follow_links
-        self._traversed_links = set()
 
         if not self.use_cache and self.resume:
             raise ArvPutArgumentConflict('resume cannot be True when use_cache is False')
@@ -419,21 +426,6 @@ class ArvPutUploadJob(object):
         # Load cached data if any and if needed
         self._setup_state(update_collection)
 
-    def _check_traversed_dir_links(self, root, dirs):
-        """
-        Remove from the 'dirs' list the already traversed directory symlinks,
-        register the new dir symlinks as traversed.
-        """
-        for d in [d for d in dirs if os.path.isdir(os.path.join(root, d)) and
-                  os.path.islink(os.path.join(root, d))]:
-            real_dirpath = os.path.realpath(os.path.join(root, d))
-            if real_dirpath in self._traversed_links:
-                dirs.remove(d)
-                self.logger.warning("Skipping '{}' symlink to directory '{}' because it was already uploaded".format(os.path.join(root, d), real_dirpath))
-            else:
-                self._traversed_links.add(real_dirpath)
-        return dirs
-
     def start(self, save_collection):
         """
         Start supporting thread & file uploading
@@ -447,18 +439,23 @@ class ArvPutUploadJob(object):
                     if self.dry_run:
                         raise ArvPutUploadIsPending()
                     self._write_stdin(self.filename or 'stdin')
+                elif not os.path.exists(path):
+                     raise PathDoesNotExistError("file or directory '{}' does not exist.".format(path))
                 elif os.path.isdir(path):
                     # Use absolute paths on cache index so CWD doesn't interfere
                     # with the caching logic.
-                    prefixdir = path = os.path.abspath(path)
-                    if prefixdir != '/':
-                        prefixdir += '/'
-                    # If following symlinks, avoid recursive traversals
-                    if self.follow_links and os.path.islink(path):
-                        self._traversed_links.add(os.path.realpath(path))
+                    orig_path = path
+                    path = os.path.abspath(path)
+                    if orig_path[-1:] == os.sep:
+                        # When passing a directory reference with a trailing slash,
+                        # its contents should be uploaded directly to the collection's root.
+                        prefixdir = path
+                    else:
+                        # When passing a directory reference with no trailing slash,
+                        # upload the directory to the collection's root.
+                        prefixdir = os.path.dirname(path)
+                    prefixdir += os.sep
                     for root, dirs, files in os.walk(path, followlinks=self.follow_links):
-                        if self.follow_links:
-                            dirs = self._check_traversed_dir_links(root, dirs)
                         # Make os.walk()'s dir traversing order deterministic
                         dirs.sort()
                         files.sort()
@@ -487,11 +484,16 @@ class ArvPutUploadJob(object):
         except (SystemExit, Exception) as e:
             self._checkpoint_before_quit = False
             # Log stack trace only when Ctrl-C isn't pressed (SIGINT)
-            # Note: We're expecting SystemExit instead of KeyboardInterrupt because
-            #   we have a custom signal handler in place that raises SystemExit with
-            #   the catched signal's code.
-            if not isinstance(e, SystemExit) or e.code != -2:
-                self.logger.warning("Abnormal termination:\n{}".format(traceback.format_exc(e)))
+            # Note: We're expecting SystemExit instead of
+            # KeyboardInterrupt because we have a custom signal
+            # handler in place that raises SystemExit with the catched
+            # signal's code.
+            if isinstance(e, PathDoesNotExistError):
+                # We aren't interested in the traceback for this case
+                pass
+            elif not isinstance(e, SystemExit) or e.code != -2:
+                self.logger.warning("Abnormal termination:\n{}".format(
+                    traceback.format_exc()))
             raise
         finally:
             if not self.dry_run:
@@ -543,7 +545,7 @@ class ArvPutUploadJob(object):
         Recursively get the total size of the collection
         """
         size = 0
-        for item in collection.values():
+        for item in listvalues(collection):
             if isinstance(item, arvados.collection.Collection) or isinstance(item, arvados.collection.Subcollection):
                 size += self._collection_size(item)
             else:
@@ -591,12 +593,17 @@ class ArvPutUploadJob(object):
             self.reporter(self.bytes_written, self.bytes_expected)
 
     def _write_stdin(self, filename):
-        output = self._local_collection.open(filename, 'w')
+        output = self._local_collection.open(filename, 'wb')
         self._write(sys.stdin, output)
         output.close()
 
     def _check_file(self, source, filename):
-        """Check if this file needs to be uploaded"""
+        """
+        Check if this file needs to be uploaded
+        """
+        # Ignore symlinks when requested
+        if (not self.follow_links) and os.path.islink(source):
+            return
         resume_offset = 0
         should_upload = False
         new_file_in_cache = False
@@ -657,17 +664,17 @@ class ArvPutUploadJob(object):
 
     def _upload_files(self):
         for source, resume_offset, filename in self._files_to_upload:
-            with open(source, 'r') as source_fd:
+            with open(source, 'rb') as source_fd:
                 with self._state_lock:
                     self._state['files'][source]['mtime'] = os.path.getmtime(source)
                     self._state['files'][source]['size'] = os.path.getsize(source)
                 if resume_offset > 0:
                     # Start upload where we left off
-                    output = self._local_collection.open(filename, 'a')
+                    output = self._local_collection.open(filename, 'ab')
                     source_fd.seek(resume_offset)
                 else:
                     # Start from scratch
-                    output = self._local_collection.open(filename, 'w')
+                    output = self._local_collection.open(filename, 'wb')
                 self._write(source_fd, output)
                 output.close(flush=False)
 
@@ -701,11 +708,11 @@ class ArvPutUploadJob(object):
         if self.use_cache:
             # Set up cache file name from input paths.
             md5 = hashlib.md5()
-            md5.update(arvados.config.get('ARVADOS_API_HOST', '!nohost'))
+            md5.update(arvados.config.get('ARVADOS_API_HOST', '!nohost').encode())
             realpaths = sorted(os.path.realpath(path) for path in self.paths)
-            md5.update('\0'.join(realpaths))
+            md5.update(b'\0'.join([p.encode() for p in realpaths]))
             if self.filename:
-                md5.update(self.filename)
+                md5.update(self.filename.encode())
             cache_filename = md5.hexdigest()
             cache_filepath = os.path.join(
                 arv_cmd.make_home_conf_dir(self.CACHE_DIR, 0o700, 'raise'),
@@ -741,7 +748,7 @@ class ArvPutUploadJob(object):
     def collection_file_paths(self, col, path_prefix='.'):
         """Return a list of file paths by recursively go through the entire collection `col`"""
         file_paths = []
-        for name, item in col.items():
+        for name, item in listitems(col):
             if isinstance(item, arvados.arvfile.ArvadosFile):
                 file_paths.append(os.path.join(path_prefix, name))
             elif isinstance(item, arvados.collection.Subcollection):
@@ -766,6 +773,7 @@ class ArvPutUploadJob(object):
             state = json.dumps(self._state)
         try:
             new_cache = tempfile.NamedTemporaryFile(
+                mode='w+',
                 dir=os.path.dirname(self._cache_filename), delete=False)
             self._lock_file(new_cache)
             new_cache.write(state)
@@ -790,8 +798,8 @@ class ArvPutUploadJob(object):
 
     def portable_data_hash(self):
         pdh = self._my_collection().portable_data_hash()
-        m = self._my_collection().stripped_manifest()
-        local_pdh = hashlib.md5(m).hexdigest() + '+' + str(len(m))
+        m = self._my_collection().stripped_manifest().encode()
+        local_pdh = '{}+{}'.format(hashlib.md5(m).hexdigest(), len(m))
         if pdh != local_pdh:
             logger.warning("\n".join([
                 "arv-put: API server provided PDH differs from local manifest.",
@@ -817,7 +825,7 @@ class ArvPutUploadJob(object):
                     locators.append(loc)
                 return locators
         elif isinstance(item, arvados.collection.Collection):
-            l = [self._datablocks_on_item(x) for x in item.values()]
+            l = [self._datablocks_on_item(x) for x in listvalues(item)]
             # Fast list flattener method taken from:
             # http://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
             return [loc for sublist in l for loc in sublist]
@@ -835,24 +843,17 @@ class ArvPutUploadJob(object):
 def expected_bytes_for(pathlist, follow_links=True):
     # Walk the given directory trees and stat files, adding up file sizes,
     # so we can display progress as percent
-    linked_dirs = set()
     bytesum = 0
     for path in pathlist:
         if os.path.isdir(path):
             for root, dirs, files in os.walk(path, followlinks=follow_links):
-                if follow_links:
-                    # Skip those linked dirs that were visited more than once.
-                    for d in [x for x in dirs if os.path.islink(os.path.join(root, x))]:
-                        d_realpath = os.path.realpath(os.path.join(root, d))
-                        if d_realpath in linked_dirs:
-                            # Linked dir already visited, skip it.
-                            dirs.remove(d)
-                        else:
-                            # Will only visit this dir once
-                            linked_dirs.add(d_realpath)
                 # Sum file sizes
                 for f in files:
-                    bytesum += os.path.getsize(os.path.join(root, f))
+                    filepath = os.path.join(root, f)
+                    # Ignore symlinked files when requested
+                    if (not follow_links) and os.path.islink(filepath):
+                        continue
+                    bytesum += os.path.getsize(filepath)
         elif not os.path.isfile(path):
             return None
         else:
@@ -1000,6 +1001,10 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
     except ArvPutUploadNotPending:
         # No files pending for upload
         sys.exit(0)
+    except PathDoesNotExistError as error:
+        logger.error("\n".join([
+            "arv-put: %s" % str(error)]))
+        sys.exit(1)
 
     if args.progress:  # Print newline to split stderr from stdout for humans.
         logger.info("\n")
@@ -1035,7 +1040,7 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
         if not output.endswith('\n'):
             stdout.write('\n')
 
-    for sigcode, orig_handler in orig_signal_handlers.items():
+    for sigcode, orig_handler in listitems(orig_signal_handlers):
         signal.signal(sigcode, orig_handler)
 
     if status != 0: