add Real Time Genomics pipeline template
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13 import string
14 import bz2
15 import zlib
16 import fcntl
17
18 from apiclient import errors
19 from apiclient.discovery import build
20
21 class CredentialsFromEnv:
22     @staticmethod
23     def http_request(self, uri, **kwargs):
24         from httplib import BadStatusLine
25         if 'headers' not in kwargs:
26             kwargs['headers'] = {}
27         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
28         try:
29             return self.orig_http_request(uri, **kwargs)
30         except BadStatusLine:
31             # This is how httplib tells us that it tried to reuse an
32             # existing connection but it was already closed by the
33             # server. In that case, yes, we would like to retry.
34             # Unfortunately, we are not absolutely certain that the
35             # previous call did not succeed, so this is slightly
36             # risky.
37             return self.orig_http_request(uri, **kwargs)
38     def authorize(self, http):
39         http.orig_http_request = http.request
40         http.request = types.MethodType(self.http_request, http)
41         return http
42
43 url = ('https://%s/discovery/v1/apis/'
44        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
45 credentials = CredentialsFromEnv()
46 http = httplib2.Http()
47 http = credentials.authorize(http)
48 http.disable_ssl_certificate_validation=True
49 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
50
51 def task_set_output(self,s):
52     service.job_tasks().update(uuid=self['uuid'],
53                                job_task=json.dumps({
54                 'output':s,
55                 'success':True,
56                 'progress':1.0
57                 })).execute()
58
59 _current_task = None
60 def current_task():
61     global _current_task
62     if _current_task:
63         return _current_task
64     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
65     t = UserDict.UserDict(t)
66     t.set_output = types.MethodType(task_set_output, t)
67     t.tmpdir = os.environ['TASK_WORK']
68     _current_task = t
69     return t
70
71 _current_job = None
72 def current_job():
73     global _current_job
74     if _current_job:
75         return _current_job
76     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
77     t = UserDict.UserDict(t)
78     t.tmpdir = os.environ['JOB_WORK']
79     _current_job = t
80     return t
81
82 def api():
83     return service
84
85 class JobTask:
86     def __init__(self, parameters=dict(), resource_limits=dict()):
87         print "init jobtask %s %s" % (parameters, resource_limits)
88
89 class job_setup:
90     @staticmethod
91     def one_task_per_input_file(if_sequence=0, and_end_task=True):
92         if if_sequence != current_task()['sequence']:
93             return
94         job_input = current_job()['script_parameters']['input']
95         cr = CollectionReader(job_input)
96         for s in cr.all_streams():
97             for f in s.all_files():
98                 task_input = f.as_manifest()
99                 new_task_attrs = {
100                     'job_uuid': current_job()['uuid'],
101                     'created_by_job_task_uuid': current_task()['uuid'],
102                     'sequence': if_sequence + 1,
103                     'parameters': {
104                         'input':task_input
105                         }
106                     }
107                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
108         if and_end_task:
109             service.job_tasks().update(uuid=current_task()['uuid'],
110                                        job_task=json.dumps({'success':True})
111                                        ).execute()
112             exit(0)
113
114     @staticmethod
115     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
116         if if_sequence != current_task()['sequence']:
117             return
118         job_input = current_job()['script_parameters']['input']
119         cr = CollectionReader(job_input)
120         for s in cr.all_streams():
121             task_input = s.tokens()
122             new_task_attrs = {
123                 'job_uuid': current_job()['uuid'],
124                 'created_by_job_task_uuid': current_task()['uuid'],
125                 'sequence': if_sequence + 1,
126                 'parameters': {
127                     'input':task_input
128                     }
129                 }
130             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
131         if and_end_task:
132             service.job_tasks().update(uuid=current_task()['uuid'],
133                                        job_task=json.dumps({'success':True})
134                                        ).execute()
135             exit(0)
136
137 class util:
138     @staticmethod
139     def run_command(execargs, **kwargs):
140         if 'stdin' not in kwargs:
141             kwargs['stdin'] = subprocess.PIPE
142         if 'stdout' not in kwargs:
143             kwargs['stdout'] = subprocess.PIPE
144         if 'stderr' not in kwargs:
145             kwargs['stderr'] = subprocess.PIPE
146         p = subprocess.Popen(execargs, close_fds=True, shell=False,
147                              **kwargs)
148         stdoutdata, stderrdata = p.communicate(None)
149         if p.returncode != 0:
150             raise Exception("run_command %s exit %d:\n%s" %
151                             (execargs, p.returncode, stderrdata))
152         return stdoutdata, stderrdata
153
154     @staticmethod
155     def git_checkout(url, version, path):
156         if not re.search('^/', path):
157             path = os.path.join(current_job().tmpdir, path)
158         if not os.path.exists(path):
159             util.run_command(["git", "clone", url, path],
160                              cwd=os.path.dirname(path))
161         util.run_command(["git", "checkout", version],
162                          cwd=path)
163         return path
164
165     @staticmethod
166     def tar_extractor(path, decompress_flag):
167         return subprocess.Popen(["tar",
168                                  "-C", path,
169                                  ("-x%sf" % decompress_flag),
170                                  "-"],
171                                 stdout=None,
172                                 stdin=subprocess.PIPE, stderr=sys.stderr,
173                                 shell=False, close_fds=True)
174
175     @staticmethod
176     def tarball_extract(tarball, path):
177         """Retrieve a tarball from Keep and extract it to a local
178         directory.  Return the absolute path where the tarball was
179         extracted. If the top level of the tarball contained just one
180         file or directory, return the absolute path of that single
181         item.
182
183         tarball -- collection locator
184         path -- where to extract the tarball: absolute, or relative to job tmp
185         """
186         if not re.search('^/', path):
187             path = os.path.join(current_job().tmpdir, path)
188         lockfile = open(path + '.lock', 'w')
189         fcntl.flock(lockfile, fcntl.LOCK_EX)
190         try:
191             os.stat(path)
192         except OSError:
193             os.mkdir(path)
194         already_have_it = False
195         try:
196             if os.readlink(os.path.join(path, '.locator')) == tarball:
197                 already_have_it = True
198         except OSError:
199             pass
200         if not already_have_it:
201
202             # emulate "rm -f" (i.e., if the file does not exist, we win)
203             try:
204                 os.unlink(os.path.join(path, '.locator'))
205             except OSError:
206                 if os.path.exists(os.path.join(path, '.locator')):
207                     os.unlink(os.path.join(path, '.locator'))
208
209             for f in CollectionReader(tarball).all_files():
210                 if re.search('\.(tbz|tar.bz2)$', f.name()):
211                     p = util.tar_extractor(path, 'j')
212                 elif re.search('\.(tgz|tar.gz)$', f.name()):
213                     p = util.tar_extractor(path, 'z')
214                 elif re.search('\.tar$', f.name()):
215                     p = util.tar_extractor(path, '')
216                 else:
217                     raise Exception("tarball_extract cannot handle filename %s"
218                                     % f.name())
219                 while True:
220                     buf = f.read(2**20)
221                     if len(buf) == 0:
222                         break
223                     p.stdin.write(buf)
224                 p.stdin.close()
225                 p.wait()
226                 if p.returncode != 0:
227                     lockfile.close()
228                     raise Exception("tar exited %d" % p.returncode)
229             os.symlink(tarball, os.path.join(path, '.locator'))
230         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
231         lockfile.close()
232         if len(tld_extracts) == 1:
233             return os.path.join(path, tld_extracts[0])
234         return path
235
236     @staticmethod
237     def zipball_extract(zipball, path):
238         """Retrieve a zip archive from Keep and extract it to a local
239         directory.  Return the absolute path where the archive was
240         extracted. If the top level of the archive contained just one
241         file or directory, return the absolute path of that single
242         item.
243
244         zipball -- collection locator
245         path -- where to extract the archive: absolute, or relative to job tmp
246         """
247         if not re.search('^/', path):
248             path = os.path.join(current_job().tmpdir, path)
249         lockfile = open(path + '.lock', 'w')
250         fcntl.flock(lockfile, fcntl.LOCK_EX)
251         try:
252             os.stat(path)
253         except OSError:
254             os.mkdir(path)
255         already_have_it = False
256         try:
257             if os.readlink(os.path.join(path, '.locator')) == zipball:
258                 already_have_it = True
259         except OSError:
260             pass
261         if not already_have_it:
262
263             # emulate "rm -f" (i.e., if the file does not exist, we win)
264             try:
265                 os.unlink(os.path.join(path, '.locator'))
266             except OSError:
267                 if os.path.exists(os.path.join(path, '.locator')):
268                     os.unlink(os.path.join(path, '.locator'))
269
270             for f in CollectionReader(zipball).all_files():
271                 if not re.search('\.zip$', f.name()):
272                     raise Exception("zipball_extract cannot handle filename %s"
273                                     % f.name())
274                 zip_filename = os.path.join(path, os.path.basename(f.name()))
275                 zip_file = open(zip_filename, 'wb')
276                 while True:
277                     buf = f.read(2**20)
278                     if len(buf) == 0:
279                         break
280                     zip_file.write(buf)
281                 zip_file.close()
282                 
283                 p = subprocess.Popen(["unzip",
284                                       "-q", "-o",
285                                       "-d", path,
286                                       zip_filename],
287                                      stdout=None,
288                                      stdin=None, stderr=sys.stderr,
289                                      shell=False, close_fds=True)
290                 p.wait()
291                 if p.returncode != 0:
292                     lockfile.close()
293                     raise Exception("unzip exited %d" % p.returncode)
294                 os.unlink(zip_filename)
295             os.symlink(zipball, os.path.join(path, '.locator'))
296         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
297         lockfile.close()
298         if len(tld_extracts) == 1:
299             return os.path.join(path, tld_extracts[0])
300         return path
301
302     @staticmethod
303     def collection_extract(collection, path, files=[], decompress=True):
304         """Retrieve a collection from Keep and extract it to a local
305         directory.  Return the absolute path where the collection was
306         extracted.
307
308         collection -- collection locator
309         path -- where to extract: absolute, or relative to job tmp
310         """
311         if not re.search('^/', path):
312             path = os.path.join(current_job().tmpdir, path)
313         lockfile = open(path + '.lock', 'w')
314         fcntl.flock(lockfile, fcntl.LOCK_EX)
315         try:
316             os.stat(path)
317         except OSError:
318             os.mkdir(path)
319         already_have_it = False
320         try:
321             if os.readlink(os.path.join(path, '.locator')) == collection:
322                 already_have_it = True
323         except OSError:
324             pass
325
326         # emulate "rm -f" (i.e., if the file does not exist, we win)
327         try:
328             os.unlink(os.path.join(path, '.locator'))
329         except OSError:
330             if os.path.exists(os.path.join(path, '.locator')):
331                 os.unlink(os.path.join(path, '.locator'))
332
333         files_got = []
334         for s in CollectionReader(collection).all_streams():
335             stream_name = s.name()
336             for f in s.all_files():
337                 if (files == [] or
338                     ((f.name() not in files_got) and
339                      (f.name() in files or
340                       (decompress and f.decompressed_name() in files)))):
341                     outname = f.decompressed_name() if decompress else f.name()
342                     files_got += [outname]
343                     if os.path.exists(os.path.join(path, stream_name, outname)):
344                         continue
345                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
346                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
347                     for buf in (f.readall_decompressed() if decompress
348                                 else f.readall()):
349                         outfile.write(buf)
350                     outfile.close()
351         if len(files_got) < len(files):
352             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
353         os.symlink(collection, os.path.join(path, '.locator'))
354
355         lockfile.close()
356         return path
357
358     @staticmethod
359     def mkdir_dash_p(path):
360         if not os.path.exists(path):
361             util.mkdir_dash_p(os.path.dirname(path))
362             try:
363                 os.mkdir(path)
364             except OSError:
365                 if not os.path.exists(path):
366                     os.mkdir(path)
367
368     @staticmethod
369     def stream_extract(stream, path, files=[], decompress=True):
370         """Retrieve a stream from Keep and extract it to a local
371         directory.  Return the absolute path where the stream was
372         extracted.
373
374         stream -- StreamReader object
375         path -- where to extract: absolute, or relative to job tmp
376         """
377         if not re.search('^/', path):
378             path = os.path.join(current_job().tmpdir, path)
379         lockfile = open(path + '.lock', 'w')
380         fcntl.flock(lockfile, fcntl.LOCK_EX)
381         try:
382             os.stat(path)
383         except OSError:
384             os.mkdir(path)
385
386         files_got = []
387         for f in stream.all_files():
388             if (files == [] or
389                 ((f.name() not in files_got) and
390                  (f.name() in files or
391                   (decompress and f.decompressed_name() in files)))):
392                 outname = f.decompressed_name() if decompress else f.name()
393                 files_got += [outname]
394                 if os.path.exists(os.path.join(path, outname)):
395                     os.unlink(os.path.join(path, outname))
396                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
397                 outfile = open(os.path.join(path, outname), 'wb')
398                 for buf in (f.readall_decompressed() if decompress
399                             else f.readall()):
400                     outfile.write(buf)
401                 outfile.close()
402         if len(files_got) < len(files):
403             raise Exception("Wanted files %s but only got %s from %s" %
404                             (files, files_got, map(lambda z: z.name(),
405                                                    list(stream.all_files()))))
406         lockfile.close()
407         return path
408
409     @staticmethod
410     def listdir_recursive(dirname, base=None):
411         allfiles = []
412         for ent in sorted(os.listdir(dirname)):
413             ent_path = os.path.join(dirname, ent)
414             ent_base = os.path.join(base, ent) if base else ent
415             if os.path.isdir(ent_path):
416                 allfiles += util.listdir_recursive(ent_path, ent_base)
417             else:
418                 allfiles += [ent_base]
419         return allfiles
420
421 class DataReader:
422     def __init__(self, data_locator):
423         self.data_locator = data_locator
424         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
425                                   stdout=subprocess.PIPE,
426                                   stdin=None, stderr=subprocess.PIPE,
427                                   shell=False, close_fds=True)
428     def __enter__(self):
429         pass
430     def __exit__(self):
431         self.close()
432     def read(self, size, **kwargs):
433         return self.p.stdout.read(size, **kwargs)
434     def close(self):
435         self.p.stdout.close()
436         if not self.p.stderr.closed:
437             for err in self.p.stderr:
438                 print >> sys.stderr, err
439             self.p.stderr.close()
440         self.p.wait()
441         if self.p.returncode != 0:
442             raise Exception("whget subprocess exited %d" % self.p.returncode)
443
444 class StreamFileReader:
445     def __init__(self, stream, pos, size, name):
446         self._stream = stream
447         self._pos = pos
448         self._size = size
449         self._name = name
450         self._filepos = 0
451     def name(self):
452         return self._name
453     def decompressed_name(self):
454         return re.sub('\.(bz2|gz)$', '', self._name)
455     def size(self):
456         return self._size
457     def stream_name(self):
458         return self._stream.name()
459     def read(self, size, **kwargs):
460         self._stream.seek(self._pos + self._filepos)
461         data = self._stream.read(min(size, self._size - self._filepos))
462         self._filepos += len(data)
463         return data
464     def readall(self, size=2**20, **kwargs):
465         while True:
466             data = self.read(size, **kwargs)
467             if data == '':
468                 break
469             yield data
470     def bunzip2(self, size):
471         decompressor = bz2.BZ2Decompressor()
472         for chunk in self.readall(size):
473             data = decompressor.decompress(chunk)
474             if data and data != '':
475                 yield data
476     def gunzip(self, size):
477         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
478         for chunk in self.readall(size):
479             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
480             if data and data != '':
481                 yield data
482     def readall_decompressed(self, size=2**20):
483         self._stream.seek(self._pos + self._filepos)
484         if re.search('\.bz2$', self._name):
485             return self.bunzip2(size)
486         elif re.search('\.gz$', self._name):
487             return self.gunzip(size)
488         else:
489             return self.readall(size)
490     def readlines(self, decompress=True):
491         if decompress:
492             datasource = self.readall_decompressed()
493         else:
494             self._stream.seek(self._pos + self._filepos)
495             datasource = self.readall()
496         data = ''
497         for newdata in datasource:
498             data += newdata
499             sol = 0
500             while True:
501                 eol = string.find(data, "\n", sol)
502                 if eol < 0:
503                     break
504                 yield data[sol:eol+1]
505                 sol = eol+1
506             data = data[sol:]
507         if data != '':
508             yield data
509     def as_manifest(self):
510         if self.size() == 0:
511             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
512                     % (self._stream.name(), self.name()))
513         return string.join(self._stream.tokens_for_range(self._pos, self._size),
514                            " ") + "\n"
515
516 class StreamReader:
517     def __init__(self, tokens):
518         self._tokens = tokens
519         self._current_datablock_data = None
520         self._current_datablock_pos = 0
521         self._current_datablock_index = -1
522         self._pos = 0
523
524         self._stream_name = None
525         self.data_locators = []
526         self.files = []
527
528         for tok in self._tokens:
529             if self._stream_name == None:
530                 self._stream_name = tok
531             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
532                 self.data_locators += [tok]
533             elif re.search(r'^\d+:\d+:\S+', tok):
534                 pos, size, name = tok.split(':',2)
535                 self.files += [[int(pos), int(size), name]]
536             else:
537                 raise Exception("Invalid manifest format")
538
539     def tokens(self):
540         return self._tokens
541     def tokens_for_range(self, range_start, range_size):
542         resp = [self._stream_name]
543         return_all_tokens = False
544         block_start = 0
545         token_bytes_skipped = 0
546         for locator in self.data_locators:
547             sizehint = re.search(r'\+(\d+)', locator)
548             if not sizehint:
549                 return_all_tokens = True
550             if return_all_tokens:
551                 resp += [locator]
552                 next
553             blocksize = int(sizehint.group(0))
554             if range_start + range_size <= block_start:
555                 break
556             if range_start < block_start + blocksize:
557                 resp += [locator]
558             else:
559                 token_bytes_skipped += blocksize
560             block_start += blocksize
561         for f in self.files:
562             if ((f[0] < range_start + range_size)
563                 and
564                 (f[0] + f[1] > range_start)
565                 and
566                 f[1] > 0):
567                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
568         return resp
569     def name(self):
570         return self._stream_name
571     def all_files(self):
572         for f in self.files:
573             pos, size, name = f
574             yield StreamFileReader(self, pos, size, name)
575     def nextdatablock(self):
576         if self._current_datablock_index < 0:
577             self._current_datablock_pos = 0
578             self._current_datablock_index = 0
579         else:
580             self._current_datablock_pos += self.current_datablock_size()
581             self._current_datablock_index += 1
582         self._current_datablock_data = None
583     def current_datablock_data(self):
584         if self._current_datablock_data == None:
585             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
586         return self._current_datablock_data
587     def current_datablock_size(self):
588         if self._current_datablock_index < 0:
589             self.nextdatablock()
590         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
591         if sizehint:
592             return int(sizehint.group(0))
593         return len(self.current_datablock_data())
594     def seek(self, pos):
595         """Set the position of the next read operation."""
596         self._pos = pos
597     def really_seek(self):
598         """Find and load the appropriate data block, so the byte at
599         _pos is in memory.
600         """
601         if self._pos == self._current_datablock_pos:
602             return True
603         if (self._current_datablock_pos != None and
604             self._pos >= self._current_datablock_pos and
605             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
606             return True
607         if self._pos < self._current_datablock_pos:
608             self._current_datablock_index = -1
609             self.nextdatablock()
610         while (self._pos > self._current_datablock_pos and
611                self._pos > self._current_datablock_pos + self.current_datablock_size()):
612             self.nextdatablock()
613     def read(self, size):
614         """Read no more than size bytes -- but at least one byte,
615         unless _pos is already at the end of the stream.
616         """
617         if size == 0:
618             return ''
619         self.really_seek()
620         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
621             self.nextdatablock()
622             if self._current_datablock_index >= len(self.data_locators):
623                 return None
624         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
625         self._pos += len(data)
626         return data
627
628 class CollectionReader:
629     def __init__(self, manifest_locator_or_text):
630         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
631             self._manifest_text = manifest_locator_or_text
632             self._manifest_locator = None
633         else:
634             self._manifest_locator = manifest_locator_or_text
635             self._manifest_text = None
636         self._streams = None
637     def __enter__(self):
638         pass
639     def __exit__(self):
640         pass
641     def _populate(self):
642         if self._streams != None:
643             return
644         if not self._manifest_text:
645             self._manifest_text = Keep.get(self._manifest_locator)
646         self._streams = []
647         for stream_line in self._manifest_text.split("\n"):
648             if stream_line != '':
649                 stream_tokens = stream_line.split()
650                 self._streams += [stream_tokens]
651     def all_streams(self):
652         self._populate()
653         resp = []
654         for s in self._streams:
655             resp += [StreamReader(s)]
656         return resp
657     def all_files(self):
658         for s in self.all_streams():
659             for f in s.all_files():
660                 yield f
661     def manifest_text(self):
662         self._populate()
663         return self._manifest_text
664
665 class CollectionWriter:
666     KEEP_BLOCK_SIZE = 2**26
667     def __init__(self):
668         self._data_buffer = []
669         self._data_buffer_len = 0
670         self._current_stream_files = []
671         self._current_stream_length = 0
672         self._current_stream_locators = []
673         self._current_stream_name = '.'
674         self._current_file_name = None
675         self._current_file_pos = 0
676         self._finished_streams = []
677     def __enter__(self):
678         pass
679     def __exit__(self):
680         self.finish()
681     def write_directory_tree(self,
682                              path, stream_name='.', max_manifest_depth=-1):
683         self.start_new_stream(stream_name)
684         todo = []
685         if max_manifest_depth == 0:
686             dirents = util.listdir_recursive(path)
687         else:
688             dirents = sorted(os.listdir(path))
689         for dirent in dirents:
690             target = os.path.join(path, dirent)
691             if os.path.isdir(target):
692                 todo += [[target,
693                           os.path.join(stream_name, dirent),
694                           max_manifest_depth-1]]
695             else:
696                 self.start_new_file(dirent)
697                 with open(target, 'rb') as f:
698                     while True:
699                         buf = f.read(2**26)
700                         if len(buf) == 0:
701                             break
702                         self.write(buf)
703         self.finish_current_stream()
704         map(lambda x: self.write_directory_tree(*x), todo)
705
706     def write(self, newdata):
707         self._data_buffer += [newdata]
708         self._data_buffer_len += len(newdata)
709         self._current_stream_length += len(newdata)
710         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
711             self.flush_data()
712     def flush_data(self):
713         data_buffer = ''.join(self._data_buffer)
714         if data_buffer != '':
715             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
716             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
717             self._data_buffer_len = len(self._data_buffer[0])
718     def start_new_file(self, newfilename=None):
719         self.finish_current_file()
720         self.set_current_file_name(newfilename)
721     def set_current_file_name(self, newfilename):
722         newfilename = re.sub(r' ', '\\\\040', newfilename)
723         if re.search(r'[ \t\n]', newfilename):
724             raise AssertionError("Manifest filenames cannot contain whitespace")
725         self._current_file_name = newfilename
726     def current_file_name(self):
727         return self._current_file_name
728     def finish_current_file(self):
729         if self._current_file_name == None:
730             if self._current_file_pos == self._current_stream_length:
731                 return
732             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
733         self._current_stream_files += [[self._current_file_pos,
734                                        self._current_stream_length - self._current_file_pos,
735                                        self._current_file_name]]
736         self._current_file_pos = self._current_stream_length
737     def start_new_stream(self, newstreamname='.'):
738         self.finish_current_stream()
739         self.set_current_stream_name(newstreamname)
740     def set_current_stream_name(self, newstreamname):
741         if re.search(r'[ \t\n]', newstreamname):
742             raise AssertionError("Manifest stream names cannot contain whitespace")
743         self._current_stream_name = newstreamname
744     def current_stream_name(self):
745         return self._current_stream_name
746     def finish_current_stream(self):
747         self.finish_current_file()
748         self.flush_data()
749         if len(self._current_stream_files) == 0:
750             pass
751         elif self._current_stream_name == None:
752             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
753         else:
754             self._finished_streams += [[self._current_stream_name,
755                                        self._current_stream_locators,
756                                        self._current_stream_files]]
757         self._current_stream_files = []
758         self._current_stream_length = 0
759         self._current_stream_locators = []
760         self._current_stream_name = None
761         self._current_file_pos = 0
762         self._current_file_name = None
763     def finish(self):
764         return Keep.put(self.manifest_text())
765     def manifest_text(self):
766         self.finish_current_stream()
767         manifest = ''
768         for stream in self._finished_streams:
769             if not re.search(r'^\.(/.*)?$', stream[0]):
770                 manifest += './'
771             manifest += stream[0]
772             if len(stream[1]) == 0:
773                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
774             else:
775                 for locator in stream[1]:
776                     manifest += " %s" % locator
777             for sfile in stream[2]:
778                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
779             manifest += "\n"
780         return manifest
781
782 class Keep:
783     @staticmethod
784     def put(data):
785         if 'KEEP_LOCAL_STORE' in os.environ:
786             return Keep.local_store_put(data)
787         p = subprocess.Popen(["whput", "-"],
788                              stdout=subprocess.PIPE,
789                              stdin=subprocess.PIPE,
790                              stderr=subprocess.PIPE,
791                              shell=False, close_fds=True)
792         stdoutdata, stderrdata = p.communicate(data)
793         if p.returncode != 0:
794             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
795         return stdoutdata.rstrip()
796     @staticmethod
797     def get(locator):
798         if 'KEEP_LOCAL_STORE' in os.environ:
799             return Keep.local_store_get(locator)
800         p = subprocess.Popen(["whget", locator, "-"],
801                              stdout=subprocess.PIPE,
802                              stdin=None,
803                              stderr=subprocess.PIPE,
804                              shell=False, close_fds=True)
805         stdoutdata, stderrdata = p.communicate(None)
806         if p.returncode != 0:
807             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
808         m = hashlib.new('md5')
809         m.update(stdoutdata)
810         try:
811             if locator.index(m.hexdigest()) == 0:
812                 return stdoutdata
813         except ValueError:
814             pass
815         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))
816     @staticmethod
817     def local_store_put(data):
818         m = hashlib.new('md5')
819         m.update(data)
820         md5 = m.hexdigest()
821         locator = '%s+%d' % (md5, len(data))
822         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
823             f.write(data)
824         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
825                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
826         return locator
827     @staticmethod
828     def local_store_get(locator):
829         r = re.search('^([0-9a-f]{32,})', locator)
830         if not r:
831             raise Exception("Keep.get: invalid data locator '%s'" % locator)
832         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
833             return ''
834         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
835             return f.read()