fix extra empty stream in all_streams() response
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13 import string
14 import bz2
15 import zlib
16 import fcntl
17
18 from apiclient import errors
19 from apiclient.discovery import build
20
21 class CredentialsFromEnv:
22     @staticmethod
23     def http_request(self, uri, **kwargs):
24         from httplib import BadStatusLine
25         if 'headers' not in kwargs:
26             kwargs['headers'] = {}
27         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
28         try:
29             return self.orig_http_request(uri, **kwargs)
30         except BadStatusLine:
31             # This is how httplib tells us that it tried to reuse an
32             # existing connection but it was already closed by the
33             # server. In that case, yes, we would like to retry.
34             # Unfortunately, we are not absolutely certain that the
35             # previous call did not succeed, so this is slightly
36             # risky.
37             return self.orig_http_request(uri, **kwargs)
38     def authorize(self, http):
39         http.orig_http_request = http.request
40         http.request = types.MethodType(self.http_request, http)
41         return http
42
43 url = ('https://%s/discovery/v1/apis/'
44        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
45 credentials = CredentialsFromEnv()
46 http = httplib2.Http()
47 http = credentials.authorize(http)
48 http.disable_ssl_certificate_validation=True
49 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
50
51 def task_set_output(self,s):
52     service.job_tasks().update(uuid=self['uuid'],
53                                job_task=json.dumps({
54                 'output':s,
55                 'success':True,
56                 'progress':1.0
57                 })).execute()
58
59 _current_task = None
60 def current_task():
61     global _current_task
62     if _current_task:
63         return _current_task
64     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
65     t = UserDict.UserDict(t)
66     t.set_output = types.MethodType(task_set_output, t)
67     t.tmpdir = os.environ['TASK_WORK']
68     _current_task = t
69     return t
70
71 _current_job = None
72 def current_job():
73     global _current_job
74     if _current_job:
75         return _current_job
76     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
77     t = UserDict.UserDict(t)
78     t.tmpdir = os.environ['JOB_WORK']
79     _current_job = t
80     return t
81
82 def api():
83     return service
84
85 class JobTask:
86     def __init__(self, parameters=dict(), resource_limits=dict()):
87         print "init jobtask %s %s" % (parameters, resource_limits)
88
89 class job_setup:
90     @staticmethod
91     def one_task_per_input_file(if_sequence=0, and_end_task=True):
92         if if_sequence != current_task()['sequence']:
93             return
94         job_input = current_job()['script_parameters']['input']
95         cr = CollectionReader(job_input)
96         for s in cr.all_streams():
97             for f in s.all_files():
98                 task_input = f.as_manifest()
99                 new_task_attrs = {
100                     'job_uuid': current_job()['uuid'],
101                     'created_by_job_task_uuid': current_task()['uuid'],
102                     'sequence': if_sequence + 1,
103                     'parameters': {
104                         'input':task_input
105                         }
106                     }
107                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
108         if and_end_task:
109             service.job_tasks().update(uuid=current_task()['uuid'],
110                                        job_task=json.dumps({'success':True})
111                                        ).execute()
112             exit(0)
113
114     @staticmethod
115     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
116         if if_sequence != current_task()['sequence']:
117             return
118         job_input = current_job()['script_parameters']['input']
119         cr = CollectionReader(job_input)
120         for s in cr.all_streams():
121             task_input = s.tokens()
122             new_task_attrs = {
123                 'job_uuid': current_job()['uuid'],
124                 'created_by_job_task_uuid': current_task()['uuid'],
125                 'sequence': if_sequence + 1,
126                 'parameters': {
127                     'input':task_input
128                     }
129                 }
130             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
131         if and_end_task:
132             service.job_tasks().update(uuid=current_task()['uuid'],
133                                        job_task=json.dumps({'success':True})
134                                        ).execute()
135             exit(0)
136
137 class util:
138     @staticmethod
139     def run_command(execargs, **kwargs):
140         if 'stdin' not in kwargs:
141             kwargs['stdin'] = subprocess.PIPE
142         if 'stdout' not in kwargs:
143             kwargs['stdout'] = subprocess.PIPE
144         if 'stderr' not in kwargs:
145             kwargs['stderr'] = subprocess.PIPE
146         p = subprocess.Popen(execargs, close_fds=True, shell=False,
147                              **kwargs)
148         stdoutdata, stderrdata = p.communicate(None)
149         if p.returncode != 0:
150             raise Exception("run_command %s exit %d:\n%s" %
151                             (execargs, p.returncode, stderrdata))
152         return stdoutdata, stderrdata
153
154     @staticmethod
155     def git_checkout(url, version, path):
156         if not re.search('^/', path):
157             path = os.path.join(current_job().tmpdir, path)
158         if not os.path.exists(path):
159             util.run_command(["git", "clone", url, path],
160                              cwd=os.path.dirname(path))
161         util.run_command(["git", "checkout", version],
162                          cwd=path)
163         return path
164
165     @staticmethod
166     def tar_extractor(path, decompress_flag):
167         return subprocess.Popen(["tar",
168                                  "-C", path,
169                                  ("-x%sf" % decompress_flag),
170                                  "-"],
171                                 stdout=None,
172                                 stdin=subprocess.PIPE, stderr=sys.stderr,
173                                 shell=False, close_fds=True)
174
175     @staticmethod
176     def tarball_extract(tarball, path):
177         """Retrieve a tarball from Keep and extract it to a local
178         directory.  Return the absolute path where the tarball was
179         extracted. If the top level of the tarball contained just one
180         file or directory, return the absolute path of that single
181         item.
182
183         tarball -- collection locator
184         path -- where to extract the tarball: absolute, or relative to job tmp
185         """
186         if not re.search('^/', path):
187             path = os.path.join(current_job().tmpdir, path)
188         lockfile = open(path + '.lock', 'w')
189         fcntl.flock(lockfile, fcntl.LOCK_EX)
190         try:
191             os.stat(path)
192         except OSError:
193             os.mkdir(path)
194         already_have_it = False
195         try:
196             if os.readlink(os.path.join(path, '.locator')) == tarball:
197                 already_have_it = True
198         except OSError:
199             pass
200         if not already_have_it:
201
202             # emulate "rm -f" (i.e., if the file does not exist, we win)
203             try:
204                 os.unlink(os.path.join(path, '.locator'))
205             except OSError:
206                 if os.path.exists(os.path.join(path, '.locator')):
207                     os.unlink(os.path.join(path, '.locator'))
208
209             for f in CollectionReader(tarball).all_files():
210                 if re.search('\.(tbz|tar.bz2)$', f.name()):
211                     p = util.tar_extractor(path, 'j')
212                 elif re.search('\.(tgz|tar.gz)$', f.name()):
213                     p = util.tar_extractor(path, 'z')
214                 elif re.search('\.tar$', f.name()):
215                     p = util.tar_extractor(path, '')
216                 else:
217                     raise Exception("tarball_extract cannot handle filename %s"
218                                     % f.name())
219                 while True:
220                     buf = f.read(2**20)
221                     if len(buf) == 0:
222                         break
223                     p.stdin.write(buf)
224                 p.stdin.close()
225                 p.wait()
226                 if p.returncode != 0:
227                     lockfile.close()
228                     raise Exception("tar exited %d" % p.returncode)
229             os.symlink(tarball, os.path.join(path, '.locator'))
230         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
231         lockfile.close()
232         if len(tld_extracts) == 1:
233             return os.path.join(path, tld_extracts[0])
234         return path
235
236     @staticmethod
237     def zipball_extract(zipball, path):
238         """Retrieve a zip archive from Keep and extract it to a local
239         directory.  Return the absolute path where the archive was
240         extracted. If the top level of the archive contained just one
241         file or directory, return the absolute path of that single
242         item.
243
244         zipball -- collection locator
245         path -- where to extract the archive: absolute, or relative to job tmp
246         """
247         if not re.search('^/', path):
248             path = os.path.join(current_job().tmpdir, path)
249         lockfile = open(path + '.lock', 'w')
250         fcntl.flock(lockfile, fcntl.LOCK_EX)
251         try:
252             os.stat(path)
253         except OSError:
254             os.mkdir(path)
255         already_have_it = False
256         try:
257             if os.readlink(os.path.join(path, '.locator')) == zipball:
258                 already_have_it = True
259         except OSError:
260             pass
261         if not already_have_it:
262
263             # emulate "rm -f" (i.e., if the file does not exist, we win)
264             try:
265                 os.unlink(os.path.join(path, '.locator'))
266             except OSError:
267                 if os.path.exists(os.path.join(path, '.locator')):
268                     os.unlink(os.path.join(path, '.locator'))
269
270             for f in CollectionReader(zipball).all_files():
271                 if not re.search('\.zip$', f.name()):
272                     raise Exception("zipball_extract cannot handle filename %s"
273                                     % f.name())
274                 zip_filename = os.path.join(path, os.path.basename(f.name()))
275                 zip_file = open(zip_filename, 'wb')
276                 while True:
277                     buf = f.read(2**20)
278                     if len(buf) == 0:
279                         break
280                     zip_file.write(buf)
281                 zip_file.close()
282                 
283                 p = subprocess.Popen(["unzip",
284                                       "-q", "-o",
285                                       "-d", path,
286                                       zip_filename],
287                                      stdout=None,
288                                      stdin=None, stderr=sys.stderr,
289                                      shell=False, close_fds=True)
290                 p.wait()
291                 if p.returncode != 0:
292                     lockfile.close()
293                     raise Exception("unzip exited %d" % p.returncode)
294                 os.unlink(zip_filename)
295             os.symlink(zipball, os.path.join(path, '.locator'))
296         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
297         lockfile.close()
298         if len(tld_extracts) == 1:
299             return os.path.join(path, tld_extracts[0])
300         return path
301
302     @staticmethod
303     def collection_extract(collection, path, files=[], decompress=True):
304         """Retrieve a collection from Keep and extract it to a local
305         directory.  Return the absolute path where the collection was
306         extracted.
307
308         collection -- collection locator
309         path -- where to extract: absolute, or relative to job tmp
310         """
311         if not re.search('^/', path):
312             path = os.path.join(current_job().tmpdir, path)
313         lockfile = open(path + '.lock', 'w')
314         fcntl.flock(lockfile, fcntl.LOCK_EX)
315         try:
316             os.stat(path)
317         except OSError:
318             os.mkdir(path)
319         already_have_it = False
320         try:
321             if os.readlink(os.path.join(path, '.locator')) == collection:
322                 already_have_it = True
323         except OSError:
324             pass
325
326         # emulate "rm -f" (i.e., if the file does not exist, we win)
327         try:
328             os.unlink(os.path.join(path, '.locator'))
329         except OSError:
330             if os.path.exists(os.path.join(path, '.locator')):
331                 os.unlink(os.path.join(path, '.locator'))
332
333         files_got = []
334         for f in CollectionReader(collection).all_files():
335             if (files == [] or
336                 ((f.name() not in files_got) and
337                  (f.name() in files or
338                   (decompress and f.decompressed_name() in files)))):
339                 outname = f.decompressed_name() if decompress else f.name()
340                 files_got += [outname]
341                 if os.path.exists(os.path.join(path, outname)):
342                     continue
343                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
344                 outfile = open(os.path.join(path, outname), 'wb')
345                 for buf in (f.readall_decompressed() if decompress
346                             else f.readall()):
347                     outfile.write(buf)
348                 outfile.close()
349         if len(files_got) < len(files):
350             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
351         os.symlink(collection, os.path.join(path, '.locator'))
352
353         lockfile.close()
354         return path
355
356     @staticmethod
357     def mkdir_dash_p(path):
358         if not os.path.exists(path):
359             util.mkdir_dash_p(os.path.dirname(path))
360             try:
361                 os.mkdir(path)
362             except OSError:
363                 if not os.path.exists(path):
364                     os.mkdir(path)
365
366     @staticmethod
367     def stream_extract(stream, path, files=[], decompress=True):
368         """Retrieve a stream from Keep and extract it to a local
369         directory.  Return the absolute path where the stream was
370         extracted.
371
372         stream -- StreamReader object
373         path -- where to extract: absolute, or relative to job tmp
374         """
375         if not re.search('^/', path):
376             path = os.path.join(current_job().tmpdir, path)
377         lockfile = open(path + '.lock', 'w')
378         fcntl.flock(lockfile, fcntl.LOCK_EX)
379         try:
380             os.stat(path)
381         except OSError:
382             os.mkdir(path)
383
384         files_got = []
385         for f in stream.all_files():
386             if (files == [] or
387                 ((f.name() not in files_got) and
388                  (f.name() in files or
389                   (decompress and f.decompressed_name() in files)))):
390                 outname = f.decompressed_name() if decompress else f.name()
391                 files_got += [outname]
392                 if os.path.exists(os.path.join(path, outname)):
393                     os.unlink(os.path.join(path, outname))
394                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
395                 outfile = open(os.path.join(path, outname), 'wb')
396                 for buf in (f.readall_decompressed() if decompress
397                             else f.readall()):
398                     outfile.write(buf)
399                 outfile.close()
400         if len(files_got) < len(files):
401             raise Exception("Wanted files %s but only got %s from %s" %
402                             (files, files_got, map(lambda z: z.name(),
403                                                    list(stream.all_files()))))
404         lockfile.close()
405         return path
406
407     @staticmethod
408     def listdir_recursive(dirname, base=None):
409         allfiles = []
410         for ent in sorted(os.listdir(dirname)):
411             ent_path = os.path.join(dirname, ent)
412             ent_base = os.path.join(base, ent) if base else ent
413             if os.path.isdir(ent_path):
414                 allfiles += util.listdir_recursive(ent_path, ent_base)
415             else:
416                 allfiles += [ent_base]
417         return allfiles
418
419 class DataReader:
420     def __init__(self, data_locator):
421         self.data_locator = data_locator
422         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
423                                   stdout=subprocess.PIPE,
424                                   stdin=None, stderr=subprocess.PIPE,
425                                   shell=False, close_fds=True)
426     def __enter__(self):
427         pass
428     def __exit__(self):
429         self.close()
430     def read(self, size, **kwargs):
431         return self.p.stdout.read(size, **kwargs)
432     def close(self):
433         self.p.stdout.close()
434         if not self.p.stderr.closed:
435             for err in self.p.stderr:
436                 print >> sys.stderr, err
437             self.p.stderr.close()
438         self.p.wait()
439         if self.p.returncode != 0:
440             raise Exception("whget subprocess exited %d" % self.p.returncode)
441
442 class StreamFileReader:
443     def __init__(self, stream, pos, size, name):
444         self._stream = stream
445         self._pos = pos
446         self._size = size
447         self._name = name
448         self._filepos = 0
449     def name(self):
450         return self._name
451     def decompressed_name(self):
452         return re.sub('\.(bz2|gz)$', '', self._name)
453     def size(self):
454         return self._size
455     def stream_name(self):
456         return self._stream.name()
457     def read(self, size, **kwargs):
458         self._stream.seek(self._pos + self._filepos)
459         data = self._stream.read(min(size, self._size - self._filepos))
460         self._filepos += len(data)
461         return data
462     def readall(self, size=2**20, **kwargs):
463         while True:
464             data = self.read(size, **kwargs)
465             if data == '':
466                 break
467             yield data
468     def bunzip2(self, size):
469         decompressor = bz2.BZ2Decompressor()
470         for chunk in self.readall(size):
471             data = decompressor.decompress(chunk)
472             if data and data != '':
473                 yield data
474     def gunzip(self, size):
475         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
476         for chunk in self.readall(size):
477             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
478             if data and data != '':
479                 yield data
480     def readall_decompressed(self, size=2**20):
481         self._stream.seek(self._pos + self._filepos)
482         if re.search('\.bz2$', self._name):
483             return self.bunzip2(size)
484         elif re.search('\.gz$', self._name):
485             return self.gunzip(size)
486         else:
487             return self.readall(size)
488     def readlines(self, decompress=True):
489         if decompress:
490             datasource = self.readall_decompressed()
491         else:
492             self._stream.seek(self._pos + self._filepos)
493             datasource = self.readall()
494         data = ''
495         for newdata in datasource:
496             data += newdata
497             sol = 0
498             while True:
499                 eol = string.find(data, "\n", sol)
500                 if eol < 0:
501                     break
502                 yield data[sol:eol+1]
503                 sol = eol+1
504             data = data[sol:]
505         if data != '':
506             yield data
507     def as_manifest(self):
508         if self.size() == 0:
509             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
510                     % (self._stream.name(), self.name()))
511         return string.join(self._stream.tokens_for_range(self._pos, self._size),
512                            " ") + "\n"
513
514 class StreamReader:
515     def __init__(self, tokens):
516         self._tokens = tokens
517         self._current_datablock_data = None
518         self._current_datablock_pos = 0
519         self._current_datablock_index = -1
520         self._pos = 0
521
522         self._stream_name = None
523         self.data_locators = []
524         self.files = []
525
526         for tok in self._tokens:
527             if self._stream_name == None:
528                 self._stream_name = tok
529             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
530                 self.data_locators += [tok]
531             elif re.search(r'^\d+:\d+:\S+', tok):
532                 pos, size, name = tok.split(':',2)
533                 self.files += [[int(pos), int(size), name]]
534             else:
535                 raise Exception("Invalid manifest format")
536
537     def tokens(self):
538         return self._tokens
539     def tokens_for_range(self, range_start, range_size):
540         resp = [self._stream_name]
541         return_all_tokens = False
542         block_start = 0
543         token_bytes_skipped = 0
544         for locator in self.data_locators:
545             sizehint = re.search(r'\+(\d+)', locator)
546             if not sizehint:
547                 return_all_tokens = True
548             if return_all_tokens:
549                 resp += [locator]
550                 next
551             blocksize = int(sizehint.group(0))
552             if range_start + range_size <= block_start:
553                 break
554             if range_start < block_start + blocksize:
555                 resp += [locator]
556             else:
557                 token_bytes_skipped += blocksize
558             block_start += blocksize
559         for f in self.files:
560             if ((f[0] < range_start + range_size)
561                 and
562                 (f[0] + f[1] > range_start)
563                 and
564                 f[1] > 0):
565                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
566         return resp
567     def name(self):
568         return self._stream_name
569     def all_files(self):
570         for f in self.files:
571             pos, size, name = f
572             yield StreamFileReader(self, pos, size, name)
573     def nextdatablock(self):
574         if self._current_datablock_index < 0:
575             self._current_datablock_pos = 0
576             self._current_datablock_index = 0
577         else:
578             self._current_datablock_pos += self.current_datablock_size()
579             self._current_datablock_index += 1
580         self._current_datablock_data = None
581     def current_datablock_data(self):
582         if self._current_datablock_data == None:
583             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
584         return self._current_datablock_data
585     def current_datablock_size(self):
586         if self._current_datablock_index < 0:
587             self.nextdatablock()
588         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
589         if sizehint:
590             return int(sizehint.group(0))
591         return len(self.current_datablock_data())
592     def seek(self, pos):
593         """Set the position of the next read operation."""
594         self._pos = pos
595     def really_seek(self):
596         """Find and load the appropriate data block, so the byte at
597         _pos is in memory.
598         """
599         if self._pos == self._current_datablock_pos:
600             return True
601         if (self._current_datablock_pos != None and
602             self._pos >= self._current_datablock_pos and
603             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
604             return True
605         if self._pos < self._current_datablock_pos:
606             self._current_datablock_index = -1
607             self.nextdatablock()
608         while (self._pos > self._current_datablock_pos and
609                self._pos > self._current_datablock_pos + self.current_datablock_size()):
610             self.nextdatablock()
611     def read(self, size):
612         """Read no more than size bytes -- but at least one byte,
613         unless _pos is already at the end of the stream.
614         """
615         if size == 0:
616             return ''
617         self.really_seek()
618         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
619             self.nextdatablock()
620             if self._current_datablock_index >= len(self.data_locators):
621                 return None
622         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
623         self._pos += len(data)
624         return data
625
626 class CollectionReader:
627     def __init__(self, manifest_locator_or_text):
628         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
629             self._manifest_text = manifest_locator_or_text
630             self._manifest_locator = None
631         else:
632             self._manifest_locator = manifest_locator_or_text
633             self._manifest_text = None
634         self._streams = None
635     def __enter__(self):
636         pass
637     def __exit__(self):
638         pass
639     def _populate(self):
640         if self._streams != None:
641             return
642         if not self._manifest_text:
643             self._manifest_text = Keep.get(self._manifest_locator)
644         self._streams = []
645         for stream_line in self._manifest_text.split("\n"):
646             if stream_line != '':
647                 stream_tokens = stream_line.split()
648                 self._streams += [stream_tokens]
649     def all_streams(self):
650         self._populate()
651         resp = []
652         for s in self._streams:
653             resp += [StreamReader(s)]
654         return resp
655     def all_files(self):
656         for s in self.all_streams():
657             for f in s.all_files():
658                 yield f
659     def manifest_text(self):
660         self._populate()
661         return self._manifest_text
662
663 class CollectionWriter:
664     KEEP_BLOCK_SIZE = 2**26
665     def __init__(self):
666         self._data_buffer = []
667         self._data_buffer_len = 0
668         self._current_stream_files = []
669         self._current_stream_length = 0
670         self._current_stream_locators = []
671         self._current_stream_name = '.'
672         self._current_file_name = None
673         self._current_file_pos = 0
674         self._finished_streams = []
675     def __enter__(self):
676         pass
677     def __exit__(self):
678         self.finish()
679     def write(self, newdata):
680         self._data_buffer += [newdata]
681         self._data_buffer_len += len(newdata)
682         self._current_stream_length += len(newdata)
683         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
684             self.flush_data()
685     def flush_data(self):
686         data_buffer = ''.join(self._data_buffer)
687         if data_buffer != '':
688             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
689             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
690             self._data_buffer_len = len(self._data_buffer[0])
691     def start_new_file(self, newfilename=None):
692         self.finish_current_file()
693         self.set_current_file_name(newfilename)
694     def set_current_file_name(self, newfilename):
695         if re.search(r'[ \t\n]', newfilename):
696             raise AssertionError("Manifest filenames cannot contain whitespace")
697         self._current_file_name = newfilename
698     def current_file_name(self):
699         return self._current_file_name
700     def finish_current_file(self):
701         if self._current_file_name == None:
702             if self._current_file_pos == self._current_stream_length:
703                 return
704             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
705         self._current_stream_files += [[self._current_file_pos,
706                                        self._current_stream_length - self._current_file_pos,
707                                        self._current_file_name]]
708         self._current_file_pos = self._current_stream_length
709     def start_new_stream(self, newstreamname='.'):
710         self.finish_current_stream()
711         self.set_current_stream_name(newstreamname)
712     def set_current_stream_name(self, newstreamname):
713         if re.search(r'[ \t\n]', newstreamname):
714             raise AssertionError("Manifest stream names cannot contain whitespace")
715         self._current_stream_name = newstreamname
716     def current_stream_name(self):
717         return self._current_stream_name
718     def finish_current_stream(self):
719         self.finish_current_file()
720         self.flush_data()
721         if len(self._current_stream_files) == 0:
722             pass
723         elif self._current_stream_name == None:
724             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
725         else:
726             self._finished_streams += [[self._current_stream_name,
727                                        self._current_stream_locators,
728                                        self._current_stream_files]]
729         self._current_stream_files = []
730         self._current_stream_length = 0
731         self._current_stream_locators = []
732         self._current_stream_name = None
733         self._current_file_pos = 0
734         self._current_file_name = None
735     def finish(self):
736         return Keep.put(self.manifest_text())
737     def manifest_text(self):
738         self.finish_current_stream()
739         manifest = ''
740         for stream in self._finished_streams:
741             if not re.search(r'^\.(/.*)?$', stream[0]):
742                 manifest += './'
743             manifest += stream[0]
744             if len(stream[1]) == 0:
745                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
746             else:
747                 for locator in stream[1]:
748                     manifest += " %s" % locator
749             for sfile in stream[2]:
750                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
751             manifest += "\n"
752         return manifest
753
754 class Keep:
755     @staticmethod
756     def put(data):
757         if 'KEEP_LOCAL_STORE' in os.environ:
758             return Keep.local_store_put(data)
759         p = subprocess.Popen(["whput", "-"],
760                              stdout=subprocess.PIPE,
761                              stdin=subprocess.PIPE,
762                              stderr=subprocess.PIPE,
763                              shell=False, close_fds=True)
764         stdoutdata, stderrdata = p.communicate(data)
765         if p.returncode != 0:
766             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
767         return stdoutdata.rstrip()
768     @staticmethod
769     def get(locator):
770         if 'KEEP_LOCAL_STORE' in os.environ:
771             return Keep.local_store_get(locator)
772         p = subprocess.Popen(["whget", locator, "-"],
773                              stdout=subprocess.PIPE,
774                              stdin=None,
775                              stderr=subprocess.PIPE,
776                              shell=False, close_fds=True)
777         stdoutdata, stderrdata = p.communicate(None)
778         if p.returncode != 0:
779             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
780         m = hashlib.new('md5')
781         m.update(stdoutdata)
782         try:
783             if locator.index(m.hexdigest()) == 0:
784                 return stdoutdata
785         except ValueError:
786             pass
787         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))
788     @staticmethod
789     def local_store_put(data):
790         m = hashlib.new('md5')
791         m.update(data)
792         md5 = m.hexdigest()
793         locator = '%s+%d' % (md5, len(data))
794         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
795             f.write(data)
796         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
797                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
798         return locator
799     @staticmethod
800     def local_store_get(locator):
801         r = re.search('^([0-9a-f]{32,})', locator)
802         if not r:
803             raise Exception("Keep.get: invalid data locator '%s'" % locator)
804         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
805             return ''
806         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
807             return f.read()