Accept data from a generator in arvados.CollectionWriter.write()
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 class CredentialsFromEnv(object):
28     @staticmethod
29     def http_request(self, uri, **kwargs):
30         from httplib import BadStatusLine
31         if 'headers' not in kwargs:
32             kwargs['headers'] = {}
33         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
34         try:
35             return self.orig_http_request(uri, **kwargs)
36         except BadStatusLine:
37             # This is how httplib tells us that it tried to reuse an
38             # existing connection but it was already closed by the
39             # server. In that case, yes, we would like to retry.
40             # Unfortunately, we are not absolutely certain that the
41             # previous call did not succeed, so this is slightly
42             # risky.
43             return self.orig_http_request(uri, **kwargs)
44     def authorize(self, http):
45         http.orig_http_request = http.request
46         http.request = types.MethodType(self.http_request, http)
47         return http
48
49 url = ('https://%s/discovery/v1/apis/'
50        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
51 credentials = CredentialsFromEnv()
52
53 # Use system's CA certificates (if we find them) instead of httplib2's
54 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
55 if not os.path.exists(ca_certs):
56     ca_certs = None             # use httplib2 default
57
58 http = httplib2.Http(ca_certs=ca_certs)
59 http = credentials.authorize(http)
60 if re.match(r'(?i)^(true|1|yes)$',
61             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
62     http.disable_ssl_certificate_validation=True
63 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
64
65 def task_set_output(self,s):
66     service.job_tasks().update(uuid=self['uuid'],
67                                job_task=json.dumps({
68                 'output':s,
69                 'success':True,
70                 'progress':1.0
71                 })).execute()
72
73 _current_task = None
74 def current_task():
75     global _current_task
76     if _current_task:
77         return _current_task
78     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
79     t = UserDict.UserDict(t)
80     t.set_output = types.MethodType(task_set_output, t)
81     t.tmpdir = os.environ['TASK_WORK']
82     _current_task = t
83     return t
84
85 _current_job = None
86 def current_job():
87     global _current_job
88     if _current_job:
89         return _current_job
90     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
91     t = UserDict.UserDict(t)
92     t.tmpdir = os.environ['JOB_WORK']
93     _current_job = t
94     return t
95
96 def api():
97     return service
98
99 class JobTask(object):
100     def __init__(self, parameters=dict(), runtime_constraints=dict()):
101         print "init jobtask %s %s" % (parameters, runtime_constraints)
102
103 class job_setup:
104     @staticmethod
105     def one_task_per_input_file(if_sequence=0, and_end_task=True):
106         if if_sequence != current_task()['sequence']:
107             return
108         job_input = current_job()['script_parameters']['input']
109         cr = CollectionReader(job_input)
110         for s in cr.all_streams():
111             for f in s.all_files():
112                 task_input = f.as_manifest()
113                 new_task_attrs = {
114                     'job_uuid': current_job()['uuid'],
115                     'created_by_job_task_uuid': current_task()['uuid'],
116                     'sequence': if_sequence + 1,
117                     'parameters': {
118                         'input':task_input
119                         }
120                     }
121                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
122         if and_end_task:
123             service.job_tasks().update(uuid=current_task()['uuid'],
124                                        job_task=json.dumps({'success':True})
125                                        ).execute()
126             exit(0)
127
128     @staticmethod
129     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
130         if if_sequence != current_task()['sequence']:
131             return
132         job_input = current_job()['script_parameters']['input']
133         cr = CollectionReader(job_input)
134         for s in cr.all_streams():
135             task_input = s.tokens()
136             new_task_attrs = {
137                 'job_uuid': current_job()['uuid'],
138                 'created_by_job_task_uuid': current_task()['uuid'],
139                 'sequence': if_sequence + 1,
140                 'parameters': {
141                     'input':task_input
142                     }
143                 }
144             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
145         if and_end_task:
146             service.job_tasks().update(uuid=current_task()['uuid'],
147                                        job_task=json.dumps({'success':True})
148                                        ).execute()
149             exit(0)
150
151 class util:
152     @staticmethod
153     def clear_tmpdir(path=None):
154         """
155         Ensure the given directory (or TASK_TMPDIR if none given)
156         exists and is empty.
157         """
158         if path == None:
159             path = current_task().tmpdir
160         if os.path.exists(path):
161             p = subprocess.Popen(['rm', '-rf', path])
162             stdout, stderr = p.communicate(None)
163             if p.returncode != 0:
164                 raise Exception('rm -rf %s: %s' % (path, stderr))
165         os.mkdir(path)
166
167     @staticmethod
168     def run_command(execargs, **kwargs):
169         kwargs.setdefault('stdin', subprocess.PIPE)
170         kwargs.setdefault('stdout', subprocess.PIPE)
171         kwargs.setdefault('stderr', sys.stderr)
172         kwargs.setdefault('close_fds', True)
173         kwargs.setdefault('shell', False)
174         p = subprocess.Popen(execargs, **kwargs)
175         stdoutdata, stderrdata = p.communicate(None)
176         if p.returncode != 0:
177             raise Exception("run_command %s exit %d:\n%s" %
178                             (execargs, p.returncode, stderrdata))
179         return stdoutdata, stderrdata
180
181     @staticmethod
182     def git_checkout(url, version, path):
183         if not re.search('^/', path):
184             path = os.path.join(current_job().tmpdir, path)
185         if not os.path.exists(path):
186             util.run_command(["git", "clone", url, path],
187                              cwd=os.path.dirname(path))
188         util.run_command(["git", "checkout", version],
189                          cwd=path)
190         return path
191
192     @staticmethod
193     def tar_extractor(path, decompress_flag):
194         return subprocess.Popen(["tar",
195                                  "-C", path,
196                                  ("-x%sf" % decompress_flag),
197                                  "-"],
198                                 stdout=None,
199                                 stdin=subprocess.PIPE, stderr=sys.stderr,
200                                 shell=False, close_fds=True)
201
202     @staticmethod
203     def tarball_extract(tarball, path):
204         """Retrieve a tarball from Keep and extract it to a local
205         directory.  Return the absolute path where the tarball was
206         extracted. If the top level of the tarball contained just one
207         file or directory, return the absolute path of that single
208         item.
209
210         tarball -- collection locator
211         path -- where to extract the tarball: absolute, or relative to job tmp
212         """
213         if not re.search('^/', path):
214             path = os.path.join(current_job().tmpdir, path)
215         lockfile = open(path + '.lock', 'w')
216         fcntl.flock(lockfile, fcntl.LOCK_EX)
217         try:
218             os.stat(path)
219         except OSError:
220             os.mkdir(path)
221         already_have_it = False
222         try:
223             if os.readlink(os.path.join(path, '.locator')) == tarball:
224                 already_have_it = True
225         except OSError:
226             pass
227         if not already_have_it:
228
229             # emulate "rm -f" (i.e., if the file does not exist, we win)
230             try:
231                 os.unlink(os.path.join(path, '.locator'))
232             except OSError:
233                 if os.path.exists(os.path.join(path, '.locator')):
234                     os.unlink(os.path.join(path, '.locator'))
235
236             for f in CollectionReader(tarball).all_files():
237                 if re.search('\.(tbz|tar.bz2)$', f.name()):
238                     p = util.tar_extractor(path, 'j')
239                 elif re.search('\.(tgz|tar.gz)$', f.name()):
240                     p = util.tar_extractor(path, 'z')
241                 elif re.search('\.tar$', f.name()):
242                     p = util.tar_extractor(path, '')
243                 else:
244                     raise Exception("tarball_extract cannot handle filename %s"
245                                     % f.name())
246                 while True:
247                     buf = f.read(2**20)
248                     if len(buf) == 0:
249                         break
250                     p.stdin.write(buf)
251                 p.stdin.close()
252                 p.wait()
253                 if p.returncode != 0:
254                     lockfile.close()
255                     raise Exception("tar exited %d" % p.returncode)
256             os.symlink(tarball, os.path.join(path, '.locator'))
257         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
258         lockfile.close()
259         if len(tld_extracts) == 1:
260             return os.path.join(path, tld_extracts[0])
261         return path
262
263     @staticmethod
264     def zipball_extract(zipball, path):
265         """Retrieve a zip archive from Keep and extract it to a local
266         directory.  Return the absolute path where the archive was
267         extracted. If the top level of the archive contained just one
268         file or directory, return the absolute path of that single
269         item.
270
271         zipball -- collection locator
272         path -- where to extract the archive: absolute, or relative to job tmp
273         """
274         if not re.search('^/', path):
275             path = os.path.join(current_job().tmpdir, path)
276         lockfile = open(path + '.lock', 'w')
277         fcntl.flock(lockfile, fcntl.LOCK_EX)
278         try:
279             os.stat(path)
280         except OSError:
281             os.mkdir(path)
282         already_have_it = False
283         try:
284             if os.readlink(os.path.join(path, '.locator')) == zipball:
285                 already_have_it = True
286         except OSError:
287             pass
288         if not already_have_it:
289
290             # emulate "rm -f" (i.e., if the file does not exist, we win)
291             try:
292                 os.unlink(os.path.join(path, '.locator'))
293             except OSError:
294                 if os.path.exists(os.path.join(path, '.locator')):
295                     os.unlink(os.path.join(path, '.locator'))
296
297             for f in CollectionReader(zipball).all_files():
298                 if not re.search('\.zip$', f.name()):
299                     raise Exception("zipball_extract cannot handle filename %s"
300                                     % f.name())
301                 zip_filename = os.path.join(path, os.path.basename(f.name()))
302                 zip_file = open(zip_filename, 'wb')
303                 while True:
304                     buf = f.read(2**20)
305                     if len(buf) == 0:
306                         break
307                     zip_file.write(buf)
308                 zip_file.close()
309                 
310                 p = subprocess.Popen(["unzip",
311                                       "-q", "-o",
312                                       "-d", path,
313                                       zip_filename],
314                                      stdout=None,
315                                      stdin=None, stderr=sys.stderr,
316                                      shell=False, close_fds=True)
317                 p.wait()
318                 if p.returncode != 0:
319                     lockfile.close()
320                     raise Exception("unzip exited %d" % p.returncode)
321                 os.unlink(zip_filename)
322             os.symlink(zipball, os.path.join(path, '.locator'))
323         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
324         lockfile.close()
325         if len(tld_extracts) == 1:
326             return os.path.join(path, tld_extracts[0])
327         return path
328
329     @staticmethod
330     def collection_extract(collection, path, files=[], decompress=True):
331         """Retrieve a collection from Keep and extract it to a local
332         directory.  Return the absolute path where the collection was
333         extracted.
334
335         collection -- collection locator
336         path -- where to extract: absolute, or relative to job tmp
337         """
338         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
339         if matches:
340             collection_hash = matches.group(1)
341         else:
342             collection_hash = hashlib.md5(collection).hexdigest()
343         if not re.search('^/', path):
344             path = os.path.join(current_job().tmpdir, path)
345         lockfile = open(path + '.lock', 'w')
346         fcntl.flock(lockfile, fcntl.LOCK_EX)
347         try:
348             os.stat(path)
349         except OSError:
350             os.mkdir(path)
351         already_have_it = False
352         try:
353             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
354                 already_have_it = True
355         except OSError:
356             pass
357
358         # emulate "rm -f" (i.e., if the file does not exist, we win)
359         try:
360             os.unlink(os.path.join(path, '.locator'))
361         except OSError:
362             if os.path.exists(os.path.join(path, '.locator')):
363                 os.unlink(os.path.join(path, '.locator'))
364
365         files_got = []
366         for s in CollectionReader(collection).all_streams():
367             stream_name = s.name()
368             for f in s.all_files():
369                 if (files == [] or
370                     ((f.name() not in files_got) and
371                      (f.name() in files or
372                       (decompress and f.decompressed_name() in files)))):
373                     outname = f.decompressed_name() if decompress else f.name()
374                     files_got += [outname]
375                     if os.path.exists(os.path.join(path, stream_name, outname)):
376                         continue
377                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
378                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
379                     for buf in (f.readall_decompressed() if decompress
380                                 else f.readall()):
381                         outfile.write(buf)
382                     outfile.close()
383         if len(files_got) < len(files):
384             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
385         os.symlink(collection_hash, os.path.join(path, '.locator'))
386
387         lockfile.close()
388         return path
389
390     @staticmethod
391     def mkdir_dash_p(path):
392         if not os.path.exists(path):
393             util.mkdir_dash_p(os.path.dirname(path))
394             try:
395                 os.mkdir(path)
396             except OSError:
397                 if not os.path.exists(path):
398                     os.mkdir(path)
399
400     @staticmethod
401     def stream_extract(stream, path, files=[], decompress=True):
402         """Retrieve a stream from Keep and extract it to a local
403         directory.  Return the absolute path where the stream was
404         extracted.
405
406         stream -- StreamReader object
407         path -- where to extract: absolute, or relative to job tmp
408         """
409         if not re.search('^/', path):
410             path = os.path.join(current_job().tmpdir, path)
411         lockfile = open(path + '.lock', 'w')
412         fcntl.flock(lockfile, fcntl.LOCK_EX)
413         try:
414             os.stat(path)
415         except OSError:
416             os.mkdir(path)
417
418         files_got = []
419         for f in stream.all_files():
420             if (files == [] or
421                 ((f.name() not in files_got) and
422                  (f.name() in files or
423                   (decompress and f.decompressed_name() in files)))):
424                 outname = f.decompressed_name() if decompress else f.name()
425                 files_got += [outname]
426                 if os.path.exists(os.path.join(path, outname)):
427                     os.unlink(os.path.join(path, outname))
428                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
429                 outfile = open(os.path.join(path, outname), 'wb')
430                 for buf in (f.readall_decompressed() if decompress
431                             else f.readall()):
432                     outfile.write(buf)
433                 outfile.close()
434         if len(files_got) < len(files):
435             raise Exception("Wanted files %s but only got %s from %s" %
436                             (files, files_got, map(lambda z: z.name(),
437                                                    list(stream.all_files()))))
438         lockfile.close()
439         return path
440
441     @staticmethod
442     def listdir_recursive(dirname, base=None):
443         allfiles = []
444         for ent in sorted(os.listdir(dirname)):
445             ent_path = os.path.join(dirname, ent)
446             ent_base = os.path.join(base, ent) if base else ent
447             if os.path.isdir(ent_path):
448                 allfiles += util.listdir_recursive(ent_path, ent_base)
449             else:
450                 allfiles += [ent_base]
451         return allfiles
452
453 class StreamFileReader(object):
454     def __init__(self, stream, pos, size, name):
455         self._stream = stream
456         self._pos = pos
457         self._size = size
458         self._name = name
459         self._filepos = 0
460     def name(self):
461         return self._name
462     def decompressed_name(self):
463         return re.sub('\.(bz2|gz)$', '', self._name)
464     def size(self):
465         return self._size
466     def stream_name(self):
467         return self._stream.name()
468     def read(self, size, **kwargs):
469         self._stream.seek(self._pos + self._filepos)
470         data = self._stream.read(min(size, self._size - self._filepos))
471         self._filepos += len(data)
472         return data
473     def readall(self, size=2**20, **kwargs):
474         while True:
475             data = self.read(size, **kwargs)
476             if data == '':
477                 break
478             yield data
479     def bunzip2(self, size):
480         decompressor = bz2.BZ2Decompressor()
481         for chunk in self.readall(size):
482             data = decompressor.decompress(chunk)
483             if data and data != '':
484                 yield data
485     def gunzip(self, size):
486         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
487         for chunk in self.readall(size):
488             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
489             if data and data != '':
490                 yield data
491     def readall_decompressed(self, size=2**20):
492         self._stream.seek(self._pos + self._filepos)
493         if re.search('\.bz2$', self._name):
494             return self.bunzip2(size)
495         elif re.search('\.gz$', self._name):
496             return self.gunzip(size)
497         else:
498             return self.readall(size)
499     def readlines(self, decompress=True):
500         if decompress:
501             datasource = self.readall_decompressed()
502         else:
503             self._stream.seek(self._pos + self._filepos)
504             datasource = self.readall()
505         data = ''
506         for newdata in datasource:
507             data += newdata
508             sol = 0
509             while True:
510                 eol = string.find(data, "\n", sol)
511                 if eol < 0:
512                     break
513                 yield data[sol:eol+1]
514                 sol = eol+1
515             data = data[sol:]
516         if data != '':
517             yield data
518     def as_manifest(self):
519         if self.size() == 0:
520             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
521                     % (self._stream.name(), self.name()))
522         return string.join(self._stream.tokens_for_range(self._pos, self._size),
523                            " ") + "\n"
524
525 class StreamReader(object):
526     def __init__(self, tokens):
527         self._tokens = tokens
528         self._current_datablock_data = None
529         self._current_datablock_pos = 0
530         self._current_datablock_index = -1
531         self._pos = 0
532
533         self._stream_name = None
534         self.data_locators = []
535         self.files = []
536
537         for tok in self._tokens:
538             if self._stream_name == None:
539                 self._stream_name = tok
540             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
541                 self.data_locators += [tok]
542             elif re.search(r'^\d+:\d+:\S+', tok):
543                 pos, size, name = tok.split(':',2)
544                 self.files += [[int(pos), int(size), name]]
545             else:
546                 raise Exception("Invalid manifest format")
547
548     def tokens(self):
549         return self._tokens
550     def tokens_for_range(self, range_start, range_size):
551         resp = [self._stream_name]
552         return_all_tokens = False
553         block_start = 0
554         token_bytes_skipped = 0
555         for locator in self.data_locators:
556             sizehint = re.search(r'\+(\d+)', locator)
557             if not sizehint:
558                 return_all_tokens = True
559             if return_all_tokens:
560                 resp += [locator]
561                 next
562             blocksize = int(sizehint.group(0))
563             if range_start + range_size <= block_start:
564                 break
565             if range_start < block_start + blocksize:
566                 resp += [locator]
567             else:
568                 token_bytes_skipped += blocksize
569             block_start += blocksize
570         for f in self.files:
571             if ((f[0] < range_start + range_size)
572                 and
573                 (f[0] + f[1] > range_start)
574                 and
575                 f[1] > 0):
576                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
577         return resp
578     def name(self):
579         return self._stream_name
580     def all_files(self):
581         for f in self.files:
582             pos, size, name = f
583             yield StreamFileReader(self, pos, size, name)
584     def nextdatablock(self):
585         if self._current_datablock_index < 0:
586             self._current_datablock_pos = 0
587             self._current_datablock_index = 0
588         else:
589             self._current_datablock_pos += self.current_datablock_size()
590             self._current_datablock_index += 1
591         self._current_datablock_data = None
592     def current_datablock_data(self):
593         if self._current_datablock_data == None:
594             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
595         return self._current_datablock_data
596     def current_datablock_size(self):
597         if self._current_datablock_index < 0:
598             self.nextdatablock()
599         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
600         if sizehint:
601             return int(sizehint.group(0))
602         return len(self.current_datablock_data())
603     def seek(self, pos):
604         """Set the position of the next read operation."""
605         self._pos = pos
606     def really_seek(self):
607         """Find and load the appropriate data block, so the byte at
608         _pos is in memory.
609         """
610         if self._pos == self._current_datablock_pos:
611             return True
612         if (self._current_datablock_pos != None and
613             self._pos >= self._current_datablock_pos and
614             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
615             return True
616         if self._pos < self._current_datablock_pos:
617             self._current_datablock_index = -1
618             self.nextdatablock()
619         while (self._pos > self._current_datablock_pos and
620                self._pos > self._current_datablock_pos + self.current_datablock_size()):
621             self.nextdatablock()
622     def read(self, size):
623         """Read no more than size bytes -- but at least one byte,
624         unless _pos is already at the end of the stream.
625         """
626         if size == 0:
627             return ''
628         self.really_seek()
629         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
630             self.nextdatablock()
631             if self._current_datablock_index >= len(self.data_locators):
632                 return None
633         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
634         self._pos += len(data)
635         return data
636
637 class CollectionReader(object):
638     def __init__(self, manifest_locator_or_text):
639         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
640             self._manifest_text = manifest_locator_or_text
641             self._manifest_locator = None
642         else:
643             self._manifest_locator = manifest_locator_or_text
644             self._manifest_text = None
645         self._streams = None
646     def __enter__(self):
647         pass
648     def __exit__(self):
649         pass
650     def _populate(self):
651         if self._streams != None:
652             return
653         if not self._manifest_text:
654             self._manifest_text = Keep.get(self._manifest_locator)
655         self._streams = []
656         for stream_line in self._manifest_text.split("\n"):
657             if stream_line != '':
658                 stream_tokens = stream_line.split()
659                 self._streams += [stream_tokens]
660     def all_streams(self):
661         self._populate()
662         resp = []
663         for s in self._streams:
664             resp += [StreamReader(s)]
665         return resp
666     def all_files(self):
667         for s in self.all_streams():
668             for f in s.all_files():
669                 yield f
670     def manifest_text(self):
671         self._populate()
672         return self._manifest_text
673
674 class CollectionWriter(object):
675     KEEP_BLOCK_SIZE = 2**26
676     def __init__(self):
677         self._data_buffer = []
678         self._data_buffer_len = 0
679         self._current_stream_files = []
680         self._current_stream_length = 0
681         self._current_stream_locators = []
682         self._current_stream_name = '.'
683         self._current_file_name = None
684         self._current_file_pos = 0
685         self._finished_streams = []
686     def __enter__(self):
687         pass
688     def __exit__(self):
689         self.finish()
690     def write_directory_tree(self,
691                              path, stream_name='.', max_manifest_depth=-1):
692         self.start_new_stream(stream_name)
693         todo = []
694         if max_manifest_depth == 0:
695             dirents = sorted(util.listdir_recursive(path))
696         else:
697             dirents = sorted(os.listdir(path))
698         for dirent in dirents:
699             target = os.path.join(path, dirent)
700             if os.path.isdir(target):
701                 todo += [[target,
702                           os.path.join(stream_name, dirent),
703                           max_manifest_depth-1]]
704             else:
705                 self.start_new_file(dirent)
706                 with open(target, 'rb') as f:
707                     while True:
708                         buf = f.read(2**26)
709                         if len(buf) == 0:
710                             break
711                         self.write(buf)
712         self.finish_current_stream()
713         map(lambda x: self.write_directory_tree(*x), todo)
714
715     def write(self, newdata):
716         if hasattr(newdata, '__iter__'):
717             for s in newdata:
718                 self.write(s)
719             return
720         self._data_buffer += [newdata]
721         self._data_buffer_len += len(newdata)
722         self._current_stream_length += len(newdata)
723         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
724             self.flush_data()
725     def flush_data(self):
726         data_buffer = ''.join(self._data_buffer)
727         if data_buffer != '':
728             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
729             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
730             self._data_buffer_len = len(self._data_buffer[0])
731     def start_new_file(self, newfilename=None):
732         self.finish_current_file()
733         self.set_current_file_name(newfilename)
734     def set_current_file_name(self, newfilename):
735         newfilename = re.sub(r' ', '\\\\040', newfilename)
736         if re.search(r'[ \t\n]', newfilename):
737             raise AssertionError("Manifest filenames cannot contain whitespace")
738         self._current_file_name = newfilename
739     def current_file_name(self):
740         return self._current_file_name
741     def finish_current_file(self):
742         if self._current_file_name == None:
743             if self._current_file_pos == self._current_stream_length:
744                 return
745             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
746         self._current_stream_files += [[self._current_file_pos,
747                                        self._current_stream_length - self._current_file_pos,
748                                        self._current_file_name]]
749         self._current_file_pos = self._current_stream_length
750     def start_new_stream(self, newstreamname='.'):
751         self.finish_current_stream()
752         self.set_current_stream_name(newstreamname)
753     def set_current_stream_name(self, newstreamname):
754         if re.search(r'[ \t\n]', newstreamname):
755             raise AssertionError("Manifest stream names cannot contain whitespace")
756         self._current_stream_name = newstreamname
757     def current_stream_name(self):
758         return self._current_stream_name
759     def finish_current_stream(self):
760         self.finish_current_file()
761         self.flush_data()
762         if len(self._current_stream_files) == 0:
763             pass
764         elif self._current_stream_name == None:
765             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
766         else:
767             self._finished_streams += [[self._current_stream_name,
768                                        self._current_stream_locators,
769                                        self._current_stream_files]]
770         self._current_stream_files = []
771         self._current_stream_length = 0
772         self._current_stream_locators = []
773         self._current_stream_name = None
774         self._current_file_pos = 0
775         self._current_file_name = None
776     def finish(self):
777         return Keep.put(self.manifest_text())
778     def manifest_text(self):
779         self.finish_current_stream()
780         manifest = ''
781         for stream in self._finished_streams:
782             if not re.search(r'^\.(/.*)?$', stream[0]):
783                 manifest += './'
784             manifest += stream[0]
785             if len(stream[1]) == 0:
786                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
787             else:
788                 for locator in stream[1]:
789                     manifest += " %s" % locator
790             for sfile in stream[2]:
791                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
792             manifest += "\n"
793         return manifest
794
795 global_client_object = None
796
797 class Keep:
798     @staticmethod
799     def global_client_object():
800         global global_client_object
801         if global_client_object == None:
802             global_client_object = KeepClient()
803         return global_client_object
804
805     @staticmethod
806     def get(locator, **kwargs):
807         return Keep.global_client_object().get(locator, **kwargs)
808
809     @staticmethod
810     def put(data, **kwargs):
811         return Keep.global_client_object().put(data, **kwargs)
812
813 class KeepClient(object):
814
815     class ThreadLimiter(object):
816         """
817         Limit the number of threads running at a given time to
818         {desired successes} minus {successes reported}. When successes
819         reported == desired, wake up the remaining threads and tell
820         them to quit.
821
822         Should be used in a "with" block.
823         """
824         def __init__(self, todo):
825             self._todo = todo
826             self._done = 0
827             self._todo_lock = threading.Semaphore(todo)
828             self._done_lock = threading.Lock()
829         def __enter__(self):
830             self._todo_lock.acquire()
831             return self
832         def __exit__(self, type, value, traceback):
833             self._todo_lock.release()
834         def shall_i_proceed(self):
835             """
836             Return true if the current thread should do stuff. Return
837             false if the current thread should just stop.
838             """
839             with self._done_lock:
840                 return (self._done < self._todo)
841         def increment_done(self):
842             """
843             Report that the current thread was successful.
844             """
845             with self._done_lock:
846                 self._done += 1
847         def done(self):
848             """
849             Return how many successes were reported.
850             """
851             with self._done_lock:
852                 return self._done
853
854     class KeepWriterThread(threading.Thread):
855         """
856         Write a blob of data to the given Keep server. Call
857         increment_done() of the given ThreadLimiter if the write
858         succeeds.
859         """
860         def __init__(self, **kwargs):
861             super(KeepClient.KeepWriterThread, self).__init__()
862             self.args = kwargs
863         def run(self):
864             with self.args['thread_limiter'] as limiter:
865                 if not limiter.shall_i_proceed():
866                     # My turn arrived, but the job has been done without
867                     # me.
868                     return
869                 logging.debug("KeepWriterThread %s proceeding %s %s" %
870                               (str(threading.current_thread()),
871                                self.args['data_hash'],
872                                self.args['service_root']))
873                 h = httplib2.Http()
874                 url = self.args['service_root'] + self.args['data_hash']
875                 api_token = os.environ['ARVADOS_API_TOKEN']
876                 headers = {'Authorization': "OAuth2 %s" % api_token}
877                 try:
878                     resp, content = h.request(url.encode('utf-8'), 'PUT',
879                                               headers=headers,
880                                               body=self.args['data'])
881                     if (resp['status'] == '401' and
882                         re.match(r'Timestamp verification failed', content)):
883                         body = KeepClient.sign_for_old_server(
884                             self.args['data_hash'],
885                             self.args['data'])
886                         h = httplib2.Http()
887                         resp, content = h.request(url.encode('utf-8'), 'PUT',
888                                                   headers=headers,
889                                                   body=body)
890                     if re.match(r'^2\d\d$', resp['status']):
891                         logging.debug("KeepWriterThread %s succeeded %s %s" %
892                                       (str(threading.current_thread()),
893                                        self.args['data_hash'],
894                                        self.args['service_root']))
895                         return limiter.increment_done()
896                     logging.warning("Request fail: PUT %s => %s %s" %
897                                     (url, resp['status'], content))
898                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
899                     logging.warning("Request fail: PUT %s => %s: %s" %
900                                     (url, type(e), str(e)))
901
902     def __init__(self):
903         self.lock = threading.Lock()
904         self.service_roots = None
905
906     def shuffled_service_roots(self, hash):
907         if self.service_roots == None:
908             self.lock.acquire()
909             keep_disks = api().keep_disks().list().execute()['items']
910             roots = (("http%s://%s:%d/" %
911                       ('s' if f['service_ssl_flag'] else '',
912                        f['service_host'],
913                        f['service_port']))
914                      for f in keep_disks)
915             self.service_roots = sorted(set(roots))
916             logging.debug(str(self.service_roots))
917             self.lock.release()
918         seed = hash
919         pool = self.service_roots[:]
920         pseq = []
921         while len(pool) > 0:
922             if len(seed) < 8:
923                 if len(pseq) < len(hash) / 4: # first time around
924                     seed = hash[-4:] + hash
925                 else:
926                     seed += hash
927             probe = int(seed[0:8], 16) % len(pool)
928             pseq += [pool[probe]]
929             pool = pool[:probe] + pool[probe+1:]
930             seed = seed[8:]
931         logging.debug(str(pseq))
932         return pseq
933
934     def get(self, locator):
935         if 'KEEP_LOCAL_STORE' in os.environ:
936             return KeepClient.local_store_get(locator)
937         expect_hash = re.sub(r'\+.*', '', locator)
938         for service_root in self.shuffled_service_roots(expect_hash):
939             h = httplib2.Http()
940             url = service_root + expect_hash
941             api_token = os.environ['ARVADOS_API_TOKEN']
942             headers = {'Authorization': "OAuth2 %s" % api_token,
943                        'Accept': 'application/octet-stream'}
944             try:
945                 resp, content = h.request(url.encode('utf-8'), 'GET',
946                                           headers=headers)
947                 if re.match(r'^2\d\d$', resp['status']):
948                     m = hashlib.new('md5')
949                     m.update(content)
950                     md5 = m.hexdigest()
951                     if md5 == expect_hash:
952                         return content
953                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
954             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
955                 logging.info("Request fail: GET %s => %s: %s" %
956                              (url, type(e), str(e)))
957         raise Exception("Not found: %s" % expect_hash)
958
959     def put(self, data, **kwargs):
960         if 'KEEP_LOCAL_STORE' in os.environ:
961             return KeepClient.local_store_put(data)
962         m = hashlib.new('md5')
963         m.update(data)
964         data_hash = m.hexdigest()
965         have_copies = 0
966         want_copies = kwargs.get('copies', 2)
967         if not (want_copies > 0):
968             return data_hash
969         threads = []
970         thread_limiter = KeepClient.ThreadLimiter(want_copies)
971         for service_root in self.shuffled_service_roots(data_hash):
972             t = KeepClient.KeepWriterThread(data=data,
973                                             data_hash=data_hash,
974                                             service_root=service_root,
975                                             thread_limiter=thread_limiter)
976             t.start()
977             threads += [t]
978         for t in threads:
979             t.join()
980         have_copies = thread_limiter.done()
981         if have_copies == want_copies:
982             return (data_hash + '+' + str(len(data)))
983         raise Exception("Write fail for %s: wanted %d but wrote %d" %
984                         (data_hash, want_copies, have_copies))
985
986     @staticmethod
987     def sign_for_old_server(data_hash, data):
988         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
989
990
991     @staticmethod
992     def local_store_put(data):
993         m = hashlib.new('md5')
994         m.update(data)
995         md5 = m.hexdigest()
996         locator = '%s+%d' % (md5, len(data))
997         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
998             f.write(data)
999         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1000                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1001         return locator
1002     @staticmethod
1003     def local_store_get(locator):
1004         r = re.search('^([0-9a-f]{32,})', locator)
1005         if not r:
1006             raise Exception("Keep.get: invalid data locator '%s'" % locator)
1007         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
1008             return ''
1009         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1010             return f.read()