-class ArvPutCollectionWriter(arvados.ResumableCollectionWriter):
- STATE_PROPS = (arvados.ResumableCollectionWriter.STATE_PROPS +
- ['bytes_written', '_seen_inputs'])
-
- def __init__(self, cache=None, reporter=None, bytes_expected=None, **kwargs):
- self.bytes_written = 0
- self._seen_inputs = []
- self.cache = cache
- self.reporter = reporter
- self.bytes_expected = bytes_expected
- super(ArvPutCollectionWriter, self).__init__(**kwargs)
-
- @classmethod
- def from_cache(cls, cache, reporter=None, bytes_expected=None,
- num_retries=0, replication=0):
- try:
- state = cache.load()
- state['_data_buffer'] = [base64.decodestring(state['_data_buffer'])]
- writer = cls.from_state(state, cache, reporter, bytes_expected,
- num_retries=num_retries,
- replication=replication)
- except (TypeError, ValueError,
- arvados.errors.StaleWriterStateError) as error:
- return cls(cache, reporter, bytes_expected,
- num_retries=num_retries,
- replication=replication)
- else:
- return writer
-
- def cache_state(self):
- if self.cache is None:
- return
- state = self.dump_state()
- # Transform attributes for serialization.
- for attr, value in state.items():
- if attr == '_data_buffer':
- state[attr] = base64.encodestring(''.join(value))
- elif hasattr(value, 'popleft'):
- state[attr] = list(value)
- self.cache.save(state)
-
- def report_progress(self):
- if self.reporter is not None:
- self.reporter(self.bytes_written, self.bytes_expected)
-
- def flush_data(self):
- start_buffer_len = self._data_buffer_len
- start_block_count = self.bytes_written / arvados.config.KEEP_BLOCK_SIZE
- super(ArvPutCollectionWriter, self).flush_data()
- if self._data_buffer_len < start_buffer_len: # We actually PUT data.
- self.bytes_written += (start_buffer_len - self._data_buffer_len)
- self.report_progress()
- if (self.bytes_written / arvados.config.KEEP_BLOCK_SIZE) > start_block_count:
- self.cache_state()
-
- def _record_new_input(self, input_type, source_name, dest_name):
- # The key needs to be a list because that's what we'll get back
- # from JSON deserialization.
- key = [input_type, source_name, dest_name]
- if key in self._seen_inputs:
- return False
- self._seen_inputs.append(key)
- return True
-
- def write_file(self, source, filename=None):
- if self._record_new_input('file', source, filename):
- super(ArvPutCollectionWriter, self).write_file(source, filename)
-
- def write_directory_tree(self,
- path, stream_name='.', max_manifest_depth=-1):
- if self._record_new_input('directory', path, stream_name):
- super(ArvPutCollectionWriter, self).write_directory_tree(
- path, stream_name, max_manifest_depth)
-
-