2752: Refactor progress reporting in arv-put's CollectionWriter.
authorBrett Smith <brett@curoverse.com>
Fri, 23 May 2014 20:22:39 +0000 (16:22 -0400)
committerBrett Smith <brett@curoverse.com>
Fri, 30 May 2014 14:40:09 +0000 (10:40 -0400)
sdk/python/arvados/commands/put.py
sdk/python/tests/test_arv-put.py

index 705dcfd8f35b321c261eb3953b17ad4030271287..37974c09459cd19e9615368fd8a21c245ce46214 100644 (file)
@@ -199,22 +199,28 @@ class ResumeCache(object):
         self.close()
 
 
-class ResumeCacheCollectionWriter(arvados.ResumableCollectionWriter):
-    def __init__(self, cache=None):
+class ArvPutCollectionWriter(arvados.ResumableCollectionWriter):
+    def __init__(self, cache=None, reporter=None, bytes_expected=None):
+        self.__init_locals__(cache, reporter, bytes_expected)
+        super(ArvPutCollectionWriter, self).__init__()
+
+    def __init_locals__(self, cache, reporter, bytes_expected):
         self.cache = cache
-        super(ResumeCacheCollectionWriter, self).__init__()
+        self.report_func = reporter
+        self.bytes_written = 0
+        self.bytes_expected = bytes_expected
 
     @classmethod
-    def from_cache(cls, cache):
+    def from_cache(cls, cache, reporter=None, bytes_expected=None):
         try:
             state = cache.load()
             state['_data_buffer'] = [base64.decodestring(state['_data_buffer'])]
             writer = cls.from_state(state)
         except (TypeError, ValueError,
                 arvados.errors.StaleWriterStateError) as error:
-            return cls(cache)
+            return cls(cache, reporter, bytes_expected)
         else:
-            writer.cache = cache
+            writer.__init_locals__(cache, reporter, bytes_expected)
             return writer
 
     def checkpoint_state(self):
@@ -229,41 +235,17 @@ class ResumeCacheCollectionWriter(arvados.ResumableCollectionWriter):
                 state[attr] = list(value)
         self.cache.save(state)
 
-
-class CollectionWriterWithProgress(arvados.CollectionWriter):
-    def flush_data(self, *args, **kwargs):
-        if not getattr(self, 'display_type', None):
-            return
-        if not hasattr(self, 'bytes_flushed'):
-            self.bytes_flushed = 0
-        self.bytes_flushed += self._data_buffer_len
-        super(CollectionWriterWithProgress, self).flush_data(*args, **kwargs)
-        self.bytes_flushed -= self._data_buffer_len
-        if self.display_type == 'machine':
-            sys.stderr.write('%s %d: %d written %d total\n' %
-                             (sys.argv[0],
-                              os.getpid(),
-                              self.bytes_flushed,
-                              getattr(self, 'bytes_expected', -1)))
-        elif getattr(self, 'bytes_expected', 0) > 0:
-            pct = 100.0 * self.bytes_flushed / self.bytes_expected
-            sys.stderr.write('\r%dM / %dM %.1f%% ' %
-                             (self.bytes_flushed >> 20,
-                              self.bytes_expected >> 20, pct))
-        else:
-            sys.stderr.write('\r%d ' % self.bytes_flushed)
-
-    def manifest_text(self, *args, **kwargs):
-        manifest_text = (super(CollectionWriterWithProgress, self)
-                         .manifest_text(*args, **kwargs))
-        if getattr(self, 'display_type', None):
-            if self.display_type == 'human':
-                sys.stderr.write('\n')
-            self.display_type = None
-        return manifest_text
+    def flush_data(self):
+        bytes_buffered = self._data_buffer_len
+        super(ArvPutCollectionWriter, self).flush_data()
+        self.bytes_written += (bytes_buffered - self._data_buffer_len)
+        if self.report_func is not None:
+            self.report_func(self.bytes_written, self.bytes_expected)
 
 
 def expected_bytes_for(pathlist):
+    # Walk the given directory trees and stat files, adding up file sizes,
+    # so we can display progress as percent
     bytesum = 0
     for path in pathlist:
         if os.path.isdir(path):
@@ -289,23 +271,23 @@ def human_progress(bytes_written, bytes_expected):
     else:
         return "\r{} ".format(bytes_written)
 
+def progress_writer(progress_func, outfile=sys.stderr):
+    def write_progress(bytes_written, bytes_expected):
+        outfile.write(progress_func(bytes_written, bytes_expected))
+    return write_progress
+
 def main(arguments=None):
     args = parse_arguments(arguments)
 
     if args.progress:
-        writer = CollectionWriterWithProgress()
-        writer.display_type = 'human'
+        reporter = progress_writer(human_progress)
     elif args.batch_progress:
-        writer = CollectionWriterWithProgress()
-        writer.display_type = 'machine'
+        reporter = progress_writer(machine_progress)
     else:
-        writer = arvados.CollectionWriter()
+        reporter = None
 
