fix out of scope function use
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13 import string
14 import bz2
15 import zlib
16 import fcntl
17
18 from apiclient import errors
19 from apiclient.discovery import build
20
21 class CredentialsFromEnv:
22     @staticmethod
23     def http_request(self, uri, **kwargs):
24         from httplib import BadStatusLine
25         if 'headers' not in kwargs:
26             kwargs['headers'] = {}
27         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
28         try:
29             return self.orig_http_request(uri, **kwargs)
30         except BadStatusLine:
31             # This is how httplib tells us that it tried to reuse an
32             # existing connection but it was already closed by the
33             # server. In that case, yes, we would like to retry.
34             # Unfortunately, we are not absolutely certain that the
35             # previous call did not succeed, so this is slightly
36             # risky.
37             return self.orig_http_request(uri, **kwargs)
38     def authorize(self, http):
39         http.orig_http_request = http.request
40         http.request = types.MethodType(self.http_request, http)
41         return http
42
43 url = ('https://%s/discovery/v1/apis/'
44        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
45 credentials = CredentialsFromEnv()
46 http = httplib2.Http()
47 http = credentials.authorize(http)
48 http.disable_ssl_certificate_validation=True
49 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
50
51 def task_set_output(self,s):
52     service.job_tasks().update(uuid=self['uuid'],
53                                job_task=json.dumps({
54                 'output':s,
55                 'success':True,
56                 'progress':1.0
57                 })).execute()
58
59 _current_task = None
60 def current_task():
61     global _current_task
62     if _current_task:
63         return _current_task
64     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
65     t = UserDict.UserDict(t)
66     t.set_output = types.MethodType(task_set_output, t)
67     t.tmpdir = os.environ['TASK_WORK']
68     _current_task = t
69     return t
70
71 _current_job = None
72 def current_job():
73     global _current_job
74     if _current_job:
75         return _current_job
76     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
77     t = UserDict.UserDict(t)
78     t.tmpdir = os.environ['JOB_WORK']
79     _current_job = t
80     return t
81
82 def api():
83     return service
84
85 class JobTask:
86     def __init__(self, parameters=dict(), resource_limits=dict()):
87         print "init jobtask %s %s" % (parameters, resource_limits)
88
89 class job_setup:
90     @staticmethod
91     def one_task_per_input_file(if_sequence=0, and_end_task=True):
92         if if_sequence != current_task()['sequence']:
93             return
94         job_input = current_job()['script_parameters']['input']
95         cr = CollectionReader(job_input)
96         for s in cr.all_streams():
97             for f in s.all_files():
98                 task_input = f.as_manifest()
99                 new_task_attrs = {
100                     'job_uuid': current_job()['uuid'],
101                     'created_by_job_task_uuid': current_task()['uuid'],
102                     'sequence': if_sequence + 1,
103                     'parameters': {
104                         'input':task_input
105                         }
106                     }
107                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
108         if and_end_task:
109             service.job_tasks().update(uuid=current_task()['uuid'],
110                                        job_task=json.dumps({'success':True})
111                                        ).execute()
112             exit(0)
113
114     @staticmethod
115     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
116         if if_sequence != current_task()['sequence']:
117             return
118         job_input = current_job()['script_parameters']['input']
119         cr = CollectionReader(job_input)
120         for s in cr.all_streams():
121             task_input = s.tokens()
122             new_task_attrs = {
123                 'job_uuid': current_job()['uuid'],
124                 'created_by_job_task_uuid': current_task()['uuid'],
125                 'sequence': if_sequence + 1,
126                 'parameters': {
127                     'input':task_input
128                     }
129                 }
130             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
131         if and_end_task:
132             service.job_tasks().update(uuid=current_task()['uuid'],
133                                        job_task=json.dumps({'success':True})
134                                        ).execute()
135             exit(0)
136
137 class util:
138     @staticmethod
139     def run_command(execargs, **kwargs):
140         if 'stdin' not in kwargs:
141             kwargs['stdin'] = subprocess.PIPE
142         if 'stdout' not in kwargs:
143             kwargs['stdout'] = subprocess.PIPE
144         if 'stderr' not in kwargs:
145             kwargs['stderr'] = subprocess.PIPE
146         p = subprocess.Popen(execargs, close_fds=True, shell=False,
147                              **kwargs)
148         stdoutdata, stderrdata = p.communicate(None)
149         if p.returncode != 0:
150             raise Exception("run_command %s exit %d:\n%s" %
151                             (execargs, p.returncode, stderrdata))
152         return stdoutdata, stderrdata
153
154     @staticmethod
155     def git_checkout(url, version, path):
156         if not re.search('^/', path):
157             path = os.path.join(current_job().tmpdir, path)
158         if not os.path.exists(path):
159             util.run_command(["git", "clone", url, path],
160                              cwd=os.path.dirname(path))
161         util.run_command(["git", "checkout", version],
162                          cwd=path)
163         return path
164
165     @staticmethod
166     def tar_extractor(path, decompress_flag):
167         return subprocess.Popen(["tar",
168                                  "-C", path,
169                                  ("-x%sf" % decompress_flag),
170                                  "-"],
171                                 stdout=None,
172                                 stdin=subprocess.PIPE, stderr=sys.stderr,
173                                 shell=False, close_fds=True)
174
175     @staticmethod
176     def tarball_extract(tarball, path):
177         """Retrieve a tarball from Keep and extract it to a local
178         directory.  Return the absolute path where the tarball was
179         extracted. If the top level of the tarball contained just one
180         file or directory, return the absolute path of that single
181         item.
182
183         tarball -- collection locator
184         path -- where to extract the tarball: absolute, or relative to job tmp
185         """
186         if not re.search('^/', path):
187             path = os.path.join(current_job().tmpdir, path)
188         lockfile = open(path + '.lock', 'w')
189         fcntl.flock(lockfile, fcntl.LOCK_EX)
190         try:
191             os.stat(path)
192         except OSError:
193             os.mkdir(path)
194         already_have_it = False
195         try:
196             if os.readlink(os.path.join(path, '.locator')) == tarball:
197                 already_have_it = True
198         except OSError:
199             pass
200         if not already_have_it:
201
202             # emulate "rm -f" (i.e., if the file does not exist, we win)
203             try:
204                 os.unlink(os.path.join(path, '.locator'))
205             except OSError:
206                 if os.path.exists(os.path.join(path, '.locator')):
207                     os.unlink(os.path.join(path, '.locator'))
208
209             for f in CollectionReader(tarball).all_files():
210                 if re.search('\.(tbz|tar.bz2)$', f.name()):
211                     p = util.tar_extractor(path, 'j')
212                 elif re.search('\.(tgz|tar.gz)$', f.name()):
213                     p = util.tar_extractor(path, 'z')
214                 elif re.search('\.tar$', f.name()):
215                     p = util.tar_extractor(path, '')
216                 else:
217                     raise Exception("tarball_extract cannot handle filename %s"
218                                     % f.name())
219                 while True:
220                     buf = f.read(2**20)
221                     if len(buf) == 0:
222                         break
223                     p.stdin.write(buf)
224                 p.stdin.close()
225                 p.wait()
226                 if p.returncode != 0:
227                     lockfile.close()
228                     raise Exception("tar exited %d" % p.returncode)
229             os.symlink(tarball, os.path.join(path, '.locator'))
230         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
231         lockfile.close()
232         if len(tld_extracts) == 1:
233             return os.path.join(path, tld_extracts[0])
234         return path
235
236     @staticmethod
237     def zipball_extract(zipball, path):
238         """Retrieve a zip archive from Keep and extract it to a local
239         directory.  Return the absolute path where the archive was
240         extracted. If the top level of the archive contained just one
241         file or directory, return the absolute path of that single
242         item.
243
244         zipball -- collection locator
245         path -- where to extract the archive: absolute, or relative to job tmp
246         """
247         if not re.search('^/', path):
248             path = os.path.join(current_job().tmpdir, path)
249         lockfile = open(path + '.lock', 'w')
250         fcntl.flock(lockfile, fcntl.LOCK_EX)
251         try:
252             os.stat(path)
253         except OSError:
254             os.mkdir(path)
255         already_have_it = False
256         try:
257             if os.readlink(os.path.join(path, '.locator')) == zipball:
258                 already_have_it = True
259         except OSError:
260             pass
261         if not already_have_it:
262
263             # emulate "rm -f" (i.e., if the file does not exist, we win)
264             try:
265                 os.unlink(os.path.join(path, '.locator'))
266             except OSError:
267                 if os.path.exists(os.path.join(path, '.locator')):
268                     os.unlink(os.path.join(path, '.locator'))
269
270             for f in CollectionReader(zipball).all_files():
271                 if not re.search('\.zip$', f.name()):
272                     raise Exception("zipball_extract cannot handle filename %s"
273                                     % f.name())
274                 zip_filename = os.path.join(path, os.path.basename(f.name()))
275                 zip_file = open(zip_filename, 'wb')
276                 while True:
277                     buf = f.read(2**20)
278                     if len(buf) == 0:
279                         break
280                     zip_file.write(buf)
281                 zip_file.close()
282                 
283                 p = subprocess.Popen(["unzip",
284                                       "-q", "-o",
285                                       "-d", path,
286                                       zip_filename],
287                                      stdout=None,
288                                      stdin=None, stderr=sys.stderr,
289                                      shell=False, close_fds=True)
290                 p.wait()
291                 if p.returncode != 0:
292                     lockfile.close()
293                     raise Exception("unzip exited %d" % p.returncode)
294                 os.unlink(zip_filename)
295             os.symlink(zipball, os.path.join(path, '.locator'))
296         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
297         lockfile.close()
298         if len(tld_extracts) == 1:
299             return os.path.join(path, tld_extracts[0])
300         return path
301
302     @staticmethod
303     def collection_extract(collection, path, files=[], decompress=True):
304         """Retrieve a collection from Keep and extract it to a local
305         directory.  Return the absolute path where the collection was
306         extracted.
307
308         collection -- collection locator
309         path -- where to extract: absolute, or relative to job tmp
310         """
311         if not re.search('^/', path):
312             path = os.path.join(current_job().tmpdir, path)
313         lockfile = open(path + '.lock', 'w')
314         fcntl.flock(lockfile, fcntl.LOCK_EX)
315         try:
316             os.stat(path)
317         except OSError:
318             os.mkdir(path)
319         already_have_it = False
320         try:
321             if os.readlink(os.path.join(path, '.locator')) == collection:
322                 already_have_it = True
323         except OSError:
324             pass
325
326         # emulate "rm -f" (i.e., if the file does not exist, we win)
327         try:
328             os.unlink(os.path.join(path, '.locator'))
329         except OSError:
330             if os.path.exists(os.path.join(path, '.locator')):
331                 os.unlink(os.path.join(path, '.locator'))
332
333         files_got = []
334         for f in CollectionReader(collection).all_files():
335             if (files == [] or
336                 ((f.name() not in files_got) and
337                  (f.name() in files or
338                   (decompress and f.decompressed_name() in files)))):
339                 outname = f.decompressed_name() if decompress else f.name()
340                 files_got += [outname]
341                 if os.path.exists(os.path.join(path, outname)):
342                     continue
343                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
344                 outfile = open(os.path.join(path, outname), 'wb')
345                 for buf in (f.readall_decompressed() if decompress
346                             else f.readall()):
347                     outfile.write(buf)
348                 outfile.close()
349         if len(files_got) < len(files):
350             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
351         os.symlink(collection, os.path.join(path, '.locator'))
352
353         lockfile.close()
354         return path
355
356     @staticmethod
357     def mkdir_dash_p(path):
358         if not os.path.exists(path):
359             util.mkdir_dash_p(os.dirname(path))
360             try:
361                 os.mkdir(path)
362             except OSError:
363                 if not os.path.exists(path):
364                     os.mkdir(path)
365
366     @staticmethod
367     def stream_extract(stream, path, files=[], decompress=True):
368         """Retrieve a stream from Keep and extract it to a local
369         directory.  Return the absolute path where the stream was
370         extracted.
371
372         stream -- StreamReader object
373         path -- where to extract: absolute, or relative to job tmp
374         """
375         if not re.search('^/', path):
376             path = os.path.join(current_job().tmpdir, path)
377         lockfile = open(path + '.lock', 'w')
378         fcntl.flock(lockfile, fcntl.LOCK_EX)
379         try:
380             os.stat(path)
381         except OSError:
382             os.mkdir(path)
383
384         files_got = []
385         for f in stream.all_files():
386             if (files == [] or
387                 ((f.name() not in files_got) and
388                  (f.name() in files or
389                   (decompress and f.decompressed_name() in files)))):
390                 outname = f.decompressed_name() if decompress else f.name()
391                 files_got += [outname]
392                 if os.path.exists(os.path.join(path, outname)):
393                     continue
394                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
395                 outfile = open(os.path.join(path, outname), 'wb')
396                 for buf in (f.readall_decompressed() if decompress
397                             else f.readall()):
398                     outfile.write(buf)
399                 outfile.close()
400         if len(files_got) < len(files):
401             raise Exception("Wanted files %s but only got %s from %s" %
402                             (files, files_got, map(lambda z: z.name(),
403                                                    list(stream.all_files()))))
404         lockfile.close()
405         return path
406
407 class DataReader:
408     def __init__(self, data_locator):
409         self.data_locator = data_locator
410         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
411                                   stdout=subprocess.PIPE,
412                                   stdin=None, stderr=subprocess.PIPE,
413                                   shell=False, close_fds=True)
414     def __enter__(self):
415         pass
416     def __exit__(self):
417         self.close()
418     def read(self, size, **kwargs):
419         return self.p.stdout.read(size, **kwargs)
420     def close(self):
421         self.p.stdout.close()
422         if not self.p.stderr.closed:
423             for err in self.p.stderr:
424                 print >> sys.stderr, err
425             self.p.stderr.close()
426         self.p.wait()
427         if self.p.returncode != 0:
428             raise Exception("whget subprocess exited %d" % self.p.returncode)
429
430 class StreamFileReader:
431     def __init__(self, stream, pos, size, name):
432         self._stream = stream
433         self._pos = pos
434         self._size = size
435         self._name = name
436         self._filepos = 0
437     def name(self):
438         return self._name
439     def decompressed_name(self):
440         return re.sub('\.(bz2|gz)$', '', self._name)
441     def size(self):
442         return self._size
443     def stream_name(self):
444         return self._stream.name()
445     def read(self, size, **kwargs):
446         self._stream.seek(self._pos + self._filepos)
447         data = self._stream.read(min(size, self._size - self._filepos))
448         self._filepos += len(data)
449         return data
450     def readall(self, size=2**20, **kwargs):
451         while True:
452             data = self.read(size, **kwargs)
453             if data == '':
454                 break
455             yield data
456     def bunzip2(self, size):
457         decompressor = bz2.BZ2Decompressor()
458         for chunk in self.readall(size):
459             data = decompressor.decompress(chunk)
460             if data and data != '':
461                 yield data
462     def gunzip(self, size):
463         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
464         for chunk in self.readall(size):
465             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
466             if data and data != '':
467                 yield data
468     def readall_decompressed(self, size=2**20):
469         self._stream.seek(self._pos + self._filepos)
470         if re.search('\.bz2$', self._name):
471             return self.bunzip2(size)
472         elif re.search('\.gz$', self._name):
473             return self.gunzip(size)
474         else:
475             return self.readall(size)
476     def readlines(self, decompress=True):
477         if decompress:
478             datasource = self.readall_decompressed()
479         else:
480             self._stream.seek(self._pos + self._filepos)
481             datasource = self.readall()
482         data = ''
483         for newdata in datasource:
484             data += newdata
485             sol = 0
486             while True:
487                 eol = string.find(data, "\n", sol)
488                 if eol < 0:
489                     break
490                 yield data[sol:eol+1]
491                 sol = eol+1
492             data = data[sol:]
493         if data != '':
494             yield data
495     def as_manifest(self):
496         if self.size() == 0:
497             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
498                     % (self._stream.name(), self.name()))
499         return string.join(self._stream.tokens_for_range(self._pos, self._size),
500                            " ") + "\n"
501
502 class StreamReader:
503     def __init__(self, tokens):
504         self._tokens = tokens
505         self._current_datablock_data = None
506         self._current_datablock_pos = 0
507         self._current_datablock_index = -1
508         self._pos = 0
509
510         self._stream_name = None
511         self.data_locators = []
512         self.files = []
513
514         for tok in self._tokens:
515             if self._stream_name == None:
516                 self._stream_name = tok
517             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
518                 self.data_locators += [tok]
519             elif re.search(r'^\d+:\d+:\S+', tok):
520                 pos, size, name = tok.split(':',2)
521                 self.files += [[int(pos), int(size), name]]
522             else:
523                 raise Exception("Invalid manifest format")
524
525     def tokens(self):
526         return self._tokens
527     def tokens_for_range(self, range_start, range_size):
528         resp = [self._stream_name]
529         return_all_tokens = False
530         block_start = 0
531         token_bytes_skipped = 0
532         for locator in self.data_locators:
533             sizehint = re.search(r'\+(\d+)', locator)
534             if not sizehint:
535                 return_all_tokens = True
536             if return_all_tokens:
537                 resp += [locator]
538                 next
539             blocksize = int(sizehint.group(0))
540             if range_start + range_size <= block_start:
541                 break
542             if range_start < block_start + blocksize:
543                 resp += [locator]
544             else:
545                 token_bytes_skipped += blocksize
546             block_start += blocksize
547         for f in self.files:
548             if ((f[0] < range_start + range_size)
549                 and
550                 (f[0] + f[1] > range_start)
551                 and
552                 f[1] > 0):
553                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
554         return resp
555     def name(self):
556         return self._stream_name
557     def all_files(self):
558         for f in self.files:
559             pos, size, name = f
560             yield StreamFileReader(self, pos, size, name)
561     def nextdatablock(self):
562         if self._current_datablock_index < 0:
563             self._current_datablock_pos = 0
564             self._current_datablock_index = 0
565         else:
566             self._current_datablock_pos += self.current_datablock_size()
567             self._current_datablock_index += 1
568         self._current_datablock_data = None
569     def current_datablock_data(self):
570         if self._current_datablock_data == None:
571             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
572         return self._current_datablock_data
573     def current_datablock_size(self):
574         if self._current_datablock_index < 0:
575             self.nextdatablock()
576         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
577         if sizehint:
578             return int(sizehint.group(0))
579         return len(self.current_datablock_data())
580     def seek(self, pos):
581         """Set the position of the next read operation."""
582         self._pos = pos
583     def really_seek(self):
584         """Find and load the appropriate data block, so the byte at
585         _pos is in memory.
586         """
587         if self._pos == self._current_datablock_pos:
588             return True
589         if (self._current_datablock_pos != None and
590             self._pos >= self._current_datablock_pos and
591             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
592             return True
593         if self._pos < self._current_datablock_pos:
594             self._current_datablock_index = -1
595             self.nextdatablock()
596         while (self._pos > self._current_datablock_pos and
597                self._pos > self._current_datablock_pos + self.current_datablock_size()):
598             self.nextdatablock()
599     def read(self, size):
600         """Read no more than size bytes -- but at least one byte,
601         unless _pos is already at the end of the stream.
602         """
603         if size == 0:
604             return ''
605         self.really_seek()
606         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
607             self.nextdatablock()
608             if self._current_datablock_index >= len(self.data_locators):
609                 return None
610         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
611         self._pos += len(data)
612         return data
613
614 class CollectionReader:
615     def __init__(self, manifest_locator_or_text):
616         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
617             self._manifest_text = manifest_locator_or_text
618             self._manifest_locator = None
619         else:
620             self._manifest_locator = manifest_locator_or_text
621             self._manifest_text = None
622         self._streams = None
623     def __enter__(self):
624         pass
625     def __exit__(self):
626         pass
627     def _populate(self):
628         if self._streams != None:
629             return
630         if not self._manifest_text:
631             self._manifest_text = Keep.get(self._manifest_locator)
632         self._streams = []
633         for stream_line in self._manifest_text.split("\n"):
634             stream_tokens = stream_line.split()
635             self._streams += [stream_tokens]
636     def all_streams(self):
637         self._populate()
638         resp = []
639         for s in self._streams:
640             resp += [StreamReader(s)]
641         return resp
642     def all_files(self):
643         for s in self.all_streams():
644             for f in s.all_files():
645                 yield f
646     def manifest_text(self):
647         self._populate()
648         return self._manifest_text
649
650 class CollectionWriter:
651     KEEP_BLOCK_SIZE = 2**26
652     def __init__(self):
653         self._data_buffer = []
654         self._data_buffer_len = 0
655         self._current_stream_files = []
656         self._current_stream_length = 0
657         self._current_stream_locators = []
658         self._current_stream_name = '.'
659         self._current_file_name = None
660         self._current_file_pos = 0
661         self._finished_streams = []
662     def __enter__(self):
663         pass
664     def __exit__(self):
665         self.finish()
666     def write(self, newdata):
667         self._data_buffer += [newdata]
668         self._data_buffer_len += len(newdata)
669         self._current_stream_length += len(newdata)
670         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
671             self.flush_data()
672     def flush_data(self):
673         data_buffer = ''.join(self._data_buffer)
674         if data_buffer != '':
675             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
676             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
677             self._data_buffer_len = len(self._data_buffer[0])
678     def start_new_file(self, newfilename=None):
679         self.finish_current_file()
680         self.set_current_file_name(newfilename)
681     def set_current_file_name(self, newfilename):
682         if re.search(r'[ \t\n]', newfilename):
683             raise AssertionError("Manifest filenames cannot contain whitespace")
684         self._current_file_name = newfilename
685     def current_file_name(self):
686         return self._current_file_name
687     def finish_current_file(self):
688         if self._current_file_name == None:
689             if self._current_file_pos == self._current_stream_length:
690                 return
691             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
692         self._current_stream_files += [[self._current_file_pos,
693                                        self._current_stream_length - self._current_file_pos,
694                                        self._current_file_name]]
695         self._current_file_pos = self._current_stream_length
696     def start_new_stream(self, newstreamname='.'):
697         self.finish_current_stream()
698         self.set_current_stream_name(newstreamname)
699     def set_current_stream_name(self, newstreamname):
700         if re.search(r'[ \t\n]', newstreamname):
701             raise AssertionError("Manifest stream names cannot contain whitespace")
702         self._current_stream_name = newstreamname
703     def current_stream_name(self):
704         return self._current_stream_name
705     def finish_current_stream(self):
706         self.finish_current_file()
707         self.flush_data()
708         if len(self._current_stream_files) == 0:
709             pass
710         elif self._current_stream_name == None:
711             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
712         else:
713             self._finished_streams += [[self._current_stream_name,
714                                        self._current_stream_locators,
715                                        self._current_stream_files]]
716         self._current_stream_files = []
717         self._current_stream_length = 0
718         self._current_stream_locators = []
719         self._current_stream_name = None
720         self._current_file_pos = 0
721         self._current_file_name = None
722     def finish(self):
723         return Keep.put(self.manifest_text())
724     def manifest_text(self):
725         self.finish_current_stream()
726         manifest = ''
727         for stream in self._finished_streams:
728             if not re.search(r'^\.(/.*)?$', stream[0]):
729                 manifest += './'
730             manifest += stream[0]
731             if len(stream[1]) == 0:
732                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
733             else:
734                 for locator in stream[1]:
735                     manifest += " %s" % locator
736             for sfile in stream[2]:
737                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
738             manifest += "\n"
739         return manifest
740
741 class Keep:
742     @staticmethod
743     def put(data):
744         if 'KEEP_LOCAL_STORE' in os.environ:
745             return Keep.local_store_put(data)
746         p = subprocess.Popen(["whput", "-"],
747                              stdout=subprocess.PIPE,
748                              stdin=subprocess.PIPE,
749                              stderr=subprocess.PIPE,
750                              shell=False, close_fds=True)
751         stdoutdata, stderrdata = p.communicate(data)
752         if p.returncode != 0:
753             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
754         return stdoutdata.rstrip()
755     @staticmethod
756     def get(locator):
757         if 'KEEP_LOCAL_STORE' in os.environ:
758             return Keep.local_store_get(locator)
759         p = subprocess.Popen(["whget", locator, "-"],
760                              stdout=subprocess.PIPE,
761                              stdin=None,
762                              stderr=subprocess.PIPE,
763                              shell=False, close_fds=True)
764         stdoutdata, stderrdata = p.communicate(None)
765         if p.returncode != 0:
766             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
767         m = hashlib.new('md5')
768         m.update(stdoutdata)
769         try:
770             if locator.index(m.hexdigest()) == 0:
771                 return stdoutdata
772         except ValueError:
773             pass
774         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))
775     @staticmethod
776     def local_store_put(data):
777         m = hashlib.new('md5')
778         m.update(data)
779         md5 = m.hexdigest()
780         locator = '%s+%d' % (md5, len(data))
781         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
782             f.write(data)
783         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
784                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
785         return locator
786     @staticmethod
787     def local_store_get(locator):
788         r = re.search('^([0-9a-f]{32,})', locator)
789         if not r:
790             raise Exception("Keep.get: invalid data locator '%s'" % locator)
791         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
792             return ''
793         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
794             return f.read()