20933: Use [0-9] instead of \d in regex
[arvados.git] / sdk / python / arvados / commands / arv_copy.py
index 95e4f61d17c40d9a6012e0089cf14522bd93cb68..6c7d873af4a0d7741123502c3444112ecafcf395 100755 (executable)
@@ -146,29 +146,37 @@ def main():
 
     # Identify the kind of object we have been given, and begin copying.
     t = uuid_type(src_arv, args.object_uuid)
-    if t == 'Collection':
-        set_src_owner_uuid(src_arv.collections(), args.object_uuid, args)
-        result = copy_collection(args.object_uuid,
-                                 src_arv, dst_arv,
-                                 args)
-    elif t == 'Workflow':
-        set_src_owner_uuid(src_arv.workflows(), args.object_uuid, args)
-        result = copy_workflow(args.object_uuid, src_arv, dst_arv, args)
-    elif t == 'Group':
-        set_src_owner_uuid(src_arv.groups(), args.object_uuid, args)
-        result = copy_project(args.object_uuid, src_arv, dst_arv, args.project_uuid, args)
-    elif t == 'httpURL':
-        result = copy_from_http(args.object_uuid, src_arv, dst_arv, args)
-    else:
-        abort("cannot copy object {} of type {}".format(args.object_uuid, t))
+
+    try:
+        if t == 'Collection':
+            set_src_owner_uuid(src_arv.collections(), args.object_uuid, args)
+            result = copy_collection(args.object_uuid,
+                                     src_arv, dst_arv,
+                                     args)
+        elif t == 'Workflow':
+            set_src_owner_uuid(src_arv.workflows(), args.object_uuid, args)
+            result = copy_workflow(args.object_uuid, src_arv, dst_arv, args)
+        elif t == 'Group':
+            set_src_owner_uuid(src_arv.groups(), args.object_uuid, args)
+            result = copy_project(args.object_uuid, src_arv, dst_arv, args.project_uuid, args)
+        elif t == 'httpURL':
+            result = copy_from_http(args.object_uuid, src_arv, dst_arv, args)
+        else:
+            abort("cannot copy object {} of type {}".format(args.object_uuid, t))
+    except Exception as e:
+        logger.error("%s", e, exc_info=args.verbose)
+        exit(1)
 
     # Clean up any outstanding temp git repositories.
     for d in listvalues(local_repo_dir):
         shutil.rmtree(d, ignore_errors=True)
 
+    if not result:
+        exit(1)
+
     # If no exception was thrown and the response does not have an
     # error_token field, presume success
-    if result in None or 'error_token' in result or 'uuid' not in result:
+    if result is None or 'error_token' in result or 'uuid' not in result:
         if result:
             logger.error("API server returned an error result: {}".format(result))
         exit(1)
@@ -319,21 +327,20 @@ def copy_workflow(wf_uuid, src, dst, args):
 
     # copy collections and docker images
     if args.recursive and wf["definition"]:
-        wf_def = yaml.safe_load(wf["definition"])
-        if wf_def is not None:
-            locations = []
-            docker_images = {}
-            graph = wf_def.get('$graph', None)
-            if graph is not None:
-                workflow_collections(graph, locations, docker_images)
-            else:
-                workflow_collections(wf_def, locations, docker_images)
+        env = {"ARVADOS_API_HOST": urllib.parse.urlparse(src._rootDesc["rootUrl"]).netloc,
+               "ARVADOS_API_TOKEN": src.api_token,
+               "PATH": os.environ["PATH"]}
+        try:
+            result = subprocess.run(["arvados-cwl-runner", "--quiet", "--print-keep-deps", "arvwf:"+wf_uuid],
+                                    capture_output=True, env=env)
+        except (FileNotFoundError, subprocess.CalledProcessError):
+            logger.error('Copying workflows requires arvados-cwl-runner 2.7.1 or later to be installed in PATH.')
+            return
 
-            if locations:
-                copy_collections(locations, src, dst, args)
+        locations = json.loads(result.stdout)
 
-            for image in docker_images:
-                copy_docker_image(image, docker_images[image], src, dst, args)
+        if locations:
+            copy_collections(locations, src, dst, args)
 
     # copy the workflow itself
     del wf['uuid']
