Merge branch '1495-crunch-documentation'
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
28
29 class errors:
30     class SyntaxError(Exception):
31         pass
32     class AssertionError(Exception):
33         pass
34     class NotFoundError(Exception):
35         pass
36     class CommandFailedError(Exception):
37         pass
38     class KeepWriteError(Exception):
39         pass
40     class NotImplementedError(Exception):
41         pass
42
43 class CredentialsFromEnv(object):
44     @staticmethod
45     def http_request(self, uri, **kwargs):
46         from httplib import BadStatusLine
47         if 'headers' not in kwargs:
48             kwargs['headers'] = {}
49         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
50         try:
51             return self.orig_http_request(uri, **kwargs)
52         except BadStatusLine:
53             # This is how httplib tells us that it tried to reuse an
54             # existing connection but it was already closed by the
55             # server. In that case, yes, we would like to retry.
56             # Unfortunately, we are not absolutely certain that the
57             # previous call did not succeed, so this is slightly
58             # risky.
59             return self.orig_http_request(uri, **kwargs)
60     def authorize(self, http):
61         http.orig_http_request = http.request
62         http.request = types.MethodType(self.http_request, http)
63         return http
64
65 url = ('https://%s:%s/discovery/v1/apis/'
66        '{api}/{apiVersion}/rest' %
67            (os.environ['ARVADOS_API_HOST'],
68             os.environ.get('ARVADOS_API_PORT') or "443"))
69 credentials = CredentialsFromEnv()
70
71 # Use system's CA certificates (if we find them) instead of httplib2's
72 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
73 if not os.path.exists(ca_certs):
74     ca_certs = None             # use httplib2 default
75
76 http = httplib2.Http(ca_certs=ca_certs)
77 http = credentials.authorize(http)
78 if re.match(r'(?i)^(true|1|yes)$',
79             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
80     http.disable_ssl_certificate_validation=True
81 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
82
83 def task_set_output(self,s):
84     service.job_tasks().update(uuid=self['uuid'],
85                                body={
86             'output':s,
87             'success':True,
88             'progress':1.0
89             }).execute()
90
91 _current_task = None
92 def current_task():
93     global _current_task
94     if _current_task:
95         return _current_task
96     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
97     t = UserDict.UserDict(t)
98     t.set_output = types.MethodType(task_set_output, t)
99     t.tmpdir = os.environ['TASK_WORK']
100     _current_task = t
101     return t
102
103 _current_job = None
104 def current_job():
105     global _current_job
106     if _current_job:
107         return _current_job
108     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
109     t = UserDict.UserDict(t)
110     t.tmpdir = os.environ['JOB_WORK']
111     _current_job = t
112     return t
113
114 def getjobparam(*args):
115     return current_job()['script_parameters'].get(*args)
116
117 def api():
118     return service
119
120 class JobTask(object):
121     def __init__(self, parameters=dict(), runtime_constraints=dict()):
122         print "init jobtask %s %s" % (parameters, runtime_constraints)
123
124 class job_setup:
125     @staticmethod
126     def one_task_per_input_file(if_sequence=0, and_end_task=True):
127         if if_sequence != current_task()['sequence']:
128             return
129         job_input = current_job()['script_parameters']['input']
130         cr = CollectionReader(job_input)
131         for s in cr.all_streams():
132             for f in s.all_files():
133                 task_input = f.as_manifest()
134                 new_task_attrs = {
135                     'job_uuid': current_job()['uuid'],
136                     'created_by_job_task_uuid': current_task()['uuid'],
137                     'sequence': if_sequence + 1,
138                     'parameters': {
139                         'input':task_input
140                         }
141                     }
142                 service.job_tasks().create(body=new_task_attrs).execute()
143         if and_end_task:
144             service.job_tasks().update(uuid=current_task()['uuid'],
145                                        body={'success':True}
146                                        ).execute()
147             exit(0)
148
149     @staticmethod
150     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
151         if if_sequence != current_task()['sequence']:
152             return
153         job_input = current_job()['script_parameters']['input']
154         cr = CollectionReader(job_input)
155         for s in cr.all_streams():
156             task_input = s.tokens()
157             new_task_attrs = {
158                 'job_uuid': current_job()['uuid'],
159                 'created_by_job_task_uuid': current_task()['uuid'],
160                 'sequence': if_sequence + 1,
161                 'parameters': {
162                     'input':task_input
163                     }
164                 }
165             service.job_tasks().create(body=new_task_attrs).execute()
166         if and_end_task:
167             service.job_tasks().update(uuid=current_task()['uuid'],
168                                        body={'success':True}
169                                        ).execute()
170             exit(0)
171
172 class util:
173     @staticmethod
174     def clear_tmpdir(path=None):
175         """
176         Ensure the given directory (or TASK_TMPDIR if none given)
177         exists and is empty.
178         """
179         if path == None:
180             path = current_task().tmpdir
181         if os.path.exists(path):
182             p = subprocess.Popen(['rm', '-rf', path])
183             stdout, stderr = p.communicate(None)
184             if p.returncode != 0:
185                 raise Exception('rm -rf %s: %s' % (path, stderr))
186         os.mkdir(path)
187
188     @staticmethod
189     def run_command(execargs, **kwargs):
190         kwargs.setdefault('stdin', subprocess.PIPE)
191         kwargs.setdefault('stdout', subprocess.PIPE)
192         kwargs.setdefault('stderr', sys.stderr)
193         kwargs.setdefault('close_fds', True)
194         kwargs.setdefault('shell', False)
195         p = subprocess.Popen(execargs, **kwargs)
196         stdoutdata, stderrdata = p.communicate(None)
197         if p.returncode != 0:
198             raise errors.CommandFailedError(
199                 "run_command %s exit %d:\n%s" %
200                 (execargs, p.returncode, stderrdata))
201         return stdoutdata, stderrdata
202
203     @staticmethod
204     def git_checkout(url, version, path):
205         if not re.search('^/', path):
206             path = os.path.join(current_job().tmpdir, path)
207         if not os.path.exists(path):
208             util.run_command(["git", "clone", url, path],
209                              cwd=os.path.dirname(path))
210         util.run_command(["git", "checkout", version],
211                          cwd=path)
212         return path
213
214     @staticmethod
215     def tar_extractor(path, decompress_flag):
216         return subprocess.Popen(["tar",
217                                  "-C", path,
218                                  ("-x%sf" % decompress_flag),
219                                  "-"],
220                                 stdout=None,
221                                 stdin=subprocess.PIPE, stderr=sys.stderr,
222                                 shell=False, close_fds=True)
223
224     @staticmethod
225     def tarball_extract(tarball, path):
226         """Retrieve a tarball from Keep and extract it to a local
227         directory.  Return the absolute path where the tarball was
228         extracted. If the top level of the tarball contained just one
229         file or directory, return the absolute path of that single
230         item.
231
232         tarball -- collection locator
233         path -- where to extract the tarball: absolute, or relative to job tmp
234         """
235         if not re.search('^/', path):
236             path = os.path.join(current_job().tmpdir, path)
237         lockfile = open(path + '.lock', 'w')
238         fcntl.flock(lockfile, fcntl.LOCK_EX)
239         try:
240             os.stat(path)
241         except OSError:
242             os.mkdir(path)
243         already_have_it = False
244         try:
245             if os.readlink(os.path.join(path, '.locator')) == tarball:
246                 already_have_it = True
247         except OSError:
248             pass
249         if not already_have_it:
250
251             # emulate "rm -f" (i.e., if the file does not exist, we win)
252             try:
253                 os.unlink(os.path.join(path, '.locator'))
254             except OSError:
255                 if os.path.exists(os.path.join(path, '.locator')):
256                     os.unlink(os.path.join(path, '.locator'))
257
258             for f in CollectionReader(tarball).all_files():
259                 if re.search('\.(tbz|tar.bz2)$', f.name()):
260                     p = util.tar_extractor(path, 'j')
261                 elif re.search('\.(tgz|tar.gz)$', f.name()):
262                     p = util.tar_extractor(path, 'z')
263                 elif re.search('\.tar$', f.name()):
264                     p = util.tar_extractor(path, '')
265                 else:
266                     raise errors.AssertionError(
267                         "tarball_extract cannot handle filename %s" % f.name())
268                 while True:
269                     buf = f.read(2**20)
270                     if len(buf) == 0:
271                         break
272                     p.stdin.write(buf)
273                 p.stdin.close()
274                 p.wait()
275                 if p.returncode != 0:
276                     lockfile.close()
277                     raise errors.CommandFailedError(
278                         "tar exited %d" % p.returncode)
279             os.symlink(tarball, os.path.join(path, '.locator'))
280         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
281         lockfile.close()
282         if len(tld_extracts) == 1:
283             return os.path.join(path, tld_extracts[0])
284         return path
285
286     @staticmethod
287     def zipball_extract(zipball, path):
288         """Retrieve a zip archive from Keep and extract it to a local
289         directory.  Return the absolute path where the archive was
290         extracted. If the top level of the archive contained just one
291         file or directory, return the absolute path of that single
292         item.
293
294         zipball -- collection locator
295         path -- where to extract the archive: absolute, or relative to job tmp
296         """
297         if not re.search('^/', path):
298             path = os.path.join(current_job().tmpdir, path)
299         lockfile = open(path + '.lock', 'w')
300         fcntl.flock(lockfile, fcntl.LOCK_EX)
301         try:
302             os.stat(path)
303         except OSError:
304             os.mkdir(path)
305         already_have_it = False
306         try:
307             if os.readlink(os.path.join(path, '.locator')) == zipball:
308                 already_have_it = True
309         except OSError:
310             pass
311         if not already_have_it:
312
313             # emulate "rm -f" (i.e., if the file does not exist, we win)
314             try:
315                 os.unlink(os.path.join(path, '.locator'))
316             except OSError:
317                 if os.path.exists(os.path.join(path, '.locator')):
318                     os.unlink(os.path.join(path, '.locator'))
319
320             for f in CollectionReader(zipball).all_files():
321                 if not re.search('\.zip$', f.name()):
322                     raise errors.NotImplementedError(
323                         "zipball_extract cannot handle filename %s" % f.name())
324                 zip_filename = os.path.join(path, os.path.basename(f.name()))
325                 zip_file = open(zip_filename, 'wb')
326                 while True:
327                     buf = f.read(2**20)
328                     if len(buf) == 0:
329                         break
330                     zip_file.write(buf)
331                 zip_file.close()
332                 
333                 p = subprocess.Popen(["unzip",
334                                       "-q", "-o",
335                                       "-d", path,
336                                       zip_filename],
337                                      stdout=None,
338                                      stdin=None, stderr=sys.stderr,
339                                      shell=False, close_fds=True)
340                 p.wait()
341                 if p.returncode != 0:
342                     lockfile.close()
343                     raise errors.CommandFailedError(
344                         "unzip exited %d" % p.returncode)
345                 os.unlink(zip_filename)
346             os.symlink(zipball, os.path.join(path, '.locator'))
347         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
348         lockfile.close()
349         if len(tld_extracts) == 1:
350             return os.path.join(path, tld_extracts[0])
351         return path
352
353     @staticmethod
354     def collection_extract(collection, path, files=[], decompress=True):
355         """Retrieve a collection from Keep and extract it to a local
356         directory.  Return the absolute path where the collection was
357         extracted.
358
359         collection -- collection locator
360         path -- where to extract: absolute, or relative to job tmp
361         """
362         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
363         if matches:
364             collection_hash = matches.group(1)
365         else:
366             collection_hash = hashlib.md5(collection).hexdigest()
367         if not re.search('^/', path):
368             path = os.path.join(current_job().tmpdir, path)
369         lockfile = open(path + '.lock', 'w')
370         fcntl.flock(lockfile, fcntl.LOCK_EX)
371         try:
372             os.stat(path)
373         except OSError:
374             os.mkdir(path)
375         already_have_it = False
376         try:
377             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
378                 already_have_it = True
379         except OSError:
380             pass
381
382         # emulate "rm -f" (i.e., if the file does not exist, we win)
383         try:
384             os.unlink(os.path.join(path, '.locator'))
385         except OSError:
386             if os.path.exists(os.path.join(path, '.locator')):
387                 os.unlink(os.path.join(path, '.locator'))
388
389         files_got = []
390         for s in CollectionReader(collection).all_streams():
391             stream_name = s.name()
392             for f in s.all_files():
393                 if (files == [] or
394                     ((f.name() not in files_got) and
395                      (f.name() in files or
396                       (decompress and f.decompressed_name() in files)))):
397                     outname = f.decompressed_name() if decompress else f.name()
398                     files_got += [outname]
399                     if os.path.exists(os.path.join(path, stream_name, outname)):
400                         continue
401                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
402                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
403                     for buf in (f.readall_decompressed() if decompress
404                                 else f.readall()):
405                         outfile.write(buf)
406                     outfile.close()
407         if len(files_got) < len(files):
408             raise errors.AssertionError(
409                 "Wanted files %s but only got %s from %s" %
410                 (files, files_got,
411                  [z.name() for z in CollectionReader(collection).all_files()]))
412         os.symlink(collection_hash, os.path.join(path, '.locator'))
413
414         lockfile.close()
415         return path
416
417     @staticmethod
418     def mkdir_dash_p(path):
419         if not os.path.exists(path):
420             util.mkdir_dash_p(os.path.dirname(path))
421             try:
422                 os.mkdir(path)
423             except OSError:
424                 if not os.path.exists(path):
425                     os.mkdir(path)
426
427     @staticmethod
428     def stream_extract(stream, path, files=[], decompress=True):
429         """Retrieve a stream from Keep and extract it to a local
430         directory.  Return the absolute path where the stream was
431         extracted.
432
433         stream -- StreamReader object
434         path -- where to extract: absolute, or relative to job tmp
435         """
436         if not re.search('^/', path):
437             path = os.path.join(current_job().tmpdir, path)
438         lockfile = open(path + '.lock', 'w')
439         fcntl.flock(lockfile, fcntl.LOCK_EX)
440         try:
441             os.stat(path)
442         except OSError:
443             os.mkdir(path)
444
445         files_got = []
446         for f in stream.all_files():
447             if (files == [] or
448                 ((f.name() not in files_got) and
449                  (f.name() in files or
450                   (decompress and f.decompressed_name() in files)))):
451                 outname = f.decompressed_name() if decompress else f.name()
452                 files_got += [outname]
453                 if os.path.exists(os.path.join(path, outname)):
454                     os.unlink(os.path.join(path, outname))
455                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
456                 outfile = open(os.path.join(path, outname), 'wb')
457                 for buf in (f.readall_decompressed() if decompress
458                             else f.readall()):
459                     outfile.write(buf)
460                 outfile.close()
461         if len(files_got) < len(files):
462             raise errors.AssertionError(
463                 "Wanted files %s but only got %s from %s" %
464                 (files, files_got, [z.name() for z in stream.all_files()]))
465         lockfile.close()
466         return path
467
468     @staticmethod
469     def listdir_recursive(dirname, base=None):
470         allfiles = []
471         for ent in sorted(os.listdir(dirname)):
472             ent_path = os.path.join(dirname, ent)
473             ent_base = os.path.join(base, ent) if base else ent
474             if os.path.isdir(ent_path):
475                 allfiles += util.listdir_recursive(ent_path, ent_base)
476             else:
477                 allfiles += [ent_base]
478         return allfiles
479
480 class StreamFileReader(object):
481     def __init__(self, stream, pos, size, name):
482         self._stream = stream
483         self._pos = pos
484         self._size = size
485         self._name = name
486         self._filepos = 0
487     def name(self):
488         return self._name
489     def decompressed_name(self):
490         return re.sub('\.(bz2|gz)$', '', self._name)
491     def size(self):
492         return self._size
493     def stream_name(self):
494         return self._stream.name()
495     def read(self, size, **kwargs):
496         self._stream.seek(self._pos + self._filepos)
497         data = self._stream.read(min(size, self._size - self._filepos))
498         self._filepos += len(data)
499         return data
500     def readall(self, size=2**20, **kwargs):
501         while True:
502             data = self.read(size, **kwargs)
503             if data == '':
504                 break
505             yield data
506     def bunzip2(self, size):
507         decompressor = bz2.BZ2Decompressor()
508         for chunk in self.readall(size):
509             data = decompressor.decompress(chunk)
510             if data and data != '':
511                 yield data
512     def gunzip(self, size):
513         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
514         for chunk in self.readall(size):
515             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
516             if data and data != '':
517                 yield data
518     def readall_decompressed(self, size=2**20):
519         self._stream.seek(self._pos + self._filepos)
520         if re.search('\.bz2$', self._name):
521             return self.bunzip2(size)
522         elif re.search('\.gz$', self._name):
523             return self.gunzip(size)
524         else:
525             return self.readall(size)
526     def readlines(self, decompress=True):
527         if decompress:
528             datasource = self.readall_decompressed()
529         else:
530             self._stream.seek(self._pos + self._filepos)
531             datasource = self.readall()
532         data = ''
533         for newdata in datasource:
534             data += newdata
535             sol = 0
536             while True:
537                 eol = string.find(data, "\n", sol)
538                 if eol < 0:
539                     break
540                 yield data[sol:eol+1]
541                 sol = eol+1
542             data = data[sol:]
543         if data != '':
544             yield data
545     def as_manifest(self):
546         if self.size() == 0:
547             return ("%s %s 0:0:%s\n"
548                     % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
549         return string.join(self._stream.tokens_for_range(self._pos, self._size),
550                            " ") + "\n"
551
552 class StreamReader(object):
553     def __init__(self, tokens):
554         self._tokens = tokens
555         self._current_datablock_data = None
556         self._current_datablock_pos = 0
557         self._current_datablock_index = -1
558         self._pos = 0
559
560         self._stream_name = None
561         self.data_locators = []
562         self.files = []
563
564         for tok in self._tokens:
565             if self._stream_name == None:
566                 self._stream_name = tok
567             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
568                 self.data_locators += [tok]
569             elif re.search(r'^\d+:\d+:\S+', tok):
570                 pos, size, name = tok.split(':',2)
571                 self.files += [[int(pos), int(size), name]]
572             else:
573                 raise errors.SyntaxError("Invalid manifest format")
574
575     def tokens(self):
576         return self._tokens
577     def tokens_for_range(self, range_start, range_size):
578         resp = [self._stream_name]
579         return_all_tokens = False
580         block_start = 0
581         token_bytes_skipped = 0
582         for locator in self.data_locators:
583             sizehint = re.search(r'\+(\d+)', locator)
584             if not sizehint:
585                 return_all_tokens = True
586             if return_all_tokens:
587                 resp += [locator]
588                 next
589             blocksize = int(sizehint.group(0))
590             if range_start + range_size <= block_start:
591                 break
592             if range_start < block_start + blocksize:
593                 resp += [locator]
594             else:
595                 token_bytes_skipped += blocksize
596             block_start += blocksize
597         for f in self.files:
598             if ((f[0] < range_start + range_size)
599                 and
600                 (f[0] + f[1] > range_start)
601                 and
602                 f[1] > 0):
603                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
604         return resp
605     def name(self):
606         return self._stream_name
607     def all_files(self):
608         for f in self.files:
609             pos, size, name = f
610             yield StreamFileReader(self, pos, size, name)
611     def nextdatablock(self):
612         if self._current_datablock_index < 0:
613             self._current_datablock_pos = 0
614             self._current_datablock_index = 0
615         else:
616             self._current_datablock_pos += self.current_datablock_size()
617             self._current_datablock_index += 1
618         self._current_datablock_data = None
619     def current_datablock_data(self):
620         if self._current_datablock_data == None:
621             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
622         return self._current_datablock_data
623     def current_datablock_size(self):
624         if self._current_datablock_index < 0:
625             self.nextdatablock()
626         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
627         if sizehint:
628             return int(sizehint.group(0))
629         return len(self.current_datablock_data())
630     def seek(self, pos):
631         """Set the position of the next read operation."""
632         self._pos = pos
633     def really_seek(self):
634         """Find and load the appropriate data block, so the byte at
635         _pos is in memory.
636         """
637         if self._pos == self._current_datablock_pos:
638             return True
639         if (self._current_datablock_pos != None and
640             self._pos >= self._current_datablock_pos and
641             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
642             return True
643         if self._pos < self._current_datablock_pos:
644             self._current_datablock_index = -1
645             self.nextdatablock()
646         while (self._pos > self._current_datablock_pos and
647                self._pos > self._current_datablock_pos + self.current_datablock_size()):
648             self.nextdatablock()
649     def read(self, size):
650         """Read no more than size bytes -- but at least one byte,
651         unless _pos is already at the end of the stream.
652         """
653         if size == 0:
654             return ''
655         self.really_seek()
656         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
657             self.nextdatablock()
658             if self._current_datablock_index >= len(self.data_locators):
659                 return None
660         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
661         self._pos += len(data)
662         return data
663
664 class CollectionReader(object):
665     def __init__(self, manifest_locator_or_text):
666         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
667             self._manifest_text = manifest_locator_or_text
668             self._manifest_locator = None
669         else:
670             self._manifest_locator = manifest_locator_or_text
671             self._manifest_text = None
672         self._streams = None
673     def __enter__(self):
674         pass
675     def __exit__(self):
676         pass
677     def _populate(self):
678         if self._streams != None:
679             return
680         if not self._manifest_text:
681             self._manifest_text = Keep.get(self._manifest_locator)
682         self._streams = []
683         for stream_line in self._manifest_text.split("\n"):
684             if stream_line != '':
685                 stream_tokens = stream_line.split()
686                 self._streams += [stream_tokens]
687     def all_streams(self):
688         self._populate()
689         resp = []
690         for s in self._streams:
691             resp += [StreamReader(s)]
692         return resp
693     def all_files(self):
694         for s in self.all_streams():
695             for f in s.all_files():
696                 yield f
697     def manifest_text(self):
698         self._populate()
699         return self._manifest_text
700
701 class CollectionWriter(object):
702     KEEP_BLOCK_SIZE = 2**26
703     def __init__(self):
704         self._data_buffer = []
705         self._data_buffer_len = 0
706         self._current_stream_files = []
707         self._current_stream_length = 0
708         self._current_stream_locators = []
709         self._current_stream_name = '.'
710         self._current_file_name = None
711         self._current_file_pos = 0
712         self._finished_streams = []
713     def __enter__(self):
714         pass
715     def __exit__(self):
716         self.finish()
717     def write_directory_tree(self,
718                              path, stream_name='.', max_manifest_depth=-1):
719         self.start_new_stream(stream_name)
720         todo = []
721         if max_manifest_depth == 0:
722             dirents = sorted(util.listdir_recursive(path))
723         else:
724             dirents = sorted(os.listdir(path))
725         for dirent in dirents:
726             target = os.path.join(path, dirent)
727             if os.path.isdir(target):
728                 todo += [[target,
729                           os.path.join(stream_name, dirent),
730                           max_manifest_depth-1]]
731             else:
732                 self.start_new_file(dirent)
733                 with open(target, 'rb') as f:
734                     while True:
735                         buf = f.read(2**26)
736                         if len(buf) == 0:
737                             break
738                         self.write(buf)
739         self.finish_current_stream()
740         map(lambda x: self.write_directory_tree(*x), todo)
741
742     def write(self, newdata):
743         if hasattr(newdata, '__iter__'):
744             for s in newdata:
745                 self.write(s)
746             return
747         self._data_buffer += [newdata]
748         self._data_buffer_len += len(newdata)
749         self._current_stream_length += len(newdata)
750         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
751             self.flush_data()
752     def flush_data(self):
753         data_buffer = ''.join(self._data_buffer)
754         if data_buffer != '':
755             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
756             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
757             self._data_buffer_len = len(self._data_buffer[0])
758     def start_new_file(self, newfilename=None):
759         self.finish_current_file()
760         self.set_current_file_name(newfilename)
761     def set_current_file_name(self, newfilename):
762         newfilename = re.sub(r' ', '\\\\040', newfilename)
763         if re.search(r'[ \t\n]', newfilename):
764             raise errors.AssertionError(
765                 "Manifest filenames cannot contain whitespace: %s" %
766                 newfilename)
767         self._current_file_name = newfilename
768     def current_file_name(self):
769         return self._current_file_name
770     def finish_current_file(self):
771         if self._current_file_name == None:
772             if self._current_file_pos == self._current_stream_length:
773                 return
774             raise errors.AssertionError(
775                 "Cannot finish an unnamed file " +
776                 "(%d bytes at offset %d in '%s' stream)" %
777                 (self._current_stream_length - self._current_file_pos,
778                  self._current_file_pos,
779                  self._current_stream_name))
780         self._current_stream_files += [[self._current_file_pos,
781                                        self._current_stream_length - self._current_file_pos,
782                                        self._current_file_name]]
783         self._current_file_pos = self._current_stream_length
784     def start_new_stream(self, newstreamname='.'):
785         self.finish_current_stream()
786         self.set_current_stream_name(newstreamname)
787     def set_current_stream_name(self, newstreamname):
788         if re.search(r'[ \t\n]', newstreamname):
789             raise errors.AssertionError(
790                 "Manifest stream names cannot contain whitespace")
791         self._current_stream_name = '.' if newstreamname=='' else newstreamname
792     def current_stream_name(self):
793         return self._current_stream_name
794     def finish_current_stream(self):
795         self.finish_current_file()
796         self.flush_data()
797         if len(self._current_stream_files) == 0:
798             pass
799         elif self._current_stream_name == None:
800             raise errors.AssertionError(
801                 "Cannot finish an unnamed stream (%d bytes in %d files)" %
802                 (self._current_stream_length, len(self._current_stream_files)))
803         else:
804             if len(self._current_stream_locators) == 0:
805                 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
806             self._finished_streams += [[self._current_stream_name,
807                                        self._current_stream_locators,
808                                        self._current_stream_files]]
809         self._current_stream_files = []
810         self._current_stream_length = 0
811         self._current_stream_locators = []
812         self._current_stream_name = None
813         self._current_file_pos = 0
814         self._current_file_name = None
815     def finish(self):
816         return Keep.put(self.manifest_text())
817     def manifest_text(self):
818         self.finish_current_stream()
819         manifest = ''
820         for stream in self._finished_streams:
821             if not re.search(r'^\.(/.*)?$', stream[0]):
822                 manifest += './'
823             manifest += stream[0]
824             for locator in stream[1]:
825                 manifest += " %s" % locator
826             for sfile in stream[2]:
827                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
828             manifest += "\n"
829         return manifest
830     def data_locators(self):
831         ret = []
832         for name, locators, files in self._finished_streams:
833             ret += locators
834         return ret
835
836 global_client_object = None
837
838 class Keep:
839     @staticmethod
840     def global_client_object():
841         global global_client_object
842         if global_client_object == None:
843             global_client_object = KeepClient()
844         return global_client_object
845
846     @staticmethod
847     def get(locator, **kwargs):
848         return Keep.global_client_object().get(locator, **kwargs)
849
850     @staticmethod
851     def put(data, **kwargs):
852         return Keep.global_client_object().put(data, **kwargs)
853
854 class KeepClient(object):
855
856     class ThreadLimiter(object):
857         """
858         Limit the number of threads running at a given time to
859         {desired successes} minus {successes reported}. When successes
860         reported == desired, wake up the remaining threads and tell
861         them to quit.
862
863         Should be used in a "with" block.
864         """
865         def __init__(self, todo):
866             self._todo = todo
867             self._done = 0
868             self._todo_lock = threading.Semaphore(todo)
869             self._done_lock = threading.Lock()
870         def __enter__(self):
871             self._todo_lock.acquire()
872             return self
873         def __exit__(self, type, value, traceback):
874             self._todo_lock.release()
875         def shall_i_proceed(self):
876             """
877             Return true if the current thread should do stuff. Return
878             false if the current thread should just stop.
879             """
880             with self._done_lock:
881                 return (self._done < self._todo)
882         def increment_done(self):
883             """
884             Report that the current thread was successful.
885             """
886             with self._done_lock:
887                 self._done += 1
888         def done(self):
889             """
890             Return how many successes were reported.
891             """
892             with self._done_lock:
893                 return self._done
894
895     class KeepWriterThread(threading.Thread):
896         """
897         Write a blob of data to the given Keep server. Call
898         increment_done() of the given ThreadLimiter if the write
899         succeeds.
900         """
901         def __init__(self, **kwargs):
902             super(KeepClient.KeepWriterThread, self).__init__()
903             self.args = kwargs
904         def run(self):
905             with self.args['thread_limiter'] as limiter:
906                 if not limiter.shall_i_proceed():
907                     # My turn arrived, but the job has been done without
908                     # me.
909                     return
910                 logging.debug("KeepWriterThread %s proceeding %s %s" %
911                               (str(threading.current_thread()),
912                                self.args['data_hash'],
913                                self.args['service_root']))
914                 h = httplib2.Http()
915                 url = self.args['service_root'] + self.args['data_hash']
916                 api_token = os.environ['ARVADOS_API_TOKEN']
917                 headers = {'Authorization': "OAuth2 %s" % api_token}
918                 try:
919                     resp, content = h.request(url.encode('utf-8'), 'PUT',
920                                               headers=headers,
921                                               body=self.args['data'])
922                     if (resp['status'] == '401' and
923                         re.match(r'Timestamp verification failed', content)):
924                         body = KeepClient.sign_for_old_server(
925                             self.args['data_hash'],
926                             self.args['data'])
927                         h = httplib2.Http()
928                         resp, content = h.request(url.encode('utf-8'), 'PUT',
929                                                   headers=headers,
930                                                   body=body)
931                     if re.match(r'^2\d\d$', resp['status']):
932                         logging.debug("KeepWriterThread %s succeeded %s %s" %
933                                       (str(threading.current_thread()),
934                                        self.args['data_hash'],
935                                        self.args['service_root']))
936                         return limiter.increment_done()
937                     logging.warning("Request fail: PUT %s => %s %s" %
938                                     (url, resp['status'], content))
939                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
940                     logging.warning("Request fail: PUT %s => %s: %s" %
941                                     (url, type(e), str(e)))
942
943     def __init__(self):
944         self.lock = threading.Lock()
945         self.service_roots = None
946
947     def shuffled_service_roots(self, hash):
948         if self.service_roots == None:
949             self.lock.acquire()
950             keep_disks = api().keep_disks().list().execute()['items']
951             roots = (("http%s://%s:%d/" %
952                       ('s' if f['service_ssl_flag'] else '',
953                        f['service_host'],
954                        f['service_port']))
955                      for f in keep_disks)
956             self.service_roots = sorted(set(roots))
957             logging.debug(str(self.service_roots))
958             self.lock.release()
959         seed = hash
960         pool = self.service_roots[:]
961         pseq = []
962         while len(pool) > 0:
963             if len(seed) < 8:
964                 if len(pseq) < len(hash) / 4: # first time around
965                     seed = hash[-4:] + hash
966                 else:
967                     seed += hash
968             probe = int(seed[0:8], 16) % len(pool)
969             pseq += [pool[probe]]
970             pool = pool[:probe] + pool[probe+1:]
971             seed = seed[8:]
972         logging.debug(str(pseq))
973         return pseq
974
975     def get(self, locator):
976         if 'KEEP_LOCAL_STORE' in os.environ:
977             return KeepClient.local_store_get(locator)
978         expect_hash = re.sub(r'\+.*', '', locator)
979         for service_root in self.shuffled_service_roots(expect_hash):
980             h = httplib2.Http()
981             url = service_root + expect_hash
982             api_token = os.environ['ARVADOS_API_TOKEN']
983             headers = {'Authorization': "OAuth2 %s" % api_token,
984                        'Accept': 'application/octet-stream'}
985             try:
986                 resp, content = h.request(url.encode('utf-8'), 'GET',
987                                           headers=headers)
988                 if re.match(r'^2\d\d$', resp['status']):
989                     m = hashlib.new('md5')
990                     m.update(content)
991                     md5 = m.hexdigest()
992                     if md5 == expect_hash:
993                         return content
994                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
995             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
996                 logging.info("Request fail: GET %s => %s: %s" %
997                              (url, type(e), str(e)))
998         raise errors.NotFoundError("Block not found: %s" % expect_hash)
999
1000     def put(self, data, **kwargs):
1001         if 'KEEP_LOCAL_STORE' in os.environ:
1002             return KeepClient.local_store_put(data)
1003         m = hashlib.new('md5')
1004         m.update(data)
1005         data_hash = m.hexdigest()
1006         have_copies = 0
1007         want_copies = kwargs.get('copies', 2)
1008         if not (want_copies > 0):
1009             return data_hash
1010         threads = []
1011         thread_limiter = KeepClient.ThreadLimiter(want_copies)
1012         for service_root in self.shuffled_service_roots(data_hash):
1013             t = KeepClient.KeepWriterThread(data=data,
1014                                             data_hash=data_hash,
1015                                             service_root=service_root,
1016                                             thread_limiter=thread_limiter)
1017             t.start()
1018             threads += [t]
1019         for t in threads:
1020             t.join()
1021         have_copies = thread_limiter.done()
1022         if have_copies == want_copies:
1023             return (data_hash + '+' + str(len(data)))
1024         raise errors.KeepWriteError(
1025             "Write fail for %s: wanted %d but wrote %d" %
1026             (data_hash, want_copies, have_copies))
1027
1028     @staticmethod
1029     def sign_for_old_server(data_hash, data):
1030         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1031
1032
1033     @staticmethod
1034     def local_store_put(data):
1035         m = hashlib.new('md5')
1036         m.update(data)
1037         md5 = m.hexdigest()
1038         locator = '%s+%d' % (md5, len(data))
1039         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1040             f.write(data)
1041         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1042                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1043         return locator
1044     @staticmethod
1045     def local_store_get(locator):
1046         r = re.search('^([0-9a-f]{32,})', locator)
1047         if not r:
1048             raise errors.NotFoundError(
1049                 "Invalid data locator: '%s'" % locator)
1050         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1051             return ''
1052         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1053             return f.read()