Merge branch 'master' into 3618-column-ordering
[arvados.git] / sdk / python / arvados / commands / put.py
index 15551fac07db8d29b6aca0da8f40b5880edd0edd..4a926c701c9f45648dc6337fc294391986e03379 100644 (file)
@@ -3,7 +3,6 @@
 # TODO:
 # --md5sum - display md5 of each file as read from disk
 
-import apiclient.errors
 import argparse
 import arvados
 import base64
@@ -18,6 +17,7 @@ import signal
 import socket
 import sys
 import tempfile
+from apiclient import errors as apiclient_errors
 
 import arvados.commands._util as arv_cmd
 
@@ -145,7 +145,7 @@ Do not continue interrupted uploads from cached state.
 
 arg_parser = argparse.ArgumentParser(
     description='Copy data from the local filesystem to Keep.',
-    parents=[upload_opts, run_opts])
+    parents=[upload_opts, run_opts, arv_cmd.retry_opt])
 
 def parse_arguments(arguments):
     args = arg_parser.parse_args(arguments)
@@ -246,23 +246,26 @@ class ArvPutCollectionWriter(arvados.ResumableCollectionWriter):
                    ['bytes_written', '_seen_inputs'])
 
     def __init__(self, cache=None, reporter=None, bytes_expected=None,
-                 api_client=None):
+                 api_client=None, num_retries=0):
         self.bytes_written = 0
         self._seen_inputs = []
         self.cache = cache
         self.reporter = reporter
         self.bytes_expected = bytes_expected
-        super(ArvPutCollectionWriter, self).__init__(api_client)
+        super(ArvPutCollectionWriter, self).__init__(
+            api_client, num_retries=num_retries)
 
     @classmethod
-    def from_cache(cls, cache, reporter=None, bytes_expected=None):
+    def from_cache(cls, cache, reporter=None, bytes_expected=None,
+                   num_retries=0):
         try:
             state = cache.load()
             state['_data_buffer'] = [base64.decodestring(state['_data_buffer'])]
-            writer = cls.from_state(state, cache, reporter, bytes_expected)
+            writer = cls.from_state(state, cache, reporter, bytes_expected,
+                                    num_retries=num_retries)
         except (TypeError, ValueError,
                 arvados.errors.StaleWriterStateError) as error:
-            return cls(cache, reporter, bytes_expected)
+            return cls(cache, reporter, bytes_expected, num_retries=num_retries)
         else:
             return writer
 
@@ -348,26 +351,24 @@ def progress_writer(progress_func, outfile=sys.stderr):
 def exit_signal_handler(sigcode, frame):
     sys.exit(-sigcode)
 
-def desired_project_uuid(api_client, project_uuid):
-    if project_uuid:
-        if arvados.util.user_uuid_pattern.match(project_uuid):
-            api_client.users().get(uuid=project_uuid).execute()
-            return project_uuid
-        elif arvados.util.group_uuid_pattern.match(project_uuid):
-            api_client.groups().get(uuid=project_uuid).execute()
-            return project_uuid
-        else:
-            raise ValueError("Not a valid project uuid: {}".format(project_uuid))
+def desired_project_uuid(api_client, project_uuid, num_retries):
+    if not project_uuid:
+        query = api_client.users().current()
+    elif arvados.util.user_uuid_pattern.match(project_uuid):
+        query = api_client.users().get(uuid=project_uuid)
+    elif arvados.util.group_uuid_pattern.match(project_uuid):
+        query = api_client.groups().get(uuid=project_uuid)
     else:
-        return api_client.users().current().execute()['uuid']
+        raise ValueError("Not a valid project UUID: {}".format(project_uuid))
+    return query.execute(num_retries=num_retries)['uuid']
 
 def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
     global api_client
-    if api_client is None:
-        api_client = arvados.api('v1')
-    status = 0
 
     args = parse_arguments(arguments)
+    status = 0
+    if api_client is None:
+        api_client = arvados.api('v1')
 
     # Determine the name to use
     if args.name:
@@ -387,8 +388,9 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
 
     # Determine the parent project
     try:
-        project_uuid = desired_project_uuid(api_client, args.project_uuid)
-    except (apiclient.errors.Error, ValueError) as error:
+        project_uuid = desired_project_uuid(api_client, args.project_uuid,
+                                            args.retries)
+    except (apiclient_errors.Error, ValueError) as error:
         print >>stderr, error
         sys.exit(1)
 
@@ -413,10 +415,11 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
             sys.exit(1)
 
     if resume_cache is None:
-        writer = ArvPutCollectionWriter(resume_cache, reporter, bytes_expected)
+        writer = ArvPutCollectionWriter(resume_cache, reporter, bytes_expected,
+                                        num_retries=args.retries)
     else:
         writer = ArvPutCollectionWriter.from_cache(
-            resume_cache, reporter, bytes_expected)
+            resume_cache, reporter, bytes_expected, num_retries=args.retries)
 
     # Install our signal handler for each code in CAUGHT_SIGNALS, and save
     # the originals.
@@ -456,7 +459,7 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
                     'manifest_text': writer.manifest_text()
                     },
                 ensure_unique_name=True
-                ).execute()
+                ).execute(num_retries=args.retries)
 
             print >>stderr, "Collection saved as '%s'" % collection['name']
 
@@ -465,7 +468,7 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
             else:
                 output = collection['uuid']
 
-        except apiclient.errors.Error as error:
+        except apiclient_errors.Error as error:
             print >>stderr, (
                 "arv-put: Error creating Collection on project: {}.".format(
                     error))