Merge branch 'master' of git.clinicalfuture.com:arvados into 1637-improve-arv-tutorial
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 class CredentialsFromEnv(object):
28     @staticmethod
29     def http_request(self, uri, **kwargs):
30         from httplib import BadStatusLine
31         if 'headers' not in kwargs:
32             kwargs['headers'] = {}
33         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
34         try:
35             return self.orig_http_request(uri, **kwargs)
36         except BadStatusLine:
37             # This is how httplib tells us that it tried to reuse an
38             # existing connection but it was already closed by the
39             # server. In that case, yes, we would like to retry.
40             # Unfortunately, we are not absolutely certain that the
41             # previous call did not succeed, so this is slightly
42             # risky.
43             return self.orig_http_request(uri, **kwargs)
44     def authorize(self, http):
45         http.orig_http_request = http.request
46         http.request = types.MethodType(self.http_request, http)
47         return http
48
49 url = ('https://%s/discovery/v1/apis/'
50        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
51 credentials = CredentialsFromEnv()
52
53 # Use system's CA certificates (if we find them) instead of httplib2's
54 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
55 if not os.path.exists(ca_certs):
56     ca_certs = None             # use httplib2 default
57
58 http = httplib2.Http(ca_certs=ca_certs)
59 http = credentials.authorize(http)
60 if re.match(r'(?i)^(true|1|yes)$',
61             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
62     http.disable_ssl_certificate_validation=True
63 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
64
65 def task_set_output(self,s):
66     service.job_tasks().update(uuid=self['uuid'],
67                                job_task=json.dumps({
68                 'output':s,
69                 'success':True,
70                 'progress':1.0
71                 })).execute()
72
73 _current_task = None
74 def current_task():
75     global _current_task
76     if _current_task:
77         return _current_task
78     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
79     t = UserDict.UserDict(t)
80     t.set_output = types.MethodType(task_set_output, t)
81     t.tmpdir = os.environ['TASK_WORK']
82     _current_task = t
83     return t
84
85 _current_job = None
86 def current_job():
87     global _current_job
88     if _current_job:
89         return _current_job
90     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
91     t = UserDict.UserDict(t)
92     t.tmpdir = os.environ['JOB_WORK']
93     _current_job = t
94     return t
95
96 def getjobparam(*args):
97     return current_job()['script_parameters'].get(*args)
98
99 def api():
100     return service
101
102 class JobTask(object):
103     def __init__(self, parameters=dict(), runtime_constraints=dict()):
104         print "init jobtask %s %s" % (parameters, runtime_constraints)
105
106 class job_setup:
107     @staticmethod
108     def one_task_per_input_file(if_sequence=0, and_end_task=True):
109         if if_sequence != current_task()['sequence']:
110             return
111         job_input = current_job()['script_parameters']['input']
112         cr = CollectionReader(job_input)
113         for s in cr.all_streams():
114             for f in s.all_files():
115                 task_input = f.as_manifest()
116                 new_task_attrs = {
117                     'job_uuid': current_job()['uuid'],
118                     'created_by_job_task_uuid': current_task()['uuid'],
119                     'sequence': if_sequence + 1,
120                     'parameters': {
121                         'input':task_input
122                         }
123                     }
124                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
125         if and_end_task:
126             service.job_tasks().update(uuid=current_task()['uuid'],
127                                        job_task=json.dumps({'success':True})
128                                        ).execute()
129             exit(0)
130
131     @staticmethod
132     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
133         if if_sequence != current_task()['sequence']:
134             return
135         job_input = current_job()['script_parameters']['input']
136         cr = CollectionReader(job_input)
137         for s in cr.all_streams():
138             task_input = s.tokens()
139             new_task_attrs = {
140                 'job_uuid': current_job()['uuid'],
141                 'created_by_job_task_uuid': current_task()['uuid'],
142                 'sequence': if_sequence + 1,
143                 'parameters': {
144                     'input':task_input
145                     }
146                 }
147             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
148         if and_end_task:
149             service.job_tasks().update(uuid=current_task()['uuid'],
150                                        job_task=json.dumps({'success':True})
151                                        ).execute()
152             exit(0)
153
154 class util:
155     @staticmethod
156     def clear_tmpdir(path=None):
157         """
158         Ensure the given directory (or TASK_TMPDIR if none given)
159         exists and is empty.
160         """
161         if path == None:
162             path = current_task().tmpdir
163         if os.path.exists(path):
164             p = subprocess.Popen(['rm', '-rf', path])
165             stdout, stderr = p.communicate(None)
166             if p.returncode != 0:
167                 raise Exception('rm -rf %s: %s' % (path, stderr))
168         os.mkdir(path)
169
170     @staticmethod
171     def run_command(execargs, **kwargs):
172         kwargs.setdefault('stdin', subprocess.PIPE)
173         kwargs.setdefault('stdout', subprocess.PIPE)
174         kwargs.setdefault('stderr', sys.stderr)
175         kwargs.setdefault('close_fds', True)
176         kwargs.setdefault('shell', False)
177         p = subprocess.Popen(execargs, **kwargs)
178         stdoutdata, stderrdata = p.communicate(None)
179         if p.returncode != 0:
180             raise Exception("run_command %s exit %d:\n%s" %
181                             (execargs, p.returncode, stderrdata))
182         return stdoutdata, stderrdata
183
184     @staticmethod
185     def git_checkout(url, version, path):
186         if not re.search('^/', path):
187             path = os.path.join(current_job().tmpdir, path)
188         if not os.path.exists(path):
189             util.run_command(["git", "clone", url, path],
190                              cwd=os.path.dirname(path))
191         util.run_command(["git", "checkout", version],
192                          cwd=path)
193         return path
194
195     @staticmethod
196     def tar_extractor(path, decompress_flag):
197         return subprocess.Popen(["tar",
198                                  "-C", path,
199                                  ("-x%sf" % decompress_flag),
200                                  "-"],
201                                 stdout=None,
202                                 stdin=subprocess.PIPE, stderr=sys.stderr,
203                                 shell=False, close_fds=True)
204
205     @staticmethod
206     def tarball_extract(tarball, path):
207         """Retrieve a tarball from Keep and extract it to a local
208         directory.  Return the absolute path where the tarball was
209         extracted. If the top level of the tarball contained just one
210         file or directory, return the absolute path of that single
211         item.
212
213         tarball -- collection locator
214         path -- where to extract the tarball: absolute, or relative to job tmp
215         """
216         if not re.search('^/', path):
217             path = os.path.join(current_job().tmpdir, path)
218         lockfile = open(path + '.lock', 'w')
219         fcntl.flock(lockfile, fcntl.LOCK_EX)
220         try:
221             os.stat(path)
222         except OSError:
223             os.mkdir(path)
224         already_have_it = False
225         try:
226             if os.readlink(os.path.join(path, '.locator')) == tarball:
227                 already_have_it = True
228         except OSError:
229             pass
230         if not already_have_it:
231
232             # emulate "rm -f" (i.e., if the file does not exist, we win)
233             try:
234                 os.unlink(os.path.join(path, '.locator'))
235             except OSError:
236                 if os.path.exists(os.path.join(path, '.locator')):
237                     os.unlink(os.path.join(path, '.locator'))
238
239             for f in CollectionReader(tarball).all_files():
240                 if re.search('\.(tbz|tar.bz2)$', f.name()):
241                     p = util.tar_extractor(path, 'j')
242                 elif re.search('\.(tgz|tar.gz)$', f.name()):
243                     p = util.tar_extractor(path, 'z')
244                 elif re.search('\.tar$', f.name()):
245                     p = util.tar_extractor(path, '')
246                 else:
247                     raise Exception("tarball_extract cannot handle filename %s"
248                                     % f.name())
249                 while True:
250                     buf = f.read(2**20)
251                     if len(buf) == 0:
252                         break
253                     p.stdin.write(buf)
254                 p.stdin.close()
255                 p.wait()
256                 if p.returncode != 0:
257                     lockfile.close()
258                     raise Exception("tar exited %d" % p.returncode)
259             os.symlink(tarball, os.path.join(path, '.locator'))
260         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
261         lockfile.close()
262         if len(tld_extracts) == 1:
263             return os.path.join(path, tld_extracts[0])
264         return path
265
266     @staticmethod
267     def zipball_extract(zipball, path):
268         """Retrieve a zip archive from Keep and extract it to a local
269         directory.  Return the absolute path where the archive was
270         extracted. If the top level of the archive contained just one
271         file or directory, return the absolute path of that single
272         item.
273
274         zipball -- collection locator
275         path -- where to extract the archive: absolute, or relative to job tmp
276         """
277         if not re.search('^/', path):
278             path = os.path.join(current_job().tmpdir, path)
279         lockfile = open(path + '.lock', 'w')
280         fcntl.flock(lockfile, fcntl.LOCK_EX)
281         try:
282             os.stat(path)
283         except OSError:
284             os.mkdir(path)
285         already_have_it = False
286         try:
287             if os.readlink(os.path.join(path, '.locator')) == zipball:
288                 already_have_it = True
289         except OSError:
290             pass
291         if not already_have_it:
292
293             # emulate "rm -f" (i.e., if the file does not exist, we win)
294             try:
295                 os.unlink(os.path.join(path, '.locator'))
296             except OSError:
297                 if os.path.exists(os.path.join(path, '.locator')):
298                     os.unlink(os.path.join(path, '.locator'))
299
300             for f in CollectionReader(zipball).all_files():
301                 if not re.search('\.zip$', f.name()):
302                     raise Exception("zipball_extract cannot handle filename %s"
303                                     % f.name())
304                 zip_filename = os.path.join(path, os.path.basename(f.name()))
305                 zip_file = open(zip_filename, 'wb')
306                 while True:
307                     buf = f.read(2**20)
308                     if len(buf) == 0:
309                         break
310                     zip_file.write(buf)
311                 zip_file.close()
312                 
313                 p = subprocess.Popen(["unzip",
314                                       "-q", "-o",
315                                       "-d", path,
316                                       zip_filename],
317                                      stdout=None,
318                                      stdin=None, stderr=sys.stderr,
319                                      shell=False, close_fds=True)
320                 p.wait()
321                 if p.returncode != 0:
322                     lockfile.close()
323                     raise Exception("unzip exited %d" % p.returncode)
324                 os.unlink(zip_filename)
325             os.symlink(zipball, os.path.join(path, '.locator'))
326         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
327         lockfile.close()
328         if len(tld_extracts) == 1:
329             return os.path.join(path, tld_extracts[0])
330         return path
331
332     @staticmethod
333     def collection_extract(collection, path, files=[], decompress=True):
334         """Retrieve a collection from Keep and extract it to a local
335         directory.  Return the absolute path where the collection was
336         extracted.
337
338         collection -- collection locator
339         path -- where to extract: absolute, or relative to job tmp
340         """
341         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
342         if matches:
343             collection_hash = matches.group(1)
344         else:
345             collection_hash = hashlib.md5(collection).hexdigest()
346         if not re.search('^/', path):
347             path = os.path.join(current_job().tmpdir, path)
348         lockfile = open(path + '.lock', 'w')
349         fcntl.flock(lockfile, fcntl.LOCK_EX)
350         try:
351             os.stat(path)
352         except OSError:
353             os.mkdir(path)
354         already_have_it = False
355         try:
356             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
357                 already_have_it = True
358         except OSError:
359             pass
360
361         # emulate "rm -f" (i.e., if the file does not exist, we win)
362         try:
363             os.unlink(os.path.join(path, '.locator'))
364         except OSError:
365             if os.path.exists(os.path.join(path, '.locator')):
366                 os.unlink(os.path.join(path, '.locator'))
367
368         files_got = []
369         for s in CollectionReader(collection).all_streams():
370             stream_name = s.name()
371             for f in s.all_files():
372                 if (files == [] or
373                     ((f.name() not in files_got) and
374                      (f.name() in files or
375                       (decompress and f.decompressed_name() in files)))):
376                     outname = f.decompressed_name() if decompress else f.name()
377                     files_got += [outname]
378                     if os.path.exists(os.path.join(path, stream_name, outname)):
379                         continue
380                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
381                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
382                     for buf in (f.readall_decompressed() if decompress
383                                 else f.readall()):
384                         outfile.write(buf)
385                     outfile.close()
386         if len(files_got) < len(files):
387             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
388         os.symlink(collection_hash, os.path.join(path, '.locator'))
389
390         lockfile.close()
391         return path
392
393     @staticmethod
394     def mkdir_dash_p(path):
395         if not os.path.exists(path):
396             util.mkdir_dash_p(os.path.dirname(path))
397             try:
398                 os.mkdir(path)
399             except OSError:
400                 if not os.path.exists(path):
401                     os.mkdir(path)
402
403     @staticmethod
404     def stream_extract(stream, path, files=[], decompress=True):
405         """Retrieve a stream from Keep and extract it to a local
406         directory.  Return the absolute path where the stream was
407         extracted.
408
409         stream -- StreamReader object
410         path -- where to extract: absolute, or relative to job tmp
411         """
412         if not re.search('^/', path):
413             path = os.path.join(current_job().tmpdir, path)
414         lockfile = open(path + '.lock', 'w')
415         fcntl.flock(lockfile, fcntl.LOCK_EX)
416         try:
417             os.stat(path)
418         except OSError:
419             os.mkdir(path)
420
421         files_got = []
422         for f in stream.all_files():
423             if (files == [] or
424                 ((f.name() not in files_got) and
425                  (f.name() in files or
426                   (decompress and f.decompressed_name() in files)))):
427                 outname = f.decompressed_name() if decompress else f.name()
428                 files_got += [outname]
429                 if os.path.exists(os.path.join(path, outname)):
430                     os.unlink(os.path.join(path, outname))
431                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
432                 outfile = open(os.path.join(path, outname), 'wb')
433                 for buf in (f.readall_decompressed() if decompress
434                             else f.readall()):
435                     outfile.write(buf)
436                 outfile.close()
437         if len(files_got) < len(files):
438             raise Exception("Wanted files %s but only got %s from %s" %
439                             (files, files_got, map(lambda z: z.name(),
440                                                    list(stream.all_files()))))
441         lockfile.close()
442         return path
443
444     @staticmethod
445     def listdir_recursive(dirname, base=None):
446         allfiles = []
447         for ent in sorted(os.listdir(dirname)):
448             ent_path = os.path.join(dirname, ent)
449             ent_base = os.path.join(base, ent) if base else ent
450             if os.path.isdir(ent_path):
451                 allfiles += util.listdir_recursive(ent_path, ent_base)
452             else:
453                 allfiles += [ent_base]
454         return allfiles
455
456 class StreamFileReader(object):
457     def __init__(self, stream, pos, size, name):
458         self._stream = stream
459         self._pos = pos
460         self._size = size
461         self._name = name
462         self._filepos = 0
463     def name(self):
464         return self._name
465     def decompressed_name(self):
466         return re.sub('\.(bz2|gz)$', '', self._name)
467     def size(self):
468         return self._size
469     def stream_name(self):
470         return self._stream.name()
471     def read(self, size, **kwargs):
472         self._stream.seek(self._pos + self._filepos)
473         data = self._stream.read(min(size, self._size - self._filepos))
474         self._filepos += len(data)
475         return data
476     def readall(self, size=2**20, **kwargs):
477         while True:
478             data = self.read(size, **kwargs)
479             if data == '':
480                 break
481             yield data
482     def bunzip2(self, size):
483         decompressor = bz2.BZ2Decompressor()
484         for chunk in self.readall(size):
485             data = decompressor.decompress(chunk)
486             if data and data != '':
487                 yield data
488     def gunzip(self, size):
489         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
490         for chunk in self.readall(size):
491             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
492             if data and data != '':
493                 yield data
494     def readall_decompressed(self, size=2**20):
495         self._stream.seek(self._pos + self._filepos)
496         if re.search('\.bz2$', self._name):
497             return self.bunzip2(size)
498         elif re.search('\.gz$', self._name):
499             return self.gunzip(size)
500         else:
501             return self.readall(size)
502     def readlines(self, decompress=True):
503         if decompress:
504             datasource = self.readall_decompressed()
505         else:
506             self._stream.seek(self._pos + self._filepos)
507             datasource = self.readall()
508         data = ''
509         for newdata in datasource:
510             data += newdata
511             sol = 0
512             while True:
513                 eol = string.find(data, "\n", sol)
514                 if eol < 0:
515                     break
516                 yield data[sol:eol+1]
517                 sol = eol+1
518             data = data[sol:]
519         if data != '':
520             yield data
521     def as_manifest(self):
522         if self.size() == 0:
523             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
524                     % (self._stream.name(), self.name()))
525         return string.join(self._stream.tokens_for_range(self._pos, self._size),
526                            " ") + "\n"
527
528 class StreamReader(object):
529     def __init__(self, tokens):
530         self._tokens = tokens
531         self._current_datablock_data = None
532         self._current_datablock_pos = 0
533         self._current_datablock_index = -1
534         self._pos = 0
535
536         self._stream_name = None
537         self.data_locators = []
538         self.files = []
539
540         for tok in self._tokens:
541             if self._stream_name == None:
542                 self._stream_name = tok
543             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
544                 self.data_locators += [tok]
545             elif re.search(r'^\d+:\d+:\S+', tok):
546                 pos, size, name = tok.split(':',2)
547                 self.files += [[int(pos), int(size), name]]
548             else:
549                 raise Exception("Invalid manifest format")
550
551     def tokens(self):
552         return self._tokens
553     def tokens_for_range(self, range_start, range_size):
554         resp = [self._stream_name]
555         return_all_tokens = False
556         block_start = 0
557         token_bytes_skipped = 0
558         for locator in self.data_locators:
559             sizehint = re.search(r'\+(\d+)', locator)
560             if not sizehint:
561                 return_all_tokens = True
562             if return_all_tokens:
563                 resp += [locator]
564                 next
565             blocksize = int(sizehint.group(0))
566             if range_start + range_size <= block_start:
567                 break
568             if range_start < block_start + blocksize:
569                 resp += [locator]
570             else:
571                 token_bytes_skipped += blocksize
572             block_start += blocksize
573         for f in self.files:
574             if ((f[0] < range_start + range_size)
575                 and
576                 (f[0] + f[1] > range_start)
577                 and
578                 f[1] > 0):
579                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
580         return resp
581     def name(self):
582         return self._stream_name
583     def all_files(self):
584         for f in self.files:
585             pos, size, name = f
586             yield StreamFileReader(self, pos, size, name)
587     def nextdatablock(self):
588         if self._current_datablock_index < 0:
589             self._current_datablock_pos = 0
590             self._current_datablock_index = 0
591         else:
592             self._current_datablock_pos += self.current_datablock_size()
593             self._current_datablock_index += 1
594         self._current_datablock_data = None
595     def current_datablock_data(self):
596         if self._current_datablock_data == None:
597             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
598         return self._current_datablock_data
599     def current_datablock_size(self):
600         if self._current_datablock_index < 0:
601             self.nextdatablock()
602         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
603         if sizehint:
604             return int(sizehint.group(0))
605         return len(self.current_datablock_data())
606     def seek(self, pos):
607         """Set the position of the next read operation."""
608         self._pos = pos
609     def really_seek(self):
610         """Find and load the appropriate data block, so the byte at
611         _pos is in memory.
612         """
613         if self._pos == self._current_datablock_pos:
614             return True
615         if (self._current_datablock_pos != None and
616             self._pos >= self._current_datablock_pos and
617             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
618             return True
619         if self._pos < self._current_datablock_pos:
620             self._current_datablock_index = -1
621             self.nextdatablock()
622         while (self._pos > self._current_datablock_pos and
623                self._pos > self._current_datablock_pos + self.current_datablock_size()):
624             self.nextdatablock()
625     def read(self, size):
626         """Read no more than size bytes -- but at least one byte,
627         unless _pos is already at the end of the stream.
628         """
629         if size == 0:
630             return ''
631         self.really_seek()
632         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
633             self.nextdatablock()
634             if self._current_datablock_index >= len(self.data_locators):
635                 return None
636         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
637         self._pos += len(data)
638         return data
639
640 class CollectionReader(object):
641     def __init__(self, manifest_locator_or_text):
642         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
643             self._manifest_text = manifest_locator_or_text
644             self._manifest_locator = None
645         else:
646             self._manifest_locator = manifest_locator_or_text
647             self._manifest_text = None
648         self._streams = None
649     def __enter__(self):
650         pass
651     def __exit__(self):
652         pass
653     def _populate(self):
654         if self._streams != None:
655             return
656         if not self._manifest_text:
657             self._manifest_text = Keep.get(self._manifest_locator)
658         self._streams = []
659         for stream_line in self._manifest_text.split("\n"):
660             if stream_line != '':
661                 stream_tokens = stream_line.split()
662                 self._streams += [stream_tokens]
663     def all_streams(self):
664         self._populate()
665         resp = []
666         for s in self._streams:
667             resp += [StreamReader(s)]
668         return resp
669     def all_files(self):
670         for s in self.all_streams():
671             for f in s.all_files():
672                 yield f
673     def manifest_text(self):
674         self._populate()
675         return self._manifest_text
676
677 class CollectionWriter(object):
678     KEEP_BLOCK_SIZE = 2**26
679     def __init__(self):
680         self._data_buffer = []
681         self._data_buffer_len = 0
682         self._current_stream_files = []
683         self._current_stream_length = 0
684         self._current_stream_locators = []
685         self._current_stream_name = '.'
686         self._current_file_name = None
687         self._current_file_pos = 0
688         self._finished_streams = []
689     def __enter__(self):
690         pass
691     def __exit__(self):
692         self.finish()
693     def write_directory_tree(self,
694                              path, stream_name='.', max_manifest_depth=-1):
695         self.start_new_stream(stream_name)
696         todo = []
697         if max_manifest_depth == 0:
698             dirents = sorted(util.listdir_recursive(path))
699         else:
700             dirents = sorted(os.listdir(path))
701         for dirent in dirents:
702             target = os.path.join(path, dirent)
703             if os.path.isdir(target):
704                 todo += [[target,
705                           os.path.join(stream_name, dirent),
706                           max_manifest_depth-1]]
707             else:
708                 self.start_new_file(dirent)
709                 with open(target, 'rb') as f:
710                     while True:
711                         buf = f.read(2**26)
712                         if len(buf) == 0:
713                             break
714                         self.write(buf)
715         self.finish_current_stream()
716         map(lambda x: self.write_directory_tree(*x), todo)
717
718     def write(self, newdata):
719         if hasattr(newdata, '__iter__'):
720             for s in newdata:
721                 self.write(s)
722             return
723         self._data_buffer += [newdata]
724         self._data_buffer_len += len(newdata)
725         self._current_stream_length += len(newdata)
726         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
727             self.flush_data()
728     def flush_data(self):
729         data_buffer = ''.join(self._data_buffer)
730         if data_buffer != '':
731             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
732             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
733             self._data_buffer_len = len(self._data_buffer[0])
734     def start_new_file(self, newfilename=None):
735         self.finish_current_file()
736         self.set_current_file_name(newfilename)
737     def set_current_file_name(self, newfilename):
738         newfilename = re.sub(r' ', '\\\\040', newfilename)
739         if re.search(r'[ \t\n]', newfilename):
740             raise AssertionError("Manifest filenames cannot contain whitespace")
741         self._current_file_name = newfilename
742     def current_file_name(self):
743         return self._current_file_name
744     def finish_current_file(self):
745         if self._current_file_name == None:
746             if self._current_file_pos == self._current_stream_length:
747                 return
748             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
749         self._current_stream_files += [[self._current_file_pos,
750                                        self._current_stream_length - self._current_file_pos,
751                                        self._current_file_name]]
752         self._current_file_pos = self._current_stream_length
753     def start_new_stream(self, newstreamname='.'):
754         self.finish_current_stream()
755         self.set_current_stream_name(newstreamname)
756     def set_current_stream_name(self, newstreamname):
757         if re.search(r'[ \t\n]', newstreamname):
758             raise AssertionError("Manifest stream names cannot contain whitespace")
759         self._current_stream_name = '.' if newstreamname=='' else newstreamname
760     def current_stream_name(self):
761         return self._current_stream_name
762     def finish_current_stream(self):
763         self.finish_current_file()
764         self.flush_data()
765         if len(self._current_stream_files) == 0:
766             pass
767         elif self._current_stream_name == None:
768             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
769         else:
770             self._finished_streams += [[self._current_stream_name,
771                                        self._current_stream_locators,
772                                        self._current_stream_files]]
773         self._current_stream_files = []
774         self._current_stream_length = 0
775         self._current_stream_locators = []
776         self._current_stream_name = None
777         self._current_file_pos = 0
778         self._current_file_name = None
779     def finish(self):
780         return Keep.put(self.manifest_text())
781     def manifest_text(self):
782         self.finish_current_stream()
783         manifest = ''
784         for stream in self._finished_streams:
785             if not re.search(r'^\.(/.*)?$', stream[0]):
786                 manifest += './'
787             manifest += stream[0]
788             if len(stream[1]) == 0:
789                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
790             else:
791                 for locator in stream[1]:
792                     manifest += " %s" % locator
793             for sfile in stream[2]:
794                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
795             manifest += "\n"
796         return manifest
797
798 global_client_object = None
799
800 class Keep:
801     @staticmethod
802     def global_client_object():
803         global global_client_object
804         if global_client_object == None:
805             global_client_object = KeepClient()
806         return global_client_object
807
808     @staticmethod
809     def get(locator, **kwargs):
810         return Keep.global_client_object().get(locator, **kwargs)
811
812     @staticmethod
813     def put(data, **kwargs):
814         return Keep.global_client_object().put(data, **kwargs)
815
816 class KeepClient(object):
817
818     class ThreadLimiter(object):
819         """
820         Limit the number of threads running at a given time to
821         {desired successes} minus {successes reported}. When successes
822         reported == desired, wake up the remaining threads and tell
823         them to quit.
824
825         Should be used in a "with" block.
826         """
827         def __init__(self, todo):
828             self._todo = todo
829             self._done = 0
830             self._todo_lock = threading.Semaphore(todo)
831             self._done_lock = threading.Lock()
832         def __enter__(self):
833             self._todo_lock.acquire()
834             return self
835         def __exit__(self, type, value, traceback):
836             self._todo_lock.release()
837         def shall_i_proceed(self):
838             """
839             Return true if the current thread should do stuff. Return
840             false if the current thread should just stop.
841             """
842             with self._done_lock:
843                 return (self._done < self._todo)
844         def increment_done(self):
845             """
846             Report that the current thread was successful.
847             """
848             with self._done_lock:
849                 self._done += 1
850         def done(self):
851             """
852             Return how many successes were reported.
853             """
854             with self._done_lock:
855                 return self._done
856
857     class KeepWriterThread(threading.Thread):
858         """
859         Write a blob of data to the given Keep server. Call
860         increment_done() of the given ThreadLimiter if the write
861         succeeds.
862         """
863         def __init__(self, **kwargs):
864             super(KeepClient.KeepWriterThread, self).__init__()
865             self.args = kwargs
866         def run(self):
867             with self.args['thread_limiter'] as limiter:
868                 if not limiter.shall_i_proceed():
869                     # My turn arrived, but the job has been done without
870                     # me.
871                     return
872                 logging.debug("KeepWriterThread %s proceeding %s %s" %
873                               (str(threading.current_thread()),
874                                self.args['data_hash'],
875                                self.args['service_root']))
876                 h = httplib2.Http()
877                 url = self.args['service_root'] + self.args['data_hash']
878                 api_token = os.environ['ARVADOS_API_TOKEN']
879                 headers = {'Authorization': "OAuth2 %s" % api_token}
880                 try:
881                     resp, content = h.request(url.encode('utf-8'), 'PUT',
882                                               headers=headers,
883                                               body=self.args['data'])
884                     if (resp['status'] == '401' and
885                         re.match(r'Timestamp verification failed', content)):
886                         body = KeepClient.sign_for_old_server(
887                             self.args['data_hash'],
888                             self.args['data'])
889                         h = httplib2.Http()
890                         resp, content = h.request(url.encode('utf-8'), 'PUT',
891                                                   headers=headers,
892                                                   body=body)
893                     if re.match(r'^2\d\d$', resp['status']):
894                         logging.debug("KeepWriterThread %s succeeded %s %s" %
895                                       (str(threading.current_thread()),
896                                        self.args['data_hash'],
897                                        self.args['service_root']))
898                         return limiter.increment_done()
899                     logging.warning("Request fail: PUT %s => %s %s" %
900                                     (url, resp['status'], content))
901                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
902                     logging.warning("Request fail: PUT %s => %s: %s" %
903                                     (url, type(e), str(e)))
904
905     def __init__(self):
906         self.lock = threading.Lock()
907         self.service_roots = None
908
909     def shuffled_service_roots(self, hash):
910         if self.service_roots == None:
911             self.lock.acquire()
912             keep_disks = api().keep_disks().list().execute()['items']
913             roots = (("http%s://%s:%d/" %
914                       ('s' if f['service_ssl_flag'] else '',
915                        f['service_host'],
916                        f['service_port']))
917                      for f in keep_disks)
918             self.service_roots = sorted(set(roots))
919             logging.debug(str(self.service_roots))
920             self.lock.release()
921         seed = hash
922         pool = self.service_roots[:]
923         pseq = []
924         while len(pool) > 0:
925             if len(seed) < 8:
926                 if len(pseq) < len(hash) / 4: # first time around
927                     seed = hash[-4:] + hash
928                 else:
929                     seed += hash
930             probe = int(seed[0:8], 16) % len(pool)
931             pseq += [pool[probe]]
932             pool = pool[:probe] + pool[probe+1:]
933             seed = seed[8:]
934         logging.debug(str(pseq))
935         return pseq
936
937     def get(self, locator):
938         if 'KEEP_LOCAL_STORE' in os.environ:
939             return KeepClient.local_store_get(locator)
940         expect_hash = re.sub(r'\+.*', '', locator)
941         for service_root in self.shuffled_service_roots(expect_hash):
942             h = httplib2.Http()
943             url = service_root + expect_hash
944             api_token = os.environ['ARVADOS_API_TOKEN']
945             headers = {'Authorization': "OAuth2 %s" % api_token,
946                        'Accept': 'application/octet-stream'}
947             try:
948                 resp, content = h.request(url.encode('utf-8'), 'GET',
949                                           headers=headers)
950                 if re.match(r'^2\d\d$', resp['status']):
951                     m = hashlib.new('md5')
952                     m.update(content)
953                     md5 = m.hexdigest()
954                     if md5 == expect_hash:
955                         return content
956                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
957             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
958                 logging.info("Request fail: GET %s => %s: %s" %
959                              (url, type(e), str(e)))
960         raise Exception("Not found: %s" % expect_hash)
961
962     def put(self, data, **kwargs):
963         if 'KEEP_LOCAL_STORE' in os.environ:
964             return KeepClient.local_store_put(data)
965         m = hashlib.new('md5')
966         m.update(data)
967         data_hash = m.hexdigest()
968         have_copies = 0
969         want_copies = kwargs.get('copies', 2)
970         if not (want_copies > 0):
971             return data_hash
972         threads = []
973         thread_limiter = KeepClient.ThreadLimiter(want_copies)
974         for service_root in self.shuffled_service_roots(data_hash):
975             t = KeepClient.KeepWriterThread(data=data,
976                                             data_hash=data_hash,
977                                             service_root=service_root,
978                                             thread_limiter=thread_limiter)
979             t.start()
980             threads += [t]
981         for t in threads:
982             t.join()
983         have_copies = thread_limiter.done()
984         if have_copies == want_copies:
985             return (data_hash + '+' + str(len(data)))
986         raise Exception("Write fail for %s: wanted %d but wrote %d" %
987                         (data_hash, want_copies, have_copies))
988
989     @staticmethod
990     def sign_for_old_server(data_hash, data):
991         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
992
993
994     @staticmethod
995     def local_store_put(data):
996         m = hashlib.new('md5')
997         m.update(data)
998         md5 = m.hexdigest()
999         locator = '%s+%d' % (md5, len(data))
1000         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1001             f.write(data)
1002         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1003                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1004         return locator
1005     @staticmethod
1006     def local_store_get(locator):
1007         r = re.search('^([0-9a-f]{32,})', locator)
1008         if not r:
1009             raise Exception("Keep.get: invalid data locator '%s'" % locator)
1010         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
1011             return ''
1012         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1013             return f.read()