Added --storage-classes argument to arv-put.
[arvados.git] / sdk / python / arvados / commands / put.py
index 9fa68ecfe4a3cfc6a587481b6e08fdc96b4882fa..cba00c3c8cf153039de990d27867558d0dbc699a 100644 (file)
@@ -1,3 +1,7 @@
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
 from __future__ import division
 from future.utils import listitems, listvalues
 from builtins import str
@@ -30,7 +34,6 @@ from arvados._version import __version__
 
 import arvados.commands._util as arv_cmd
 
-CAUGHT_SIGNALS = [signal.SIGINT, signal.SIGQUIT, signal.SIGTERM]
 api_client = None
 
 upload_opts = argparse.ArgumentParser(add_help=False)
@@ -137,6 +140,10 @@ physical storage devices (e.g., disks) should have a copy of each data
 block. Default is to use the server-provided default (if any) or 2.
 """)
 
+upload_opts.add_argument('--storage-classes', help="""
+Specify comma separated list of storage classes to be used when saving data to Keep.
+""")
+
 upload_opts.add_argument('--threads', type=int, metavar='N', default=None,
                          help="""
 Set the number of upload threads to be used. Take into account that
@@ -164,6 +171,8 @@ using a path-like pattern like 'subdir/*.txt', all text files inside 'subdir'
 directory, relative to the provided input dirs will be excluded.
 When using a filename pattern like '*.txt', any text file will be excluded
 no matter where is placed.
+For the special case of needing to exclude only files or dirs directly below
+the given input directory, you can use a pattern like './exclude_this.gif'.
 You can specify multiple patterns by using this argument more than once.
 """)
 
@@ -187,6 +196,12 @@ Display machine-readable progress on stderr (bytes and, if known,
 total data size).
 """)
 
+run_opts.add_argument('--silent', action='store_true',
+                      help="""
+Do not print any debug messages to console. (Any error messages will
+still be displayed.)
+""")
+
 _group = run_opts.add_mutually_exclusive_group()
 _group.add_argument('--resume', action='store_true', default=True,
                     help="""
@@ -237,7 +252,7 @@ def parse_arguments(arguments):
     """)
 
     # Turn on --progress by default if stderr is a tty.
-    if (not (args.batch_progress or args.no_progress)
+    if (not (args.batch_progress or args.no_progress or args.silent)
         and os.isatty(sys.stderr.fileno())):
         args.progress = True
 
@@ -297,6 +312,24 @@ class FileUploadList(list):
         super(FileUploadList, self).append(other)
 
 
+# Appends the X-Request-Id to the log message when log level is ERROR or DEBUG
+class ArvPutLogFormatter(logging.Formatter):
+    std_fmtr = logging.Formatter(arvados.log_format, arvados.log_date_format)
+    err_fmtr = None
+    request_id_informed = False
+
+    def __init__(self, request_id):
+        self.err_fmtr = logging.Formatter(
+            arvados.log_format+' (X-Request-Id: {})'.format(request_id),
+            arvados.log_date_format)
+
+    def format(self, record):
+        if (not self.request_id_informed) and (record.levelno in (logging.DEBUG, logging.ERROR)):
+            self.request_id_informed = True
+            return self.err_fmtr.format(record)
+        return self.std_fmtr.format(record)
+
+
 class ResumeCache(object):
     CACHE_DIR = '.cache/arvados/arv-put'
 
@@ -387,10 +420,10 @@ class ArvPutUploadJob(object):
     }
 
     def __init__(self, paths, resume=True, use_cache=True, reporter=None,
-                 name=None, owner_uuid=None,
+                 name=None, owner_uuid=None, api_client=None,
                  ensure_unique_name=False, num_retries=None,
-                 put_threads=None, replication_desired=None,
-                 filename=None, update_time=60.0, update_collection=None,
+                 put_threads=None, replication_desired=None, filename=None,
+                 update_time=60.0, update_collection=None, storage_classes=None,
                  logger=logging.getLogger('arvados.arv_put'), dry_run=False,
                  follow_links=True, exclude_paths=[], exclude_names=None):
         self.paths = paths
