3198: Many tests. Fixed lots of bugs.
[arvados.git] / sdk / python / arvados / arvfile.py
1 import functools
2 import os
3 import zlib
4 import bz2
5 from .ranges import *
6 from arvados.retry import retry_method
7 import config
8 import hashlib
9 import hashlib
10 import threading
11 import Queue
12
13 def split(path):
14     """split(path) -> streamname, filename
15
16     Separate the stream name and file name in a /-separated stream path.
17     If no stream name is available, assume '.'.
18     """
19     try:
20         stream_name, file_name = path.rsplit('/', 1)
21     except ValueError:  # No / in string
22         stream_name, file_name = '.', path
23     return stream_name, file_name
24
25 class ArvadosFileBase(object):
26     def __init__(self, name, mode):
27         self.name = name
28         self.mode = mode
29         self.closed = False
30
31     @staticmethod
32     def _before_close(orig_func):
33         @functools.wraps(orig_func)
34         def wrapper(self, *args, **kwargs):
35             if self.closed:
36                 raise ValueError("I/O operation on closed stream file")
37             return orig_func(self, *args, **kwargs)
38         return wrapper
39
40     def __enter__(self):
41         return self
42
43     def __exit__(self, exc_type, exc_value, traceback):
44         try:
45             self.close()
46         except Exception:
47             if exc_type is None:
48                 raise
49
50     def close(self):
51         self.closed = True
52
53
54 class ArvadosFileReaderBase(ArvadosFileBase):
55     class _NameAttribute(str):
56         # The Python file API provides a plain .name attribute.
57         # Older SDK provided a name() method.
58         # This class provides both, for maximum compatibility.
59         def __call__(self):
60             return self
61
62     def __init__(self, name, mode, num_retries=None):
63         super(ArvadosFileReaderBase, self).__init__(self._NameAttribute(name), mode)
64         self._filepos = 0L
65         self.num_retries = num_retries
66         self._readline_cache = (None, None)
67
68     def __iter__(self):
69         while True:
70             data = self.readline()
71             if not data:
72                 break
73             yield data
74
75     def decompressed_name(self):
76         return re.sub('\.(bz2|gz)$', '', self.name)
77
78     @ArvadosFileBase._before_close
79     def seek(self, pos, whence=os.SEEK_CUR):
80         if whence == os.SEEK_CUR:
81             pos += self._filepos
82         elif whence == os.SEEK_END:
83             pos += self.size()
84         self._filepos = min(max(pos, 0L), self.size())
85
86     def tell(self):
87         return self._filepos
88
89     @ArvadosFileBase._before_close
90     @retry_method
91     def readall(self, size=2**20, num_retries=None):
92         while True:
93             data = self.read(size, num_retries=num_retries)
94             if data == '':
95                 break
96             yield data
97
98     @ArvadosFileBase._before_close
99     @retry_method
100     def readline(self, size=float('inf'), num_retries=None):
101         cache_pos, cache_data = self._readline_cache
102         if self.tell() == cache_pos:
103             data = [cache_data]
104         else:
105             data = ['']
106         data_size = len(data[-1])
107         while (data_size < size) and ('\n' not in data[-1]):
108             next_read = self.read(2 ** 20, num_retries=num_retries)
109             if not next_read:
110                 break
111             data.append(next_read)
112             data_size += len(next_read)
113         data = ''.join(data)
114         try:
115             nextline_index = data.index('\n') + 1
116         except ValueError:
117             nextline_index = len(data)
118         nextline_index = min(nextline_index, size)
119         self._readline_cache = (self.tell(), data[nextline_index:])
120         return data[:nextline_index]
121
122     @ArvadosFileBase._before_close
123     @retry_method
124     def decompress(self, decompress, size, num_retries=None):
125         for segment in self.readall(size, num_retries):
126             data = decompress(segment)
127             if data:
128                 yield data
129
130     @ArvadosFileBase._before_close
131     @retry_method
132     def readall_decompressed(self, size=2**20, num_retries=None):
133         self.seek(0)
134         if self.name.endswith('.bz2'):
135             dc = bz2.BZ2Decompressor()
136             return self.decompress(dc.decompress, size,
137                                    num_retries=num_retries)
138         elif self.name.endswith('.gz'):
139             dc = zlib.decompressobj(16+zlib.MAX_WBITS)
140             return self.decompress(lambda segment: dc.decompress(dc.unconsumed_tail + segment),
141                                    size, num_retries=num_retries)
142         else:
143             return self.readall(size, num_retries=num_retries)
144
145     @ArvadosFileBase._before_close
146     @retry_method
147     def readlines(self, sizehint=float('inf'), num_retries=None):
148         data = []
149         data_size = 0
150         for s in self.readall(num_retries=num_retries):
151             data.append(s)
152             data_size += len(s)
153             if data_size >= sizehint:
154                 break
155         return ''.join(data).splitlines(True)
156
157
158 class StreamFileReader(ArvadosFileReaderBase):
159     def __init__(self, stream, segments, name):
160         super(StreamFileReader, self).__init__(name, 'rb', num_retries=stream.num_retries)
161         self._stream = stream
162         self.segments = segments
163
164     def stream_name(self):
165         return self._stream.name()
166
167     def size(self):
168         n = self.segments[-1]
169         return n.range_start + n.range_size
170
171     @ArvadosFileBase._before_close
172     @retry_method
173     def read(self, size, num_retries=None):
174         """Read up to 'size' bytes from the stream, starting at the current file position"""
175         if size == 0:
176             return ''
177
178         data = ''
179         available_chunks = locators_and_ranges(self.segments, self._filepos, size)
180         if available_chunks:
181             lr = available_chunks[0]
182             data = self._stream._readfrom(lr.locator+lr.segment_offset,
183                                           lr.segment_size,
184                                           num_retries=num_retries)
185
186         self._filepos += len(data)
187         return data
188
189     @ArvadosFileBase._before_close
190     @retry_method
191     def readfrom(self, start, size, num_retries=None):
192         """Read up to 'size' bytes from the stream, starting at 'start'"""
193         if size == 0:
194             return ''
195
196         data = []
197         for lr in locators_and_ranges(self.segments, start, size):
198             data.append(self._stream._readfrom(lr.locator+lr.segment_offset, lr.segment_size,
199                                               num_retries=num_retries))
200         return ''.join(data)
201
202     def as_manifest(self):
203         from stream import normalize_stream
204         segs = []
205         for r in self.segments:
206             segs.extend(self._stream.locators_and_ranges(r.locator, r.range_size))
207         return " ".join(normalize_stream(".", {self.name: segs})) + "\n"
208
209
210 class BufferBlock(object):
211     WRITABLE = 0
212     PENDING = 1
213     COMMITTED = 2
214
215     def __init__(self, blockid, starting_size):
216         self.blockid = blockid
217         self.buffer_block = bytearray(starting_size)
218         self.buffer_view = memoryview(self.buffer_block)
219         self.write_pointer = 0
220         self.state = BufferBlock.WRITABLE
221         self._locator = None
222
223     def append(self, data):
224         if self.state == BufferBlock.WRITABLE:
225             while (self.write_pointer+len(data)) > len(self.buffer_block):
226                 new_buffer_block = bytearray(len(self.buffer_block) * 2)
227                 new_buffer_block[0:self.write_pointer] = self.buffer_block[0:self.write_pointer]
228                 self.buffer_block = new_buffer_block
229                 self.buffer_view = memoryview(self.buffer_block)
230             self.buffer_view[self.write_pointer:self.write_pointer+len(data)] = data
231             self.write_pointer += len(data)
232             self._locator = None
233         else:
234             raise AssertionError("Buffer block is not writable")
235
236     def size(self):
237         return self.write_pointer
238
239     def locator(self):
240         if self._locator is None:
241             self._locator = "%s+%i" % (hashlib.md5(self.buffer_view[0:self.write_pointer]).hexdigest(), self.size())
242         return self._locator
243
244 class AsyncKeepWriteErrors(Exception):
245     def __init__(self, errors):
246         self.errors = errors
247
248 class BlockManager(object):
249     def __init__(self, keep):
250         self._keep = keep
251         self._bufferblocks = {}
252         self._put_queue = None
253         self._put_errors = None
254         self._put_threads = None
255         self._prefetch_queue = None
256         self._prefetch_threads = None
257
258     def alloc_bufferblock(self, blockid=None, starting_size=2**14):
259         if blockid is None:
260             blockid = "bufferblock%i" % len(self._bufferblocks)
261         bb = BufferBlock(blockid, starting_size=starting_size)
262         self._bufferblocks[bb.blockid] = bb
263         return bb
264
265     def stop_threads(self):
266         if self._put_threads is not None:
267             for t in self._put_threads:
268                 self._put_queue.put(None)
269             for t in self._put_threads:
270                 t.join()
271         self._put_threads = None
272         self._put_queue = None
273         self._put_errors = None
274
275         if self._prefetch_threads is not None:
276             for t in self._prefetch_threads:
277                 self._prefetch_queue.put(None)
278             for t in self._prefetch_threads:
279                 t.join()
280         self._prefetch_threads = None
281         self._prefetch_queue = None
282
283     def commit_bufferblock(self, block):
284         def worker(self):
285             while True:
286                 try:
287                     b = self._put_queue.get()
288                     if b is None:
289                         return
290                     b._locator = self._keep.put(b.buffer_view[0:b.write_pointer].tobytes())
291                     b.state = BufferBlock.COMMITTED
292                     b.buffer_view = None
293                     b.buffer_block = None
294                 except Exception as e:
295                     print e
296                     self._put_errors.put(e)
297                 finally:
298                     if self._put_queue is not None:
299                         self._put_queue.task_done()
300
301         if self._put_threads is None:
302             self._put_queue = Queue.Queue()
303             self._put_errors = Queue.Queue()
304             self._put_threads = [threading.Thread(target=worker, args=(self,)),
305                                 threading.Thread(target=worker, args=(self,))]
306             self._put_threads[0].start()
307             self._put_threads[1].start()
308
309         block.state = BufferBlock.PENDING
310         self._put_queue.put(block)
311
312     def get_block(self, locator, num_retries):
313         if locator in self._bufferblocks:
314             bb = self._bufferblocks[locator]
315             if bb.state != BufferBlock.COMMITTED:
316                 return bb.buffer_view[0:bb.write_pointer].tobytes()
317             else:
318                 locator = bb._locator
319         return self._keep.get(locator, num_retries=num_retries)
320
321     def commit_all(self):
322         for k,v in self._bufferblocks.items():
323             if v.state == BufferBlock.WRITABLE:
324                 self.commit_bufferblock(v)
325         if self._put_queue is not None:
326             self._put_queue.join()
327             if not self._put_errors.empty():
328                 e = []
329                 try:
330                     while True:
331                         e.append(self._put_errors.get(False))
332                 except Queue.Empty:
333                     pass
334                 raise AsyncKeepWriteErrors(e)
335
336     def block_prefetch(self, locator):
337         def worker(self):
338             while True:
339                 try:
340                     b = self._prefetch_queue.get()
341                     if b is None:
342                         return
343                     self._keep.get(b)
344                 except:
345                     pass
346
347         if locator in self._bufferblocks:
348             return
349         if self._prefetch_threads is None:
350             self._prefetch_queue = Queue.Queue()
351             self._prefetch_threads = [threading.Thread(target=worker, args=(self,))]
352             self._prefetch_threads[0].start()
353         self._prefetch_queue.put(locator)
354
355 class ArvadosFile(object):
356     def __init__(self, parent, stream=[], segments=[], keep=None):
357         '''
358         stream: a list of Range objects representing a block stream
359         segments: a list of Range objects representing segments
360         '''
361         self.parent = parent
362         self._modified = True
363         self._segments = []
364         for s in segments:
365             self.add_segment(stream, s.range_start, s.range_size)
366         self._current_bblock = None
367         self._keep = keep
368
369     def set_unmodified(self):
370         self._modified = False
371
372     def modified(self):
373         return self._modified
374
375     def truncate(self, size):
376         new_segs = []
377         for r in self._segments:
378             range_end = r.range_start+r.range_size
379             if r.range_start >= size:
380                 # segment is past the trucate size, all done
381                 break
382             elif size < range_end:
383                 nr = Range(r.locator, r.range_start, size - r.range_start)
384                 nr.segment_offset = r.segment_offset
385                 new_segs.append(nr)
386                 break
387             else:
388                 new_segs.append(r)
389
390         self._segments = new_segs
391         self._modified = True
392
393     def readfrom(self, offset, size, num_retries):
394         if size == 0 or offset >= self.size():
395             return ''
396         if self._keep is None:
397             self._keep = KeepClient(num_retries=num_retries)
398         data = []
399
400         for lr in locators_and_ranges(self._segments, offset, size + config.KEEP_BLOCK_SIZE):
401             self.parent._my_block_manager().block_prefetch(lr.locator)
402
403         for lr in locators_and_ranges(self._segments, offset, size):
404             # TODO: if data is empty, wait on block get, otherwise only
405             # get more data if the block is already in the cache.
406             data.append(self.parent._my_block_manager().get_block(lr.locator, num_retries=num_retries)[lr.segment_offset:lr.segment_offset+lr.segment_size])
407         return ''.join(data)
408
409     def _repack_writes(self):
410         '''Test if the buffer block has more data than is referenced by actual segments
411         (this happens when a buffered write over-writes a file range written in
412         a previous buffered write).  Re-pack the buffer block for efficiency
413         and to avoid leaking information.
414         '''
415         segs = self._segments
416
417         # Sum up the segments to get the total bytes of the file referencing
418         # into the buffer block.
419         bufferblock_segs = [s for s in segs if s.locator == self._current_bblock.blockid]
420         write_total = sum([s.range_size for s in bufferblock_segs])
421
422         if write_total < self._current_bblock.size():
423             # There is more data in the buffer block than is actually accounted for by segments, so
424             # re-pack into a new buffer by copying over to a new buffer block.
425             new_bb = self.parent._my_block_manager().alloc_bufferblock(self._current_bblock.blockid, starting_size=write_total)
426             for t in bufferblock_segs:
427                 new_bb.append(self._current_bblock.buffer_view[t.segment_offset:t.segment_offset+t.range_size].tobytes())
428                 t.segment_offset = new_bb.size() - t.range_size
429
430             self._current_bblock = new_bb
431
432     def writeto(self, offset, data, num_retries):
433         if len(data) == 0:
434             return
435
436         if offset > self.size():
437             raise ArgumentError("Offset is past the end of the file")
438
439         if len(data) > config.KEEP_BLOCK_SIZE:
440             raise ArgumentError("Please append data in chunks smaller than %i bytes (config.KEEP_BLOCK_SIZE)" % (config.KEEP_BLOCK_SIZE))
441
442         self._modified = True
443
444         if self._current_bblock is None or self._current_bblock.state != BufferBlock.WRITABLE:
445             self._current_bblock = self.parent._my_block_manager().alloc_bufferblock()
446
447         if (self._current_bblock.size() + len(data)) > config.KEEP_BLOCK_SIZE:
448             self._repack_writes()
449             if (self._current_bblock.size() + len(data)) > config.KEEP_BLOCK_SIZE:
450                 self.parent._my_block_manager().commit_bufferblock(self._current_bblock)
451                 self._current_bblock = self.parent._my_block_manager().alloc_bufferblock()
452
453         self._current_bblock.append(data)
454         replace_range(self._segments, offset, len(data), self._current_bblock.blockid, self._current_bblock.write_pointer - len(data))
455
456     def add_segment(self, blocks, pos, size):
457         self._modified = True
458         for lr in locators_and_ranges(blocks, pos, size):
459             last = self._segments[-1] if self._segments else Range(0, 0, 0)
460             r = Range(lr.locator, last.range_start+last.range_size, lr.segment_size, lr.segment_offset)
461             self._segments.append(r)
462
463     def size(self):
464         if self._segments:
465             n = self._segments[-1]
466             return n.range_start + n.range_size
467         else:
468             return 0
469
470
471 class ArvadosFileReader(ArvadosFileReaderBase):
472     def __init__(self, arvadosfile, name, mode="r", num_retries=None):
473         super(ArvadosFileReader, self).__init__(name, mode, num_retries=num_retries)
474         self.arvadosfile = arvadosfile
475
476     def size(self):
477         return self.arvadosfile.size()
478
479     @ArvadosFileBase._before_close
480     @retry_method
481     def read(self, size, num_retries=None):
482         """Read up to 'size' bytes from the stream, starting at the current file position"""
483         data = self.arvadosfile.readfrom(self._filepos, size, num_retries=num_retries)
484         self._filepos += len(data)
485         return data
486
487     @ArvadosFileBase._before_close
488     @retry_method
489     def readfrom(self, offset, size, num_retries=None):
490         """Read up to 'size' bytes from the stream, starting at the current file position"""
491         return self.arvadosfile.readfrom(offset, size, num_retries)
492
493     def flush(self):
494         pass
495
496 class ArvadosFileWriter(ArvadosFileReader):
497     def __init__(self, arvadosfile, name, mode, num_retries=None):
498         super(ArvadosFileWriter, self).__init__(arvadosfile, name, mode, num_retries=num_retries)
499
500     @ArvadosFileBase._before_close
501     @retry_method
502     def write(self, data, num_retries=None):
503         if self.mode[0] == "a":
504             self.arvadosfile.writeto(self.size(), data)
505         else:
506             self.arvadosfile.writeto(self._filepos, data, num_retries)
507             self._filepos += len(data)
508
509     @ArvadosFileBase._before_close
510     @retry_method
511     def writelines(self, seq, num_retries=None):
512         for s in seq:
513             self.write(s)
514
515     def truncate(self, size=None):
516         if size is None:
517             size = self._filepos
518         self.arvadosfile.truncate(size)
519         if self._filepos > self.size():
520             self._filepos = self.size()