3663: update test_file_reader unit test
[arvados.git] / sdk / python / arvados / collection.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from collections import deque
22 from stat import *
23
24 from keep import *
25 from stream import *
26 import config
27 import errors
28 import util
29
30 _logger = logging.getLogger('arvados.collection')
31
32 def normalize_stream(s, stream):
33     stream_tokens = [s]
34     sortedfiles = list(stream.keys())
35     sortedfiles.sort()
36
37     blocks = {}
38     streamoffset = 0L
39     for f in sortedfiles:
40         for b in stream[f]:
41             if b[arvados.LOCATOR] not in blocks:
42                 stream_tokens.append(b[arvados.LOCATOR])
43                 blocks[b[arvados.LOCATOR]] = streamoffset
44                 streamoffset += b[arvados.BLOCKSIZE]
45
46     if len(stream_tokens) == 1:
47         stream_tokens.append(config.EMPTY_BLOCK_LOCATOR)
48
49     for f in sortedfiles:
50         current_span = None
51         fout = f.replace(' ', '\\040')
52         for segment in stream[f]:
53             segmentoffset = blocks[segment[arvados.LOCATOR]] + segment[arvados.OFFSET]
54             if current_span == None:
55                 current_span = [segmentoffset, segmentoffset + segment[arvados.SEGMENTSIZE]]
56             else:
57                 if segmentoffset == current_span[1]:
58                     current_span[1] += segment[arvados.SEGMENTSIZE]
59                 else:
60                     stream_tokens.append("{0}:{1}:{2}".format(current_span[0], current_span[1] - current_span[0], fout))
61                     current_span = [segmentoffset, segmentoffset + segment[arvados.SEGMENTSIZE]]
62
63         if current_span != None:
64             stream_tokens.append("{0}:{1}:{2}".format(current_span[0], current_span[1] - current_span[0], fout))
65
66         if len(stream[f]) == 0:
67             stream_tokens.append("0:0:{0}".format(fout))
68
69     return stream_tokens
70
71 def normalize(collection):
72     streams = {}
73     for s in collection.all_streams():
74         for f in s.all_files():
75             filestream = s.name() + "/" + f.name()
76             r = filestream.rindex("/")
77             streamname = filestream[:r]
78             filename = filestream[r+1:]
79             if streamname not in streams:
80                 streams[streamname] = {}
81             if filename not in streams[streamname]:
82                 streams[streamname][filename] = []
83             for r in f.segments:
84                 streams[streamname][filename].extend(s.locators_and_ranges(r[0], r[1]))
85
86     normalized_streams = []
87     sortedstreams = list(streams.keys())
88     sortedstreams.sort()
89     for s in sortedstreams:
90         normalized_streams.append(normalize_stream(s, streams[s]))
91     return normalized_streams
92
93
94 class CollectionReader(object):
95     def __init__(self, manifest_locator_or_text, api_client=None):
96         self._api_client = api_client
97         self._keep_client = None
98         if re.match(r'[a-f0-9]{32}(\+\d+)?(\+\S+)*$', manifest_locator_or_text):
99             self._manifest_locator = manifest_locator_or_text
100             self._manifest_text = None
101         elif re.match(r'[a-z0-9]{5}-[a-z0-9]{5}-[a-z0-9]{15}$', manifest_locator_or_text):
102             self._manifest_locator = manifest_locator_or_text
103             self._manifest_text = None
104         elif re.match(r'(\S+)( [a-f0-9]{32}(\+\d+)(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
105             self._manifest_text = manifest_locator_or_text
106             self._manifest_locator = None
107         else:
108             raise errors.ArgumentError(
109                 "Argument to CollectionReader must be a manifest or a collection UUID")
110         self._streams = None
111
112     def __enter__(self):
113         pass
114
115     def __exit__(self):
116         pass
117
118     def _populate(self):
119         if self._streams is not None:
120             return
121         if not self._manifest_text:
122             try:
123                 # As in KeepClient itself, we must wait until the last possible
124                 # moment to instantiate an API client, in order to avoid
125                 # tripping up clients that don't have access to an API server.
126                 # If we do build one, make sure our Keep client uses it.
127                 # If instantiation fails, we'll fall back to the except clause,
128                 # just like any other Collection lookup failure.
129                 if self._api_client is None:
130                     self._api_client = arvados.api('v1')
131                     self._keep_client = KeepClient(api_client=self._api_client)
132                 if self._keep_client is None:
133                     self._keep_client = KeepClient(api_client=self._api_client)
134                 c = self._api_client.collections().get(
135                     uuid=self._manifest_locator).execute()
136                 self._manifest_text = c['manifest_text']
137             except Exception as e:
138                 _logger.warning("API lookup failed for collection %s (%s: %s)",
139                                 self._manifest_locator, type(e), str(e))
140                 if self._keep_client is None:
141                     self._keep_client = KeepClient(api_client=self._api_client)
142                 self._manifest_text = self._keep_client.get(self._manifest_locator)
143         self._streams = []
144         for stream_line in self._manifest_text.split("\n"):
145             if stream_line != '':
146                 stream_tokens = stream_line.split()
147                 self._streams += [stream_tokens]
148         self._streams = normalize(self)
149
150         # now regenerate the manifest text based on the normalized stream
151
152         #print "normalizing", self._manifest_text
153         self._manifest_text = ''.join([StreamReader(stream).manifest_text() for stream in self._streams])
154         #print "result", self._manifest_text
155
156
157     def all_streams(self):
158         self._populate()
159         resp = []
160         for s in self._streams:
161             resp.append(StreamReader(s))
162         return resp
163
164     def all_files(self):
165         for s in self.all_streams():
166             for f in s.all_files():
167                 yield f
168
169     def manifest_text(self, strip=False):
170         self._populate()
171         if strip:
172             m = ''.join([StreamReader(stream).manifest_text(strip=True) for stream in self._streams])
173             return m
174         else:
175             return self._manifest_text
176
177 class CollectionWriter(object):
178     KEEP_BLOCK_SIZE = 2**26
179
180     def __init__(self, api_client=None):
181         self._api_client = api_client
182         self._keep_client = None
183         self._data_buffer = []
184         self._data_buffer_len = 0
185         self._current_stream_files = []
186         self._current_stream_length = 0
187         self._current_stream_locators = []
188         self._current_stream_name = '.'
189         self._current_file_name = None
190         self._current_file_pos = 0
191         self._finished_streams = []
192         self._close_file = None
193         self._queued_file = None
194         self._queued_dirents = deque()
195         self._queued_trees = deque()
196
197     def __enter__(self):
198         pass
199
200     def __exit__(self):
201         self.finish()
202
203     def _prep_keep_client(self):
204         if self._keep_client is None:
205             self._keep_client = KeepClient(api_client=self._api_client)
206
207     def do_queued_work(self):
208         # The work queue consists of three pieces:
209         # * _queued_file: The file object we're currently writing to the
210         #   Collection.
211         # * _queued_dirents: Entries under the current directory
212         #   (_queued_trees[0]) that we want to write or recurse through.
213         #   This may contain files from subdirectories if
214         #   max_manifest_depth == 0 for this directory.
215         # * _queued_trees: Directories that should be written as separate
216         #   streams to the Collection.
217         # This function handles the smallest piece of work currently queued
218         # (current file, then current directory, then next directory) until
219         # no work remains.  The _work_THING methods each do a unit of work on
220         # THING.  _queue_THING methods add a THING to the work queue.
221         while True:
222             if self._queued_file:
223                 self._work_file()
224             elif self._queued_dirents:
225                 self._work_dirents()
226             elif self._queued_trees:
227                 self._work_trees()
228             else:
229                 break
230
231     def _work_file(self):
232         while True:
233             buf = self._queued_file.read(self.KEEP_BLOCK_SIZE)
234             if not buf:
235                 break
236             self.write(buf)
237         self.finish_current_file()
238         if self._close_file:
239             self._queued_file.close()
240         self._close_file = None
241         self._queued_file = None
242
243     def _work_dirents(self):
244         path, stream_name, max_manifest_depth = self._queued_trees[0]
245         if stream_name != self.current_stream_name():
246             self.start_new_stream(stream_name)
247         while self._queued_dirents:
248             dirent = self._queued_dirents.popleft()
249             target = os.path.join(path, dirent)
250             if os.path.isdir(target):
251                 self._queue_tree(target,
252                                  os.path.join(stream_name, dirent),
253                                  max_manifest_depth - 1)
254             else:
255                 self._queue_file(target, dirent)
256                 break
257         if not self._queued_dirents:
258             self._queued_trees.popleft()
259
260     def _work_trees(self):
261         path, stream_name, max_manifest_depth = self._queued_trees[0]
262         make_dirents = (util.listdir_recursive if (max_manifest_depth == 0)
263                         else os.listdir)
264         d = make_dirents(path)
265         if len(d) > 0:
266             self._queue_dirents(stream_name, d)
267         else:
268             self._queued_trees.popleft()
269
270     def _queue_file(self, source, filename=None):
271         assert (self._queued_file is None), "tried to queue more than one file"
272         if not hasattr(source, 'read'):
273             source = open(source, 'rb')
274             self._close_file = True
275         else:
276             self._close_file = False
277         if filename is None:
278             filename = os.path.basename(source.name)
279         self.start_new_file(filename)
280         self._queued_file = source
281
282     def _queue_dirents(self, stream_name, dirents):
283         assert (not self._queued_dirents), "tried to queue more than one tree"
284         self._queued_dirents = deque(sorted(dirents))
285
286     def _queue_tree(self, path, stream_name, max_manifest_depth):
287         self._queued_trees.append((path, stream_name, max_manifest_depth))
288
289     def write_file(self, source, filename=None):
290         self._queue_file(source, filename)
291         self.do_queued_work()
292
293     def write_directory_tree(self,
294                              path, stream_name='.', max_manifest_depth=-1):
295         self._queue_tree(path, stream_name, max_manifest_depth)
296         self.do_queued_work()
297
298     def write(self, newdata):
299         if hasattr(newdata, '__iter__'):
300             for s in newdata:
301                 self.write(s)
302             return
303         self._data_buffer += [newdata]
304         self._data_buffer_len += len(newdata)
305         self._current_stream_length += len(newdata)
306         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
307             self.flush_data()
308
309     def flush_data(self):
310         data_buffer = ''.join(self._data_buffer)
311         if data_buffer != '':
312             self._prep_keep_client()
313             self._current_stream_locators.append(
314                 self._keep_client.put(data_buffer[0:self.KEEP_BLOCK_SIZE]))
315             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
316             self._data_buffer_len = len(self._data_buffer[0])
317
318     def start_new_file(self, newfilename=None):
319         self.finish_current_file()
320         self.set_current_file_name(newfilename)
321
322     def set_current_file_name(self, newfilename):
323         if re.search(r'[\t\n]', newfilename):
324             raise errors.AssertionError(
325                 "Manifest filenames cannot contain whitespace: %s" %
326                 newfilename)
327         self._current_file_name = newfilename
328
329     def current_file_name(self):
330         return self._current_file_name
331
332     def finish_current_file(self):
333         if self._current_file_name == None:
334             if self._current_file_pos == self._current_stream_length:
335                 return
336             raise errors.AssertionError(
337                 "Cannot finish an unnamed file " +
338                 "(%d bytes at offset %d in '%s' stream)" %
339                 (self._current_stream_length - self._current_file_pos,
340                  self._current_file_pos,
341                  self._current_stream_name))
342         self._current_stream_files += [[self._current_file_pos,
343                                         self._current_stream_length - self._current_file_pos,
344                                         self._current_file_name]]
345         self._current_file_pos = self._current_stream_length
346
347     def start_new_stream(self, newstreamname='.'):
348         self.finish_current_stream()
349         self.set_current_stream_name(newstreamname)
350
351     def set_current_stream_name(self, newstreamname):
352         if re.search(r'[\t\n]', newstreamname):
353             raise errors.AssertionError(
354                 "Manifest stream names cannot contain whitespace")
355         self._current_stream_name = '.' if newstreamname=='' else newstreamname
356
357     def current_stream_name(self):
358         return self._current_stream_name
359
360     def finish_current_stream(self):
361         self.finish_current_file()
362         self.flush_data()
363         if len(self._current_stream_files) == 0:
364             pass
365         elif self._current_stream_name == None:
366             raise errors.AssertionError(
367                 "Cannot finish an unnamed stream (%d bytes in %d files)" %
368                 (self._current_stream_length, len(self._current_stream_files)))
369         else:
370             if len(self._current_stream_locators) == 0:
371                 self._current_stream_locators += [config.EMPTY_BLOCK_LOCATOR]
372             self._finished_streams += [[self._current_stream_name,
373                                         self._current_stream_locators,
374                                         self._current_stream_files]]
375         self._current_stream_files = []
376         self._current_stream_length = 0
377         self._current_stream_locators = []
378         self._current_stream_name = None
379         self._current_file_pos = 0
380         self._current_file_name = None
381
382     def finish(self):
383         # Store the manifest in Keep and return its locator.
384         self._prep_keep_client()
385         return self._keep_client.put(self.manifest_text())
386
387     def stripped_manifest(self):
388         """
389         Return the manifest for the current collection with all permission
390         hints removed from the locators in the manifest.
391         """
392         raw = self.manifest_text()
393         clean = ''
394         for line in raw.split("\n"):
395             fields = line.split()
396             if len(fields) > 0:
397                 locators = [ re.sub(r'\+A[a-z0-9@_-]+', '', x)
398                              for x in fields[1:-1] ]
399                 clean += fields[0] + ' ' + ' '.join(locators) + ' ' + fields[-1] + "\n"
400         return clean
401
402     def manifest_text(self):
403         self.finish_current_stream()
404         manifest = ''
405
406         for stream in self._finished_streams:
407             if not re.search(r'^\.(/.*)?$', stream[0]):
408                 manifest += './'
409             manifest += stream[0].replace(' ', '\\040')
410             manifest += ' ' + ' '.join(stream[1])
411             manifest += ' ' + ' '.join("%d:%d:%s" % (sfile[0], sfile[1], sfile[2].replace(' ', '\\040')) for sfile in stream[2])
412             manifest += "\n"
413
414         if len(manifest) > 0:
415             return CollectionReader(manifest).manifest_text()
416         else:
417             return ""
418
419     def data_locators(self):
420         ret = []
421         for name, locators, files in self._finished_streams:
422             ret += locators
423         return ret
424
425
426 class ResumableCollectionWriter(CollectionWriter):
427     STATE_PROPS = ['_current_stream_files', '_current_stream_length',
428                    '_current_stream_locators', '_current_stream_name',
429                    '_current_file_name', '_current_file_pos', '_close_file',
430                    '_data_buffer', '_dependencies', '_finished_streams',
431                    '_queued_dirents', '_queued_trees']
432
433     def __init__(self, api_client=None):
434         self._dependencies = {}
435         super(ResumableCollectionWriter, self).__init__(api_client)
436
437     @classmethod
438     def from_state(cls, state, *init_args, **init_kwargs):
439         # Try to build a new writer from scratch with the given state.
440         # If the state is not suitable to resume (because files have changed,
441         # been deleted, aren't predictable, etc.), raise a
442         # StaleWriterStateError.  Otherwise, return the initialized writer.
443         # The caller is responsible for calling writer.do_queued_work()
444         # appropriately after it's returned.
445         writer = cls(*init_args, **init_kwargs)
446         for attr_name in cls.STATE_PROPS:
447             attr_value = state[attr_name]
448             attr_class = getattr(writer, attr_name).__class__
449             # Coerce the value into the same type as the initial value, if
450             # needed.
451             if attr_class not in (type(None), attr_value.__class__):
452                 attr_value = attr_class(attr_value)
453             setattr(writer, attr_name, attr_value)
454         # Check dependencies before we try to resume anything.
455         if any(KeepLocator(ls).permission_expired()
456                for ls in writer._current_stream_locators):
457             raise errors.StaleWriterStateError(
458                 "locators include expired permission hint")
459         writer.check_dependencies()
460         if state['_current_file'] is not None:
461             path, pos = state['_current_file']
462             try:
463                 writer._queued_file = open(path, 'rb')
464                 writer._queued_file.seek(pos)
465             except IOError as error:
466                 raise errors.StaleWriterStateError(
467                     "failed to reopen active file {}: {}".format(path, error))
468         return writer
469
470     def check_dependencies(self):
471         for path, orig_stat in self._dependencies.items():
472             if not S_ISREG(orig_stat[ST_MODE]):
473                 raise errors.StaleWriterStateError("{} not file".format(path))
474             try:
475                 now_stat = tuple(os.stat(path))
476             except OSError as error:
477                 raise errors.StaleWriterStateError(
478                     "failed to stat {}: {}".format(path, error))
479             if ((not S_ISREG(now_stat[ST_MODE])) or
480                 (orig_stat[ST_MTIME] != now_stat[ST_MTIME]) or
481                 (orig_stat[ST_SIZE] != now_stat[ST_SIZE])):
482                 raise errors.StaleWriterStateError("{} changed".format(path))
483
484     def dump_state(self, copy_func=lambda x: x):
485         state = {attr: copy_func(getattr(self, attr))
486                  for attr in self.STATE_PROPS}
487         if self._queued_file is None:
488             state['_current_file'] = None
489         else:
490             state['_current_file'] = (os.path.realpath(self._queued_file.name),
491                                       self._queued_file.tell())
492         return state
493
494     def _queue_file(self, source, filename=None):
495         try:
496             src_path = os.path.realpath(source)
497         except Exception:
498             raise errors.AssertionError("{} not a file path".format(source))
499         try:
500             path_stat = os.stat(src_path)
501         except OSError as stat_error:
502             path_stat = None
503         super(ResumableCollectionWriter, self)._queue_file(source, filename)
504         fd_stat = os.fstat(self._queued_file.fileno())
505         if not S_ISREG(fd_stat.st_mode):
506             # We won't be able to resume from this cache anyway, so don't
507             # worry about further checks.
508             self._dependencies[source] = tuple(fd_stat)
509         elif path_stat is None:
510             raise errors.AssertionError(
511                 "could not stat {}: {}".format(source, stat_error))
512         elif path_stat.st_ino != fd_stat.st_ino:
513             raise errors.AssertionError(
514                 "{} changed between open and stat calls".format(source))
515         else:
516             self._dependencies[src_path] = tuple(fd_stat)
517
518     def write(self, data):
519         if self._queued_file is None:
520             raise errors.AssertionError(
521                 "resumable writer can't accept unsourced data")
522         return super(ResumableCollectionWriter, self).write(data)