@@ -410,6 +443,8 @@ class ArvPutUploadJob(object):
         self.replication_desired = replication_desired
         self.put_threads = put_threads
         self.filename = filename
+        self.storage_classes = storage_classes
+        self._api_client = api_client
         self._state_lock = threading.Lock()
         self._state = None # Previous run state (file list & manifest)
         self._current_files = [] # Current run file list
@@ -452,8 +487,8 @@ class ArvPutUploadJob(object):
         """
         # If there aren't special files to be read, reset total bytes count to zero
         # to start counting.
-        if not any(filter(lambda p: not (os.path.isfile(p) or os.path.isdir(p)),
-                          self.paths)):
+        if not any([p for p in self.paths
+                    if not (os.path.isfile(p) or os.path.isdir(p))]):
             self.bytes_expected = 0
 
         for path in self.paths:
@@ -486,22 +521,20 @@ class ArvPutUploadJob(object):
                         root_relpath = ''
                     # Exclude files/dirs by full path matching pattern
                     if self.exclude_paths:
-                        dirs[:] = filter(
-                            lambda d: not any(
-                                [pathname_match(os.path.join(root_relpath, d),
-                                                pat)
-                                 for pat in self.exclude_paths]),
-                            dirs)
-                        files = filter(
-                            lambda f: not any(
-                                [pathname_match(os.path.join(root_relpath, f),
-                                                pat)
-                                 for pat in self.exclude_paths]),
-                            files)
+                        dirs[:] = [d for d in dirs
+                                   if not any(pathname_match(
+                                           os.path.join(root_relpath, d), pat)
+                                              for pat in self.exclude_paths)]
+                        files = [f for f in files
+                                 if not any(pathname_match(
+                                         os.path.join(root_relpath, f), pat)
+                                            for pat in self.exclude_paths)]
                     # Exclude files/dirs by name matching pattern
                     if self.exclude_names is not None:
-                        dirs[:] = filter(lambda d: not self.exclude_names.match(d), dirs)
-                        files = filter(lambda f: not self.exclude_names.match(f), files)
+                        dirs[:] = [d for d in dirs
+                                   if not self.exclude_names.match(d)]
+                        files = [f for f in files
+                                 if not self.exclude_names.match(f)]
                     # Make os.walk()'s dir traversing order deterministic
                     dirs.sort()
                     files.sort()
@@ -586,10 +619,14 @@ class ArvPutUploadJob(object):
                 else:
                     # The file already exist on remote collection, skip it.
                     pass
-            self._remote_collection.save(num_retries=self.num_retries)
+            self._remote_collection.save(storage_classes=self.storage_classes,
+                                         num_retries=self.num_retries)
         else:
+            if self.storage_classes is None:
+                self.storage_classes = ['default']
             self._local_collection.save_new(
                 name=self.name, owner_uuid=self.owner_uuid,
+                storage_classes=self.storage_classes,
                 ensure_unique_name=self.ensure_unique_name,
                 num_retries=self.num_retries)
 
@@ -701,6 +738,7 @@ class ArvPutUploadJob(object):
             elif file_in_local_collection.permission_expired():
                 # Permission token expired, re-upload file. This will change whenever
                 # we have a API for refreshing tokens.
+                self.logger.warning("Uploaded file '{}' access token expired, will re-upload it from scratch".format(filename))
                 should_upload = True
                 self._local_collection.remove(filename)
             elif cached_file_data['size'] == file_in_local_collection.size():
@@ -765,7 +803,8 @@ class ArvPutUploadJob(object):
         if update_collection and re.match(arvados.util.collection_uuid_pattern,
                                           update_collection):
             try:
-                self._remote_collection = arvados.collection.Collection(update_collection)
+                self._remote_collection = arvados.collection.Collection(
+                    update_collection, api_client=self._api_client)
             except arvados.errors.ApiError as error:
                 raise CollectionUpdateError("Cannot read collection {} ({})".format(update_collection, error))
             else:
@@ -812,7 +851,11 @@ class ArvPutUploadJob(object):
                 # No cache file, set empty state
                 self._state = copy.deepcopy(self.EMPTY_STATE)
             # Load the previous manifest so we can check if files were modified remotely.
-            self._local_collection = arvados.collection.Collection(self._state['manifest'], replication_desired=self.replication_desired, put_threads=self.put_threads)
+            self._local_collection = arvados.collection.Collection(
+                self._state['manifest'],
+                replication_desired=self.replication_desired,
+                put_threads=self.put_threads,
+                api_client=self._api_client)
 
     def collection_file_paths(self, col, path_prefix='.'):
         """Return a list of file paths by recursively go through the entire collection `col`"""
@@ -870,7 +913,7 @@ class ArvPutUploadJob(object):
         m = self._my_collection().stripped_manifest().encode()
         local_pdh = '{}+{}'.format(hashlib.md5(m).hexdigest(), len(m))
         if pdh != local_pdh:
-            logger.warning("\n".join([
+            self.logger.warning("\n".join([
                 "arv-put: API server provided PDH differs from local manifest.",
                 "         This should not happen; showing API server version."]))
         return pdh
@@ -918,7 +961,7 @@ _machine_format = "{} {}: {{}} written {{}} total\n".format(sys.argv[0],
 def pathname_match(pathname, pattern):
     name = pathname.split(os.sep)
     # Fix patterns like 'some/subdir/' or 'some//subdir'
-    pat = [x for x in pattern.split(os.sep) if x != '']
+    pat = [x for x in pattern.split(os.sep) if x != '' and x != '.']
     if len(name) != len(pat):
         return False
     for i in range(len(name)):
@@ -943,9 +986,6 @@ def progress_writer(progress_func, outfile=sys.stderr):
         outfile.write(progress_func(bytes_written, bytes_expected))
     return write_progress
 
-def exit_signal_handler(sigcode, frame):
-    sys.exit(-sigcode)
-
 def desired_project_uuid(api_client, project_uuid, num_retries):
     if not project_uuid:
         query = api_client.users().current()
@@ -957,15 +997,28 @@ def desired_project_uuid(api_client, project_uuid, num_retries):
         raise ValueError("Not a valid project UUID: {}".format(project_uuid))
     return query.execute(num_retries=num_retries)['uuid']
 
-def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
+def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr,
+         install_sig_handlers=True):
     global api_client
 
-    logger = logging.getLogger('arvados.arv_put')
-    logger.setLevel(logging.INFO)
     args = parse_arguments(arguments)
+    logger = logging.getLogger('arvados.arv_put')
+    if args.silent:
+        logger.setLevel(logging.WARNING)
+    else:
+        logger.setLevel(logging.INFO)
     status = 0
+
+    request_id = arvados.util.new_request_id()
+
+    formatter = ArvPutLogFormatter(request_id)
+    logging.getLogger('arvados').handlers[0].setFormatter(formatter)
+
     if api_client is None:
-        api_client = arvados.api('v1')
+        api_client = arvados.api('v1', request_id=request_id)
+
+    if install_sig_handlers:
+        arv_cmd.install_signal_handlers()
 
     # Determine the name to use
     if args.name:
@@ -1001,41 +1054,52 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
     else:
         reporter = None
 
+    #  Split storage-classes argument
+    storage_classes = None
+    if args.storage_classes:
+        storage_classes = args.storage_classes.strip().split(',')
+        if len(storage_classes) > 1:
+            logger.error("Multiple storage classes are not supported currently.")
+            sys.exit(1)
+
+
     # Setup exclude regex from all the --exclude arguments provided
     name_patterns = []
     exclude_paths = []
     exclude_names = None
     if len(args.exclude) > 0:
         # We're supporting 2 kinds of exclusion patterns:
-        # 1) --exclude '*.jpg'      (file/dir name patterns, will only match
-        #                            the name)
-        # 2) --exclude 'foo/bar'    (file/dir path patterns, will match the
+        # 1)   --exclude '*.jpg'    (file/dir name patterns, will only match
+        #                            the name, wherever the file is on the tree)
+        # 2.1) --exclude 'foo/bar'  (file/dir path patterns, will match the
         #                            entire path, and should be relative to
         #                            any input dir argument)
+        # 2.2) --exclude './*.jpg'  (Special case for excluding files/dirs
+        #                            placed directly underneath the input dir)
         for p in args.exclude:
             # Only relative paths patterns allowed
             if p.startswith(os.sep):
                 logger.error("Cannot use absolute paths with --exclude")
                 sys.exit(1)
             if os.path.dirname(p):
-                # We don't support of path patterns with '.' or '..'
+                # We don't support of path patterns with '..'
                 p_parts = p.split(os.sep)
-                if '.' in p_parts or '..' in p_parts:
+                if '..' in p_parts:
                     logger.error(
-                        "Cannot use path patterns that include '.' or '..'")
+                        "Cannot use path patterns that include or '..'")
                     sys.exit(1)
                 # Path search pattern
                 exclude_paths.append(p)
             else:
                 # Name-only search pattern
                 name_patterns.append(p)
-        # For name only matching, we can combine all patterns into a single regexp,
-        # for better performance.
+        # For name only matching, we can combine all patterns into a single
+        # regexp, for better performance.
         exclude_names = re.compile('|'.join(
             [fnmatch.translate(p) for p in name_patterns]
         )) if len(name_patterns) > 0 else None
-        # Show the user the patterns to be used, just in case they weren't specified inside
-        # quotes and got changed by the shell expansion.
+        # Show the user the patterns to be used, just in case they weren't
+        # specified inside quotes and got changed by the shell expansion.
         logger.info("Exclude patterns: {}".format(args.exclude))
 
     # If this is used by a human, and there's at least one directory to be
@@ -1048,6 +1112,7 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
                                  use_cache = args.use_cache,
                                  filename = args.filename,
                                  reporter = reporter,
+                                 api_client = api_client,
                                  num_retries = args.retries,
                                  replication_desired = args.replication,
                                  put_threads = args.threads,
@@ -1055,6 +1120,7 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
                                  owner_uuid = project_uuid,
                                  ensure_unique_name = True,
                                  update_collection = args.update_collection,
+                                 storage_classes=storage_classes,
                                  logger=logger,
                                  dry_run=args.dry_run,
                                  follow_links=args.follow_links,
@@ -1080,11 +1146,6 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
             "arv-put: %s" % str(error)]))
         sys.exit(1)
 
-    # Install our signal handler for each code in CAUGHT_SIGNALS, and save
-    # the originals.
-    orig_signal_handlers = {sigcode: signal.signal(sigcode, exit_signal_handler)
-                            for sigcode in CAUGHT_SIGNALS}
-
     if not args.dry_run and not args.update_collection and args.resume and writer.bytes_written > 0:
         logger.warning("\n".join([
             "arv-put: Resuming previous upload from last checkpoint.",
@@ -1129,13 +1190,13 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
     # Print the locator (uuid) of the new collection.
     if output is None:
         status = status or 1
-    else:
+    elif not args.silent:
         stdout.write(output)
         if not output.endswith('\n'):
             stdout.write('\n')
 
-    for sigcode, orig_handler in listitems(orig_signal_handlers):
-        signal.signal(sigcode, orig_handler)
+    if install_sig_handlers:
+        arv_cmd.restore_signal_handlers()
 
     if status != 0:
         sys.exit(status)