Merge branch '1579-keep-server-in-docker' of git.clinicalfuture.com:arvados into...
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19
20 from apiclient import errors
21 from apiclient.discovery import build
22
23 if 'ARVADOS_DEBUG' in os.environ:
24     logging.basicConfig(level=logging.DEBUG)
25
26 class CredentialsFromEnv:
27     @staticmethod
28     def http_request(self, uri, **kwargs):
29         from httplib import BadStatusLine
30         if 'headers' not in kwargs:
31             kwargs['headers'] = {}
32         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
33         try:
34             return self.orig_http_request(uri, **kwargs)
35         except BadStatusLine:
36             # This is how httplib tells us that it tried to reuse an
37             # existing connection but it was already closed by the
38             # server. In that case, yes, we would like to retry.
39             # Unfortunately, we are not absolutely certain that the
40             # previous call did not succeed, so this is slightly
41             # risky.
42             return self.orig_http_request(uri, **kwargs)
43     def authorize(self, http):
44         http.orig_http_request = http.request
45         http.request = types.MethodType(self.http_request, http)
46         return http
47
48 url = ('https://%s/discovery/v1/apis/'
49        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
50 credentials = CredentialsFromEnv()
51
52 # Use system's CA certificates (if we find them) instead of httplib2's
53 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
54 if not os.path.exists(ca_certs):
55     ca_certs = None             # use httplib2 default
56
57 http = httplib2.Http(ca_certs=ca_certs)
58 http = credentials.authorize(http)
59 if re.match(r'(?i)^(true|1|yes)$',
60             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
61     http.disable_ssl_certificate_validation=True
62 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
63
64 def task_set_output(self,s):
65     service.job_tasks().update(uuid=self['uuid'],
66                                job_task=json.dumps({
67                 'output':s,
68                 'success':True,
69                 'progress':1.0
70                 })).execute()
71
72 _current_task = None
73 def current_task():
74     global _current_task
75     if _current_task:
76         return _current_task
77     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
78     t = UserDict.UserDict(t)
79     t.set_output = types.MethodType(task_set_output, t)
80     t.tmpdir = os.environ['TASK_WORK']
81     _current_task = t
82     return t
83
84 _current_job = None
85 def current_job():
86     global _current_job
87     if _current_job:
88         return _current_job
89     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
90     t = UserDict.UserDict(t)
91     t.tmpdir = os.environ['JOB_WORK']
92     _current_job = t
93     return t
94
95 def api():
96     return service
97
98 class JobTask:
99     def __init__(self, parameters=dict(), runtime_constraints=dict()):
100         print "init jobtask %s %s" % (parameters, runtime_constraints)
101
102 class job_setup:
103     @staticmethod
104     def one_task_per_input_file(if_sequence=0, and_end_task=True):
105         if if_sequence != current_task()['sequence']:
106             return
107         job_input = current_job()['script_parameters']['input']
108         cr = CollectionReader(job_input)
109         for s in cr.all_streams():
110             for f in s.all_files():
111                 task_input = f.as_manifest()
112                 new_task_attrs = {
113                     'job_uuid': current_job()['uuid'],
114                     'created_by_job_task_uuid': current_task()['uuid'],
115                     'sequence': if_sequence + 1,
116                     'parameters': {
117                         'input':task_input
118                         }
119                     }
120                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
121         if and_end_task:
122             service.job_tasks().update(uuid=current_task()['uuid'],
123                                        job_task=json.dumps({'success':True})
124                                        ).execute()
125             exit(0)
126
127     @staticmethod
128     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
129         if if_sequence != current_task()['sequence']:
130             return
131         job_input = current_job()['script_parameters']['input']
132         cr = CollectionReader(job_input)
133         for s in cr.all_streams():
134             task_input = s.tokens()
135             new_task_attrs = {
136                 'job_uuid': current_job()['uuid'],
137                 'created_by_job_task_uuid': current_task()['uuid'],
138                 'sequence': if_sequence + 1,
139                 'parameters': {
140                     'input':task_input
141                     }
142                 }
143             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
144         if and_end_task:
145             service.job_tasks().update(uuid=current_task()['uuid'],
146                                        job_task=json.dumps({'success':True})
147                                        ).execute()
148             exit(0)
149
150 class util:
151     @staticmethod
152     def run_command(execargs, **kwargs):
153         if 'stdin' not in kwargs:
154             kwargs['stdin'] = subprocess.PIPE
155         if 'stdout' not in kwargs:
156             kwargs['stdout'] = subprocess.PIPE
157         if 'stderr' not in kwargs:
158             kwargs['stderr'] = subprocess.PIPE
159         p = subprocess.Popen(execargs, close_fds=True, shell=False,
160                              **kwargs)
161         stdoutdata, stderrdata = p.communicate(None)
162         if p.returncode != 0:
163             raise Exception("run_command %s exit %d:\n%s" %
164                             (execargs, p.returncode, stderrdata))
165         return stdoutdata, stderrdata
166
167     @staticmethod
168     def git_checkout(url, version, path):
169         if not re.search('^/', path):
170             path = os.path.join(current_job().tmpdir, path)
171         if not os.path.exists(path):
172             util.run_command(["git", "clone", url, path],
173                              cwd=os.path.dirname(path))
174         util.run_command(["git", "checkout", version],
175                          cwd=path)
176         return path
177
178     @staticmethod
179     def tar_extractor(path, decompress_flag):
180         return subprocess.Popen(["tar",
181                                  "-C", path,
182                                  ("-x%sf" % decompress_flag),
183                                  "-"],
184                                 stdout=None,
185                                 stdin=subprocess.PIPE, stderr=sys.stderr,
186                                 shell=False, close_fds=True)
187
188     @staticmethod
189     def tarball_extract(tarball, path):
190         """Retrieve a tarball from Keep and extract it to a local
191         directory.  Return the absolute path where the tarball was
192         extracted. If the top level of the tarball contained just one
193         file or directory, return the absolute path of that single
194         item.
195
196         tarball -- collection locator
197         path -- where to extract the tarball: absolute, or relative to job tmp
198         """
199         if not re.search('^/', path):
200             path = os.path.join(current_job().tmpdir, path)
201         lockfile = open(path + '.lock', 'w')
202         fcntl.flock(lockfile, fcntl.LOCK_EX)
203         try:
204             os.stat(path)
205         except OSError:
206             os.mkdir(path)
207         already_have_it = False
208         try:
209             if os.readlink(os.path.join(path, '.locator')) == tarball:
210                 already_have_it = True
211         except OSError:
212             pass
213         if not already_have_it:
214
215             # emulate "rm -f" (i.e., if the file does not exist, we win)
216             try:
217                 os.unlink(os.path.join(path, '.locator'))
218             except OSError:
219                 if os.path.exists(os.path.join(path, '.locator')):
220                     os.unlink(os.path.join(path, '.locator'))
221
222             for f in CollectionReader(tarball).all_files():
223                 if re.search('\.(tbz|tar.bz2)$', f.name()):
224                     p = util.tar_extractor(path, 'j')
225                 elif re.search('\.(tgz|tar.gz)$', f.name()):
226                     p = util.tar_extractor(path, 'z')
227                 elif re.search('\.tar$', f.name()):
228                     p = util.tar_extractor(path, '')
229                 else:
230                     raise Exception("tarball_extract cannot handle filename %s"
231                                     % f.name())
232                 while True:
233                     buf = f.read(2**20)
234                     if len(buf) == 0:
235                         break
236                     p.stdin.write(buf)
237                 p.stdin.close()
238                 p.wait()
239                 if p.returncode != 0:
240                     lockfile.close()
241                     raise Exception("tar exited %d" % p.returncode)
242             os.symlink(tarball, os.path.join(path, '.locator'))
243         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
244         lockfile.close()
245         if len(tld_extracts) == 1:
246             return os.path.join(path, tld_extracts[0])
247         return path
248
249     @staticmethod
250     def zipball_extract(zipball, path):
251         """Retrieve a zip archive from Keep and extract it to a local
252         directory.  Return the absolute path where the archive was
253         extracted. If the top level of the archive contained just one
254         file or directory, return the absolute path of that single
255         item.
256
257         zipball -- collection locator
258         path -- where to extract the archive: absolute, or relative to job tmp
259         """
260         if not re.search('^/', path):
261             path = os.path.join(current_job().tmpdir, path)
262         lockfile = open(path + '.lock', 'w')
263         fcntl.flock(lockfile, fcntl.LOCK_EX)
264         try:
265             os.stat(path)
266         except OSError:
267             os.mkdir(path)
268         already_have_it = False
269         try:
270             if os.readlink(os.path.join(path, '.locator')) == zipball:
271                 already_have_it = True
272         except OSError:
273             pass
274         if not already_have_it:
275
276             # emulate "rm -f" (i.e., if the file does not exist, we win)
277             try:
278                 os.unlink(os.path.join(path, '.locator'))
279             except OSError:
280                 if os.path.exists(os.path.join(path, '.locator')):
281                     os.unlink(os.path.join(path, '.locator'))
282
283             for f in CollectionReader(zipball).all_files():
284                 if not re.search('\.zip$', f.name()):
285                     raise Exception("zipball_extract cannot handle filename %s"
286                                     % f.name())
287                 zip_filename = os.path.join(path, os.path.basename(f.name()))
288                 zip_file = open(zip_filename, 'wb')
289                 while True:
290                     buf = f.read(2**20)
291                     if len(buf) == 0:
292                         break
293                     zip_file.write(buf)
294                 zip_file.close()
295                 
296                 p = subprocess.Popen(["unzip",
297                                       "-q", "-o",
298                                       "-d", path,
299                                       zip_filename],
300                                      stdout=None,
301                                      stdin=None, stderr=sys.stderr,
302                                      shell=False, close_fds=True)
303                 p.wait()
304                 if p.returncode != 0:
305                     lockfile.close()
306                     raise Exception("unzip exited %d" % p.returncode)
307                 os.unlink(zip_filename)
308             os.symlink(zipball, os.path.join(path, '.locator'))
309         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
310         lockfile.close()
311         if len(tld_extracts) == 1:
312             return os.path.join(path, tld_extracts[0])
313         return path
314
315     @staticmethod
316     def collection_extract(collection, path, files=[], decompress=True):
317         """Retrieve a collection from Keep and extract it to a local
318         directory.  Return the absolute path where the collection was
319         extracted.
320
321         collection -- collection locator
322         path -- where to extract: absolute, or relative to job tmp
323         """
324         if not re.search('^/', path):
325             path = os.path.join(current_job().tmpdir, path)
326         lockfile = open(path + '.lock', 'w')
327         fcntl.flock(lockfile, fcntl.LOCK_EX)
328         try:
329             os.stat(path)
330         except OSError:
331             os.mkdir(path)
332         already_have_it = False
333         try:
334             if os.readlink(os.path.join(path, '.locator')) == collection:
335                 already_have_it = True
336         except OSError:
337             pass
338
339         # emulate "rm -f" (i.e., if the file does not exist, we win)
340         try:
341             os.unlink(os.path.join(path, '.locator'))
342         except OSError:
343             if os.path.exists(os.path.join(path, '.locator')):
344                 os.unlink(os.path.join(path, '.locator'))
345
346         files_got = []
347         for s in CollectionReader(collection).all_streams():
348             stream_name = s.name()
349             for f in s.all_files():
350                 if (files == [] or
351                     ((f.name() not in files_got) and
352                      (f.name() in files or
353                       (decompress and f.decompressed_name() in files)))):
354                     outname = f.decompressed_name() if decompress else f.name()
355                     files_got += [outname]
356                     if os.path.exists(os.path.join(path, stream_name, outname)):
357                         continue
358                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
359                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
360                     for buf in (f.readall_decompressed() if decompress
361                                 else f.readall()):
362                         outfile.write(buf)
363                     outfile.close()
364         if len(files_got) < len(files):
365             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
366         os.symlink(collection, os.path.join(path, '.locator'))
367
368         lockfile.close()
369         return path
370
371     @staticmethod
372     def mkdir_dash_p(path):
373         if not os.path.exists(path):
374             util.mkdir_dash_p(os.path.dirname(path))
375             try:
376                 os.mkdir(path)
377             except OSError:
378                 if not os.path.exists(path):
379                     os.mkdir(path)
380
381     @staticmethod
382     def stream_extract(stream, path, files=[], decompress=True):
383         """Retrieve a stream from Keep and extract it to a local
384         directory.  Return the absolute path where the stream was
385         extracted.
386
387         stream -- StreamReader object
388         path -- where to extract: absolute, or relative to job tmp
389         """
390         if not re.search('^/', path):
391             path = os.path.join(current_job().tmpdir, path)
392         lockfile = open(path + '.lock', 'w')
393         fcntl.flock(lockfile, fcntl.LOCK_EX)
394         try:
395             os.stat(path)
396         except OSError:
397             os.mkdir(path)
398
399         files_got = []
400         for f in stream.all_files():
401             if (files == [] or
402                 ((f.name() not in files_got) and
403                  (f.name() in files or
404                   (decompress and f.decompressed_name() in files)))):
405                 outname = f.decompressed_name() if decompress else f.name()
406                 files_got += [outname]
407                 if os.path.exists(os.path.join(path, outname)):
408                     os.unlink(os.path.join(path, outname))
409                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
410                 outfile = open(os.path.join(path, outname), 'wb')
411                 for buf in (f.readall_decompressed() if decompress
412                             else f.readall()):
413                     outfile.write(buf)
414                 outfile.close()
415         if len(files_got) < len(files):
416             raise Exception("Wanted files %s but only got %s from %s" %
417                             (files, files_got, map(lambda z: z.name(),
418                                                    list(stream.all_files()))))
419         lockfile.close()
420         return path
421
422     @staticmethod
423     def listdir_recursive(dirname, base=None):
424         allfiles = []
425         for ent in sorted(os.listdir(dirname)):
426             ent_path = os.path.join(dirname, ent)
427             ent_base = os.path.join(base, ent) if base else ent
428             if os.path.isdir(ent_path):
429                 allfiles += util.listdir_recursive(ent_path, ent_base)
430             else:
431                 allfiles += [ent_base]
432         return allfiles
433
434 class DataReader:
435     def __init__(self, data_locator):
436         self.data_locator = data_locator
437         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
438                                   stdout=subprocess.PIPE,
439                                   stdin=None, stderr=subprocess.PIPE,
440                                   shell=False, close_fds=True)
441     def __enter__(self):
442         pass
443     def __exit__(self):
444         self.close()
445     def read(self, size, **kwargs):
446         return self.p.stdout.read(size, **kwargs)
447     def close(self):
448         self.p.stdout.close()
449         if not self.p.stderr.closed:
450             for err in self.p.stderr:
451                 print >> sys.stderr, err
452             self.p.stderr.close()
453         self.p.wait()
454         if self.p.returncode != 0:
455             raise Exception("whget subprocess exited %d" % self.p.returncode)
456
457 class StreamFileReader:
458     def __init__(self, stream, pos, size, name):
459         self._stream = stream
460         self._pos = pos
461         self._size = size
462         self._name = name
463         self._filepos = 0
464     def name(self):
465         return self._name
466     def decompressed_name(self):
467         return re.sub('\.(bz2|gz)$', '', self._name)
468     def size(self):
469         return self._size
470     def stream_name(self):
471         return self._stream.name()
472     def read(self, size, **kwargs):
473         self._stream.seek(self._pos + self._filepos)
474         data = self._stream.read(min(size, self._size - self._filepos))
475         self._filepos += len(data)
476         return data
477     def readall(self, size=2**20, **kwargs):
478         while True:
479             data = self.read(size, **kwargs)
480             if data == '':
481                 break
482             yield data
483     def bunzip2(self, size):
484         decompressor = bz2.BZ2Decompressor()
485         for chunk in self.readall(size):
486             data = decompressor.decompress(chunk)
487             if data and data != '':
488                 yield data
489     def gunzip(self, size):
490         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
491         for chunk in self.readall(size):
492             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
493             if data and data != '':
494                 yield data
495     def readall_decompressed(self, size=2**20):
496         self._stream.seek(self._pos + self._filepos)
497         if re.search('\.bz2$', self._name):
498             return self.bunzip2(size)
499         elif re.search('\.gz$', self._name):
500             return self.gunzip(size)
501         else:
502             return self.readall(size)
503     def readlines(self, decompress=True):
504         if decompress:
505             datasource = self.readall_decompressed()
506         else:
507             self._stream.seek(self._pos + self._filepos)
508             datasource = self.readall()
509         data = ''
510         for newdata in datasource:
511             data += newdata
512             sol = 0
513             while True:
514                 eol = string.find(data, "\n", sol)
515                 if eol < 0:
516                     break
517                 yield data[sol:eol+1]
518                 sol = eol+1
519             data = data[sol:]
520         if data != '':
521             yield data
522     def as_manifest(self):
523         if self.size() == 0:
524             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
525                     % (self._stream.name(), self.name()))
526         return string.join(self._stream.tokens_for_range(self._pos, self._size),
527                            " ") + "\n"
528
529 class StreamReader:
530     def __init__(self, tokens):
531         self._tokens = tokens
532         self._current_datablock_data = None
533         self._current_datablock_pos = 0
534         self._current_datablock_index = -1
535         self._pos = 0
536
537         self._stream_name = None
538         self.data_locators = []
539         self.files = []
540
541         for tok in self._tokens:
542             if self._stream_name == None:
543                 self._stream_name = tok
544             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
545                 self.data_locators += [tok]
546             elif re.search(r'^\d+:\d+:\S+', tok):
547                 pos, size, name = tok.split(':',2)
548                 self.files += [[int(pos), int(size), name]]
549             else:
550                 raise Exception("Invalid manifest format")
551
552     def tokens(self):
553         return self._tokens
554     def tokens_for_range(self, range_start, range_size):
555         resp = [self._stream_name]
556         return_all_tokens = False
557         block_start = 0
558         token_bytes_skipped = 0
559         for locator in self.data_locators:
560             sizehint = re.search(r'\+(\d+)', locator)
561             if not sizehint:
562                 return_all_tokens = True
563             if return_all_tokens:
564                 resp += [locator]
565                 next
566             blocksize = int(sizehint.group(0))
567             if range_start + range_size <= block_start:
568                 break
569             if range_start < block_start + blocksize:
570                 resp += [locator]
571             else:
572                 token_bytes_skipped += blocksize
573             block_start += blocksize
574         for f in self.files:
575             if ((f[0] < range_start + range_size)
576                 and
577                 (f[0] + f[1] > range_start)
578                 and
579                 f[1] > 0):
580                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
581         return resp
582     def name(self):
583         return self._stream_name
584     def all_files(self):
585         for f in self.files:
586             pos, size, name = f
587             yield StreamFileReader(self, pos, size, name)
588     def nextdatablock(self):
589         if self._current_datablock_index < 0:
590             self._current_datablock_pos = 0
591             self._current_datablock_index = 0
592         else:
593             self._current_datablock_pos += self.current_datablock_size()
594             self._current_datablock_index += 1
595         self._current_datablock_data = None
596     def current_datablock_data(self):
597         if self._current_datablock_data == None:
598             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
599         return self._current_datablock_data
600     def current_datablock_size(self):
601         if self._current_datablock_index < 0:
602             self.nextdatablock()
603         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
604         if sizehint:
605             return int(sizehint.group(0))
606         return len(self.current_datablock_data())
607     def seek(self, pos):
608         """Set the position of the next read operation."""
609         self._pos = pos
610     def really_seek(self):
611         """Find and load the appropriate data block, so the byte at
612         _pos is in memory.
613         """
614         if self._pos == self._current_datablock_pos:
615             return True
616         if (self._current_datablock_pos != None and
617             self._pos >= self._current_datablock_pos and
618             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
619             return True
620         if self._pos < self._current_datablock_pos:
621             self._current_datablock_index = -1
622             self.nextdatablock()
623         while (self._pos > self._current_datablock_pos and
624                self._pos > self._current_datablock_pos + self.current_datablock_size()):
625             self.nextdatablock()
626     def read(self, size):
627         """Read no more than size bytes -- but at least one byte,
628         unless _pos is already at the end of the stream.
629         """
630         if size == 0:
631             return ''
632         self.really_seek()
633         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
634             self.nextdatablock()
635             if self._current_datablock_index >= len(self.data_locators):
636                 return None
637         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
638         self._pos += len(data)
639         return data
640
641 class CollectionReader:
642     def __init__(self, manifest_locator_or_text):
643         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
644             self._manifest_text = manifest_locator_or_text
645             self._manifest_locator = None
646         else:
647             self._manifest_locator = manifest_locator_or_text
648             self._manifest_text = None
649         self._streams = None
650     def __enter__(self):
651         pass
652     def __exit__(self):
653         pass
654     def _populate(self):
655         if self._streams != None:
656             return
657         if not self._manifest_text:
658             self._manifest_text = Keep.get(self._manifest_locator)
659         self._streams = []
660         for stream_line in self._manifest_text.split("\n"):
661             if stream_line != '':
662                 stream_tokens = stream_line.split()
663                 self._streams += [stream_tokens]
664     def all_streams(self):
665         self._populate()
666         resp = []
667         for s in self._streams:
668             resp += [StreamReader(s)]
669         return resp
670     def all_files(self):
671         for s in self.all_streams():
672             for f in s.all_files():
673                 yield f
674     def manifest_text(self):
675         self._populate()
676         return self._manifest_text
677
678 class CollectionWriter:
679     KEEP_BLOCK_SIZE = 2**26
680     def __init__(self):
681         self._data_buffer = []
682         self._data_buffer_len = 0
683         self._current_stream_files = []
684         self._current_stream_length = 0
685         self._current_stream_locators = []
686         self._current_stream_name = '.'
687         self._current_file_name = None
688         self._current_file_pos = 0
689         self._finished_streams = []
690     def __enter__(self):
691         pass
692     def __exit__(self):
693         self.finish()
694     def write_directory_tree(self,
695                              path, stream_name='.', max_manifest_depth=-1):
696         self.start_new_stream(stream_name)
697         todo = []
698         if max_manifest_depth == 0:
699             dirents = util.listdir_recursive(path)
700         else:
701             dirents = sorted(os.listdir(path))
702         for dirent in dirents:
703             target = os.path.join(path, dirent)
704             if os.path.isdir(target):
705                 todo += [[target,
706                           os.path.join(stream_name, dirent),
707                           max_manifest_depth-1]]
708             else:
709                 self.start_new_file(dirent)
710                 with open(target, 'rb') as f:
711                     while True:
712                         buf = f.read(2**26)
713                         if len(buf) == 0:
714                             break
715                         self.write(buf)
716         self.finish_current_stream()
717         map(lambda x: self.write_directory_tree(*x), todo)
718
719     def write(self, newdata):
720         self._data_buffer += [newdata]
721         self._data_buffer_len += len(newdata)
722         self._current_stream_length += len(newdata)
723         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
724             self.flush_data()
725     def flush_data(self):
726         data_buffer = ''.join(self._data_buffer)
727         if data_buffer != '':
728             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
729             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
730             self._data_buffer_len = len(self._data_buffer[0])
731     def start_new_file(self, newfilename=None):
732         self.finish_current_file()
733         self.set_current_file_name(newfilename)
734     def set_current_file_name(self, newfilename):
735         newfilename = re.sub(r' ', '\\\\040', newfilename)
736         if re.search(r'[ \t\n]', newfilename):
737             raise AssertionError("Manifest filenames cannot contain whitespace")
738         self._current_file_name = newfilename
739     def current_file_name(self):
740         return self._current_file_name
741     def finish_current_file(self):
742         if self._current_file_name == None:
743             if self._current_file_pos == self._current_stream_length:
744                 return
745             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
746         self._current_stream_files += [[self._current_file_pos,
747                                        self._current_stream_length - self._current_file_pos,
748                                        self._current_file_name]]
749         self._current_file_pos = self._current_stream_length
750     def start_new_stream(self, newstreamname='.'):
751         self.finish_current_stream()
752         self.set_current_stream_name(newstreamname)
753     def set_current_stream_name(self, newstreamname):
754         if re.search(r'[ \t\n]', newstreamname):
755             raise AssertionError("Manifest stream names cannot contain whitespace")
756         self._current_stream_name = newstreamname
757     def current_stream_name(self):
758         return self._current_stream_name
759     def finish_current_stream(self):
760         self.finish_current_file()
761         self.flush_data()
762         if len(self._current_stream_files) == 0:
763             pass
764         elif self._current_stream_name == None:
765             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
766         else:
767             self._finished_streams += [[self._current_stream_name,
768                                        self._current_stream_locators,
769                                        self._current_stream_files]]
770         self._current_stream_files = []
771         self._current_stream_length = 0
772         self._current_stream_locators = []
773         self._current_stream_name = None
774         self._current_file_pos = 0
775         self._current_file_name = None
776     def finish(self):
777         return Keep.put(self.manifest_text())
778     def manifest_text(self):
779         self.finish_current_stream()
780         manifest = ''
781         for stream in self._finished_streams:
782             if not re.search(r'^\.(/.*)?$', stream[0]):
783                 manifest += './'
784             manifest += stream[0]
785             if len(stream[1]) == 0:
786                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
787             else:
788                 for locator in stream[1]:
789                     manifest += " %s" % locator
790             for sfile in stream[2]:
791                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
792             manifest += "\n"
793         return manifest
794
795 global_client_object = None
796
797 class Keep:
798     @staticmethod
799     def global_client_object():
800         global global_client_object
801         if global_client_object == None:
802             global_client_object = KeepClient()
803         return global_client_object
804
805     @staticmethod
806     def get(locator):
807         return Keep.global_client_object().get(locator)
808
809     @staticmethod
810     def put(data):
811         return Keep.global_client_object().put(data)
812
813 class KeepClient:
814     def __init__(self):
815         self.service_roots = None
816
817     def shuffled_service_roots(self, hash):
818         if self.service_roots == None:
819             keep_disks = api().keep_disks().list().execute()['items']
820             roots = (("http%s://%s:%d/" %
821                       ('s' if f['service_ssl_flag'] else '',
822                        f['service_host'],
823                        f['service_port']))
824                      for f in keep_disks)
825             self.service_roots = sorted(set(roots))
826             logging.debug(str(self.service_roots))
827         seed = hash
828         pool = self.service_roots[:]
829         pseq = []
830         while len(pool) > 0:
831             if len(seed) < 8:
832                 if len(pseq) < len(hash) / 4: # first time around
833                     seed = hash[-4:] + hash
834                 else:
835                     seed += hash
836             probe = int(seed[0:8], 16) % len(pool)
837             pseq += [pool[probe]]
838             pool = pool[:probe] + pool[probe+1:]
839             seed = seed[8:]
840         logging.debug(str(pseq))
841         return pseq
842
843     def get(self, locator):
844         if 'KEEP_LOCAL_STORE' in os.environ:
845             return KeepClient.local_store_get(locator)
846         expect_hash = re.sub(r'\+.*', '', locator)
847         for service_root in self.shuffled_service_roots(expect_hash):
848             h = httplib2.Http()
849             url = service_root + expect_hash
850             api_token = os.environ['ARVADOS_API_TOKEN']
851             headers = {'Authorization': "OAuth2 %s" % api_token,
852                        'Accept': 'application/octet-stream'}
853             try:
854                 resp, content = h.request(url, 'GET', headers=headers)
855                 if re.match(r'^2\d\d$', resp['status']):
856                     m = hashlib.new('md5')
857                     m.update(content)
858                     md5 = m.hexdigest()
859                     if md5 == expect_hash:
860                         return content
861                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
862             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
863                 logging.info("Request fail: GET %s => %s: %s" %
864                              (url, type(e), str(e)))
865         raise Exception("Not found: %s" % expect_hash)
866
867     def put(self, data, **kwargs):
868         if 'KEEP_LOCAL_STORE' in os.environ:
869             return KeepClient.local_store_put(data)
870         m = hashlib.new('md5')
871         m.update(data)
872         data_hash = m.hexdigest()
873         have_copies = 0
874         want_copies = kwargs.get('copies', 2)
875         for service_root in self.shuffled_service_roots(data_hash):
876             h = httplib2.Http()
877             url = service_root + data_hash
878             api_token = os.environ['ARVADOS_API_TOKEN']
879             headers = {'Authorization': "OAuth2 %s" % api_token}
880             try:
881                 resp, content = h.request(url, 'PUT',
882                                           headers=headers,
883                                           body=data)
884                 if (resp['status'] == '401' and
885                     re.match(r'Timestamp verification failed', content)):
886                     body = self.sign_for_old_server(data_hash, data)
887                     h = httplib2.Http()
888                     resp, content = h.request(url, 'PUT',
889                                               headers=headers,
890                                               body=body)
891                 if re.match(r'^2\d\d$', resp['status']):
892                     have_copies += 1
893                     if have_copies == want_copies:
894                         return data_hash + '+' + str(len(data))
895                 else:
896                     logging.warning("Request fail: PUT %s => %s %s" %
897                                     (url, resp['status'], content))
898             except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
899                 logging.warning("Request fail: PUT %s => %s: %s" %
900                                 (url, type(e), str(e)))
901         raise Exception("Write fail for %s: wanted %d but wrote %d" %
902                         (data_hash, want_copies, have_copies))
903
904     def sign_for_old_server(self, data_hash, data):
905         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
906
907
908     @staticmethod
909     def local_store_put(data):
910         m = hashlib.new('md5')
911         m.update(data)
912         md5 = m.hexdigest()
913         locator = '%s+%d' % (md5, len(data))
914         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
915             f.write(data)
916         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
917                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
918         return locator
919     @staticmethod
920     def local_store_get(locator):
921         r = re.search('^([0-9a-f]{32,})', locator)
922         if not r:
923             raise Exception("Keep.get: invalid data locator '%s'" % locator)
924         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
925             return ''
926         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
927             return f.read()