remove unused class DataReader
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19
20 from apiclient import errors
21 from apiclient.discovery import build
22
23 if 'ARVADOS_DEBUG' in os.environ:
24     logging.basicConfig(level=logging.DEBUG)
25
26 class CredentialsFromEnv:
27     @staticmethod
28     def http_request(self, uri, **kwargs):
29         from httplib import BadStatusLine
30         if 'headers' not in kwargs:
31             kwargs['headers'] = {}
32         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
33         try:
34             return self.orig_http_request(uri, **kwargs)
35         except BadStatusLine:
36             # This is how httplib tells us that it tried to reuse an
37             # existing connection but it was already closed by the
38             # server. In that case, yes, we would like to retry.
39             # Unfortunately, we are not absolutely certain that the
40             # previous call did not succeed, so this is slightly
41             # risky.
42             return self.orig_http_request(uri, **kwargs)
43     def authorize(self, http):
44         http.orig_http_request = http.request
45         http.request = types.MethodType(self.http_request, http)
46         return http
47
48 url = ('https://%s/discovery/v1/apis/'
49        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
50 credentials = CredentialsFromEnv()
51
52 # Use system's CA certificates (if we find them) instead of httplib2's
53 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
54 if not os.path.exists(ca_certs):
55     ca_certs = None             # use httplib2 default
56
57 http = httplib2.Http(ca_certs=ca_certs)
58 http = credentials.authorize(http)
59 if re.match(r'(?i)^(true|1|yes)$',
60             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
61     http.disable_ssl_certificate_validation=True
62 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
63
64 def task_set_output(self,s):
65     service.job_tasks().update(uuid=self['uuid'],
66                                job_task=json.dumps({
67                 'output':s,
68                 'success':True,
69                 'progress':1.0
70                 })).execute()
71
72 _current_task = None
73 def current_task():
74     global _current_task
75     if _current_task:
76         return _current_task
77     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
78     t = UserDict.UserDict(t)
79     t.set_output = types.MethodType(task_set_output, t)
80     t.tmpdir = os.environ['TASK_WORK']
81     _current_task = t
82     return t
83
84 _current_job = None
85 def current_job():
86     global _current_job
87     if _current_job:
88         return _current_job
89     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
90     t = UserDict.UserDict(t)
91     t.tmpdir = os.environ['JOB_WORK']
92     _current_job = t
93     return t
94
95 def api():
96     return service
97
98 class JobTask:
99     def __init__(self, parameters=dict(), runtime_constraints=dict()):
100         print "init jobtask %s %s" % (parameters, runtime_constraints)
101
102 class job_setup:
103     @staticmethod
104     def one_task_per_input_file(if_sequence=0, and_end_task=True):
105         if if_sequence != current_task()['sequence']:
106             return
107         job_input = current_job()['script_parameters']['input']
108         cr = CollectionReader(job_input)
109         for s in cr.all_streams():
110             for f in s.all_files():
111                 task_input = f.as_manifest()
112                 new_task_attrs = {
113                     'job_uuid': current_job()['uuid'],
114                     'created_by_job_task_uuid': current_task()['uuid'],
115                     'sequence': if_sequence + 1,
116                     'parameters': {
117                         'input':task_input
118                         }
119                     }
120                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
121         if and_end_task:
122             service.job_tasks().update(uuid=current_task()['uuid'],
123                                        job_task=json.dumps({'success':True})
124                                        ).execute()
125             exit(0)
126
127     @staticmethod
128     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
129         if if_sequence != current_task()['sequence']:
130             return
131         job_input = current_job()['script_parameters']['input']
132         cr = CollectionReader(job_input)
133         for s in cr.all_streams():
134             task_input = s.tokens()
135             new_task_attrs = {
136                 'job_uuid': current_job()['uuid'],
137                 'created_by_job_task_uuid': current_task()['uuid'],
138                 'sequence': if_sequence + 1,
139                 'parameters': {
140                     'input':task_input
141                     }
142                 }
143             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
144         if and_end_task:
145             service.job_tasks().update(uuid=current_task()['uuid'],
146                                        job_task=json.dumps({'success':True})
147                                        ).execute()
148             exit(0)
149
150 class util:
151     @staticmethod
152     def run_command(execargs, **kwargs):
153         if 'stdin' not in kwargs:
154             kwargs['stdin'] = subprocess.PIPE
155         if 'stdout' not in kwargs:
156             kwargs['stdout'] = subprocess.PIPE
157         if 'stderr' not in kwargs:
158             kwargs['stderr'] = subprocess.PIPE
159         p = subprocess.Popen(execargs, close_fds=True, shell=False,
160                              **kwargs)
161         stdoutdata, stderrdata = p.communicate(None)
162         if p.returncode != 0:
163             raise Exception("run_command %s exit %d:\n%s" %
164                             (execargs, p.returncode, stderrdata))
165         return stdoutdata, stderrdata
166
167     @staticmethod
168     def git_checkout(url, version, path):
169         if not re.search('^/', path):
170             path = os.path.join(current_job().tmpdir, path)
171         if not os.path.exists(path):
172             util.run_command(["git", "clone", url, path],
173                              cwd=os.path.dirname(path))
174         util.run_command(["git", "checkout", version],
175                          cwd=path)
176         return path
177
178     @staticmethod
179     def tar_extractor(path, decompress_flag):
180         return subprocess.Popen(["tar",
181                                  "-C", path,
182                                  ("-x%sf" % decompress_flag),
183                                  "-"],
184                                 stdout=None,
185                                 stdin=subprocess.PIPE, stderr=sys.stderr,
186                                 shell=False, close_fds=True)
187
188     @staticmethod
189     def tarball_extract(tarball, path):
190         """Retrieve a tarball from Keep and extract it to a local
191         directory.  Return the absolute path where the tarball was
192         extracted. If the top level of the tarball contained just one
193         file or directory, return the absolute path of that single
194         item.
195
196         tarball -- collection locator
197         path -- where to extract the tarball: absolute, or relative to job tmp
198         """
199         if not re.search('^/', path):
200             path = os.path.join(current_job().tmpdir, path)
201         lockfile = open(path + '.lock', 'w')
202         fcntl.flock(lockfile, fcntl.LOCK_EX)
203         try:
204             os.stat(path)
205         except OSError:
206             os.mkdir(path)
207         already_have_it = False
208         try:
209             if os.readlink(os.path.join(path, '.locator')) == tarball:
210                 already_have_it = True
211         except OSError:
212             pass
213         if not already_have_it:
214
215             # emulate "rm -f" (i.e., if the file does not exist, we win)
216             try:
217                 os.unlink(os.path.join(path, '.locator'))
218             except OSError:
219                 if os.path.exists(os.path.join(path, '.locator')):
220                     os.unlink(os.path.join(path, '.locator'))
221
222             for f in CollectionReader(tarball).all_files():
223                 if re.search('\.(tbz|tar.bz2)$', f.name()):
224                     p = util.tar_extractor(path, 'j')
225                 elif re.search('\.(tgz|tar.gz)$', f.name()):
226                     p = util.tar_extractor(path, 'z')
227                 elif re.search('\.tar$', f.name()):
228                     p = util.tar_extractor(path, '')
229                 else:
230                     raise Exception("tarball_extract cannot handle filename %s"
231                                     % f.name())
232                 while True:
233                     buf = f.read(2**20)
234                     if len(buf) == 0:
235                         break
236                     p.stdin.write(buf)
237                 p.stdin.close()
238                 p.wait()
239                 if p.returncode != 0:
240                     lockfile.close()
241                     raise Exception("tar exited %d" % p.returncode)
242             os.symlink(tarball, os.path.join(path, '.locator'))
243         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
244         lockfile.close()
245         if len(tld_extracts) == 1:
246             return os.path.join(path, tld_extracts[0])
247         return path
248
249     @staticmethod
250     def zipball_extract(zipball, path):
251         """Retrieve a zip archive from Keep and extract it to a local
252         directory.  Return the absolute path where the archive was
253         extracted. If the top level of the archive contained just one
254         file or directory, return the absolute path of that single
255         item.
256
257         zipball -- collection locator
258         path -- where to extract the archive: absolute, or relative to job tmp
259         """
260         if not re.search('^/', path):
261             path = os.path.join(current_job().tmpdir, path)
262         lockfile = open(path + '.lock', 'w')
263         fcntl.flock(lockfile, fcntl.LOCK_EX)
264         try:
265             os.stat(path)
266         except OSError:
267             os.mkdir(path)
268         already_have_it = False
269         try:
270             if os.readlink(os.path.join(path, '.locator')) == zipball:
271                 already_have_it = True
272         except OSError:
273             pass
274         if not already_have_it:
275
276             # emulate "rm -f" (i.e., if the file does not exist, we win)
277             try:
278                 os.unlink(os.path.join(path, '.locator'))
279             except OSError:
280                 if os.path.exists(os.path.join(path, '.locator')):
281                     os.unlink(os.path.join(path, '.locator'))
282
283             for f in CollectionReader(zipball).all_files():
284                 if not re.search('\.zip$', f.name()):
285                     raise Exception("zipball_extract cannot handle filename %s"
286                                     % f.name())
287                 zip_filename = os.path.join(path, os.path.basename(f.name()))
288                 zip_file = open(zip_filename, 'wb')
289                 while True:
290                     buf = f.read(2**20)
291                     if len(buf) == 0:
292                         break
293                     zip_file.write(buf)
294                 zip_file.close()
295                 
296                 p = subprocess.Popen(["unzip",
297                                       "-q", "-o",
298                                       "-d", path,
299                                       zip_filename],
300                                      stdout=None,
301                                      stdin=None, stderr=sys.stderr,
302                                      shell=False, close_fds=True)
303                 p.wait()
304                 if p.returncode != 0:
305                     lockfile.close()
306                     raise Exception("unzip exited %d" % p.returncode)
307                 os.unlink(zip_filename)
308             os.symlink(zipball, os.path.join(path, '.locator'))
309         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
310         lockfile.close()
311         if len(tld_extracts) == 1:
312             return os.path.join(path, tld_extracts[0])
313         return path
314
315     @staticmethod
316     def collection_extract(collection, path, files=[], decompress=True):
317         """Retrieve a collection from Keep and extract it to a local
318         directory.  Return the absolute path where the collection was
319         extracted.
320
321         collection -- collection locator
322         path -- where to extract: absolute, or relative to job tmp
323         """
324         if not re.search('^/', path):
325             path = os.path.join(current_job().tmpdir, path)
326         lockfile = open(path + '.lock', 'w')
327         fcntl.flock(lockfile, fcntl.LOCK_EX)
328         try:
329             os.stat(path)
330         except OSError:
331             os.mkdir(path)
332         already_have_it = False
333         try:
334             if os.readlink(os.path.join(path, '.locator')) == collection:
335                 already_have_it = True
336         except OSError:
337             pass
338
339         # emulate "rm -f" (i.e., if the file does not exist, we win)
340         try:
341             os.unlink(os.path.join(path, '.locator'))
342         except OSError:
343             if os.path.exists(os.path.join(path, '.locator')):
344                 os.unlink(os.path.join(path, '.locator'))
345
346         files_got = []
347         for s in CollectionReader(collection).all_streams():
348             stream_name = s.name()
349             for f in s.all_files():
350                 if (files == [] or
351                     ((f.name() not in files_got) and
352                      (f.name() in files or
353                       (decompress and f.decompressed_name() in files)))):
354                     outname = f.decompressed_name() if decompress else f.name()
355                     files_got += [outname]
356                     if os.path.exists(os.path.join(path, stream_name, outname)):
357                         continue
358                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
359                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
360                     for buf in (f.readall_decompressed() if decompress
361                                 else f.readall()):
362                         outfile.write(buf)
363                     outfile.close()
364         if len(files_got) < len(files):
365             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
366         os.symlink(collection, os.path.join(path, '.locator'))
367
368         lockfile.close()
369         return path
370
371     @staticmethod
372     def mkdir_dash_p(path):
373         if not os.path.exists(path):
374             util.mkdir_dash_p(os.path.dirname(path))
375             try:
376                 os.mkdir(path)
377             except OSError:
378                 if not os.path.exists(path):
379                     os.mkdir(path)
380
381     @staticmethod
382     def stream_extract(stream, path, files=[], decompress=True):
383         """Retrieve a stream from Keep and extract it to a local
384         directory.  Return the absolute path where the stream was
385         extracted.
386
387         stream -- StreamReader object
388         path -- where to extract: absolute, or relative to job tmp
389         """
390         if not re.search('^/', path):
391             path = os.path.join(current_job().tmpdir, path)
392         lockfile = open(path + '.lock', 'w')
393         fcntl.flock(lockfile, fcntl.LOCK_EX)
394         try:
395             os.stat(path)
396         except OSError:
397             os.mkdir(path)
398
399         files_got = []
400         for f in stream.all_files():
401             if (files == [] or
402                 ((f.name() not in files_got) and
403                  (f.name() in files or
404                   (decompress and f.decompressed_name() in files)))):
405                 outname = f.decompressed_name() if decompress else f.name()
406                 files_got += [outname]
407                 if os.path.exists(os.path.join(path, outname)):
408                     os.unlink(os.path.join(path, outname))
409                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
410                 outfile = open(os.path.join(path, outname), 'wb')
411                 for buf in (f.readall_decompressed() if decompress
412                             else f.readall()):
413                     outfile.write(buf)
414                 outfile.close()
415         if len(files_got) < len(files):
416             raise Exception("Wanted files %s but only got %s from %s" %
417                             (files, files_got, map(lambda z: z.name(),
418                                                    list(stream.all_files()))))
419         lockfile.close()
420         return path
421
422     @staticmethod
423     def listdir_recursive(dirname, base=None):
424         allfiles = []
425         for ent in sorted(os.listdir(dirname)):
426             ent_path = os.path.join(dirname, ent)
427             ent_base = os.path.join(base, ent) if base else ent
428             if os.path.isdir(ent_path):
429                 allfiles += util.listdir_recursive(ent_path, ent_base)
430             else:
431                 allfiles += [ent_base]
432         return allfiles
433
434 class StreamFileReader:
435     def __init__(self, stream, pos, size, name):
436         self._stream = stream
437         self._pos = pos
438         self._size = size
439         self._name = name
440         self._filepos = 0
441     def name(self):
442         return self._name
443     def decompressed_name(self):
444         return re.sub('\.(bz2|gz)$', '', self._name)
445     def size(self):
446         return self._size
447     def stream_name(self):
448         return self._stream.name()
449     def read(self, size, **kwargs):
450         self._stream.seek(self._pos + self._filepos)
451         data = self._stream.read(min(size, self._size - self._filepos))
452         self._filepos += len(data)
453         return data
454     def readall(self, size=2**20, **kwargs):
455         while True:
456             data = self.read(size, **kwargs)
457             if data == '':
458                 break
459             yield data
460     def bunzip2(self, size):
461         decompressor = bz2.BZ2Decompressor()
462         for chunk in self.readall(size):
463             data = decompressor.decompress(chunk)
464             if data and data != '':
465                 yield data
466     def gunzip(self, size):
467         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
468         for chunk in self.readall(size):
469             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
470             if data and data != '':
471                 yield data
472     def readall_decompressed(self, size=2**20):
473         self._stream.seek(self._pos + self._filepos)
474         if re.search('\.bz2$', self._name):
475             return self.bunzip2(size)
476         elif re.search('\.gz$', self._name):
477             return self.gunzip(size)
478         else:
479             return self.readall(size)
480     def readlines(self, decompress=True):
481         if decompress:
482             datasource = self.readall_decompressed()
483         else:
484             self._stream.seek(self._pos + self._filepos)
485             datasource = self.readall()
486         data = ''
487         for newdata in datasource:
488             data += newdata
489             sol = 0
490             while True:
491                 eol = string.find(data, "\n", sol)
492                 if eol < 0:
493                     break
494                 yield data[sol:eol+1]
495                 sol = eol+1
496             data = data[sol:]
497         if data != '':
498             yield data
499     def as_manifest(self):
500         if self.size() == 0:
501             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
502                     % (self._stream.name(), self.name()))
503         return string.join(self._stream.tokens_for_range(self._pos, self._size),
504                            " ") + "\n"
505
506 class StreamReader:
507     def __init__(self, tokens):
508         self._tokens = tokens
509         self._current_datablock_data = None
510         self._current_datablock_pos = 0
511         self._current_datablock_index = -1
512         self._pos = 0
513
514         self._stream_name = None
515         self.data_locators = []
516         self.files = []
517
518         for tok in self._tokens:
519             if self._stream_name == None:
520                 self._stream_name = tok
521             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
522                 self.data_locators += [tok]
523             elif re.search(r'^\d+:\d+:\S+', tok):
524                 pos, size, name = tok.split(':',2)
525                 self.files += [[int(pos), int(size), name]]
526             else:
527                 raise Exception("Invalid manifest format")
528
529     def tokens(self):
530         return self._tokens
531     def tokens_for_range(self, range_start, range_size):
532         resp = [self._stream_name]
533         return_all_tokens = False
534         block_start = 0
535         token_bytes_skipped = 0
536         for locator in self.data_locators:
537             sizehint = re.search(r'\+(\d+)', locator)
538             if not sizehint:
539                 return_all_tokens = True
540             if return_all_tokens:
541                 resp += [locator]
542                 next
543             blocksize = int(sizehint.group(0))
544             if range_start + range_size <= block_start:
545                 break
546             if range_start < block_start + blocksize:
547                 resp += [locator]
548             else:
549                 token_bytes_skipped += blocksize
550             block_start += blocksize
551         for f in self.files:
552             if ((f[0] < range_start + range_size)
553                 and
554                 (f[0] + f[1] > range_start)
555                 and
556                 f[1] > 0):
557                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
558         return resp
559     def name(self):
560         return self._stream_name
561     def all_files(self):
562         for f in self.files:
563             pos, size, name = f
564             yield StreamFileReader(self, pos, size, name)
565     def nextdatablock(self):
566         if self._current_datablock_index < 0:
567             self._current_datablock_pos = 0
568             self._current_datablock_index = 0
569         else:
570             self._current_datablock_pos += self.current_datablock_size()
571             self._current_datablock_index += 1
572         self._current_datablock_data = None
573     def current_datablock_data(self):
574         if self._current_datablock_data == None:
575             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
576         return self._current_datablock_data
577     def current_datablock_size(self):
578         if self._current_datablock_index < 0:
579             self.nextdatablock()
580         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
581         if sizehint:
582             return int(sizehint.group(0))
583         return len(self.current_datablock_data())
584     def seek(self, pos):
585         """Set the position of the next read operation."""
586         self._pos = pos
587     def really_seek(self):
588         """Find and load the appropriate data block, so the byte at
589         _pos is in memory.
590         """
591         if self._pos == self._current_datablock_pos:
592             return True
593         if (self._current_datablock_pos != None and
594             self._pos >= self._current_datablock_pos and
595             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
596             return True
597         if self._pos < self._current_datablock_pos:
598             self._current_datablock_index = -1
599             self.nextdatablock()
600         while (self._pos > self._current_datablock_pos and
601                self._pos > self._current_datablock_pos + self.current_datablock_size()):
602             self.nextdatablock()
603     def read(self, size):
604         """Read no more than size bytes -- but at least one byte,
605         unless _pos is already at the end of the stream.
606         """
607         if size == 0:
608             return ''
609         self.really_seek()
610         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
611             self.nextdatablock()
612             if self._current_datablock_index >= len(self.data_locators):
613                 return None
614         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
615         self._pos += len(data)
616         return data
617
618 class CollectionReader:
619     def __init__(self, manifest_locator_or_text):
620         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
621             self._manifest_text = manifest_locator_or_text
622             self._manifest_locator = None
623         else:
624             self._manifest_locator = manifest_locator_or_text
625             self._manifest_text = None
626         self._streams = None
627     def __enter__(self):
628         pass
629     def __exit__(self):
630         pass
631     def _populate(self):
632         if self._streams != None:
633             return
634         if not self._manifest_text:
635             self._manifest_text = Keep.get(self._manifest_locator)
636         self._streams = []
637         for stream_line in self._manifest_text.split("\n"):
638             if stream_line != '':
639                 stream_tokens = stream_line.split()
640                 self._streams += [stream_tokens]
641     def all_streams(self):
642         self._populate()
643         resp = []
644         for s in self._streams:
645             resp += [StreamReader(s)]
646         return resp
647     def all_files(self):
648         for s in self.all_streams():
649             for f in s.all_files():
650                 yield f
651     def manifest_text(self):
652         self._populate()
653         return self._manifest_text
654
655 class CollectionWriter:
656     KEEP_BLOCK_SIZE = 2**26
657     def __init__(self):
658         self._data_buffer = []
659         self._data_buffer_len = 0
660         self._current_stream_files = []
661         self._current_stream_length = 0
662         self._current_stream_locators = []
663         self._current_stream_name = '.'
664         self._current_file_name = None
665         self._current_file_pos = 0
666         self._finished_streams = []
667     def __enter__(self):
668         pass
669     def __exit__(self):
670         self.finish()
671     def write_directory_tree(self,
672                              path, stream_name='.', max_manifest_depth=-1):
673         self.start_new_stream(stream_name)
674         todo = []
675         if max_manifest_depth == 0:
676             dirents = util.listdir_recursive(path)
677         else:
678             dirents = sorted(os.listdir(path))
679         for dirent in dirents:
680             target = os.path.join(path, dirent)
681             if os.path.isdir(target):
682                 todo += [[target,
683                           os.path.join(stream_name, dirent),
684                           max_manifest_depth-1]]
685             else:
686                 self.start_new_file(dirent)
687                 with open(target, 'rb') as f:
688                     while True:
689                         buf = f.read(2**26)
690                         if len(buf) == 0:
691                             break
692                         self.write(buf)
693         self.finish_current_stream()
694         map(lambda x: self.write_directory_tree(*x), todo)
695
696     def write(self, newdata):
697         self._data_buffer += [newdata]
698         self._data_buffer_len += len(newdata)
699         self._current_stream_length += len(newdata)
700         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
701             self.flush_data()
702     def flush_data(self):
703         data_buffer = ''.join(self._data_buffer)
704         if data_buffer != '':
705             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
706             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
707             self._data_buffer_len = len(self._data_buffer[0])
708     def start_new_file(self, newfilename=None):
709         self.finish_current_file()
710         self.set_current_file_name(newfilename)
711     def set_current_file_name(self, newfilename):
712         newfilename = re.sub(r' ', '\\\\040', newfilename)
713         if re.search(r'[ \t\n]', newfilename):
714             raise AssertionError("Manifest filenames cannot contain whitespace")
715         self._current_file_name = newfilename
716     def current_file_name(self):
717         return self._current_file_name
718     def finish_current_file(self):
719         if self._current_file_name == None:
720             if self._current_file_pos == self._current_stream_length:
721                 return
722             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
723         self._current_stream_files += [[self._current_file_pos,
724                                        self._current_stream_length - self._current_file_pos,
725                                        self._current_file_name]]
726         self._current_file_pos = self._current_stream_length
727     def start_new_stream(self, newstreamname='.'):
728         self.finish_current_stream()
729         self.set_current_stream_name(newstreamname)
730     def set_current_stream_name(self, newstreamname):
731         if re.search(r'[ \t\n]', newstreamname):
732             raise AssertionError("Manifest stream names cannot contain whitespace")
733         self._current_stream_name = newstreamname
734     def current_stream_name(self):
735         return self._current_stream_name
736     def finish_current_stream(self):
737         self.finish_current_file()
738         self.flush_data()
739         if len(self._current_stream_files) == 0:
740             pass
741         elif self._current_stream_name == None:
742             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
743         else:
744             self._finished_streams += [[self._current_stream_name,
745                                        self._current_stream_locators,
746                                        self._current_stream_files]]
747         self._current_stream_files = []
748         self._current_stream_length = 0
749         self._current_stream_locators = []
750         self._current_stream_name = None
751         self._current_file_pos = 0
752         self._current_file_name = None
753     def finish(self):
754         return Keep.put(self.manifest_text())
755     def manifest_text(self):
756         self.finish_current_stream()
757         manifest = ''
758         for stream in self._finished_streams:
759             if not re.search(r'^\.(/.*)?$', stream[0]):
760                 manifest += './'
761             manifest += stream[0]
762             if len(stream[1]) == 0:
763                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
764             else:
765                 for locator in stream[1]:
766                     manifest += " %s" % locator
767             for sfile in stream[2]:
768                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
769             manifest += "\n"
770         return manifest
771
772 global_client_object = None
773
774 class Keep:
775     @staticmethod
776     def global_client_object():
777         global global_client_object
778         if global_client_object == None:
779             global_client_object = KeepClient()
780         return global_client_object
781
782     @staticmethod
783     def get(locator):
784         return Keep.global_client_object().get(locator)
785
786     @staticmethod
787     def put(data):
788         return Keep.global_client_object().put(data)
789
790 class KeepClient:
791     def __init__(self):
792         self.service_roots = None
793
794     def shuffled_service_roots(self, hash):
795         if self.service_roots == None:
796             keep_disks = api().keep_disks().list().execute()['items']
797             roots = (("http%s://%s:%d/" %
798                       ('s' if f['service_ssl_flag'] else '',
799                        f['service_host'],
800                        f['service_port']))
801                      for f in keep_disks)
802             self.service_roots = sorted(set(roots))
803             logging.debug(str(self.service_roots))
804         seed = hash
805         pool = self.service_roots[:]
806         pseq = []
807         while len(pool) > 0:
808             if len(seed) < 8:
809                 if len(pseq) < len(hash) / 4: # first time around
810                     seed = hash[-4:] + hash
811                 else:
812                     seed += hash
813             probe = int(seed[0:8], 16) % len(pool)
814             pseq += [pool[probe]]
815             pool = pool[:probe] + pool[probe+1:]
816             seed = seed[8:]
817         logging.debug(str(pseq))
818         return pseq
819
820     def get(self, locator):
821         if 'KEEP_LOCAL_STORE' in os.environ:
822             return KeepClient.local_store_get(locator)
823         expect_hash = re.sub(r'\+.*', '', locator)
824         for service_root in self.shuffled_service_roots(expect_hash):
825             h = httplib2.Http()
826             url = service_root + expect_hash
827             api_token = os.environ['ARVADOS_API_TOKEN']
828             headers = {'Authorization': "OAuth2 %s" % api_token,
829                        'Accept': 'application/octet-stream'}
830             try:
831                 resp, content = h.request(url, 'GET', headers=headers)
832                 if re.match(r'^2\d\d$', resp['status']):
833                     m = hashlib.new('md5')
834                     m.update(content)
835                     md5 = m.hexdigest()
836                     if md5 == expect_hash:
837                         return content
838                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
839             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
840                 logging.info("Request fail: GET %s => %s: %s" %
841                              (url, type(e), str(e)))
842         raise Exception("Not found: %s" % expect_hash)
843
844     def put(self, data, **kwargs):
845         if 'KEEP_LOCAL_STORE' in os.environ:
846             return KeepClient.local_store_put(data)
847         m = hashlib.new('md5')
848         m.update(data)
849         data_hash = m.hexdigest()
850         have_copies = 0
851         want_copies = kwargs.get('copies', 2)
852         for service_root in self.shuffled_service_roots(data_hash):
853             h = httplib2.Http()
854             url = service_root + data_hash
855             api_token = os.environ['ARVADOS_API_TOKEN']
856             headers = {'Authorization': "OAuth2 %s" % api_token}
857             try:
858                 resp, content = h.request(url, 'PUT',
859                                           headers=headers,
860                                           body=data)
861                 if (resp['status'] == '401' and
862                     re.match(r'Timestamp verification failed', content)):
863                     body = self.sign_for_old_server(data_hash, data)
864                     h = httplib2.Http()
865                     resp, content = h.request(url, 'PUT',
866                                               headers=headers,
867                                               body=body)
868                 if re.match(r'^2\d\d$', resp['status']):
869                     have_copies += 1
870                     if have_copies == want_copies:
871                         return data_hash + '+' + str(len(data))
872                 else:
873                     logging.warning("Request fail: PUT %s => %s %s" %
874                                     (url, resp['status'], content))
875             except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
876                 logging.warning("Request fail: PUT %s => %s: %s" %
877                                 (url, type(e), str(e)))
878         raise Exception("Write fail for %s: wanted %d but wrote %d" %
879                         (data_hash, want_copies, have_copies))
880
881     def sign_for_old_server(self, data_hash, data):
882         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
883
884
885     @staticmethod
886     def local_store_put(data):
887         m = hashlib.new('md5')
888         m.update(data)
889         md5 = m.hexdigest()
890         locator = '%s+%d' % (md5, len(data))
891         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
892             f.write(data)
893         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
894                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
895         return locator
896     @staticmethod
897     def local_store_get(locator):
898         r = re.search('^([0-9a-f]{32,})', locator)
899         if not r:
900             raise Exception("Keep.get: invalid data locator '%s'" % locator)
901         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
902             return ''
903         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
904             return f.read()