maintain directory structure in collection_extract
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13 import string
14 import bz2
15 import zlib
16 import fcntl
17
18 from apiclient import errors
19 from apiclient.discovery import build
20
21 class CredentialsFromEnv:
22     @staticmethod
23     def http_request(self, uri, **kwargs):
24         from httplib import BadStatusLine
25         if 'headers' not in kwargs:
26             kwargs['headers'] = {}
27         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
28         try:
29             return self.orig_http_request(uri, **kwargs)
30         except BadStatusLine:
31             # This is how httplib tells us that it tried to reuse an
32             # existing connection but it was already closed by the
33             # server. In that case, yes, we would like to retry.
34             # Unfortunately, we are not absolutely certain that the
35             # previous call did not succeed, so this is slightly
36             # risky.
37             return self.orig_http_request(uri, **kwargs)
38     def authorize(self, http):
39         http.orig_http_request = http.request
40         http.request = types.MethodType(self.http_request, http)
41         return http
42
43 url = ('https://%s/discovery/v1/apis/'
44        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
45 credentials = CredentialsFromEnv()
46 http = httplib2.Http()
47 http = credentials.authorize(http)
48 http.disable_ssl_certificate_validation=True
49 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
50
51 def task_set_output(self,s):
52     service.job_tasks().update(uuid=self['uuid'],
53                                job_task=json.dumps({
54                 'output':s,
55                 'success':True,
56                 'progress':1.0
57                 })).execute()
58
59 _current_task = None
60 def current_task():
61     global _current_task
62     if _current_task:
63         return _current_task
64     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
65     t = UserDict.UserDict(t)
66     t.set_output = types.MethodType(task_set_output, t)
67     t.tmpdir = os.environ['TASK_WORK']
68     _current_task = t
69     return t
70
71 _current_job = None
72 def current_job():
73     global _current_job
74     if _current_job:
75         return _current_job
76     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
77     t = UserDict.UserDict(t)
78     t.tmpdir = os.environ['JOB_WORK']
79     _current_job = t
80     return t
81
82 def api():
83     return service
84
85 class JobTask:
86     def __init__(self, parameters=dict(), resource_limits=dict()):
87         print "init jobtask %s %s" % (parameters, resource_limits)
88
89 class job_setup:
90     @staticmethod
91     def one_task_per_input_file(if_sequence=0, and_end_task=True):
92         if if_sequence != current_task()['sequence']:
93             return
94         job_input = current_job()['script_parameters']['input']
95         cr = CollectionReader(job_input)
96         for s in cr.all_streams():
97             for f in s.all_files():
98                 task_input = f.as_manifest()
99                 new_task_attrs = {
100                     'job_uuid': current_job()['uuid'],
101                     'created_by_job_task_uuid': current_task()['uuid'],
102                     'sequence': if_sequence + 1,
103                     'parameters': {
104                         'input':task_input
105                         }
106                     }
107                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
108         if and_end_task:
109             service.job_tasks().update(uuid=current_task()['uuid'],
110                                        job_task=json.dumps({'success':True})
111                                        ).execute()
112             exit(0)
113
114     @staticmethod
115     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
116         if if_sequence != current_task()['sequence']:
117             return
118         job_input = current_job()['script_parameters']['input']
119         cr = CollectionReader(job_input)
120         for s in cr.all_streams():
121             task_input = s.tokens()
122             new_task_attrs = {
123                 'job_uuid': current_job()['uuid'],
124                 'created_by_job_task_uuid': current_task()['uuid'],
125                 'sequence': if_sequence + 1,
126                 'parameters': {
127                     'input':task_input
128                     }
129                 }
130             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
131         if and_end_task:
132             service.job_tasks().update(uuid=current_task()['uuid'],
133                                        job_task=json.dumps({'success':True})
134                                        ).execute()
135             exit(0)
136
137 class util:
138     @staticmethod
139     def run_command(execargs, **kwargs):
140         if 'stdin' not in kwargs:
141             kwargs['stdin'] = subprocess.PIPE
142         if 'stdout' not in kwargs:
143             kwargs['stdout'] = subprocess.PIPE
144         if 'stderr' not in kwargs:
145             kwargs['stderr'] = subprocess.PIPE
146         p = subprocess.Popen(execargs, close_fds=True, shell=False,
147                              **kwargs)
148         stdoutdata, stderrdata = p.communicate(None)
149         if p.returncode != 0:
150             raise Exception("run_command %s exit %d:\n%s" %
151                             (execargs, p.returncode, stderrdata))
152         return stdoutdata, stderrdata
153
154     @staticmethod
155     def git_checkout(url, version, path):
156         if not re.search('^/', path):
157             path = os.path.join(current_job().tmpdir, path)
158         if not os.path.exists(path):
159             util.run_command(["git", "clone", url, path],
160                              cwd=os.path.dirname(path))
161         util.run_command(["git", "checkout", version],
162                          cwd=path)
163         return path
164
165     @staticmethod
166     def tar_extractor(path, decompress_flag):
167         return subprocess.Popen(["tar",
168                                  "-C", path,
169                                  ("-x%sf" % decompress_flag),
170                                  "-"],
171                                 stdout=None,
172                                 stdin=subprocess.PIPE, stderr=sys.stderr,
173                                 shell=False, close_fds=True)
174
175     @staticmethod
176     def tarball_extract(tarball, path):
177         """Retrieve a tarball from Keep and extract it to a local
178         directory.  Return the absolute path where the tarball was
179         extracted. If the top level of the tarball contained just one
180         file or directory, return the absolute path of that single
181         item.
182
183         tarball -- collection locator
184         path -- where to extract the tarball: absolute, or relative to job tmp
185         """
186         if not re.search('^/', path):
187             path = os.path.join(current_job().tmpdir, path)
188         lockfile = open(path + '.lock', 'w')
189         fcntl.flock(lockfile, fcntl.LOCK_EX)
190         try:
191             os.stat(path)
192         except OSError:
193             os.mkdir(path)
194         already_have_it = False
195         try:
196             if os.readlink(os.path.join(path, '.locator')) == tarball:
197                 already_have_it = True
198         except OSError:
199             pass
200         if not already_have_it:
201
202             # emulate "rm -f" (i.e., if the file does not exist, we win)
203             try:
204                 os.unlink(os.path.join(path, '.locator'))
205             except OSError:
206                 if os.path.exists(os.path.join(path, '.locator')):
207                     os.unlink(os.path.join(path, '.locator'))
208
209             for f in CollectionReader(tarball).all_files():
210                 if re.search('\.(tbz|tar.bz2)$', f.name()):
211                     p = util.tar_extractor(path, 'j')
212                 elif re.search('\.(tgz|tar.gz)$', f.name()):
213                     p = util.tar_extractor(path, 'z')
214                 elif re.search('\.tar$', f.name()):
215                     p = util.tar_extractor(path, '')
216                 else:
217                     raise Exception("tarball_extract cannot handle filename %s"
218                                     % f.name())
219                 while True:
220                     buf = f.read(2**20)
221                     if len(buf) == 0:
222                         break
223                     p.stdin.write(buf)
224                 p.stdin.close()
225                 p.wait()
226                 if p.returncode != 0:
227                     lockfile.close()
228                     raise Exception("tar exited %d" % p.returncode)
229             os.symlink(tarball, os.path.join(path, '.locator'))
230         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
231         lockfile.close()
232         if len(tld_extracts) == 1:
233             return os.path.join(path, tld_extracts[0])
234         return path
235
236     @staticmethod
237     def zipball_extract(zipball, path):
238         """Retrieve a zip archive from Keep and extract it to a local
239         directory.  Return the absolute path where the archive was
240         extracted. If the top level of the archive contained just one
241         file or directory, return the absolute path of that single
242         item.
243
244         zipball -- collection locator
245         path -- where to extract the archive: absolute, or relative to job tmp
246         """
247         if not re.search('^/', path):
248             path = os.path.join(current_job().tmpdir, path)
249         lockfile = open(path + '.lock', 'w')
250         fcntl.flock(lockfile, fcntl.LOCK_EX)
251         try:
252             os.stat(path)
253         except OSError:
254             os.mkdir(path)
255         already_have_it = False
256         try:
257             if os.readlink(os.path.join(path, '.locator')) == zipball:
258                 already_have_it = True
259         except OSError:
260             pass
261         if not already_have_it:
262
263             # emulate "rm -f" (i.e., if the file does not exist, we win)
264             try:
265                 os.unlink(os.path.join(path, '.locator'))
266             except OSError:
267                 if os.path.exists(os.path.join(path, '.locator')):
268                     os.unlink(os.path.join(path, '.locator'))
269
270             for f in CollectionReader(zipball).all_files():
271                 if not re.search('\.zip$', f.name()):
272                     raise Exception("zipball_extract cannot handle filename %s"
273                                     % f.name())
274                 zip_filename = os.path.join(path, os.path.basename(f.name()))
275                 zip_file = open(zip_filename, 'wb')
276                 while True:
277                     buf = f.read(2**20)
278                     if len(buf) == 0:
279                         break
280                     zip_file.write(buf)
281                 zip_file.close()
282                 
283                 p = subprocess.Popen(["unzip",
284                                       "-q", "-o",
285                                       "-d", path,
286                                       zip_filename],
287                                      stdout=None,
288                                      stdin=None, stderr=sys.stderr,
289                                      shell=False, close_fds=True)
290                 p.wait()
291                 if p.returncode != 0:
292                     lockfile.close()
293                     raise Exception("unzip exited %d" % p.returncode)
294                 os.unlink(zip_filename)
295             os.symlink(zipball, os.path.join(path, '.locator'))
296         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
297         lockfile.close()
298         if len(tld_extracts) == 1:
299             return os.path.join(path, tld_extracts[0])
300         return path
301
302     @staticmethod
303     def collection_extract(collection, path, files=[], decompress=True):
304         """Retrieve a collection from Keep and extract it to a local
305         directory.  Return the absolute path where the collection was
306         extracted.
307
308         collection -- collection locator
309         path -- where to extract: absolute, or relative to job tmp
310         """
311         if not re.search('^/', path):
312             path = os.path.join(current_job().tmpdir, path)
313         lockfile = open(path + '.lock', 'w')
314         fcntl.flock(lockfile, fcntl.LOCK_EX)
315         try:
316             os.stat(path)
317         except OSError:
318             os.mkdir(path)
319         already_have_it = False
320         try:
321             if os.readlink(os.path.join(path, '.locator')) == collection:
322                 already_have_it = True
323         except OSError:
324             pass
325
326         # emulate "rm -f" (i.e., if the file does not exist, we win)
327         try:
328             os.unlink(os.path.join(path, '.locator'))
329         except OSError:
330             if os.path.exists(os.path.join(path, '.locator')):
331                 os.unlink(os.path.join(path, '.locator'))
332
333         files_got = []
334         for s in CollectionReader(collection).all_streams():
335             stream_name = s.name()
336             for f in s.all_files():
337                 if (files == [] or
338                     ((f.name() not in files_got) and
339                      (f.name() in files or
340                       (decompress and f.decompressed_name() in files)))):
341                     outname = f.decompressed_name() if decompress else f.name()
342                     files_got += [outname]
343                     if os.path.exists(os.path.join(path, stream_name, outname)):
344                         continue
345                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
346                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
347                     for buf in (f.readall_decompressed() if decompress
348                                 else f.readall()):
349                         outfile.write(buf)
350                     outfile.close()
351         if len(files_got) < len(files):
352             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
353         os.symlink(collection, os.path.join(path, '.locator'))
354
355         lockfile.close()
356         return path
357
358     @staticmethod
359     def mkdir_dash_p(path):
360         if not os.path.exists(path):
361             util.mkdir_dash_p(os.path.dirname(path))
362             try:
363                 os.mkdir(path)
364             except OSError:
365                 if not os.path.exists(path):
366                     os.mkdir(path)
367
368     @staticmethod
369     def stream_extract(stream, path, files=[], decompress=True):
370         """Retrieve a stream from Keep and extract it to a local
371         directory.  Return the absolute path where the stream was
372         extracted.
373
374         stream -- StreamReader object
375         path -- where to extract: absolute, or relative to job tmp
376         """
377         if not re.search('^/', path):
378             path = os.path.join(current_job().tmpdir, path)
379         lockfile = open(path + '.lock', 'w')
380         fcntl.flock(lockfile, fcntl.LOCK_EX)
381         try:
382             os.stat(path)
383         except OSError:
384             os.mkdir(path)
385
386         files_got = []
387         for f in stream.all_files():
388             if (files == [] or
389                 ((f.name() not in files_got) and
390                  (f.name() in files or
391                   (decompress and f.decompressed_name() in files)))):
392                 outname = f.decompressed_name() if decompress else f.name()
393                 files_got += [outname]
394                 if os.path.exists(os.path.join(path, outname)):
395                     os.unlink(os.path.join(path, outname))
396                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
397                 outfile = open(os.path.join(path, outname), 'wb')
398                 for buf in (f.readall_decompressed() if decompress
399                             else f.readall()):
400                     outfile.write(buf)
401                 outfile.close()
402         if len(files_got) < len(files):
403             raise Exception("Wanted files %s but only got %s from %s" %
404                             (files, files_got, map(lambda z: z.name(),
405                                                    list(stream.all_files()))))
406         lockfile.close()
407         return path
408
409     @staticmethod
410     def listdir_recursive(dirname, base=None):
411         allfiles = []
412         for ent in sorted(os.listdir(dirname)):
413             ent_path = os.path.join(dirname, ent)
414             ent_base = os.path.join(base, ent) if base else ent
415             if os.path.isdir(ent_path):
416                 allfiles += util.listdir_recursive(ent_path, ent_base)
417             else:
418                 allfiles += [ent_base]
419         return allfiles
420
421 class DataReader:
422     def __init__(self, data_locator):
423         self.data_locator = data_locator
424         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
425                                   stdout=subprocess.PIPE,
426                                   stdin=None, stderr=subprocess.PIPE,
427                                   shell=False, close_fds=True)
428     def __enter__(self):
429         pass
430     def __exit__(self):
431         self.close()
432     def read(self, size, **kwargs):
433         return self.p.stdout.read(size, **kwargs)
434     def close(self):
435         self.p.stdout.close()
436         if not self.p.stderr.closed:
437             for err in self.p.stderr:
438                 print >> sys.stderr, err
439             self.p.stderr.close()
440         self.p.wait()
441         if self.p.returncode != 0:
442             raise Exception("whget subprocess exited %d" % self.p.returncode)
443
444 class StreamFileReader:
445     def __init__(self, stream, pos, size, name):
446         self._stream = stream
447         self._pos = pos
448         self._size = size
449         self._name = name
450         self._filepos = 0
451     def name(self):
452         return self._name
453     def decompressed_name(self):
454         return re.sub('\.(bz2|gz)$', '', self._name)
455     def size(self):
456         return self._size
457     def stream_name(self):
458         return self._stream.name()
459     def read(self, size, **kwargs):
460         self._stream.seek(self._pos + self._filepos)
461         data = self._stream.read(min(size, self._size - self._filepos))
462         self._filepos += len(data)
463         return data
464     def readall(self, size=2**20, **kwargs):
465         while True:
466             data = self.read(size, **kwargs)
467             if data == '':
468                 break
469             yield data
470     def bunzip2(self, size):
471         decompressor = bz2.BZ2Decompressor()
472         for chunk in self.readall(size):
473             data = decompressor.decompress(chunk)
474             if data and data != '':
475                 yield data
476     def gunzip(self, size):
477         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
478         for chunk in self.readall(size):
479             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
480             if data and data != '':
481                 yield data
482     def readall_decompressed(self, size=2**20):
483         self._stream.seek(self._pos + self._filepos)
484         if re.search('\.bz2$', self._name):
485             return self.bunzip2(size)
486         elif re.search('\.gz$', self._name):
487             return self.gunzip(size)
488         else:
489             return self.readall(size)
490     def readlines(self, decompress=True):
491         if decompress:
492             datasource = self.readall_decompressed()
493         else:
494             self._stream.seek(self._pos + self._filepos)
495             datasource = self.readall()
496         data = ''
497         for newdata in datasource:
498             data += newdata
499             sol = 0
500             while True:
501                 eol = string.find(data, "\n", sol)
502                 if eol < 0:
503                     break
504                 yield data[sol:eol+1]
505                 sol = eol+1
506             data = data[sol:]
507         if data != '':
508             yield data
509     def as_manifest(self):
510         if self.size() == 0:
511             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
512                     % (self._stream.name(), self.name()))
513         return string.join(self._stream.tokens_for_range(self._pos, self._size),
514                            " ") + "\n"
515
516 class StreamReader:
517     def __init__(self, tokens):
518         self._tokens = tokens
519         self._current_datablock_data = None
520         self._current_datablock_pos = 0
521         self._current_datablock_index = -1
522         self._pos = 0
523
524         self._stream_name = None
525         self.data_locators = []
526         self.files = []
527
528         for tok in self._tokens:
529             if self._stream_name == None:
530                 self._stream_name = tok
531             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
532                 self.data_locators += [tok]
533             elif re.search(r'^\d+:\d+:\S+', tok):
534                 pos, size, name = tok.split(':',2)
535                 self.files += [[int(pos), int(size), name]]
536             else:
537                 raise Exception("Invalid manifest format")
538
539     def tokens(self):
540         return self._tokens
541     def tokens_for_range(self, range_start, range_size):
542         resp = [self._stream_name]
543         return_all_tokens = False
544         block_start = 0
545         token_bytes_skipped = 0
546         for locator in self.data_locators:
547             sizehint = re.search(r'\+(\d+)', locator)
548             if not sizehint:
549                 return_all_tokens = True
550             if return_all_tokens:
551                 resp += [locator]
552                 next
553             blocksize = int(sizehint.group(0))
554             if range_start + range_size <= block_start:
555                 break
556             if range_start < block_start + blocksize:
557                 resp += [locator]
558             else:
559                 token_bytes_skipped += blocksize
560             block_start += blocksize
561         for f in self.files:
562             if ((f[0] < range_start + range_size)
563                 and
564                 (f[0] + f[1] > range_start)
565                 and
566                 f[1] > 0):
567                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
568         return resp
569     def name(self):
570         return self._stream_name
571     def all_files(self):
572         for f in self.files:
573             pos, size, name = f
574             yield StreamFileReader(self, pos, size, name)
575     def nextdatablock(self):
576         if self._current_datablock_index < 0:
577             self._current_datablock_pos = 0
578             self._current_datablock_index = 0
579         else:
580             self._current_datablock_pos += self.current_datablock_size()
581             self._current_datablock_index += 1
582         self._current_datablock_data = None
583     def current_datablock_data(self):
584         if self._current_datablock_data == None:
585             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
586         return self._current_datablock_data
587     def current_datablock_size(self):
588         if self._current_datablock_index < 0:
589             self.nextdatablock()
590         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
591         if sizehint:
592             return int(sizehint.group(0))
593         return len(self.current_datablock_data())
594     def seek(self, pos):
595         """Set the position of the next read operation."""
596         self._pos = pos
597     def really_seek(self):
598         """Find and load the appropriate data block, so the byte at
599         _pos is in memory.
600         """
601         if self._pos == self._current_datablock_pos:
602             return True
603         if (self._current_datablock_pos != None and
604             self._pos >= self._current_datablock_pos and
605             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
606             return True
607         if self._pos < self._current_datablock_pos:
608             self._current_datablock_index = -1
609             self.nextdatablock()
610         while (self._pos > self._current_datablock_pos and
611                self._pos > self._current_datablock_pos + self.current_datablock_size()):
612             self.nextdatablock()
613     def read(self, size):
614         """Read no more than size bytes -- but at least one byte,
615         unless _pos is already at the end of the stream.
616         """
617         if size == 0:
618             return ''
619         self.really_seek()
620         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
621             self.nextdatablock()
622             if self._current_datablock_index >= len(self.data_locators):
623                 return None
624         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
625         self._pos += len(data)
626         return data
627
628 class CollectionReader:
629     def __init__(self, manifest_locator_or_text):
630         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
631             self._manifest_text = manifest_locator_or_text
632             self._manifest_locator = None
633         else:
634             self._manifest_locator = manifest_locator_or_text
635             self._manifest_text = None
636         self._streams = None
637     def __enter__(self):
638         pass
639     def __exit__(self):
640         pass
641     def _populate(self):
642         if self._streams != None:
643             return
644         if not self._manifest_text:
645             self._manifest_text = Keep.get(self._manifest_locator)
646         self._streams = []
647         for stream_line in self._manifest_text.split("\n"):
648             if stream_line != '':
649                 stream_tokens = stream_line.split()
650                 self._streams += [stream_tokens]
651     def all_streams(self):
652         self._populate()
653         resp = []
654         for s in self._streams:
655             resp += [StreamReader(s)]
656         return resp
657     def all_files(self):
658         for s in self.all_streams():
659             for f in s.all_files():
660                 yield f
661     def manifest_text(self):
662         self._populate()
663         return self._manifest_text
664
665 class CollectionWriter:
666     KEEP_BLOCK_SIZE = 2**26
667     def __init__(self):
668         self._data_buffer = []
669         self._data_buffer_len = 0
670         self._current_stream_files = []
671         self._current_stream_length = 0
672         self._current_stream_locators = []
673         self._current_stream_name = '.'
674         self._current_file_name = None
675         self._current_file_pos = 0
676         self._finished_streams = []
677     def __enter__(self):
678         pass
679     def __exit__(self):
680         self.finish()
681     def write(self, newdata):
682         self._data_buffer += [newdata]
683         self._data_buffer_len += len(newdata)
684         self._current_stream_length += len(newdata)
685         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
686             self.flush_data()
687     def flush_data(self):
688         data_buffer = ''.join(self._data_buffer)
689         if data_buffer != '':
690             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
691             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
692             self._data_buffer_len = len(self._data_buffer[0])
693     def start_new_file(self, newfilename=None):
694         self.finish_current_file()
695         self.set_current_file_name(newfilename)
696     def set_current_file_name(self, newfilename):
697         if re.search(r'[ \t\n]', newfilename):
698             raise AssertionError("Manifest filenames cannot contain whitespace")
699         self._current_file_name = newfilename
700     def current_file_name(self):
701         return self._current_file_name
702     def finish_current_file(self):
703         if self._current_file_name == None:
704             if self._current_file_pos == self._current_stream_length:
705                 return
706             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
707         self._current_stream_files += [[self._current_file_pos,
708                                        self._current_stream_length - self._current_file_pos,
709                                        self._current_file_name]]
710         self._current_file_pos = self._current_stream_length
711     def start_new_stream(self, newstreamname='.'):
712         self.finish_current_stream()
713         self.set_current_stream_name(newstreamname)
714     def set_current_stream_name(self, newstreamname):
715         if re.search(r'[ \t\n]', newstreamname):
716             raise AssertionError("Manifest stream names cannot contain whitespace")
717         self._current_stream_name = newstreamname
718     def current_stream_name(self):
719         return self._current_stream_name
720     def finish_current_stream(self):
721         self.finish_current_file()
722         self.flush_data()
723         if len(self._current_stream_files) == 0:
724             pass
725         elif self._current_stream_name == None:
726             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
727         else:
728             self._finished_streams += [[self._current_stream_name,
729                                        self._current_stream_locators,
730                                        self._current_stream_files]]
731         self._current_stream_files = []
732         self._current_stream_length = 0
733         self._current_stream_locators = []
734         self._current_stream_name = None
735         self._current_file_pos = 0
736         self._current_file_name = None
737     def finish(self):
738         return Keep.put(self.manifest_text())
739     def manifest_text(self):
740         self.finish_current_stream()
741         manifest = ''
742         for stream in self._finished_streams:
743             if not re.search(r'^\.(/.*)?$', stream[0]):
744                 manifest += './'
745             manifest += stream[0]
746             if len(stream[1]) == 0:
747                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
748             else:
749                 for locator in stream[1]:
750                     manifest += " %s" % locator
751             for sfile in stream[2]:
752                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
753             manifest += "\n"
754         return manifest
755
756 class Keep:
757     @staticmethod
758     def put(data):
759         if 'KEEP_LOCAL_STORE' in os.environ:
760             return Keep.local_store_put(data)
761         p = subprocess.Popen(["whput", "-"],
762                              stdout=subprocess.PIPE,
763                              stdin=subprocess.PIPE,
764                              stderr=subprocess.PIPE,
765                              shell=False, close_fds=True)
766         stdoutdata, stderrdata = p.communicate(data)
767         if p.returncode != 0:
768             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
769         return stdoutdata.rstrip()
770     @staticmethod
771     def get(locator):
772         if 'KEEP_LOCAL_STORE' in os.environ:
773             return Keep.local_store_get(locator)
774         p = subprocess.Popen(["whget", locator, "-"],
775                              stdout=subprocess.PIPE,
776                              stdin=None,
777                              stderr=subprocess.PIPE,
778                              shell=False, close_fds=True)
779         stdoutdata, stderrdata = p.communicate(None)
780         if p.returncode != 0:
781             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
782         m = hashlib.new('md5')
783         m.update(stdoutdata)
784         try:
785             if locator.index(m.hexdigest()) == 0:
786                 return stdoutdata
787         except ValueError:
788             pass
789         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))
790     @staticmethod
791     def local_store_put(data):
792         m = hashlib.new('md5')
793         m.update(data)
794         md5 = m.hexdigest()
795         locator = '%s+%d' % (md5, len(data))
796         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
797             f.write(data)
798         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
799                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
800         return locator
801     @staticmethod
802     def local_store_get(locator):
803         r = re.search('^([0-9a-f]{32,})', locator)
804         if not r:
805             raise Exception("Keep.get: invalid data locator '%s'" % locator)
806         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
807             return ''
808         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
809             return f.read()