@@ -570,7 +577,7 @@ def copy_collection(obj_uuid, src, dst, args):
     dst_keep = arvados.keep.KeepClient(api_client=dst, num_retries=args.retries)
     dst_manifest = io.StringIO()
     dst_locators = {}
-    bytes_written = [0]
+    bytes_written = 0
     bytes_expected = total_collection_size(manifest)
     if args.progress:
         progress_writer = ProgressWriter(human_progress)
@@ -586,8 +593,17 @@ def copy_collection(obj_uuid, src, dst, args):
     # again and build dst_manifest
 
     lock = threading.Lock()
+
+    # the get queue should be unbounded because we'll add all the
+    # block hashes we want to get, but these are small
     get_queue = queue.Queue()
-    put_queue = queue.Queue()
+
+    threadcount = 4
+
+    # the put queue contains full data blocks
+    # and if 'get' is faster than 'put' we could end up consuming
+    # a great deal of RAM if it isn't bounded.
+    put_queue = queue.Queue(threadcount)
     transfer_error = []
 
     def get_thread():
@@ -601,6 +617,8 @@ def copy_collection(obj_uuid, src, dst, args):
             blockhash = arvados.KeepLocator(word).md5sum
             with lock:
                 if blockhash in dst_locators:
+                    # Already uploaded
+                    get_queue.task_done()
                     continue
 
             try:
@@ -614,12 +632,14 @@ def copy_collection(obj_uuid, src, dst, args):
                     # Drain the 'get' queue so we end early
                     while True:
                         get_queue.get(False)
+                        get_queue.task_done()
                 except queue.Empty:
                     pass
             finally:
                 get_queue.task_done()
 
     def put_thread():
+        nonlocal bytes_written
         while True:
             item = put_queue.get()
             if item is None:
@@ -631,6 +651,8 @@ def copy_collection(obj_uuid, src, dst, args):
             blockhash = loc.md5sum
             with lock:
                 if blockhash in dst_locators:
+                    # Already uploaded
+                    put_queue.task_done()
                     continue
 
             try:
@@ -638,15 +660,16 @@ def copy_collection(obj_uuid, src, dst, args):
                 dst_locator = dst_keep.put(data, classes=(args.storage_classes or []))
                 with lock:
                     dst_locators[blockhash] = dst_locator
-                    bytes_written[0] += loc.size
+                    bytes_written += loc.size
                     if progress_writer:
-                        progress_writer.report(obj_uuid, bytes_written[0], bytes_expected)
+                        progress_writer.report(obj_uuid, bytes_written, bytes_expected)
             except e:
                 logger.error("Error putting block %s (%s bytes): %s", blockhash, loc.size, e)
                 try:
                     # Drain the 'get' queue so we end early
                     while True:
                         get_queue.get(False)
+                        get_queue.task_done()
                 except queue.Empty:
                     pass
                 transfer_error.append(e)
@@ -665,20 +688,14 @@ def copy_collection(obj_uuid, src, dst, args):
 
             get_queue.put(word)
 
-    get_queue.put(None)
-    get_queue.put(None)
-    get_queue.put(None)
-    get_queue.put(None)
+    for i in range(0, threadcount):
+        get_queue.put(None)
 
-    threading.Thread(target=get_thread, daemon=True).start()
-    threading.Thread(target=get_thread, daemon=True).start()
-    threading.Thread(target=get_thread, daemon=True).start()
-    threading.Thread(target=get_thread, daemon=True).start()
+    for i in range(0, threadcount):
+        threading.Thread(target=get_thread, daemon=True).start()
 
-    threading.Thread(target=put_thread, daemon=True).start()
-    threading.Thread(target=put_thread, daemon=True).start()
-    threading.Thread(target=put_thread, daemon=True).start()
-    threading.Thread(target=put_thread, daemon=True).start()
+    for i in range(0, threadcount):
+        threading.Thread(target=put_thread, daemon=True).start()
 
     get_queue.join()
     put_queue.join()
@@ -704,7 +721,7 @@ def copy_collection(obj_uuid, src, dst, args):
         dst_manifest.write("\n")
 
     if progress_writer:
-        progress_writer.report(obj_uuid, bytes_written[0], bytes_expected)
+        progress_writer.report(obj_uuid, bytes_written, bytes_expected)
         progress_writer.finish()
 
     # Copy the manifest and save the collection.