4823: Working on porting more StreamReader tests to ArvadosFile.
[arvados.git] / sdk / python / arvados / arvfile.py
1 import functools
2 import os
3 import zlib
4 import bz2
5 from .ranges import *
6 from arvados.retry import retry_method
7 import config
8 import hashlib
9 import hashlib
10 import threading
11 import Queue
12
13 def split(path):
14     """split(path) -> streamname, filename
15
16     Separate the stream name and file name in a /-separated stream path.
17     If no stream name is available, assume '.'.
18     """
19     try:
20         stream_name, file_name = path.rsplit('/', 1)
21     except ValueError:  # No / in string
22         stream_name, file_name = '.', path
23     return stream_name, file_name
24
25 class ArvadosFileBase(object):
26     def __init__(self, name, mode):
27         self.name = name
28         self.mode = mode
29         self.closed = False
30
31     @staticmethod
32     def _before_close(orig_func):
33         @functools.wraps(orig_func)
34         def wrapper(self, *args, **kwargs):
35             if self.closed:
36                 raise ValueError("I/O operation on closed stream file")
37             return orig_func(self, *args, **kwargs)
38         return wrapper
39
40     def __enter__(self):
41         return self
42
43     def __exit__(self, exc_type, exc_value, traceback):
44         try:
45             self.close()
46         except Exception:
47             if exc_type is None:
48                 raise
49
50     def close(self):
51         self.closed = True
52
53
54 class ArvadosFileReaderBase(ArvadosFileBase):
55     class _NameAttribute(str):
56         # The Python file API provides a plain .name attribute.
57         # Older SDK provided a name() method.
58         # This class provides both, for maximum compatibility.
59         def __call__(self):
60             return self
61
62     def __init__(self, name, mode, num_retries=None):
63         super(ArvadosFileReaderBase, self).__init__(self._NameAttribute(name), mode)
64         self._filepos = 0L
65         self.num_retries = num_retries
66         self._readline_cache = (None, None)
67
68     def __iter__(self):
69         while True:
70             data = self.readline()
71             if not data:
72                 break
73             yield data
74
75     def decompressed_name(self):
76         return re.sub('\.(bz2|gz)$', '', self.name)
77
78     @ArvadosFileBase._before_close
79     def seek(self, pos, whence=os.SEEK_CUR):
80         if whence == os.SEEK_CUR:
81             pos += self._filepos
82         elif whence == os.SEEK_END:
83             pos += self.size()
84         self._filepos = min(max(pos, 0L), self.size())
85
86     def tell(self):
87         return self._filepos
88
89     @ArvadosFileBase._before_close
90     @retry_method
91     def readall(self, size=2**20, num_retries=None):
92         while True:
93             data = self.read(size, num_retries=num_retries)
94             if data == '':
95                 break
96             yield data
97
98     @ArvadosFileBase._before_close
99     @retry_method
100     def readline(self, size=float('inf'), num_retries=None):
101         cache_pos, cache_data = self._readline_cache
102         if self.tell() == cache_pos:
103             data = [cache_data]
104         else:
105             data = ['']
106         data_size = len(data[-1])
107         while (data_size < size) and ('\n' not in data[-1]):
108             next_read = self.read(2 ** 20, num_retries=num_retries)
109             if not next_read:
110                 break
111             data.append(next_read)
112             data_size += len(next_read)
113         data = ''.join(data)
114         try:
115             nextline_index = data.index('\n') + 1
116         except ValueError:
117             nextline_index = len(data)
118         nextline_index = min(nextline_index, size)
119         self._readline_cache = (self.tell(), data[nextline_index:])
120         return data[:nextline_index]
121
122     @ArvadosFileBase._before_close
123     @retry_method
124     def decompress(self, decompress, size, num_retries=None):
125         for segment in self.readall(size, num_retries):
126             data = decompress(segment)
127             if data:
128                 yield data
129
130     @ArvadosFileBase._before_close
131     @retry_method
132     def readall_decompressed(self, size=2**20, num_retries=None):
133         self.seek(0)
134         if self.name.endswith('.bz2'):
135             dc = bz2.BZ2Decompressor()
136             return self.decompress(dc.decompress, size,
137                                    num_retries=num_retries)
138         elif self.name.endswith('.gz'):
139             dc = zlib.decompressobj(16+zlib.MAX_WBITS)
140             return self.decompress(lambda segment: dc.decompress(dc.unconsumed_tail + segment),
141                                    size, num_retries=num_retries)
142         else:
143             return self.readall(size, num_retries=num_retries)
144
145     @ArvadosFileBase._before_close
146     @retry_method
147     def readlines(self, sizehint=float('inf'), num_retries=None):
148         data = []
149         data_size = 0
150         for s in self.readall(num_retries=num_retries):
151             data.append(s)
152             data_size += len(s)
153             if data_size >= sizehint:
154                 break
155         return ''.join(data).splitlines(True)
156
157
158 class StreamFileReader(ArvadosFileReaderBase):
159     def __init__(self, stream, segments, name):
160         super(StreamFileReader, self).__init__(name, 'rb', num_retries=stream.num_retries)
161         self._stream = stream
162         self.segments = segments
163
164     def stream_name(self):
165         return self._stream.name()
166
167     def size(self):
168         n = self.segments[-1]
169         return n.range_start + n.range_size
170
171     @ArvadosFileBase._before_close
172     @retry_method
173     def read(self, size, num_retries=None):
174         """Read up to 'size' bytes from the stream, starting at the current file position"""
175         if size == 0:
176             return ''
177
178         data = ''
179         available_chunks = locators_and_ranges(self.segments, self._filepos, size)
180         if available_chunks:
181             lr = available_chunks[0]
182             data = self._stream._readfrom(lr.locator+lr.segment_offset,
183                                           lr.segment_size,
184                                           num_retries=num_retries)
185
186         self._filepos += len(data)
187         return data
188
189     @ArvadosFileBase._before_close
190     @retry_method
191     def readfrom(self, start, size, num_retries=None):
192         """Read up to 'size' bytes from the stream, starting at 'start'"""
193         if size == 0:
194             return ''
195
196         data = []
197         for lr in locators_and_ranges(self.segments, start, size):
198             data.append(self._stream._readfrom(lr.locator+lr.segment_offset, lr.segment_size,
199                                               num_retries=num_retries))
200         return ''.join(data)
201
202     def as_manifest(self):
203         from stream import normalize_stream
204         segs = []
205         for r in self.segments:
206             segs.extend(self._stream.locators_and_ranges(r.locator, r.range_size))
207         return " ".join(normalize_stream(".", {self.name: segs})) + "\n"
208
209
210 class BufferBlock(object):
211     WRITABLE = 0
212     PENDING = 1
213     COMMITTED = 2
214
215     def __init__(self, blockid, starting_size):
216         self.blockid = blockid
217         self.buffer_block = bytearray(starting_size)
218         self.buffer_view = memoryview(self.buffer_block)
219         self.write_pointer = 0
220         self.state = BufferBlock.WRITABLE
221         self._locator = None
222
223     def append(self, data):
224         if self.state == BufferBlock.WRITABLE:
225             while (self.write_pointer+len(data)) > len(self.buffer_block):
226                 new_buffer_block = bytearray(len(self.buffer_block) * 2)
227                 new_buffer_block[0:self.write_pointer] = self.buffer_block[0:self.write_pointer]
228                 self.buffer_block = new_buffer_block
229                 self.buffer_view = memoryview(self.buffer_block)
230             self.buffer_view[self.write_pointer:self.write_pointer+len(data)] = data
231             self.write_pointer += len(data)
232             self._locator = None
233         else:
234             raise AssertionError("Buffer block is not writable")
235
236     def size(self):
237         return self.write_pointer
238
239     def locator(self):
240         if self._locator is None:
241             self._locator = "%s+%i" % (hashlib.md5(self.buffer_view[0:self.write_pointer]).hexdigest(), self.size())
242         return self._locator
243
244 class AsyncKeepWriteErrors(Exception):
245     def __init__(self, errors):
246         self.errors = errors
247
248     def __repr__(self):
249         return "\n".join(self.errors)
250
251 class BlockManager(object):
252     def __init__(self, keep):
253         self._keep = keep
254         self._bufferblocks = {}
255         self._put_queue = None
256         self._put_errors = None
257         self._put_threads = None
258         self._prefetch_queue = None
259         self._prefetch_threads = None
260
261     def alloc_bufferblock(self, blockid=None, starting_size=2**14):
262         if blockid is None:
263             blockid = "bufferblock%i" % len(self._bufferblocks)
264         bb = BufferBlock(blockid, starting_size=starting_size)
265         self._bufferblocks[bb.blockid] = bb
266         return bb
267
268     def stop_threads(self):
269         if self._put_threads is not None:
270             for t in self._put_threads:
271                 self._put_queue.put(None)
272             for t in self._put_threads:
273                 t.join()
274         self._put_threads = None
275         self._put_queue = None
276         self._put_errors = None
277
278         if self._prefetch_threads is not None:
279             for t in self._prefetch_threads:
280                 self._prefetch_queue.put(None)
281             for t in self._prefetch_threads:
282                 t.join()
283         self._prefetch_threads = None
284         self._prefetch_queue = None
285
286     def commit_bufferblock(self, block):
287         def worker(self):
288             while True:
289                 try:
290                     b = self._put_queue.get()
291                     if b is None:
292                         return
293                     b._locator = self._keep.put(b.buffer_view[0:b.write_pointer].tobytes())
294                     b.state = BufferBlock.COMMITTED
295                     b.buffer_view = None
296                     b.buffer_block = None
297                 except Exception as e:
298                     print e
299                     self._put_errors.put(e)
300                 finally:
301                     if self._put_queue is not None:
302                         self._put_queue.task_done()
303
304         if self._put_threads is None:
305             self._put_queue = Queue.Queue(maxsize=2)
306             self._put_errors = Queue.Queue()
307             self._put_threads = [threading.Thread(target=worker, args=(self,)),
308                                  threading.Thread(target=worker, args=(self,))]
309             for t in self._put_threads:
310                 t.daemon = True
311                 t.start()
312
313         block.state = BufferBlock.PENDING
314         self._put_queue.put(block)
315
316     def get_block(self, locator, num_retries, cache_only=False):
317         if locator in self._bufferblocks:
318             bb = self._bufferblocks[locator]
319             if bb.state != BufferBlock.COMMITTED:
320                 return bb.buffer_view[0:bb.write_pointer].tobytes()
321             else:
322                 locator = bb._locator
323         return self._keep.get(locator, num_retries=num_retries, cache_only=cache_only)
324
325     def commit_all(self):
326         for k,v in self._bufferblocks.items():
327             if v.state == BufferBlock.WRITABLE:
328                 self.commit_bufferblock(v)
329         if self._put_queue is not None:
330             self._put_queue.join()
331             if not self._put_errors.empty():
332                 e = []
333                 try:
334                     while True:
335                         e.append(self._put_errors.get(False))
336                 except Queue.Empty:
337                     pass
338                 raise AsyncKeepWriteErrors(e)
339
340     def block_prefetch(self, locator):
341         def worker(self):
342             while True:
343                 try:
344                     b = self._prefetch_queue.get()
345                     if b is None:
346                         return
347                     self._keep.get(b)
348                 except:
349                     pass
350
351         if locator in self._bufferblocks:
352             return
353         if self._prefetch_threads is None:
354             self._prefetch_queue = Queue.Queue()
355             self._prefetch_threads = [threading.Thread(target=worker, args=(self,)),
356                                       threading.Thread(target=worker, args=(self,))]
357             for t in self._prefetch_threads:
358                 t.daemon = True
359                 t.start()
360         self._prefetch_queue.put(locator)
361
362 class ArvadosFile(object):
363     def __init__(self, parent, stream=[], segments=[]):
364         '''
365         stream: a list of Range objects representing a block stream
366         segments: a list of Range objects representing segments
367         '''
368         self.parent = parent
369         self._modified = True
370         self.segments = []
371         for s in segments:
372             self.add_segment(stream, s.locator, s.range_size)
373         self._current_bblock = None
374
375     def set_unmodified(self):
376         self._modified = False
377
378     def modified(self):
379         return self._modified
380
381     def truncate(self, size):
382         new_segs = []
383         for r in self.segments:
384             range_end = r.range_start+r.range_size
385             if r.range_start >= size:
386                 # segment is past the trucate size, all done
387                 break
388             elif size < range_end:
389                 nr = Range(r.locator, r.range_start, size - r.range_start)
390                 nr.segment_offset = r.segment_offset
391                 new_segs.append(nr)
392                 break
393             else:
394                 new_segs.append(r)
395
396         self.segments = new_segs
397         self._modified = True
398
399     def readfrom(self, offset, size, num_retries):
400         if size == 0 or offset >= self.size():
401             return ''
402         data = []
403
404         for lr in locators_and_ranges(self.segments, offset, size + config.KEEP_BLOCK_SIZE):
405             self.parent._my_block_manager().block_prefetch(lr.locator)
406
407         for lr in locators_and_ranges(self.segments, offset, size):
408             d = self.parent._my_block_manager().get_block(lr.locator, num_retries=num_retries, cache_only=bool(data))
409             if d:
410                 data.append(d[lr.segment_offset:lr.segment_offset+lr.segment_size])
411             else:
412                 break
413         return ''.join(data)
414
415     def _repack_writes(self):
416         '''Test if the buffer block has more data than is referenced by actual segments
417         (this happens when a buffered write over-writes a file range written in
418         a previous buffered write).  Re-pack the buffer block for efficiency
419         and to avoid leaking information.
420         '''
421         segs = self.segments
422
423         # Sum up the segments to get the total bytes of the file referencing
424         # into the buffer block.
425         bufferblock_segs = [s for s in segs if s.locator == self._current_bblock.blockid]
426         write_total = sum([s.range_size for s in bufferblock_segs])
427
428         if write_total < self._current_bblock.size():
429             # There is more data in the buffer block than is actually accounted for by segments, so
430             # re-pack into a new buffer by copying over to a new buffer block.
431             new_bb = self.parent._my_block_manager().alloc_bufferblock(self._current_bblock.blockid, starting_size=write_total)
432             for t in bufferblock_segs:
433                 new_bb.append(self._current_bblock.buffer_view[t.segment_offset:t.segment_offset+t.range_size].tobytes())
434                 t.segment_offset = new_bb.size() - t.range_size
435
436             self._current_bblock = new_bb
437
438     def writeto(self, offset, data, num_retries):
439         if len(data) == 0:
440             return
441
442         if offset > self.size():
443             raise ArgumentError("Offset is past the end of the file")
444
445         if len(data) > config.KEEP_BLOCK_SIZE:
446             raise ArgumentError("Please append data in chunks smaller than %i bytes (config.KEEP_BLOCK_SIZE)" % (config.KEEP_BLOCK_SIZE))
447
448         self._modified = True
449
450         if self._current_bblock is None or self._current_bblock.state != BufferBlock.WRITABLE:
451             self._current_bblock = self.parent._my_block_manager().alloc_bufferblock()
452
453         if (self._current_bblock.size() + len(data)) > config.KEEP_BLOCK_SIZE:
454             self._repack_writes()
455             if (self._current_bblock.size() + len(data)) > config.KEEP_BLOCK_SIZE:
456                 self.parent._my_block_manager().commit_bufferblock(self._current_bblock)
457                 self._current_bblock = self.parent._my_block_manager().alloc_bufferblock()
458
459         self._current_bblock.append(data)
460         replace_range(self.segments, offset, len(data), self._current_bblock.blockid, self._current_bblock.write_pointer - len(data))
461
462     def add_segment(self, blocks, pos, size):
463         self._modified = True
464         for lr in locators_and_ranges(blocks, pos, size):
465             last = self.segments[-1] if self.segments else Range(0, 0, 0)
466             r = Range(lr.locator, last.range_start+last.range_size, lr.segment_size, lr.segment_offset)
467             self.segments.append(r)
468
469     def size(self):
470         if self.segments:
471             n = self.segments[-1]
472             return n.range_start + n.range_size
473         else:
474             return 0
475
476
477 class ArvadosFileReader(ArvadosFileReaderBase):
478     def __init__(self, arvadosfile, name, mode="r", num_retries=None):
479         super(ArvadosFileReader, self).__init__(name, mode, num_retries=num_retries)
480         self.arvadosfile = arvadosfile
481
482     def size(self):
483         return self.arvadosfile.size()
484
485     @ArvadosFileBase._before_close
486     @retry_method
487     def read(self, size, num_retries=None):
488         """Read up to 'size' bytes from the stream, starting at the current file position"""
489         data = self.arvadosfile.readfrom(self._filepos, size, num_retries=num_retries)
490         self._filepos += len(data)
491         return data
492
493     @ArvadosFileBase._before_close
494     @retry_method
495     def readfrom(self, offset, size, num_retries=None):
496         """Read up to 'size' bytes from the stream, starting at the current file position"""
497         return self.arvadosfile.readfrom(offset, size, num_retries)
498
499     def flush(self):
500         pass
501
502 class ArvadosFileWriter(ArvadosFileReader):
503     def __init__(self, arvadosfile, name, mode, num_retries=None):
504         super(ArvadosFileWriter, self).__init__(arvadosfile, name, mode, num_retries=num_retries)
505
506     @ArvadosFileBase._before_close
507     @retry_method
508     def write(self, data, num_retries=None):
509         if self.mode[0] == "a":
510             self.arvadosfile.writeto(self.size(), data)
511         else:
512             self.arvadosfile.writeto(self._filepos, data, num_retries)
513             self._filepos += len(data)
514
515     @ArvadosFileBase._before_close
516     @retry_method
517     def writelines(self, seq, num_retries=None):
518         for s in seq:
519             self.write(s)
520
521     def truncate(self, size=None):
522         if size is None:
523             size = self._filepos
524         self.arvadosfile.truncate(size)
525         if self._filepos > self.size():
526             self._filepos = self.size()