Added --storage-classes argument to arv-put.
[arvados.git] / sdk / python / arvados / commands / put.py
index 2032da06203781dddde420efb4dbaeb33638aa8d..cba00c3c8cf153039de990d27867558d0dbc699a 100644 (file)
@@ -34,7 +34,6 @@ from arvados._version import __version__
 
 import arvados.commands._util as arv_cmd
 
-CAUGHT_SIGNALS = [signal.SIGINT, signal.SIGQUIT, signal.SIGTERM]
 api_client = None
 
 upload_opts = argparse.ArgumentParser(add_help=False)
@@ -141,6 +140,10 @@ physical storage devices (e.g., disks) should have a copy of each data
 block. Default is to use the server-provided default (if any) or 2.
 """)
 
+upload_opts.add_argument('--storage-classes', help="""
+Specify comma separated list of storage classes to be used when saving data to Keep.
+""")
+
 upload_opts.add_argument('--threads', type=int, metavar='N', default=None,
                          help="""
 Set the number of upload threads to be used. Take into account that
@@ -309,6 +312,24 @@ class FileUploadList(list):
         super(FileUploadList, self).append(other)
 
 
+# Appends the X-Request-Id to the log message when log level is ERROR or DEBUG
+class ArvPutLogFormatter(logging.Formatter):
+    std_fmtr = logging.Formatter(arvados.log_format, arvados.log_date_format)
+    err_fmtr = None
+    request_id_informed = False
+
+    def __init__(self, request_id):
+        self.err_fmtr = logging.Formatter(
+            arvados.log_format+' (X-Request-Id: {})'.format(request_id),
+            arvados.log_date_format)
+
+    def format(self, record):
+        if (not self.request_id_informed) and (record.levelno in (logging.DEBUG, logging.ERROR)):
+            self.request_id_informed = True
+            return self.err_fmtr.format(record)
+        return self.std_fmtr.format(record)
+
+
 class ResumeCache(object):
     CACHE_DIR = '.cache/arvados/arv-put'
 
@@ -401,8 +422,8 @@ class ArvPutUploadJob(object):
     def __init__(self, paths, resume=True, use_cache=True, reporter=None,
                  name=None, owner_uuid=None, api_client=None,
                  ensure_unique_name=False, num_retries=None,
-                 put_threads=None, replication_desired=None,
-                 filename=None, update_time=60.0, update_collection=None,
+                 put_threads=None, replication_desired=None, filename=None,
+                 update_time=60.0, update_collection=None, storage_classes=None,
                  logger=logging.getLogger('arvados.arv_put'), dry_run=False,
                  follow_links=True, exclude_paths=[], exclude_names=None):
         self.paths = paths
@@ -422,6 +443,7 @@ class ArvPutUploadJob(object):
         self.replication_desired = replication_desired
         self.put_threads = put_threads
         self.filename = filename
+        self.storage_classes = storage_classes
         self._api_client = api_client
         self._state_lock = threading.Lock()
         self._state = None # Previous run state (file list & manifest)
@@ -597,10 +619,14 @@ class ArvPutUploadJob(object):
                 else:
                     # The file already exist on remote collection, skip it.
                     pass
-            self._remote_collection.save(num_retries=self.num_retries)
+            self._remote_collection.save(storage_classes=self.storage_classes,
+                                         num_retries=self.num_retries)
         else:
+            if self.storage_classes is None:
+                self.storage_classes = ['default']
             self._local_collection.save_new(
                 name=self.name, owner_uuid=self.owner_uuid,
+                storage_classes=self.storage_classes,
                 ensure_unique_name=self.ensure_unique_name,
                 num_retries=self.num_retries)
 
@@ -960,9 +986,6 @@ def progress_writer(progress_func, outfile=sys.stderr):
         outfile.write(progress_func(bytes_written, bytes_expected))
     return write_progress
 
-def exit_signal_handler(sigcode, frame):
-    sys.exit(-sigcode)
-
 def desired_project_uuid(api_client, project_uuid, num_retries):
     if not project_uuid:
         query = api_client.users().current()
@@ -974,7 +997,8 @@ def desired_project_uuid(api_client, project_uuid, num_retries):
         raise ValueError("Not a valid project UUID: {}".format(project_uuid))
     return query.execute(num_retries=num_retries)['uuid']
 
-def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
+def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr,
+         install_sig_handlers=True):
     global api_client
 
     args = parse_arguments(arguments)
@@ -986,11 +1010,16 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
     status = 0
 
     request_id = arvados.util.new_request_id()
-    logger.info('X-Request-Id: '+request_id)
+
+    formatter = ArvPutLogFormatter(request_id)
+    logging.getLogger('arvados').handlers[0].setFormatter(formatter)
 
     if api_client is None:
         api_client = arvados.api('v1', request_id=request_id)
 
+    if install_sig_handlers:
+        arv_cmd.install_signal_handlers()
+
     # Determine the name to use
     if args.name:
         if args.stream or args.raw:
@@ -1025,6 +1054,15 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
     else:
         reporter = None
 
+    #  Split storage-classes argument
+    storage_classes = None
+    if args.storage_classes:
+        storage_classes = args.storage_classes.strip().split(',')
+        if len(storage_classes) > 1:
+            logger.error("Multiple storage classes are not supported currently.")
+            sys.exit(1)
+
+
     # Setup exclude regex from all the --exclude arguments provided
     name_patterns = []
     exclude_paths = []
@@ -1082,6 +1120,7 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
                                  owner_uuid = project_uuid,
                                  ensure_unique_name = True,
                                  update_collection = args.update_collection,
+                                 storage_classes=storage_classes,
                                  logger=logger,
                                  dry_run=args.dry_run,
                                  follow_links=args.follow_links,
@@ -1107,11 +1146,6 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
             "arv-put: %s" % str(error)]))
         sys.exit(1)
 
-    # Install our signal handler for each code in CAUGHT_SIGNALS, and save
-    # the originals.
-    orig_signal_handlers = {sigcode: signal.signal(sigcode, exit_signal_handler)
-                            for sigcode in CAUGHT_SIGNALS}
-
     if not args.dry_run and not args.update_collection and args.resume and writer.bytes_written > 0:
         logger.warning("\n".join([
             "arv-put: Resuming previous upload from last checkpoint.",
@@ -1161,8 +1195,8 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
         if not output.endswith('\n'):
             stdout.write('\n')
 
-    for sigcode, orig_handler in listitems(orig_signal_handlers):
-        signal.signal(sigcode, orig_handler)
+    if install_sig_handlers:
+        arv_cmd.restore_signal_handlers()
 
     if status != 0:
         sys.exit(status)