406d3ed4b21028c7dca2891dd85d5660eec93ae9
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 import apiclient
22 import apiclient.discovery
23
24 # Arvados configuration settings are taken from $HOME/.config/arvados.
25 # Environment variables override settings in the config file.
26 #
27 class ArvadosConfig(dict):
28     def __init__(self, config_file):
29         dict.__init__(self)
30         with open(config_file, "r") as f:
31             for config_line in f:
32                 var, val = config_line.rstrip().split('=', 2)
33                 self[var] = val
34         for var in os.environ:
35             if var.startswith('ARVADOS_'):
36                 self[var] = os.environ[var]
37
38
39 config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
40
41 if 'ARVADOS_DEBUG' in config:
42     logging.basicConfig(level=logging.DEBUG)
43
44 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
45
46 services = {}
47
48 class errors:
49     class SyntaxError(Exception):
50         pass
51     class AssertionError(Exception):
52         pass
53     class NotFoundError(Exception):
54         pass
55     class CommandFailedError(Exception):
56         pass
57     class KeepWriteError(Exception):
58         pass
59     class NotImplementedError(Exception):
60         pass
61
62 class CredentialsFromEnv(object):
63     @staticmethod
64     def http_request(self, uri, **kwargs):
65         global config
66         from httplib import BadStatusLine
67         if 'headers' not in kwargs:
68             kwargs['headers'] = {}
69         kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
70         try:
71             return self.orig_http_request(uri, **kwargs)
72         except BadStatusLine:
73             # This is how httplib tells us that it tried to reuse an
74             # existing connection but it was already closed by the
75             # server. In that case, yes, we would like to retry.
76             # Unfortunately, we are not absolutely certain that the
77             # previous call did not succeed, so this is slightly
78             # risky.
79             return self.orig_http_request(uri, **kwargs)
80     def authorize(self, http):
81         http.orig_http_request = http.request
82         http.request = types.MethodType(self.http_request, http)
83         return http
84
85 def task_set_output(self,s):
86     api('v1').job_tasks().update(uuid=self['uuid'],
87                                  body={
88             'output':s,
89             'success':True,
90             'progress':1.0
91             }).execute()
92
93 _current_task = None
94 def current_task():
95     global _current_task
96     if _current_task:
97         return _current_task
98     t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
99     t = UserDict.UserDict(t)
100     t.set_output = types.MethodType(task_set_output, t)
101     t.tmpdir = os.environ['TASK_WORK']
102     _current_task = t
103     return t
104
105 _current_job = None
106 def current_job():
107     global _current_job
108     if _current_job:
109         return _current_job
110     t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
111     t = UserDict.UserDict(t)
112     t.tmpdir = os.environ['JOB_WORK']
113     _current_job = t
114     return t
115
116 def getjobparam(*args):
117     return current_job()['script_parameters'].get(*args)
118
119 # Monkey patch discovery._cast() so objects and arrays get serialized
120 # with json.dumps() instead of str().
121 _cast_orig = apiclient.discovery._cast
122 def _cast_objects_too(value, schema_type):
123     global _cast_orig
124     if (type(value) != type('') and
125         (schema_type == 'object' or schema_type == 'array')):
126         return json.dumps(value)
127     else:
128         return _cast_orig(value, schema_type)
129 apiclient.discovery._cast = _cast_objects_too
130
131 def api(version=None):
132     global services, config
133     if not services.get(version):
134         apiVersion = version
135         if not version:
136             apiVersion = 'v1'
137             logging.info("Using default API version. " +
138                          "Call arvados.api('%s') instead." %
139                          apiVersion)
140         if 'ARVADOS_API_HOST' not in config:
141             raise Exception("ARVADOS_API_HOST is not set. Aborting.")
142         url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
143                config['ARVADOS_API_HOST'])
144         credentials = CredentialsFromEnv()
145
146         # Use system's CA certificates (if we find them) instead of httplib2's
147         ca_certs = '/etc/ssl/certs/ca-certificates.crt'
148         if not os.path.exists(ca_certs):
149             ca_certs = None             # use httplib2 default
150
151         http = httplib2.Http(ca_certs=ca_certs)
152         http = credentials.authorize(http)
153         if re.match(r'(?i)^(true|1|yes)$',
154                     config.get('ARVADOS_API_HOST_INSECURE', 'no')):
155             http.disable_ssl_certificate_validation=True
156         services[version] = apiclient.discovery.build(
157             'arvados', apiVersion, http=http, discoveryServiceUrl=url)
158     return services[version]
159
160 class JobTask(object):
161     def __init__(self, parameters=dict(), runtime_constraints=dict()):
162         print "init jobtask %s %s" % (parameters, runtime_constraints)
163
164 class job_setup:
165     @staticmethod
166     def one_task_per_input_file(if_sequence=0, and_end_task=True):
167         if if_sequence != current_task()['sequence']:
168             return
169         job_input = current_job()['script_parameters']['input']
170         cr = CollectionReader(job_input)
171         for s in cr.all_streams():
172             for f in s.all_files():
173                 task_input = f.as_manifest()
174                 new_task_attrs = {
175                     'job_uuid': current_job()['uuid'],
176                     'created_by_job_task_uuid': current_task()['uuid'],
177                     'sequence': if_sequence + 1,
178                     'parameters': {
179                         'input':task_input
180                         }
181                     }
182                 api('v1').job_tasks().create(body=new_task_attrs).execute()
183         if and_end_task:
184             api('v1').job_tasks().update(uuid=current_task()['uuid'],
185                                        body={'success':True}
186                                        ).execute()
187             exit(0)
188
189     @staticmethod
190     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
191         if if_sequence != current_task()['sequence']:
192             return
193         job_input = current_job()['script_parameters']['input']
194         cr = CollectionReader(job_input)
195         for s in cr.all_streams():
196             task_input = s.tokens()
197             new_task_attrs = {
198                 'job_uuid': current_job()['uuid'],
199                 'created_by_job_task_uuid': current_task()['uuid'],
200                 'sequence': if_sequence + 1,
201                 'parameters': {
202                     'input':task_input
203                     }
204                 }
205             api('v1').job_tasks().create(body=new_task_attrs).execute()
206         if and_end_task:
207             api('v1').job_tasks().update(uuid=current_task()['uuid'],
208                                        body={'success':True}
209                                        ).execute()
210             exit(0)
211
212 class util:
213     @staticmethod
214     def clear_tmpdir(path=None):
215         """
216         Ensure the given directory (or TASK_TMPDIR if none given)
217         exists and is empty.
218         """
219         if path == None:
220             path = current_task().tmpdir
221         if os.path.exists(path):
222             p = subprocess.Popen(['rm', '-rf', path])
223             stdout, stderr = p.communicate(None)
224             if p.returncode != 0:
225                 raise Exception('rm -rf %s: %s' % (path, stderr))
226         os.mkdir(path)
227
228     @staticmethod
229     def run_command(execargs, **kwargs):
230         kwargs.setdefault('stdin', subprocess.PIPE)
231         kwargs.setdefault('stdout', subprocess.PIPE)
232         kwargs.setdefault('stderr', sys.stderr)
233         kwargs.setdefault('close_fds', True)
234         kwargs.setdefault('shell', False)
235         p = subprocess.Popen(execargs, **kwargs)
236         stdoutdata, stderrdata = p.communicate(None)
237         if p.returncode != 0:
238             raise errors.CommandFailedError(
239                 "run_command %s exit %d:\n%s" %
240                 (execargs, p.returncode, stderrdata))
241         return stdoutdata, stderrdata
242
243     @staticmethod
244     def git_checkout(url, version, path):
245         if not re.search('^/', path):
246             path = os.path.join(current_job().tmpdir, path)
247         if not os.path.exists(path):
248             util.run_command(["git", "clone", url, path],
249                              cwd=os.path.dirname(path))
250         util.run_command(["git", "checkout", version],
251                          cwd=path)
252         return path
253
254     @staticmethod
255     def tar_extractor(path, decompress_flag):
256         return subprocess.Popen(["tar",
257                                  "-C", path,
258                                  ("-x%sf" % decompress_flag),
259                                  "-"],
260                                 stdout=None,
261                                 stdin=subprocess.PIPE, stderr=sys.stderr,
262                                 shell=False, close_fds=True)
263
264     @staticmethod
265     def tarball_extract(tarball, path):
266         """Retrieve a tarball from Keep and extract it to a local
267         directory.  Return the absolute path where the tarball was
268         extracted. If the top level of the tarball contained just one
269         file or directory, return the absolute path of that single
270         item.
271
272         tarball -- collection locator
273         path -- where to extract the tarball: absolute, or relative to job tmp
274         """
275         if not re.search('^/', path):
276             path = os.path.join(current_job().tmpdir, path)
277         lockfile = open(path + '.lock', 'w')
278         fcntl.flock(lockfile, fcntl.LOCK_EX)
279         try:
280             os.stat(path)
281         except OSError:
282             os.mkdir(path)
283         already_have_it = False
284         try:
285             if os.readlink(os.path.join(path, '.locator')) == tarball:
286                 already_have_it = True
287         except OSError:
288             pass
289         if not already_have_it:
290
291             # emulate "rm -f" (i.e., if the file does not exist, we win)
292             try:
293                 os.unlink(os.path.join(path, '.locator'))
294             except OSError:
295                 if os.path.exists(os.path.join(path, '.locator')):
296                     os.unlink(os.path.join(path, '.locator'))
297
298             for f in CollectionReader(tarball).all_files():
299                 if re.search('\.(tbz|tar.bz2)$', f.name()):
300                     p = util.tar_extractor(path, 'j')
301                 elif re.search('\.(tgz|tar.gz)$', f.name()):
302                     p = util.tar_extractor(path, 'z')
303                 elif re.search('\.tar$', f.name()):
304                     p = util.tar_extractor(path, '')
305                 else:
306                     raise errors.AssertionError(
307                         "tarball_extract cannot handle filename %s" % f.name())
308                 while True:
309                     buf = f.read(2**20)
310                     if len(buf) == 0:
311                         break
312                     p.stdin.write(buf)
313                 p.stdin.close()
314                 p.wait()
315                 if p.returncode != 0:
316                     lockfile.close()
317                     raise errors.CommandFailedError(
318                         "tar exited %d" % p.returncode)
319             os.symlink(tarball, os.path.join(path, '.locator'))
320         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
321         lockfile.close()
322         if len(tld_extracts) == 1:
323             return os.path.join(path, tld_extracts[0])
324         return path
325
326     @staticmethod
327     def zipball_extract(zipball, path):
328         """Retrieve a zip archive from Keep and extract it to a local
329         directory.  Return the absolute path where the archive was
330         extracted. If the top level of the archive contained just one
331         file or directory, return the absolute path of that single
332         item.
333
334         zipball -- collection locator
335         path -- where to extract the archive: absolute, or relative to job tmp
336         """
337         if not re.search('^/', path):
338             path = os.path.join(current_job().tmpdir, path)
339         lockfile = open(path + '.lock', 'w')
340         fcntl.flock(lockfile, fcntl.LOCK_EX)
341         try:
342             os.stat(path)
343         except OSError:
344             os.mkdir(path)
345         already_have_it = False
346         try:
347             if os.readlink(os.path.join(path, '.locator')) == zipball:
348                 already_have_it = True
349         except OSError:
350             pass
351         if not already_have_it:
352
353             # emulate "rm -f" (i.e., if the file does not exist, we win)
354             try:
355                 os.unlink(os.path.join(path, '.locator'))
356             except OSError:
357                 if os.path.exists(os.path.join(path, '.locator')):
358                     os.unlink(os.path.join(path, '.locator'))
359
360             for f in CollectionReader(zipball).all_files():
361                 if not re.search('\.zip$', f.name()):
362                     raise errors.NotImplementedError(
363                         "zipball_extract cannot handle filename %s" % f.name())
364                 zip_filename = os.path.join(path, os.path.basename(f.name()))
365                 zip_file = open(zip_filename, 'wb')
366                 while True:
367                     buf = f.read(2**20)
368                     if len(buf) == 0:
369                         break
370                     zip_file.write(buf)
371                 zip_file.close()
372                 
373                 p = subprocess.Popen(["unzip",
374                                       "-q", "-o",
375                                       "-d", path,
376                                       zip_filename],
377                                      stdout=None,
378                                      stdin=None, stderr=sys.stderr,
379                                      shell=False, close_fds=True)
380                 p.wait()
381                 if p.returncode != 0:
382                     lockfile.close()
383                     raise errors.CommandFailedError(
384                         "unzip exited %d" % p.returncode)
385                 os.unlink(zip_filename)
386             os.symlink(zipball, os.path.join(path, '.locator'))
387         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
388         lockfile.close()
389         if len(tld_extracts) == 1:
390             return os.path.join(path, tld_extracts[0])
391         return path
392
393     @staticmethod
394     def collection_extract(collection, path, files=[], decompress=True):
395         """Retrieve a collection from Keep and extract it to a local
396         directory.  Return the absolute path where the collection was
397         extracted.
398
399         collection -- collection locator
400         path -- where to extract: absolute, or relative to job tmp
401         """
402         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
403         if matches:
404             collection_hash = matches.group(1)
405         else:
406             collection_hash = hashlib.md5(collection).hexdigest()
407         if not re.search('^/', path):
408             path = os.path.join(current_job().tmpdir, path)
409         lockfile = open(path + '.lock', 'w')
410         fcntl.flock(lockfile, fcntl.LOCK_EX)
411         try:
412             os.stat(path)
413         except OSError:
414             os.mkdir(path)
415         already_have_it = False
416         try:
417             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
418                 already_have_it = True
419         except OSError:
420             pass
421
422         # emulate "rm -f" (i.e., if the file does not exist, we win)
423         try:
424             os.unlink(os.path.join(path, '.locator'))
425         except OSError:
426             if os.path.exists(os.path.join(path, '.locator')):
427                 os.unlink(os.path.join(path, '.locator'))
428
429         files_got = []
430         for s in CollectionReader(collection).all_streams():
431             stream_name = s.name()
432             for f in s.all_files():
433                 if (files == [] or
434                     ((f.name() not in files_got) and
435                      (f.name() in files or
436                       (decompress and f.decompressed_name() in files)))):
437                     outname = f.decompressed_name() if decompress else f.name()
438                     files_got += [outname]
439                     if os.path.exists(os.path.join(path, stream_name, outname)):
440                         continue
441                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
442                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
443                     for buf in (f.readall_decompressed() if decompress
444                                 else f.readall()):
445                         outfile.write(buf)
446                     outfile.close()
447         if len(files_got) < len(files):
448             raise errors.AssertionError(
449                 "Wanted files %s but only got %s from %s" %
450                 (files, files_got,
451                  [z.name() for z in CollectionReader(collection).all_files()]))
452         os.symlink(collection_hash, os.path.join(path, '.locator'))
453
454         lockfile.close()
455         return path
456
457     @staticmethod
458     def mkdir_dash_p(path):
459         if not os.path.exists(path):
460             util.mkdir_dash_p(os.path.dirname(path))
461             try:
462                 os.mkdir(path)
463             except OSError:
464                 if not os.path.exists(path):
465                     os.mkdir(path)
466
467     @staticmethod
468     def stream_extract(stream, path, files=[], decompress=True):
469         """Retrieve a stream from Keep and extract it to a local
470         directory.  Return the absolute path where the stream was
471         extracted.
472
473         stream -- StreamReader object
474         path -- where to extract: absolute, or relative to job tmp
475         """
476         if not re.search('^/', path):
477             path = os.path.join(current_job().tmpdir, path)
478         lockfile = open(path + '.lock', 'w')
479         fcntl.flock(lockfile, fcntl.LOCK_EX)
480         try:
481             os.stat(path)
482         except OSError:
483             os.mkdir(path)
484
485         files_got = []
486         for f in stream.all_files():
487             if (files == [] or
488                 ((f.name() not in files_got) and
489                  (f.name() in files or
490                   (decompress and f.decompressed_name() in files)))):
491                 outname = f.decompressed_name() if decompress else f.name()
492                 files_got += [outname]
493                 if os.path.exists(os.path.join(path, outname)):
494                     os.unlink(os.path.join(path, outname))
495                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
496                 outfile = open(os.path.join(path, outname), 'wb')
497                 for buf in (f.readall_decompressed() if decompress
498                             else f.readall()):
499                     outfile.write(buf)
500                 outfile.close()
501         if len(files_got) < len(files):
502             raise errors.AssertionError(
503                 "Wanted files %s but only got %s from %s" %
504                 (files, files_got, [z.name() for z in stream.all_files()]))
505         lockfile.close()
506         return path
507
508     @staticmethod
509     def listdir_recursive(dirname, base=None):
510         allfiles = []
511         for ent in sorted(os.listdir(dirname)):
512             ent_path = os.path.join(dirname, ent)
513             ent_base = os.path.join(base, ent) if base else ent
514             if os.path.isdir(ent_path):
515                 allfiles += util.listdir_recursive(ent_path, ent_base)
516             else:
517                 allfiles += [ent_base]
518         return allfiles
519
520 class StreamFileReader(object):
521     def __init__(self, stream, pos, size, name):
522         self._stream = stream
523         self._pos = pos
524         self._size = size
525         self._name = name
526         self._filepos = 0
527     def name(self):
528         return self._name
529     def decompressed_name(self):
530         return re.sub('\.(bz2|gz)$', '', self._name)
531     def size(self):
532         return self._size
533     def stream_name(self):
534         return self._stream.name()
535     def read(self, size, **kwargs):
536         self._stream.seek(self._pos + self._filepos)
537         data = self._stream.read(min(size, self._size - self._filepos))
538         self._filepos += len(data)
539         return data
540     def readall(self, size=2**20, **kwargs):
541         while True:
542             data = self.read(size, **kwargs)
543             if data == '':
544                 break
545             yield data
546     def bunzip2(self, size):
547         decompressor = bz2.BZ2Decompressor()
548         for chunk in self.readall(size):
549             data = decompressor.decompress(chunk)
550             if data and data != '':
551                 yield data
552     def gunzip(self, size):
553         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
554         for chunk in self.readall(size):
555             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
556             if data and data != '':
557                 yield data
558     def readall_decompressed(self, size=2**20):
559         self._stream.seek(self._pos + self._filepos)
560         if re.search('\.bz2$', self._name):
561             return self.bunzip2(size)
562         elif re.search('\.gz$', self._name):
563             return self.gunzip(size)
564         else:
565             return self.readall(size)
566     def readlines(self, decompress=True):
567         if decompress:
568             datasource = self.readall_decompressed()
569         else:
570             self._stream.seek(self._pos + self._filepos)
571             datasource = self.readall()
572         data = ''
573         for newdata in datasource:
574             data += newdata
575             sol = 0
576             while True:
577                 eol = string.find(data, "\n", sol)
578                 if eol < 0:
579                     break
580                 yield data[sol:eol+1]
581                 sol = eol+1
582             data = data[sol:]
583         if data != '':
584             yield data
585     def as_manifest(self):
586         if self.size() == 0:
587             return ("%s %s 0:0:%s\n"
588                     % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
589         return string.join(self._stream.tokens_for_range(self._pos, self._size),
590                            " ") + "\n"
591
592 class StreamReader(object):
593     def __init__(self, tokens):
594         self._tokens = tokens
595         self._current_datablock_data = None
596         self._current_datablock_pos = 0
597         self._current_datablock_index = -1
598         self._pos = 0
599
600         self._stream_name = None
601         self.data_locators = []
602         self.files = []
603
604         for tok in self._tokens:
605             if self._stream_name == None:
606                 self._stream_name = tok.replace('\\040', ' ')
607             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
608                 self.data_locators += [tok]
609             elif re.search(r'^\d+:\d+:\S+', tok):
610                 pos, size, name = tok.split(':',2)
611                 self.files += [[int(pos), int(size), name.replace('\\040', ' ')]]
612             else:
613                 raise errors.SyntaxError("Invalid manifest format")
614
615     def tokens(self):
616         return self._tokens
617     def tokens_for_range(self, range_start, range_size):
618         resp = [self._stream_name]
619         return_all_tokens = False
620         block_start = 0
621         token_bytes_skipped = 0
622         for locator in self.data_locators:
623             sizehint = re.search(r'\+(\d+)', locator)
624             if not sizehint:
625                 return_all_tokens = True
626             if return_all_tokens:
627                 resp += [locator]
628                 next
629             blocksize = int(sizehint.group(0))
630             if range_start + range_size <= block_start:
631                 break
632             if range_start < block_start + blocksize:
633                 resp += [locator]
634             else:
635                 token_bytes_skipped += blocksize
636             block_start += blocksize
637         for f in self.files:
638             if ((f[0] < range_start + range_size)
639                 and
640                 (f[0] + f[1] > range_start)
641                 and
642                 f[1] > 0):
643                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
644         return resp
645     def name(self):
646         return self._stream_name
647     def all_files(self):
648         for f in self.files:
649             pos, size, name = f
650             yield StreamFileReader(self, pos, size, name)
651     def nextdatablock(self):
652         if self._current_datablock_index < 0:
653             self._current_datablock_pos = 0
654             self._current_datablock_index = 0
655         else:
656             self._current_datablock_pos += self.current_datablock_size()
657             self._current_datablock_index += 1
658         self._current_datablock_data = None
659     def current_datablock_data(self):
660         if self._current_datablock_data == None:
661             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
662         return self._current_datablock_data
663     def current_datablock_size(self):
664         if self._current_datablock_index < 0:
665             self.nextdatablock()
666         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
667         if sizehint:
668             return int(sizehint.group(0))
669         return len(self.current_datablock_data())
670     def seek(self, pos):
671         """Set the position of the next read operation."""
672         self._pos = pos
673     def really_seek(self):
674         """Find and load the appropriate data block, so the byte at
675         _pos is in memory.
676         """
677         if self._pos == self._current_datablock_pos:
678             return True
679         if (self._current_datablock_pos != None and
680             self._pos >= self._current_datablock_pos and
681             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
682             return True
683         if self._pos < self._current_datablock_pos:
684             self._current_datablock_index = -1
685             self.nextdatablock()
686         while (self._pos > self._current_datablock_pos and
687                self._pos > self._current_datablock_pos + self.current_datablock_size()):
688             self.nextdatablock()
689     def read(self, size):
690         """Read no more than size bytes -- but at least one byte,
691         unless _pos is already at the end of the stream.
692         """
693         if size == 0:
694             return ''
695         self.really_seek()
696         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
697             self.nextdatablock()
698             if self._current_datablock_index >= len(self.data_locators):
699                 return None
700         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
701         self._pos += len(data)
702         return data
703
704 class CollectionReader(object):
705     def __init__(self, manifest_locator_or_text):
706         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
707             self._manifest_text = manifest_locator_or_text
708             self._manifest_locator = None
709         else:
710             self._manifest_locator = manifest_locator_or_text
711             self._manifest_text = None
712         self._streams = None
713     def __enter__(self):
714         pass
715     def __exit__(self):
716         pass
717     def _populate(self):
718         if self._streams != None:
719             return
720         if not self._manifest_text:
721             self._manifest_text = Keep.get(self._manifest_locator)
722         self._streams = []
723         for stream_line in self._manifest_text.split("\n"):
724             if stream_line != '':
725                 stream_tokens = stream_line.split()
726                 self._streams += [stream_tokens]
727     def all_streams(self):
728         self._populate()
729         resp = []
730         for s in self._streams:
731             resp += [StreamReader(s)]
732         return resp
733     def all_files(self):
734         for s in self.all_streams():
735             for f in s.all_files():
736                 yield f
737     def manifest_text(self):
738         self._populate()
739         return self._manifest_text
740
741 class CollectionWriter(object):
742     KEEP_BLOCK_SIZE = 2**26
743     def __init__(self):
744         self._data_buffer = []
745         self._data_buffer_len = 0
746         self._current_stream_files = []
747         self._current_stream_length = 0
748         self._current_stream_locators = []
749         self._current_stream_name = '.'
750         self._current_file_name = None
751         self._current_file_pos = 0
752         self._finished_streams = []
753     def __enter__(self):
754         pass
755     def __exit__(self):
756         self.finish()
757     def write_directory_tree(self,
758                              path, stream_name='.', max_manifest_depth=-1):
759         self.start_new_stream(stream_name)
760         todo = []
761         if max_manifest_depth == 0:
762             dirents = sorted(util.listdir_recursive(path))
763         else:
764             dirents = sorted(os.listdir(path))
765         for dirent in dirents:
766             target = os.path.join(path, dirent)
767             if os.path.isdir(target):
768                 todo += [[target,
769                           os.path.join(stream_name, dirent),
770                           max_manifest_depth-1]]
771             else:
772                 self.start_new_file(dirent)
773                 with open(target, 'rb') as f:
774                     while True:
775                         buf = f.read(2**26)
776                         if len(buf) == 0:
777                             break
778                         self.write(buf)
779         self.finish_current_stream()
780         map(lambda x: self.write_directory_tree(*x), todo)
781
782     def write(self, newdata):
783         if hasattr(newdata, '__iter__'):
784             for s in newdata:
785                 self.write(s)
786             return
787         self._data_buffer += [newdata]
788         self._data_buffer_len += len(newdata)
789         self._current_stream_length += len(newdata)
790         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
791             self.flush_data()
792     def flush_data(self):
793         data_buffer = ''.join(self._data_buffer)
794         if data_buffer != '':
795             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
796             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
797             self._data_buffer_len = len(self._data_buffer[0])
798     def start_new_file(self, newfilename=None):
799         self.finish_current_file()
800         self.set_current_file_name(newfilename)
801     def set_current_file_name(self, newfilename):
802         if re.search(r'[\t\n]', newfilename):
803             raise errors.AssertionError(
804                 "Manifest filenames cannot contain whitespace: %s" %
805                 newfilename)
806         self._current_file_name = newfilename
807     def current_file_name(self):
808         return self._current_file_name
809     def finish_current_file(self):
810         if self._current_file_name == None:
811             if self._current_file_pos == self._current_stream_length:
812                 return
813             raise errors.AssertionError(
814                 "Cannot finish an unnamed file " +
815                 "(%d bytes at offset %d in '%s' stream)" %
816                 (self._current_stream_length - self._current_file_pos,
817                  self._current_file_pos,
818                  self._current_stream_name))
819         self._current_stream_files += [[self._current_file_pos,
820                                        self._current_stream_length - self._current_file_pos,
821                                        self._current_file_name]]
822         self._current_file_pos = self._current_stream_length
823     def start_new_stream(self, newstreamname='.'):
824         self.finish_current_stream()
825         self.set_current_stream_name(newstreamname)
826     def set_current_stream_name(self, newstreamname):
827         if re.search(r'[\t\n]', newstreamname):
828             raise errors.AssertionError(
829                 "Manifest stream names cannot contain whitespace")
830         self._current_stream_name = '.' if newstreamname=='' else newstreamname
831     def current_stream_name(self):
832         return self._current_stream_name
833     def finish_current_stream(self):
834         self.finish_current_file()
835         self.flush_data()
836         if len(self._current_stream_files) == 0:
837             pass
838         elif self._current_stream_name == None:
839             raise errors.AssertionError(
840                 "Cannot finish an unnamed stream (%d bytes in %d files)" %
841                 (self._current_stream_length, len(self._current_stream_files)))
842         else:
843             if len(self._current_stream_locators) == 0:
844                 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
845             self._finished_streams += [[self._current_stream_name,
846                                        self._current_stream_locators,
847                                        self._current_stream_files]]
848         self._current_stream_files = []
849         self._current_stream_length = 0
850         self._current_stream_locators = []
851         self._current_stream_name = None
852         self._current_file_pos = 0
853         self._current_file_name = None
854     def finish(self):
855         return Keep.put(self.manifest_text())
856     def manifest_text(self):
857         self.finish_current_stream()
858         manifest = ''
859         for stream in self._finished_streams:
860             if not re.search(r'^\.(/.*)?$', stream[0]):
861                 manifest += './'
862             manifest += stream[0].replace(' ', '\\040')
863             for locator in stream[1]:
864                 manifest += " %s" % locator
865             for sfile in stream[2]:
866                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2].replace(' ', '\\040'))
867             manifest += "\n"
868         return manifest
869     def data_locators(self):
870         ret = []
871         for name, locators, files in self._finished_streams:
872             ret += locators
873         return ret
874
875 global_client_object = None
876
877 class Keep:
878     @staticmethod
879     def global_client_object():
880         global global_client_object
881         if global_client_object == None:
882             global_client_object = KeepClient()
883         return global_client_object
884
885     @staticmethod
886     def get(locator, **kwargs):
887         return Keep.global_client_object().get(locator, **kwargs)
888
889     @staticmethod
890     def put(data, **kwargs):
891         return Keep.global_client_object().put(data, **kwargs)
892
893 class KeepClient(object):
894
895     class ThreadLimiter(object):
896         """
897         Limit the number of threads running at a given time to
898         {desired successes} minus {successes reported}. When successes
899         reported == desired, wake up the remaining threads and tell
900         them to quit.
901
902         Should be used in a "with" block.
903         """
904         def __init__(self, todo):
905             self._todo = todo
906             self._done = 0
907             self._todo_lock = threading.Semaphore(todo)
908             self._done_lock = threading.Lock()
909         def __enter__(self):
910             self._todo_lock.acquire()
911             return self
912         def __exit__(self, type, value, traceback):
913             self._todo_lock.release()
914         def shall_i_proceed(self):
915             """
916             Return true if the current thread should do stuff. Return
917             false if the current thread should just stop.
918             """
919             with self._done_lock:
920                 return (self._done < self._todo)
921         def increment_done(self):
922             """
923             Report that the current thread was successful.
924             """
925             with self._done_lock:
926                 self._done += 1
927         def done(self):
928             """
929             Return how many successes were reported.
930             """
931             with self._done_lock:
932                 return self._done
933
934     class KeepWriterThread(threading.Thread):
935         """
936         Write a blob of data to the given Keep server. Call
937         increment_done() of the given ThreadLimiter if the write
938         succeeds.
939         """
940         def __init__(self, **kwargs):
941             super(KeepClient.KeepWriterThread, self).__init__()
942             self.args = kwargs
943         def run(self):
944             global config
945             with self.args['thread_limiter'] as limiter:
946                 if not limiter.shall_i_proceed():
947                     # My turn arrived, but the job has been done without
948                     # me.
949                     return
950                 logging.debug("KeepWriterThread %s proceeding %s %s" %
951                               (str(threading.current_thread()),
952                                self.args['data_hash'],
953                                self.args['service_root']))
954                 h = httplib2.Http()
955                 url = self.args['service_root'] + self.args['data_hash']
956                 api_token = config['ARVADOS_API_TOKEN']
957                 headers = {'Authorization': "OAuth2 %s" % api_token}
958                 try:
959                     resp, content = h.request(url.encode('utf-8'), 'PUT',
960                                               headers=headers,
961                                               body=self.args['data'])
962                     if (resp['status'] == '401' and
963                         re.match(r'Timestamp verification failed', content)):
964                         body = KeepClient.sign_for_old_server(
965                             self.args['data_hash'],
966                             self.args['data'])
967                         h = httplib2.Http()
968                         resp, content = h.request(url.encode('utf-8'), 'PUT',
969                                                   headers=headers,
970                                                   body=body)
971                     if re.match(r'^2\d\d$', resp['status']):
972                         logging.debug("KeepWriterThread %s succeeded %s %s" %
973                                       (str(threading.current_thread()),
974                                        self.args['data_hash'],
975                                        self.args['service_root']))
976                         return limiter.increment_done()
977                     logging.warning("Request fail: PUT %s => %s %s" %
978                                     (url, resp['status'], content))
979                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
980                     logging.warning("Request fail: PUT %s => %s: %s" %
981                                     (url, type(e), str(e)))
982
983     def __init__(self):
984         self.lock = threading.Lock()
985         self.service_roots = None
986
987     def shuffled_service_roots(self, hash):
988         if self.service_roots == None:
989             self.lock.acquire()
990             keep_disks = api().keep_disks().list().execute()['items']
991             roots = (("http%s://%s:%d/" %
992                       ('s' if f['service_ssl_flag'] else '',
993                        f['service_host'],
994                        f['service_port']))
995                      for f in keep_disks)
996             self.service_roots = sorted(set(roots))
997             logging.debug(str(self.service_roots))
998             self.lock.release()
999         seed = hash
1000         pool = self.service_roots[:]
1001         pseq = []
1002         while len(pool) > 0:
1003             if len(seed) < 8:
1004                 if len(pseq) < len(hash) / 4: # first time around
1005                     seed = hash[-4:] + hash
1006                 else:
1007                     seed += hash
1008             probe = int(seed[0:8], 16) % len(pool)
1009             pseq += [pool[probe]]
1010             pool = pool[:probe] + pool[probe+1:]
1011             seed = seed[8:]
1012         logging.debug(str(pseq))
1013         return pseq
1014
1015     def get(self, locator):
1016         global config
1017         if re.search(r',', locator):
1018             return ''.join(self.get(x) for x in locator.split(','))
1019         if 'KEEP_LOCAL_STORE' in os.environ:
1020             return KeepClient.local_store_get(locator)
1021         expect_hash = re.sub(r'\+.*', '', locator)
1022         for service_root in self.shuffled_service_roots(expect_hash):
1023             h = httplib2.Http()
1024             url = service_root + expect_hash
1025             api_token = config['ARVADOS_API_TOKEN']
1026             headers = {'Authorization': "OAuth2 %s" % api_token,
1027                        'Accept': 'application/octet-stream'}
1028             try:
1029                 resp, content = h.request(url.encode('utf-8'), 'GET',
1030                                           headers=headers)
1031                 if re.match(r'^2\d\d$', resp['status']):
1032                     m = hashlib.new('md5')
1033                     m.update(content)
1034                     md5 = m.hexdigest()
1035                     if md5 == expect_hash:
1036                         return content
1037                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
1038             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
1039                 logging.info("Request fail: GET %s => %s: %s" %
1040                              (url, type(e), str(e)))
1041         raise errors.NotFoundError("Block not found: %s" % expect_hash)
1042
1043     def put(self, data, **kwargs):
1044         if 'KEEP_LOCAL_STORE' in os.environ:
1045             return KeepClient.local_store_put(data)
1046         m = hashlib.new('md5')
1047         m.update(data)
1048         data_hash = m.hexdigest()
1049         have_copies = 0
1050         want_copies = kwargs.get('copies', 2)
1051         if not (want_copies > 0):
1052             return data_hash
1053         threads = []
1054         thread_limiter = KeepClient.ThreadLimiter(want_copies)
1055         for service_root in self.shuffled_service_roots(data_hash):
1056             t = KeepClient.KeepWriterThread(data=data,
1057                                             data_hash=data_hash,
1058                                             service_root=service_root,
1059                                             thread_limiter=thread_limiter)
1060             t.start()
1061             threads += [t]
1062         for t in threads:
1063             t.join()
1064         have_copies = thread_limiter.done()
1065         if have_copies == want_copies:
1066             return (data_hash + '+' + str(len(data)))
1067         raise errors.KeepWriteError(
1068             "Write fail for %s: wanted %d but wrote %d" %
1069             (data_hash, want_copies, have_copies))
1070
1071     @staticmethod
1072     def sign_for_old_server(data_hash, data):
1073         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1074
1075
1076     @staticmethod
1077     def local_store_put(data):
1078         m = hashlib.new('md5')
1079         m.update(data)
1080         md5 = m.hexdigest()
1081         locator = '%s+%d' % (md5, len(data))
1082         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1083             f.write(data)
1084         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1085                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1086         return locator
1087     @staticmethod
1088     def local_store_get(locator):
1089         r = re.search('^([0-9a-f]{32,})', locator)
1090         if not r:
1091             raise errors.NotFoundError(
1092                 "Invalid data locator: '%s'" % locator)
1093         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1094             return ''
1095         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1096             return f.read()
1097
1098 # We really shouldn't do this but some clients still use
1099 # arvados.service.* directly instead of arvados.api().*
1100 service = api()