8937: refactor cache check logic into a check_cache method and update all references.
authorradhika <radhika@curoverse.com>
Wed, 27 Apr 2016 15:56:01 +0000 (11:56 -0400)
committerradhika <radhika@curoverse.com>
Mon, 2 May 2016 18:31:05 +0000 (14:31 -0400)
sdk/python/arvados/commands/put.py
sdk/python/tests/test_arv_put.py

index 8fa1c8f66b9dbd456d3c661884a4052ca4b239c1..d3510db7c5d52cf297f25ca3ac229f3b87c65056 100644 (file)
@@ -197,25 +197,10 @@ class ResumeCacheConflict(Exception):
 class ResumeCache(object):
     CACHE_DIR = '.cache/arvados/arv-put'
 
-    def __init__(self, file_spec, api_client=None, num_retries=0):
+    def __init__(self, file_spec):
         self.cache_file = open(file_spec, 'a+')
         self._lock_file(self.cache_file)
         self.filename = self.cache_file.name
-        try:
-            state = self.load()
-            locator = None
-            try:
-                if "_finished_streams" in state and len(state["_finished_streams"]) > 0:
-                    locator = state["_finished_streams"][0][1][0]
-                elif "_current_stream_locators" in state and len(state["_current_stream_locators"]) > 0:
-                    locator = state["_current_stream_locators"][0]
-                if locator is not None:
-                    kc = arvados.keep.KeepClient(api_client=api_client)
-                    kc.head(locator, num_retries=num_retries)
-            except Exception as e:
-                raise arvados.errors.KeepRequestError("Head request error for {}: {}".format(locator, e))
-        except (ValueError):
-            pass
 
     @classmethod
     def make_path(cls, args):
@@ -241,6 +226,23 @@ class ResumeCache(object):
         self.cache_file.seek(0)
         return json.load(self.cache_file)
 
+    def check_cache(self, api_client=None, num_retries=0):
+        try:
+            state = self.load()
+            locator = None
+            try:
+                if "_finished_streams" in state and len(state["_finished_streams"]) > 0:
+                    locator = state["_finished_streams"][0][1][0]
+                elif "_current_stream_locators" in state and len(state["_current_stream_locators"]) > 0:
+                    locator = state["_current_stream_locators"][0]
+                if locator is not None:
+                    kc = arvados.keep.KeepClient(api_client=api_client)
+                    kc.head(locator, num_retries=num_retries)
+            except Exception as e:
+                self.restart()
+        except (ValueError):
+            pass
+
     def save(self, data):
         try:
             new_cache_fd, new_cache_name = tempfile.mkstemp(
@@ -452,14 +454,10 @@ def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
     resume_cache = None
     if args.resume:
         try:
-            cachepath = ResumeCache.make_path(args)
-            resume_cache = ResumeCache(cachepath, api_client=api_client, num_retries=args.retries)
+            resume_cache = ResumeCache(ResumeCache.make_path(args))
+            resume_cache.check_cache(api_client=api_client, num_retries=args.retries)
         except (IOError, OSError, ValueError):
             pass  # Couldn't open cache directory/file.  Continue without it.
-        except arvados.errors.KeepRequestError:
-            # delete the cache and create a new one
-            shutil.rmtree(cachepath)
-            resume_cache = ResumeCache(cachepath)
         except ResumeCacheConflict:
             print >>stderr, "\n".join([
                 "arv-put: Another process is already uploading this data.",
index f1ed35a94ae51e2d4c7e8e7ba534fbfd5e8c5645..a6c1233067bc6c035a217fab5ed88174dd58ddd9 100644 (file)
@@ -160,8 +160,9 @@ class ArvadosPutResumeCacheTest(ArvadosBaseTestCase):
             self.last_cache = arv_put.ResumeCache(cachefile.name)
         self.last_cache.save(thing)
         self.last_cache.close()
-        with self.assertRaises(arvados.errors.KeepRequestError):
-            arv_put.ResumeCache(self.last_cache.filename)
+        resume_cache = arv_put.ResumeCache(self.last_cache.filename)
+        self.assertNotEqual(None, resume_cache)
+        self.assertRaises(None, resume_cache.check_cache())
 
     def test_basic_cache_storage(self):
         thing = ['test', 'list']