ca91963a6f8e2250bc6ba3a65dd06e193a31c655
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 import apiclient
22 import apiclient.discovery
23
24 # Arvados configuration settings are taken from $HOME/.config/arvados.
25 # Environment variables override settings in the config file.
26 #
27 class ArvadosConfig(dict):
28     def __init__(self, config_file):
29         dict.__init__(self)
30         with open(config_file, "r") as f:
31             for config_line in f:
32                 var, val = config_line.rstrip().split('=', 2)
33                 self[var] = val
34         for var in os.environ:
35             if var.startswith('ARVADOS_'):
36                 self[var] = os.environ[var]
37
38
39 config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
40
41 if 'ARVADOS_DEBUG' in config:
42     logging.basicConfig(level=logging.DEBUG)
43
44 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
45
46 services = {}
47
48 class errors:
49     class SyntaxError(Exception):
50         pass
51     class AssertionError(Exception):
52         pass
53     class NotFoundError(Exception):
54         pass
55     class CommandFailedError(Exception):
56         pass
57     class KeepWriteError(Exception):
58         pass
59     class NotImplementedError(Exception):
60         pass
61
62 class CredentialsFromEnv(object):
63     @staticmethod
64     def http_request(self, uri, **kwargs):
65         global config
66         from httplib import BadStatusLine
67         if 'headers' not in kwargs:
68             kwargs['headers'] = {}
69         kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
70         try:
71             return self.orig_http_request(uri, **kwargs)
72         except BadStatusLine:
73             # This is how httplib tells us that it tried to reuse an
74             # existing connection but it was already closed by the
75             # server. In that case, yes, we would like to retry.
76             # Unfortunately, we are not absolutely certain that the
77             # previous call did not succeed, so this is slightly
78             # risky.
79             return self.orig_http_request(uri, **kwargs)
80     def authorize(self, http):
81         http.orig_http_request = http.request
82         http.request = types.MethodType(self.http_request, http)
83         return http
84
85 def task_set_output(self,s):
86     api('v1').job_tasks().update(uuid=self['uuid'],
87                                  body={
88             'output':s,
89             'success':True,
90             'progress':1.0
91             }).execute()
92
93 _current_task = None
94 def current_task():
95     global _current_task
96     if _current_task:
97         return _current_task
98     t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
99     t = UserDict.UserDict(t)
100     t.set_output = types.MethodType(task_set_output, t)
101     t.tmpdir = os.environ['TASK_WORK']
102     _current_task = t
103     return t
104
105 _current_job = None
106 def current_job():
107     global _current_job
108     if _current_job:
109         return _current_job
110     t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
111     t = UserDict.UserDict(t)
112     t.tmpdir = os.environ['JOB_WORK']
113     _current_job = t
114     return t
115
116 def getjobparam(*args):
117     return current_job()['script_parameters'].get(*args)
118
119 # Monkey patch discovery._cast() so objects and arrays get serialized
120 # with json.dumps() instead of str().
121 _cast_orig = apiclient.discovery._cast
122 def _cast_objects_too(value, schema_type):
123     global _cast_orig
124     if (type(value) != type('') and
125         (schema_type == 'object' or schema_type == 'array')):
126         return json.dumps(value)
127     else:
128         return _cast_orig(value, schema_type)
129 apiclient.discovery._cast = _cast_objects_too
130
131 def api(version=None):
132     global services, config
133     if not services.get(version):
134         apiVersion = version
135         if not version:
136             apiVersion = 'v1'
137             logging.info("Using default API version. " +
138                          "Call arvados.api('%s') instead." %
139                          apiVersion)
140         if 'ARVADOS_API_HOST' not in config:
141             raise Exception("ARVADOS_API_HOST is not set. Aborting.")
142         url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
143                config['ARVADOS_API_HOST'])
144         credentials = CredentialsFromEnv()
145
146         # Use system's CA certificates (if we find them) instead of httplib2's
147         ca_certs = '/etc/ssl/certs/ca-certificates.crt'
148         if not os.path.exists(ca_certs):
149             ca_certs = None             # use httplib2 default
150
151         http = httplib2.Http(ca_certs=ca_certs)
152         http = credentials.authorize(http)
153         if re.match(r'(?i)^(true|1|yes)$',
154                     config.get('ARVADOS_API_HOST_INSECURE', 'no')):
155             http.disable_ssl_certificate_validation=True
156         services[version] = apiclient.discovery.build(
157             'arvados', apiVersion, http=http, discoveryServiceUrl=url)
158     return services[version]
159
160 class JobTask(object):
161     def __init__(self, parameters=dict(), runtime_constraints=dict()):
162         print "init jobtask %s %s" % (parameters, runtime_constraints)
163
164 class job_setup:
165     @staticmethod
166     def one_task_per_input_file(if_sequence=0, and_end_task=True):
167         if if_sequence != current_task()['sequence']:
168             return
169         job_input = current_job()['script_parameters']['input']
170         cr = CollectionReader(job_input)
171         for s in cr.all_streams():
172             for f in s.all_files():
173                 task_input = f.as_manifest()
174                 new_task_attrs = {
175                     'job_uuid': current_job()['uuid'],
176                     'created_by_job_task_uuid': current_task()['uuid'],
177                     'sequence': if_sequence + 1,
178                     'parameters': {
179                         'input':task_input
180                         }
181                     }
182                 api('v1').job_tasks().create(body=new_task_attrs).execute()
183         if and_end_task:
184             api('v1').job_tasks().update(uuid=current_task()['uuid'],
185                                        body={'success':True}
186                                        ).execute()
187             exit(0)
188
189     @staticmethod
190     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
191         if if_sequence != current_task()['sequence']:
192             return
193         job_input = current_job()['script_parameters']['input']
194         cr = CollectionReader(job_input)
195         for s in cr.all_streams():
196             task_input = s.tokens()
197             new_task_attrs = {
198                 'job_uuid': current_job()['uuid'],
199                 'created_by_job_task_uuid': current_task()['uuid'],
200                 'sequence': if_sequence + 1,
201                 'parameters': {
202                     'input':task_input
203                     }
204                 }
205             api('v1').job_tasks().create(body=new_task_attrs).execute()
206         if and_end_task:
207             api('v1').job_tasks().update(uuid=current_task()['uuid'],
208                                        body={'success':True}
209                                        ).execute()
210             exit(0)
211
212 class util:
213     @staticmethod
214     def clear_tmpdir(path=None):
215         """
216         Ensure the given directory (or TASK_TMPDIR if none given)
217         exists and is empty.
218         """
219         if path == None:
220             path = current_task().tmpdir
221         if os.path.exists(path):
222             p = subprocess.Popen(['rm', '-rf', path])
223             stdout, stderr = p.communicate(None)
224             if p.returncode != 0:
225                 raise Exception('rm -rf %s: %s' % (path, stderr))
226         os.mkdir(path)
227
228     @staticmethod
229     def run_command(execargs, **kwargs):
230         kwargs.setdefault('stdin', subprocess.PIPE)
231         kwargs.setdefault('stdout', subprocess.PIPE)
232         kwargs.setdefault('stderr', sys.stderr)
233         kwargs.setdefault('close_fds', True)
234         kwargs.setdefault('shell', False)
235         p = subprocess.Popen(execargs, **kwargs)
236         stdoutdata, stderrdata = p.communicate(None)
237         if p.returncode != 0:
238             raise errors.CommandFailedError(
239                 "run_command %s exit %d:\n%s" %
240                 (execargs, p.returncode, stderrdata))
241         return stdoutdata, stderrdata
242
243     @staticmethod
244     def git_checkout(url, version, path):
245         if not re.search('^/', path):
246             path = os.path.join(current_job().tmpdir, path)
247         if not os.path.exists(path):
248             util.run_command(["git", "clone", url, path],
249                              cwd=os.path.dirname(path))
250         util.run_command(["git", "checkout", version],
251                          cwd=path)
252         return path
253
254     @staticmethod
255     def tar_extractor(path, decompress_flag):
256         return subprocess.Popen(["tar",
257                                  "-C", path,
258                                  ("-x%sf" % decompress_flag),
259                                  "-"],
260                                 stdout=None,
261                                 stdin=subprocess.PIPE, stderr=sys.stderr,
262                                 shell=False, close_fds=True)
263
264     @staticmethod
265     def tarball_extract(tarball, path):
266         """Retrieve a tarball from Keep and extract it to a local
267         directory.  Return the absolute path where the tarball was
268         extracted. If the top level of the tarball contained just one
269         file or directory, return the absolute path of that single
270         item.
271
272         tarball -- collection locator
273         path -- where to extract the tarball: absolute, or relative to job tmp
274         """
275         if not re.search('^/', path):
276             path = os.path.join(current_job().tmpdir, path)
277         lockfile = open(path + '.lock', 'w')
278         fcntl.flock(lockfile, fcntl.LOCK_EX)
279         try:
280             os.stat(path)
281         except OSError:
282             os.mkdir(path)
283         already_have_it = False
284         try:
285             if os.readlink(os.path.join(path, '.locator')) == tarball:
286                 already_have_it = True
287         except OSError:
288             pass
289         if not already_have_it:
290
291             # emulate "rm -f" (i.e., if the file does not exist, we win)
292             try:
293                 os.unlink(os.path.join(path, '.locator'))
294             except OSError:
295                 if os.path.exists(os.path.join(path, '.locator')):
296                     os.unlink(os.path.join(path, '.locator'))
297
298             for f in CollectionReader(tarball).all_files():
299                 if re.search('\.(tbz|tar.bz2)$', f.name()):
300                     p = util.tar_extractor(path, 'j')
301                 elif re.search('\.(tgz|tar.gz)$', f.name()):
302                     p = util.tar_extractor(path, 'z')
303                 elif re.search('\.tar$', f.name()):
304                     p = util.tar_extractor(path, '')
305                 else:
306                     raise errors.AssertionError(
307                         "tarball_extract cannot handle filename %s" % f.name())
308                 while True:
309                     buf = f.read(2**20)
310                     if len(buf) == 0:
311                         break
312                     p.stdin.write(buf)
313                 p.stdin.close()
314                 p.wait()
315                 if p.returncode != 0:
316                     lockfile.close()
317                     raise errors.CommandFailedError(
318                         "tar exited %d" % p.returncode)
319             os.symlink(tarball, os.path.join(path, '.locator'))
320         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
321         lockfile.close()
322         if len(tld_extracts) == 1:
323             return os.path.join(path, tld_extracts[0])
324         return path
325
326     @staticmethod
327     def zipball_extract(zipball, path):
328         """Retrieve a zip archive from Keep and extract it to a local
329         directory.  Return the absolute path where the archive was
330         extracted. If the top level of the archive contained just one
331         file or directory, return the absolute path of that single
332         item.
333
334         zipball -- collection locator
335         path -- where to extract the archive: absolute, or relative to job tmp
336         """
337         if not re.search('^/', path):
338             path = os.path.join(current_job().tmpdir, path)
339         lockfile = open(path + '.lock', 'w')
340         fcntl.flock(lockfile, fcntl.LOCK_EX)
341         try:
342             os.stat(path)
343         except OSError:
344             os.mkdir(path)
345         already_have_it = False
346         try:
347             if os.readlink(os.path.join(path, '.locator')) == zipball:
348                 already_have_it = True
349         except OSError:
350             pass
351         if not already_have_it:
352
353             # emulate "rm -f" (i.e., if the file does not exist, we win)
354             try:
355                 os.unlink(os.path.join(path, '.locator'))
356             except OSError:
357                 if os.path.exists(os.path.join(path, '.locator')):
358                     os.unlink(os.path.join(path, '.locator'))
359
360             for f in CollectionReader(zipball).all_files():
361                 if not re.search('\.zip$', f.name()):
362                     raise errors.NotImplementedError(
363                         "zipball_extract cannot handle filename %s" % f.name())
364                 zip_filename = os.path.join(path, os.path.basename(f.name()))
365                 zip_file = open(zip_filename, 'wb')
366                 while True:
367                     buf = f.read(2**20)
368                     if len(buf) == 0:
369                         break
370                     zip_file.write(buf)
371                 zip_file.close()
372                 
373                 p = subprocess.Popen(["unzip",
374                                       "-q", "-o",
375                                       "-d", path,
376                                       zip_filename],
377                                      stdout=None,
378                                      stdin=None, stderr=sys.stderr,
379                                      shell=False, close_fds=True)
380                 p.wait()
381                 if p.returncode != 0:
382                     lockfile.close()
383                     raise errors.CommandFailedError(
384                         "unzip exited %d" % p.returncode)
385                 os.unlink(zip_filename)
386             os.symlink(zipball, os.path.join(path, '.locator'))
387         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
388         lockfile.close()
389         if len(tld_extracts) == 1:
390             return os.path.join(path, tld_extracts[0])
391         return path
392
393     @staticmethod
394     def collection_extract(collection, path, files=[], decompress=True):
395         """Retrieve a collection from Keep and extract it to a local
396         directory.  Return the absolute path where the collection was
397         extracted.
398
399         collection -- collection locator
400         path -- where to extract: absolute, or relative to job tmp
401         """
402         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
403         if matches:
404             collection_hash = matches.group(1)
405         else:
406             collection_hash = hashlib.md5(collection).hexdigest()
407         if not re.search('^/', path):
408             path = os.path.join(current_job().tmpdir, path)
409         lockfile = open(path + '.lock', 'w')
410         fcntl.flock(lockfile, fcntl.LOCK_EX)
411         try:
412             os.stat(path)
413         except OSError:
414             os.mkdir(path)
415         already_have_it = False
416         try:
417             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
418                 already_have_it = True
419         except OSError:
420             pass
421
422         # emulate "rm -f" (i.e., if the file does not exist, we win)
423         try:
424             os.unlink(os.path.join(path, '.locator'))
425         except OSError:
426             if os.path.exists(os.path.join(path, '.locator')):
427                 os.unlink(os.path.join(path, '.locator'))
428
429         files_got = []
430         for s in CollectionReader(collection).all_streams():
431             stream_name = s.name()
432             for f in s.all_files():
433                 if (files == [] or
434                     ((f.name() not in files_got) and
435                      (f.name() in files or
436                       (decompress and f.decompressed_name() in files)))):
437                     outname = f.decompressed_name() if decompress else f.name()
438                     files_got += [outname]
439                     if os.path.exists(os.path.join(path, stream_name, outname)):
440                         continue
441                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
442                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
443                     for buf in (f.readall_decompressed() if decompress
444                                 else f.readall()):
445                         outfile.write(buf)
446                     outfile.close()
447         if len(files_got) < len(files):
448             raise errors.AssertionError(
449                 "Wanted files %s but only got %s from %s" %
450                 (files, files_got,
451                  [z.name() for z in CollectionReader(collection).all_files()]))
452         os.symlink(collection_hash, os.path.join(path, '.locator'))
453
454         lockfile.close()
455         return path
456
457     @staticmethod
458     def mkdir_dash_p(path):
459         if not os.path.exists(path):
460             util.mkdir_dash_p(os.path.dirname(path))
461             try:
462                 os.mkdir(path)
463             except OSError:
464                 if not os.path.exists(path):
465                     os.mkdir(path)
466
467     @staticmethod
468     def stream_extract(stream, path, files=[], decompress=True):
469         """Retrieve a stream from Keep and extract it to a local
470         directory.  Return the absolute path where the stream was
471         extracted.
472
473         stream -- StreamReader object
474         path -- where to extract: absolute, or relative to job tmp
475         """
476         if not re.search('^/', path):
477             path = os.path.join(current_job().tmpdir, path)
478         lockfile = open(path + '.lock', 'w')
479         fcntl.flock(lockfile, fcntl.LOCK_EX)
480         try:
481             os.stat(path)
482         except OSError:
483             os.mkdir(path)
484
485         files_got = []
486         for f in stream.all_files():
487             if (files == [] or
488                 ((f.name() not in files_got) and
489                  (f.name() in files or
490                   (decompress and f.decompressed_name() in files)))):
491                 outname = f.decompressed_name() if decompress else f.name()
492                 files_got += [outname]
493                 if os.path.exists(os.path.join(path, outname)):
494                     os.unlink(os.path.join(path, outname))
495                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
496                 outfile = open(os.path.join(path, outname), 'wb')
497                 for buf in (f.readall_decompressed() if decompress
498                             else f.readall()):
499                     outfile.write(buf)
500                 outfile.close()
501         if len(files_got) < len(files):
502             raise errors.AssertionError(
503                 "Wanted files %s but only got %s from %s" %
504                 (files, files_got, [z.name() for z in stream.all_files()]))
505         lockfile.close()
506         return path
507
508     @staticmethod
509     def listdir_recursive(dirname, base=None):
510         allfiles = []
511         for ent in sorted(os.listdir(dirname)):
512             ent_path = os.path.join(dirname, ent)
513             ent_base = os.path.join(base, ent) if base else ent
514             if os.path.isdir(ent_path):
515                 allfiles += util.listdir_recursive(ent_path, ent_base)
516             else:
517                 allfiles += [ent_base]
518         return allfiles
519
520 class StreamFileReader(object):
521     def __init__(self, stream, pos, size, name):
522         self._stream = stream
523         self._pos = pos
524         self._size = size
525         self._name = name
526         self._filepos = 0
527     def name(self):
528         return self._name
529     def decompressed_name(self):
530         return re.sub('\.(bz2|gz)$', '', self._name)
531     def size(self):
532         return self._size
533     def stream_name(self):
534         return self._stream.name()
535     def read(self, size, **kwargs):
536         self._stream.seek(self._pos + self._filepos)
537         data = self._stream.read(min(size, self._size - self._filepos))
538         self._filepos += len(data)
539         return data
540     def readall(self, size=2**20, **kwargs):
541         while True:
542             data = self.read(size, **kwargs)
543             if data == '':
544                 break
545             yield data
546     def bunzip2(self, size):
547         decompressor = bz2.BZ2Decompressor()
548         for chunk in self.readall(size):
549             data = decompressor.decompress(chunk)
550             if data and data != '':
551                 yield data
552     def gunzip(self, size):
553         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
554         for chunk in self.readall(size):
555             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
556             if data and data != '':
557                 yield data
558     def readall_decompressed(self, size=2**20):
559         self._stream.seek(self._pos + self._filepos)
560         if re.search('\.bz2$', self._name):
561             return self.bunzip2(size)
562         elif re.search('\.gz$', self._name):
563             return self.gunzip(size)
564         else:
565             return self.readall(size)
566     def readlines(self, decompress=True):
567         if decompress:
568             datasource = self.readall_decompressed()
569         else:
570             self._stream.seek(self._pos + self._filepos)
571             datasource = self.readall()
572         data = ''
573         for newdata in datasource:
574             data += newdata
575             sol = 0
576             while True:
577                 eol = string.find(data, "\n", sol)
578                 if eol < 0:
579                     break
580                 yield data[sol:eol+1]
581                 sol = eol+1
582             data = data[sol:]
583         if data != '':
584             yield data
585     def as_manifest(self):
586         if self.size() == 0:
587             return ("%s %s 0:0:%s\n"
588                     % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
589         return string.join(self._stream.tokens_for_range(self._pos, self._size),
590                            " ") + "\n"
591
592 class StreamReader(object):
593     def __init__(self, tokens):
594         self._tokens = tokens
595         self._current_datablock_data = None
596         self._current_datablock_pos = 0
597         self._current_datablock_index = -1
598         self._pos = 0
599
600         self._stream_name = None
601         self.data_locators = []
602         self.files = []
603
604         for tok in self._tokens:
605             if self._stream_name == None:
606                 self._stream_name = tok
607             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
608                 self.data_locators += [tok]
609             elif re.search(r'^\d+:\d+:\S+', tok):
610                 pos, size, name = tok.split(':',2)
611                 self.files += [[int(pos), int(size), name]]
612             else:
613                 raise errors.SyntaxError("Invalid manifest format")
614
615     def tokens(self):
616         return self._tokens
617     def tokens_for_range(self, range_start, range_size):
618         resp = [self._stream_name]
619         return_all_tokens = False
620         block_start = 0
621         token_bytes_skipped = 0
622         for locator in self.data_locators:
623             sizehint = re.search(r'\+(\d+)', locator)
624             if not sizehint:
625                 return_all_tokens = True
626             if return_all_tokens:
627                 resp += [locator]
628                 next
629             blocksize = int(sizehint.group(0))
630             if range_start + range_size <= block_start:
631                 break
632             if range_start < block_start + blocksize:
633                 resp += [locator]
634             else:
635                 token_bytes_skipped += blocksize
636             block_start += blocksize
637         for f in self.files:
638             if ((f[0] < range_start + range_size)
639                 and
640                 (f[0] + f[1] > range_start)
641                 and
642                 f[1] > 0):
643                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
644         return resp
645     def name(self):
646         return self._stream_name
647     def all_files(self):
648         for f in self.files:
649             pos, size, name = f
650             yield StreamFileReader(self, pos, size, name)
651     def nextdatablock(self):
652         if self._current_datablock_index < 0:
653             self._current_datablock_pos = 0
654             self._current_datablock_index = 0
655         else:
656             self._current_datablock_pos += self.current_datablock_size()
657             self._current_datablock_index += 1
658         self._current_datablock_data = None
659     def current_datablock_data(self):
660         if self._current_datablock_data == None:
661             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
662         return self._current_datablock_data
663     def current_datablock_size(self):
664         if self._current_datablock_index < 0:
665             self.nextdatablock()
666         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
667         if sizehint:
668             return int(sizehint.group(0))
669         return len(self.current_datablock_data())
670     def seek(self, pos):
671         """Set the position of the next read operation."""
672         self._pos = pos
673     def really_seek(self):
674         """Find and load the appropriate data block, so the byte at
675         _pos is in memory.
676         """
677         if self._pos == self._current_datablock_pos:
678             return True
679         if (self._current_datablock_pos != None and
680             self._pos >= self._current_datablock_pos and
681             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
682             return True
683         if self._pos < self._current_datablock_pos:
684             self._current_datablock_index = -1
685             self.nextdatablock()
686         while (self._pos > self._current_datablock_pos and
687                self._pos > self._current_datablock_pos + self.current_datablock_size()):
688             self.nextdatablock()
689     def read(self, size):
690         """Read no more than size bytes -- but at least one byte,
691         unless _pos is already at the end of the stream.
692         """
693         if size == 0:
694             return ''
695         self.really_seek()
696         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
697             self.nextdatablock()
698             if self._current_datablock_index >= len(self.data_locators):
699                 return None
700         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
701         self._pos += len(data)
702         return data
703
704 class CollectionReader(object):
705     def __init__(self, manifest_locator_or_text):
706         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
707             self._manifest_text = manifest_locator_or_text
708             self._manifest_locator = None
709         else:
710             self._manifest_locator = manifest_locator_or_text
711             self._manifest_text = None
712         self._streams = None
713     def __enter__(self):
714         pass
715     def __exit__(self):
716         pass
717     def _populate(self):
718         if self._streams != None:
719             return
720         if not self._manifest_text:
721             self._manifest_text = Keep.get(self._manifest_locator)
722         self._streams = []
723         for stream_line in self._manifest_text.split("\n"):
724             if stream_line != '':
725                 stream_tokens = stream_line.split()
726                 self._streams += [stream_tokens]
727     def all_streams(self):
728         self._populate()
729         resp = []
730         for s in self._streams:
731             resp += [StreamReader(s)]
732         return resp
733     def all_files(self):
734         for s in self.all_streams():
735             for f in s.all_files():
736                 yield f
737     def manifest_text(self):
738         self._populate()
739         return self._manifest_text
740
741 class CollectionWriter(object):
742     KEEP_BLOCK_SIZE = 2**26
743     def __init__(self):
744         self._data_buffer = []
745         self._data_buffer_len = 0
746         self._current_stream_files = []
747         self._current_stream_length = 0
748         self._current_stream_locators = []
749         self._current_stream_name = '.'
750         self._current_file_name = None
751         self._current_file_pos = 0
752         self._finished_streams = []
753     def __enter__(self):
754         pass
755     def __exit__(self):
756         self.finish()
757     def write_directory_tree(self,
758                              path, stream_name='.', max_manifest_depth=-1):
759         self.start_new_stream(stream_name)
760         todo = []
761         if max_manifest_depth == 0:
762             dirents = sorted(util.listdir_recursive(path))
763         else:
764             dirents = sorted(os.listdir(path))
765         for dirent in dirents:
766             target = os.path.join(path, dirent)
767             if os.path.isdir(target):
768                 todo += [[target,
769                           os.path.join(stream_name, dirent),
770                           max_manifest_depth-1]]
771             else:
772                 self.start_new_file(dirent)
773                 with open(target, 'rb') as f:
774                     while True:
775                         buf = f.read(2**26)
776                         if len(buf) == 0:
777                             break
778                         self.write(buf)
779         self.finish_current_stream()
780         map(lambda x: self.write_directory_tree(*x), todo)
781
782     def write(self, newdata):
783         if hasattr(newdata, '__iter__'):
784             for s in newdata:
785                 self.write(s)
786             return
787         self._data_buffer += [newdata]
788         self._data_buffer_len += len(newdata)
789         self._current_stream_length += len(newdata)
790         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
791             self.flush_data()
792     def flush_data(self):
793         data_buffer = ''.join(self._data_buffer)
794         if data_buffer != '':
795             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
796             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
797             self._data_buffer_len = len(self._data_buffer[0])
798     def start_new_file(self, newfilename=None):
799         self.finish_current_file()
800         self.set_current_file_name(newfilename)
801     def set_current_file_name(self, newfilename):
802         newfilename = re.sub(r' ', '\\\\040', newfilename)
803         if re.search(r'[ \t\n]', newfilename):
804             raise errors.AssertionError(
805                 "Manifest filenames cannot contain whitespace: %s" %
806                 newfilename)
807         self._current_file_name = newfilename
808     def current_file_name(self):
809         return self._current_file_name
810     def finish_current_file(self):
811         if self._current_file_name == None:
812             if self._current_file_pos == self._current_stream_length:
813                 return
814             raise errors.AssertionError(
815                 "Cannot finish an unnamed file " +
816                 "(%d bytes at offset %d in '%s' stream)" %
817                 (self._current_stream_length - self._current_file_pos,
818                  self._current_file_pos,
819                  self._current_stream_name))
820         self._current_stream_files += [[self._current_file_pos,
821                                        self._current_stream_length - self._current_file_pos,
822                                        self._current_file_name]]
823         self._current_file_pos = self._current_stream_length
824     def start_new_stream(self, newstreamname='.'):
825         self.finish_current_stream()
826         self.set_current_stream_name(newstreamname)
827     def set_current_stream_name(self, newstreamname):
828         if re.search(r'[ \t\n]', newstreamname):
829             raise errors.AssertionError(
830                 "Manifest stream names cannot contain whitespace")
831         self._current_stream_name = '.' if newstreamname=='' else newstreamname
832     def current_stream_name(self):
833         return self._current_stream_name
834     def finish_current_stream(self):
835         self.finish_current_file()
836         self.flush_data()
837         if len(self._current_stream_files) == 0:
838             pass
839         elif self._current_stream_name == None:
840             raise errors.AssertionError(
841                 "Cannot finish an unnamed stream (%d bytes in %d files)" %
842                 (self._current_stream_length, len(self._current_stream_files)))
843         else:
844             if len(self._current_stream_locators) == 0:
845                 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
846             self._finished_streams += [[self._current_stream_name,
847                                        self._current_stream_locators,
848                                        self._current_stream_files]]
849         self._current_stream_files = []
850         self._current_stream_length = 0
851         self._current_stream_locators = []
852         self._current_stream_name = None
853         self._current_file_pos = 0
854         self._current_file_name = None
855     def finish(self):
856         return Keep.put(self.manifest_text())
857     def manifest_text(self):
858         self.finish_current_stream()
859         manifest = ''
860         for stream in self._finished_streams:
861             if not re.search(r'^\.(/.*)?$', stream[0]):
862                 manifest += './'
863             manifest += stream[0]
864             for locator in stream[1]:
865                 manifest += " %s" % locator
866             for sfile in stream[2]:
867                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
868             manifest += "\n"
869         return manifest
870     def data_locators(self):
871         ret = []
872         for name, locators, files in self._finished_streams:
873             ret += locators
874         return ret
875
876 global_client_object = None
877
878 class Keep:
879     @staticmethod
880     def global_client_object():
881         global global_client_object
882         if global_client_object == None:
883             global_client_object = KeepClient()
884         return global_client_object
885
886     @staticmethod
887     def get(locator, **kwargs):
888         return Keep.global_client_object().get(locator, **kwargs)
889
890     @staticmethod
891     def put(data, **kwargs):
892         return Keep.global_client_object().put(data, **kwargs)
893
894 class KeepClient(object):
895
896     class ThreadLimiter(object):
897         """
898         Limit the number of threads running at a given time to
899         {desired successes} minus {successes reported}. When successes
900         reported == desired, wake up the remaining threads and tell
901         them to quit.
902
903         Should be used in a "with" block.
904         """
905         def __init__(self, todo):
906             self._todo = todo
907             self._done = 0
908             self._todo_lock = threading.Semaphore(todo)
909             self._done_lock = threading.Lock()
910         def __enter__(self):
911             self._todo_lock.acquire()
912             return self
913         def __exit__(self, type, value, traceback):
914             self._todo_lock.release()
915         def shall_i_proceed(self):
916             """
917             Return true if the current thread should do stuff. Return
918             false if the current thread should just stop.
919             """
920             with self._done_lock:
921                 return (self._done < self._todo)
922         def increment_done(self):
923             """
924             Report that the current thread was successful.
925             """
926             with self._done_lock:
927                 self._done += 1
928         def done(self):
929             """
930             Return how many successes were reported.
931             """
932             with self._done_lock:
933                 return self._done
934
935     class KeepWriterThread(threading.Thread):
936         """
937         Write a blob of data to the given Keep server. Call
938         increment_done() of the given ThreadLimiter if the write
939         succeeds.
940         """
941         def __init__(self, **kwargs):
942             super(KeepClient.KeepWriterThread, self).__init__()
943             self.args = kwargs
944         def run(self):
945             global config
946             with self.args['thread_limiter'] as limiter:
947                 if not limiter.shall_i_proceed():
948                     # My turn arrived, but the job has been done without
949                     # me.
950                     return
951                 logging.debug("KeepWriterThread %s proceeding %s %s" %
952                               (str(threading.current_thread()),
953                                self.args['data_hash'],
954                                self.args['service_root']))
955                 h = httplib2.Http()
956                 url = self.args['service_root'] + self.args['data_hash']
957                 api_token = config['ARVADOS_API_TOKEN']
958                 headers = {'Authorization': "OAuth2 %s" % api_token}
959                 try:
960                     resp, content = h.request(url.encode('utf-8'), 'PUT',
961                                               headers=headers,
962                                               body=self.args['data'])
963                     if (resp['status'] == '401' and
964                         re.match(r'Timestamp verification failed', content)):
965                         body = KeepClient.sign_for_old_server(
966                             self.args['data_hash'],
967                             self.args['data'])
968                         h = httplib2.Http()
969                         resp, content = h.request(url.encode('utf-8'), 'PUT',
970                                                   headers=headers,
971                                                   body=body)
972                     if re.match(r'^2\d\d$', resp['status']):
973                         logging.debug("KeepWriterThread %s succeeded %s %s" %
974                                       (str(threading.current_thread()),
975                                        self.args['data_hash'],
976                                        self.args['service_root']))
977                         return limiter.increment_done()
978                     logging.warning("Request fail: PUT %s => %s %s" %
979                                     (url, resp['status'], content))
980                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
981                     logging.warning("Request fail: PUT %s => %s: %s" %
982                                     (url, type(e), str(e)))
983
984     def __init__(self):
985         self.lock = threading.Lock()
986         self.service_roots = None
987
988     def shuffled_service_roots(self, hash):
989         if self.service_roots == None:
990             self.lock.acquire()
991             keep_disks = api().keep_disks().list().execute()['items']
992             roots = (("http%s://%s:%d/" %
993                       ('s' if f['service_ssl_flag'] else '',
994                        f['service_host'],
995                        f['service_port']))
996                      for f in keep_disks)
997             self.service_roots = sorted(set(roots))
998             logging.debug(str(self.service_roots))
999             self.lock.release()
1000         seed = hash
1001         pool = self.service_roots[:]
1002         pseq = []
1003         while len(pool) > 0:
1004             if len(seed) < 8:
1005                 if len(pseq) < len(hash) / 4: # first time around
1006                     seed = hash[-4:] + hash
1007                 else:
1008                     seed += hash
1009             probe = int(seed[0:8], 16) % len(pool)
1010             pseq += [pool[probe]]
1011             pool = pool[:probe] + pool[probe+1:]
1012             seed = seed[8:]
1013         logging.debug(str(pseq))
1014         return pseq
1015
1016     def get(self, locator):
1017         global config
1018         if re.search(r',', locator):
1019             return ''.join(self.get(x) for x in locator.split(','))
1020         if 'KEEP_LOCAL_STORE' in os.environ:
1021             return KeepClient.local_store_get(locator)
1022         expect_hash = re.sub(r'\+.*', '', locator)
1023         for service_root in self.shuffled_service_roots(expect_hash):
1024             h = httplib2.Http()
1025             url = service_root + expect_hash
1026             api_token = config['ARVADOS_API_TOKEN']
1027             headers = {'Authorization': "OAuth2 %s" % api_token,
1028                        'Accept': 'application/octet-stream'}
1029             try:
1030                 resp, content = h.request(url.encode('utf-8'), 'GET',
1031                                           headers=headers)
1032                 if re.match(r'^2\d\d$', resp['status']):
1033                     m = hashlib.new('md5')
1034                     m.update(content)
1035                     md5 = m.hexdigest()
1036                     if md5 == expect_hash:
1037                         return content
1038                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
1039             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
1040                 logging.info("Request fail: GET %s => %s: %s" %
1041                              (url, type(e), str(e)))
1042         raise errors.NotFoundError("Block not found: %s" % expect_hash)
1043
1044     def put(self, data, **kwargs):
1045         if 'KEEP_LOCAL_STORE' in os.environ:
1046             return KeepClient.local_store_put(data)
1047         m = hashlib.new('md5')
1048         m.update(data)
1049         data_hash = m.hexdigest()
1050         have_copies = 0
1051         want_copies = kwargs.get('copies', 2)
1052         if not (want_copies > 0):
1053             return data_hash
1054         threads = []
1055         thread_limiter = KeepClient.ThreadLimiter(want_copies)
1056         for service_root in self.shuffled_service_roots(data_hash):
1057             t = KeepClient.KeepWriterThread(data=data,
1058                                             data_hash=data_hash,
1059                                             service_root=service_root,
1060                                             thread_limiter=thread_limiter)
1061             t.start()
1062             threads += [t]
1063         for t in threads:
1064             t.join()
1065         have_copies = thread_limiter.done()
1066         if have_copies == want_copies:
1067             return (data_hash + '+' + str(len(data)))
1068         raise errors.KeepWriteError(
1069             "Write fail for %s: wanted %d but wrote %d" %
1070             (data_hash, want_copies, have_copies))
1071
1072     @staticmethod
1073     def sign_for_old_server(data_hash, data):
1074         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1075
1076
1077     @staticmethod
1078     def local_store_put(data):
1079         m = hashlib.new('md5')
1080         m.update(data)
1081         md5 = m.hexdigest()
1082         locator = '%s+%d' % (md5, len(data))
1083         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1084             f.write(data)
1085         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1086                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1087         return locator
1088     @staticmethod
1089     def local_store_get(locator):
1090         r = re.search('^([0-9a-f]{32,})', locator)
1091         if not r:
1092             raise errors.NotFoundError(
1093                 "Invalid data locator: '%s'" % locator)
1094         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1095             return ''
1096         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1097             return f.read()
1098
1099 # We really shouldn't do this but some clients still use
1100 # arvados.service.* directly instead of arvados.api().*
1101 service = api()