-    # Walk the given directory trees and stat files, adding up file sizes,
-    # so we can display progress as percent
-    writer.bytes_expected = expected_bytes_for(args.paths)
-    if writer.bytes_expected is None:
-        del writer.bytes_expected
+    writer = ArvPutCollectionWriter(
+        reporter=reporter, bytes_expected=expected_bytes_for(args.paths))
 
     # Copy file data to Keep.
     for path in args.paths:
index db03eca43035d5405381fb5681aa13b2189952a7..e765482106554ff77eedd63e79abbf9b27dc67f3 100644 (file)
@@ -171,62 +171,73 @@ class ArvadosPutResumeCacheTest(ArvadosBaseTestCase):
                 os.unlink(cachefile.name)
 
 
-class ArvadosPutResumeCacheCollectionWriterTest(ArvadosKeepLocalStoreTestCase):
+class ArvadosPutCollectionWriterTest(ArvadosKeepLocalStoreTestCase):
     def setUp(self):
-        super(ArvadosPutResumeCacheCollectionWriterTest, self).setUp()
+        super(ArvadosPutCollectionWriterTest, self).setUp()
         with tempfile.NamedTemporaryFile(delete=False) as cachefile:
             self.cache = arv_put.ResumeCache(cachefile.name)
             self.cache_filename = cachefile.name
 
     def tearDown(self):
-        super(ArvadosPutResumeCacheCollectionWriterTest, self).tearDown()
+        super(ArvadosPutCollectionWriterTest, self).tearDown()
         if os.path.exists(self.cache_filename):
             self.cache.destroy()
         self.cache.close()
 
     def test_writer_caches(self):
-        cwriter = arv_put.ResumeCacheCollectionWriter(self.cache)
+        cwriter = arv_put.ArvPutCollectionWriter(self.cache)
         cwriter.write_file('/dev/null')
         self.assertTrue(self.cache.load())
         self.assertEquals(". 0:0:null\n", cwriter.manifest_text())
 
     def test_writer_works_without_cache(self):
-        cwriter = arv_put.ResumeCacheCollectionWriter()
+        cwriter = arv_put.ArvPutCollectionWriter()
         cwriter.write_file('/dev/null')
         self.assertEquals(". 0:0:null\n", cwriter.manifest_text())
 
     def test_writer_resumes_from_cache(self):
-        cwriter = arv_put.ResumeCacheCollectionWriter(self.cache)
+        cwriter = arv_put.ArvPutCollectionWriter(self.cache)
         with self.make_test_file() as testfile:
             cwriter.write_file(testfile.name, 'test')
-            new_writer = arv_put.ResumeCacheCollectionWriter.from_cache(
+            new_writer = arv_put.ArvPutCollectionWriter.from_cache(
                 self.cache)
             self.assertEquals(
                 ". 098f6bcd4621d373cade4e832627b4f6+4 0:4:test\n",
                 new_writer.manifest_text())
 
     def test_new_writer_from_stale_cache(self):
-        cwriter = arv_put.ResumeCacheCollectionWriter(self.cache)
+        cwriter = arv_put.ArvPutCollectionWriter(self.cache)
         with self.make_test_file() as testfile:
             cwriter.write_file(testfile.name, 'test')
-        new_writer = arv_put.ResumeCacheCollectionWriter.from_cache(self.cache)
+        new_writer = arv_put.ArvPutCollectionWriter.from_cache(self.cache)
         new_writer.write_file('/dev/null')
         self.assertEquals(". 0:0:null\n", new_writer.manifest_text())
 
     def test_new_writer_from_empty_cache(self):
-        cwriter = arv_put.ResumeCacheCollectionWriter.from_cache(self.cache)
+        cwriter = arv_put.ArvPutCollectionWriter.from_cache(self.cache)
         cwriter.write_file('/dev/null')
         self.assertEquals(". 0:0:null\n", cwriter.manifest_text())
 
     def test_writer_resumable_after_arbitrary_bytes(self):
-        cwriter = arv_put.ResumeCacheCollectionWriter(self.cache)
+        cwriter = arv_put.ArvPutCollectionWriter(self.cache)
         # These bytes are intentionally not valid UTF-8.
         with self.make_test_file('\x00\x07\xe2') as testfile:
             cwriter.write_file(testfile.name, 'test')
-            new_writer = arv_put.ResumeCacheCollectionWriter.from_cache(
+            new_writer = arv_put.ArvPutCollectionWriter.from_cache(
                 self.cache)
         self.assertEquals(cwriter.manifest_text(), new_writer.manifest_text())
 
+    def test_progress_reporting(self):
+        for expect_count in (None, 8):
+            progression = []
+            cwriter = arv_put.ArvPutCollectionWriter(
+                reporter=lambda *args: progression.append(args),
+                bytes_expected=expect_count)
+            with self.make_test_file() as testfile:
+                cwriter.write_file(testfile.name, 'test')
+            cwriter.finish_current_stream()
+            self.assertIn((4, expect_count), progression)
+
 
 class ArvadosExpectedBytesTest(ArvadosBaseTestCase):
     TEST_SIZE = os.path.getsize(__file__)