22 import apiclient.discovery
24 # Arvados configuration settings are taken from $HOME/.config/arvados.
25 # Environment variables override settings in the config file.
27 class ArvadosConfig(dict):
28 def __init__(self, config_file):
30 with open(config_file, "r") as f:
32 var, val = config_line.rstrip().split('=', 2)
34 for var in os.environ:
35 if var.startswith('ARVADOS_'):
36 self[var] = os.environ[var]
39 config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
41 if 'ARVADOS_DEBUG' in config:
42 logging.basicConfig(level=logging.DEBUG)
44 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
49 class SyntaxError(Exception):
51 class AssertionError(Exception):
53 class NotFoundError(Exception):
55 class CommandFailedError(Exception):
57 class KeepWriteError(Exception):
59 class NotImplementedError(Exception):
62 class CredentialsFromEnv(object):
64 def http_request(self, uri, **kwargs):
66 from httplib import BadStatusLine
67 if 'headers' not in kwargs:
68 kwargs['headers'] = {}
69 kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
71 return self.orig_http_request(uri, **kwargs)
73 # This is how httplib tells us that it tried to reuse an
74 # existing connection but it was already closed by the
75 # server. In that case, yes, we would like to retry.
76 # Unfortunately, we are not absolutely certain that the
77 # previous call did not succeed, so this is slightly
79 return self.orig_http_request(uri, **kwargs)
80 def authorize(self, http):
81 http.orig_http_request = http.request
82 http.request = types.MethodType(self.http_request, http)
85 def task_set_output(self,s):
86 api('v1').job_tasks().update(uuid=self['uuid'],
98 t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
99 t = UserDict.UserDict(t)
100 t.set_output = types.MethodType(task_set_output, t)
101 t.tmpdir = os.environ['TASK_WORK']
110 t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
111 t = UserDict.UserDict(t)
112 t.tmpdir = os.environ['JOB_WORK']
116 def getjobparam(*args):
117 return current_job()['script_parameters'].get(*args)
119 # Monkey patch discovery._cast() so objects and arrays get serialized
120 # with json.dumps() instead of str().
121 _cast_orig = apiclient.discovery._cast
122 def _cast_objects_too(value, schema_type):
124 if (type(value) != type('') and
125 (schema_type == 'object' or schema_type == 'array')):
126 return json.dumps(value)
128 return _cast_orig(value, schema_type)
129 apiclient.discovery._cast = _cast_objects_too
131 def api(version=None):
132 global services, config
133 if not services.get(version):
137 logging.info("Using default API version. " +
138 "Call arvados.api('%s') instead." %
140 if 'ARVADOS_API_HOST' not in config:
141 raise Exception("ARVADOS_API_HOST is not set. Aborting.")
142 url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
143 config['ARVADOS_API_HOST'])
144 credentials = CredentialsFromEnv()
146 # Use system's CA certificates (if we find them) instead of httplib2's
147 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
148 if not os.path.exists(ca_certs):
149 ca_certs = None # use httplib2 default
151 http = httplib2.Http(ca_certs=ca_certs)
152 http = credentials.authorize(http)
153 if re.match(r'(?i)^(true|1|yes)$',
154 config.get('ARVADOS_API_HOST_INSECURE', 'no')):
155 http.disable_ssl_certificate_validation=True
156 services[version] = apiclient.discovery.build(
157 'arvados', apiVersion, http=http, discoveryServiceUrl=url)
158 return services[version]
160 class JobTask(object):
161 def __init__(self, parameters=dict(), runtime_constraints=dict()):
162 print "init jobtask %s %s" % (parameters, runtime_constraints)
166 def one_task_per_input_file(if_sequence=0, and_end_task=True):
167 if if_sequence != current_task()['sequence']:
169 job_input = current_job()['script_parameters']['input']
170 cr = CollectionReader(job_input)
171 for s in cr.all_streams():
172 for f in s.all_files():
173 task_input = f.as_manifest()
175 'job_uuid': current_job()['uuid'],
176 'created_by_job_task_uuid': current_task()['uuid'],
177 'sequence': if_sequence + 1,
182 api('v1').job_tasks().create(body=new_task_attrs).execute()
184 api('v1').job_tasks().update(uuid=current_task()['uuid'],
185 body={'success':True}
190 def one_task_per_input_stream(if_sequence=0, and_end_task=True):
191 if if_sequence != current_task()['sequence']:
193 job_input = current_job()['script_parameters']['input']
194 cr = CollectionReader(job_input)
195 for s in cr.all_streams():
196 task_input = s.tokens()
198 'job_uuid': current_job()['uuid'],
199 'created_by_job_task_uuid': current_task()['uuid'],
200 'sequence': if_sequence + 1,
205 api('v1').job_tasks().create(body=new_task_attrs).execute()
207 api('v1').job_tasks().update(uuid=current_task()['uuid'],
208 body={'success':True}
214 def clear_tmpdir(path=None):
216 Ensure the given directory (or TASK_TMPDIR if none given)
220 path = current_task().tmpdir
221 if os.path.exists(path):
222 p = subprocess.Popen(['rm', '-rf', path])
223 stdout, stderr = p.communicate(None)
224 if p.returncode != 0:
225 raise Exception('rm -rf %s: %s' % (path, stderr))
229 def run_command(execargs, **kwargs):
230 kwargs.setdefault('stdin', subprocess.PIPE)
231 kwargs.setdefault('stdout', subprocess.PIPE)
232 kwargs.setdefault('stderr', sys.stderr)
233 kwargs.setdefault('close_fds', True)
234 kwargs.setdefault('shell', False)
235 p = subprocess.Popen(execargs, **kwargs)
236 stdoutdata, stderrdata = p.communicate(None)
237 if p.returncode != 0:
238 raise errors.CommandFailedError(
239 "run_command %s exit %d:\n%s" %
240 (execargs, p.returncode, stderrdata))
241 return stdoutdata, stderrdata
244 def git_checkout(url, version, path):
245 if not re.search('^/', path):
246 path = os.path.join(current_job().tmpdir, path)
247 if not os.path.exists(path):
248 util.run_command(["git", "clone", url, path],
249 cwd=os.path.dirname(path))
250 util.run_command(["git", "checkout", version],
255 def tar_extractor(path, decompress_flag):
256 return subprocess.Popen(["tar",
258 ("-x%sf" % decompress_flag),
261 stdin=subprocess.PIPE, stderr=sys.stderr,
262 shell=False, close_fds=True)
265 def tarball_extract(tarball, path):
266 """Retrieve a tarball from Keep and extract it to a local
267 directory. Return the absolute path where the tarball was
268 extracted. If the top level of the tarball contained just one
269 file or directory, return the absolute path of that single
272 tarball -- collection locator
273 path -- where to extract the tarball: absolute, or relative to job tmp
275 if not re.search('^/', path):
276 path = os.path.join(current_job().tmpdir, path)
277 lockfile = open(path + '.lock', 'w')
278 fcntl.flock(lockfile, fcntl.LOCK_EX)
283 already_have_it = False
285 if os.readlink(os.path.join(path, '.locator')) == tarball:
286 already_have_it = True
289 if not already_have_it:
291 # emulate "rm -f" (i.e., if the file does not exist, we win)
293 os.unlink(os.path.join(path, '.locator'))
295 if os.path.exists(os.path.join(path, '.locator')):
296 os.unlink(os.path.join(path, '.locator'))
298 for f in CollectionReader(tarball).all_files():
299 if re.search('\.(tbz|tar.bz2)$', f.name()):
300 p = util.tar_extractor(path, 'j')
301 elif re.search('\.(tgz|tar.gz)$', f.name()):
302 p = util.tar_extractor(path, 'z')
303 elif re.search('\.tar$', f.name()):
304 p = util.tar_extractor(path, '')
306 raise errors.AssertionError(
307 "tarball_extract cannot handle filename %s" % f.name())
315 if p.returncode != 0:
317 raise errors.CommandFailedError(
318 "tar exited %d" % p.returncode)
319 os.symlink(tarball, os.path.join(path, '.locator'))
320 tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
322 if len(tld_extracts) == 1:
323 return os.path.join(path, tld_extracts[0])
327 def zipball_extract(zipball, path):
328 """Retrieve a zip archive from Keep and extract it to a local
329 directory. Return the absolute path where the archive was
330 extracted. If the top level of the archive contained just one
331 file or directory, return the absolute path of that single
334 zipball -- collection locator
335 path -- where to extract the archive: absolute, or relative to job tmp
337 if not re.search('^/', path):
338 path = os.path.join(current_job().tmpdir, path)
339 lockfile = open(path + '.lock', 'w')
340 fcntl.flock(lockfile, fcntl.LOCK_EX)
345 already_have_it = False
347 if os.readlink(os.path.join(path, '.locator')) == zipball:
348 already_have_it = True
351 if not already_have_it:
353 # emulate "rm -f" (i.e., if the file does not exist, we win)
355 os.unlink(os.path.join(path, '.locator'))
357 if os.path.exists(os.path.join(path, '.locator')):
358 os.unlink(os.path.join(path, '.locator'))
360 for f in CollectionReader(zipball).all_files():
361 if not re.search('\.zip$', f.name()):
362 raise errors.NotImplementedError(
363 "zipball_extract cannot handle filename %s" % f.name())
364 zip_filename = os.path.join(path, os.path.basename(f.name()))
365 zip_file = open(zip_filename, 'wb')
373 p = subprocess.Popen(["unzip",
378 stdin=None, stderr=sys.stderr,
379 shell=False, close_fds=True)
381 if p.returncode != 0:
383 raise errors.CommandFailedError(
384 "unzip exited %d" % p.returncode)
385 os.unlink(zip_filename)
386 os.symlink(zipball, os.path.join(path, '.locator'))
387 tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
389 if len(tld_extracts) == 1:
390 return os.path.join(path, tld_extracts[0])
394 def collection_extract(collection, path, files=[], decompress=True):
395 """Retrieve a collection from Keep and extract it to a local
396 directory. Return the absolute path where the collection was
399 collection -- collection locator
400 path -- where to extract: absolute, or relative to job tmp
402 matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
404 collection_hash = matches.group(1)
406 collection_hash = hashlib.md5(collection).hexdigest()
407 if not re.search('^/', path):
408 path = os.path.join(current_job().tmpdir, path)
409 lockfile = open(path + '.lock', 'w')
410 fcntl.flock(lockfile, fcntl.LOCK_EX)
415 already_have_it = False
417 if os.readlink(os.path.join(path, '.locator')) == collection_hash:
418 already_have_it = True
422 # emulate "rm -f" (i.e., if the file does not exist, we win)
424 os.unlink(os.path.join(path, '.locator'))
426 if os.path.exists(os.path.join(path, '.locator')):
427 os.unlink(os.path.join(path, '.locator'))
430 for s in CollectionReader(collection).all_streams():
431 stream_name = s.name()
432 for f in s.all_files():
434 ((f.name() not in files_got) and
435 (f.name() in files or
436 (decompress and f.decompressed_name() in files)))):
437 outname = f.decompressed_name() if decompress else f.name()
438 files_got += [outname]
439 if os.path.exists(os.path.join(path, stream_name, outname)):
441 util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
442 outfile = open(os.path.join(path, stream_name, outname), 'wb')
443 for buf in (f.readall_decompressed() if decompress
447 if len(files_got) < len(files):
448 raise errors.AssertionError(
449 "Wanted files %s but only got %s from %s" %
451 [z.name() for z in CollectionReader(collection).all_files()]))
452 os.symlink(collection_hash, os.path.join(path, '.locator'))
458 def mkdir_dash_p(path):
459 if not os.path.exists(path):
460 util.mkdir_dash_p(os.path.dirname(path))
464 if not os.path.exists(path):
468 def stream_extract(stream, path, files=[], decompress=True):
469 """Retrieve a stream from Keep and extract it to a local
470 directory. Return the absolute path where the stream was
473 stream -- StreamReader object
474 path -- where to extract: absolute, or relative to job tmp
476 if not re.search('^/', path):
477 path = os.path.join(current_job().tmpdir, path)
478 lockfile = open(path + '.lock', 'w')
479 fcntl.flock(lockfile, fcntl.LOCK_EX)
486 for f in stream.all_files():
488 ((f.name() not in files_got) and
489 (f.name() in files or
490 (decompress and f.decompressed_name() in files)))):
491 outname = f.decompressed_name() if decompress else f.name()
492 files_got += [outname]
493 if os.path.exists(os.path.join(path, outname)):
494 os.unlink(os.path.join(path, outname))
495 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
496 outfile = open(os.path.join(path, outname), 'wb')
497 for buf in (f.readall_decompressed() if decompress
501 if len(files_got) < len(files):
502 raise errors.AssertionError(
503 "Wanted files %s but only got %s from %s" %
504 (files, files_got, [z.name() for z in stream.all_files()]))
509 def listdir_recursive(dirname, base=None):
511 for ent in sorted(os.listdir(dirname)):
512 ent_path = os.path.join(dirname, ent)
513 ent_base = os.path.join(base, ent) if base else ent
514 if os.path.isdir(ent_path):
515 allfiles += util.listdir_recursive(ent_path, ent_base)
517 allfiles += [ent_base]
520 class StreamFileReader(object):
521 def __init__(self, stream, pos, size, name):
522 self._stream = stream
529 def decompressed_name(self):
530 return re.sub('\.(bz2|gz)$', '', self._name)
533 def stream_name(self):
534 return self._stream.name()
535 def read(self, size, **kwargs):
536 self._stream.seek(self._pos + self._filepos)
537 data = self._stream.read(min(size, self._size - self._filepos))
538 self._filepos += len(data)
540 def readall(self, size=2**20, **kwargs):
542 data = self.read(size, **kwargs)
546 def bunzip2(self, size):
547 decompressor = bz2.BZ2Decompressor()
548 for chunk in self.readall(size):
549 data = decompressor.decompress(chunk)
550 if data and data != '':
552 def gunzip(self, size):
553 decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
554 for chunk in self.readall(size):
555 data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
556 if data and data != '':
558 def readall_decompressed(self, size=2**20):
559 self._stream.seek(self._pos + self._filepos)
560 if re.search('\.bz2$', self._name):
561 return self.bunzip2(size)
562 elif re.search('\.gz$', self._name):
563 return self.gunzip(size)
565 return self.readall(size)
566 def readlines(self, decompress=True):
568 datasource = self.readall_decompressed()
570 self._stream.seek(self._pos + self._filepos)
571 datasource = self.readall()
573 for newdata in datasource:
577 eol = string.find(data, "\n", sol)
580 yield data[sol:eol+1]
585 def as_manifest(self):
587 return ("%s %s 0:0:%s\n"
588 % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
589 return string.join(self._stream.tokens_for_range(self._pos, self._size),
592 class StreamReader(object):
593 def __init__(self, tokens):
594 self._tokens = tokens
595 self._current_datablock_data = None
596 self._current_datablock_pos = 0
597 self._current_datablock_index = -1
600 self._stream_name = None
601 self.data_locators = []
604 for tok in self._tokens:
605 if self._stream_name == None:
606 self._stream_name = tok
607 elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
608 self.data_locators += [tok]
609 elif re.search(r'^\d+:\d+:\S+', tok):
610 pos, size, name = tok.split(':',2)
611 self.files += [[int(pos), int(size), name]]
613 raise errors.SyntaxError("Invalid manifest format")
617 def tokens_for_range(self, range_start, range_size):
618 resp = [self._stream_name]
619 return_all_tokens = False
621 token_bytes_skipped = 0
622 for locator in self.data_locators:
623 sizehint = re.search(r'\+(\d+)', locator)
625 return_all_tokens = True
626 if return_all_tokens:
629 blocksize = int(sizehint.group(0))
630 if range_start + range_size <= block_start:
632 if range_start < block_start + blocksize:
635 token_bytes_skipped += blocksize
636 block_start += blocksize
638 if ((f[0] < range_start + range_size)
640 (f[0] + f[1] > range_start)
643 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
646 return self._stream_name
650 yield StreamFileReader(self, pos, size, name)
651 def nextdatablock(self):
652 if self._current_datablock_index < 0:
653 self._current_datablock_pos = 0
654 self._current_datablock_index = 0
656 self._current_datablock_pos += self.current_datablock_size()
657 self._current_datablock_index += 1
658 self._current_datablock_data = None
659 def current_datablock_data(self):
660 if self._current_datablock_data == None:
661 self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
662 return self._current_datablock_data
663 def current_datablock_size(self):
664 if self._current_datablock_index < 0:
666 sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
668 return int(sizehint.group(0))
669 return len(self.current_datablock_data())
671 """Set the position of the next read operation."""
673 def really_seek(self):
674 """Find and load the appropriate data block, so the byte at
677 if self._pos == self._current_datablock_pos:
679 if (self._current_datablock_pos != None and
680 self._pos >= self._current_datablock_pos and
681 self._pos <= self._current_datablock_pos + self.current_datablock_size()):
683 if self._pos < self._current_datablock_pos:
684 self._current_datablock_index = -1
686 while (self._pos > self._current_datablock_pos and
687 self._pos > self._current_datablock_pos + self.current_datablock_size()):
689 def read(self, size):
690 """Read no more than size bytes -- but at least one byte,
691 unless _pos is already at the end of the stream.
696 while self._pos >= self._current_datablock_pos + self.current_datablock_size():
698 if self._current_datablock_index >= len(self.data_locators):
700 data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
701 self._pos += len(data)
704 class CollectionReader(object):
705 def __init__(self, manifest_locator_or_text):
706 if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
707 self._manifest_text = manifest_locator_or_text
708 self._manifest_locator = None
710 self._manifest_locator = manifest_locator_or_text
711 self._manifest_text = None
718 if self._streams != None:
720 if not self._manifest_text:
721 self._manifest_text = Keep.get(self._manifest_locator)
723 for stream_line in self._manifest_text.split("\n"):
724 if stream_line != '':
725 stream_tokens = stream_line.split()
726 self._streams += [stream_tokens]
727 def all_streams(self):
730 for s in self._streams:
731 resp += [StreamReader(s)]
734 for s in self.all_streams():
735 for f in s.all_files():
737 def manifest_text(self):
739 return self._manifest_text
741 class CollectionWriter(object):
742 KEEP_BLOCK_SIZE = 2**26
744 self._data_buffer = []
745 self._data_buffer_len = 0
746 self._current_stream_files = []
747 self._current_stream_length = 0
748 self._current_stream_locators = []
749 self._current_stream_name = '.'
750 self._current_file_name = None
751 self._current_file_pos = 0
752 self._finished_streams = []
757 def write_directory_tree(self,
758 path, stream_name='.', max_manifest_depth=-1):
759 self.start_new_stream(stream_name)
761 if max_manifest_depth == 0:
762 dirents = sorted(util.listdir_recursive(path))
764 dirents = sorted(os.listdir(path))
765 for dirent in dirents:
766 target = os.path.join(path, dirent)
767 if os.path.isdir(target):
769 os.path.join(stream_name, dirent),
770 max_manifest_depth-1]]
772 self.start_new_file(dirent)
773 with open(target, 'rb') as f:
779 self.finish_current_stream()
780 map(lambda x: self.write_directory_tree(*x), todo)
782 def write(self, newdata):
783 if hasattr(newdata, '__iter__'):
787 self._data_buffer += [newdata]
788 self._data_buffer_len += len(newdata)
789 self._current_stream_length += len(newdata)
790 while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
792 def flush_data(self):
793 data_buffer = ''.join(self._data_buffer)
794 if data_buffer != '':
795 self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
796 self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
797 self._data_buffer_len = len(self._data_buffer[0])
798 def start_new_file(self, newfilename=None):
799 self.finish_current_file()
800 self.set_current_file_name(newfilename)
801 def set_current_file_name(self, newfilename):
802 newfilename = re.sub(r' ', '\\\\040', newfilename)
803 if re.search(r'[ \t\n]', newfilename):
804 raise errors.AssertionError(
805 "Manifest filenames cannot contain whitespace: %s" %
807 self._current_file_name = newfilename
808 def current_file_name(self):
809 return self._current_file_name
810 def finish_current_file(self):
811 if self._current_file_name == None:
812 if self._current_file_pos == self._current_stream_length:
814 raise errors.AssertionError(
815 "Cannot finish an unnamed file " +
816 "(%d bytes at offset %d in '%s' stream)" %
817 (self._current_stream_length - self._current_file_pos,
818 self._current_file_pos,
819 self._current_stream_name))
820 self._current_stream_files += [[self._current_file_pos,
821 self._current_stream_length - self._current_file_pos,
822 self._current_file_name]]
823 self._current_file_pos = self._current_stream_length
824 def start_new_stream(self, newstreamname='.'):
825 self.finish_current_stream()
826 self.set_current_stream_name(newstreamname)
827 def set_current_stream_name(self, newstreamname):
828 if re.search(r'[ \t\n]', newstreamname):
829 raise errors.AssertionError(
830 "Manifest stream names cannot contain whitespace")
831 self._current_stream_name = '.' if newstreamname=='' else newstreamname
832 def current_stream_name(self):
833 return self._current_stream_name
834 def finish_current_stream(self):
835 self.finish_current_file()
837 if len(self._current_stream_files) == 0:
839 elif self._current_stream_name == None:
840 raise errors.AssertionError(
841 "Cannot finish an unnamed stream (%d bytes in %d files)" %
842 (self._current_stream_length, len(self._current_stream_files)))
844 if len(self._current_stream_locators) == 0:
845 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
846 self._finished_streams += [[self._current_stream_name,
847 self._current_stream_locators,
848 self._current_stream_files]]
849 self._current_stream_files = []
850 self._current_stream_length = 0
851 self._current_stream_locators = []
852 self._current_stream_name = None
853 self._current_file_pos = 0
854 self._current_file_name = None
856 return Keep.put(self.manifest_text())
857 def manifest_text(self):
858 self.finish_current_stream()
860 for stream in self._finished_streams:
861 if not re.search(r'^\.(/.*)?$', stream[0]):
863 manifest += stream[0]
864 for locator in stream[1]:
865 manifest += " %s" % locator
866 for sfile in stream[2]:
867 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
870 def data_locators(self):
872 for name, locators, files in self._finished_streams:
876 global_client_object = None
880 def global_client_object():
881 global global_client_object
882 if global_client_object == None:
883 global_client_object = KeepClient()
884 return global_client_object
887 def get(locator, **kwargs):
888 return Keep.global_client_object().get(locator, **kwargs)
891 def put(data, **kwargs):
892 return Keep.global_client_object().put(data, **kwargs)
894 class KeepClient(object):
896 class ThreadLimiter(object):
898 Limit the number of threads running at a given time to
899 {desired successes} minus {successes reported}. When successes
900 reported == desired, wake up the remaining threads and tell
903 Should be used in a "with" block.
905 def __init__(self, todo):
908 self._todo_lock = threading.Semaphore(todo)
909 self._done_lock = threading.Lock()
911 self._todo_lock.acquire()
913 def __exit__(self, type, value, traceback):
914 self._todo_lock.release()
915 def shall_i_proceed(self):
917 Return true if the current thread should do stuff. Return
918 false if the current thread should just stop.
920 with self._done_lock:
921 return (self._done < self._todo)
922 def increment_done(self):
924 Report that the current thread was successful.
926 with self._done_lock:
930 Return how many successes were reported.
932 with self._done_lock:
935 class KeepWriterThread(threading.Thread):
937 Write a blob of data to the given Keep server. Call
938 increment_done() of the given ThreadLimiter if the write
941 def __init__(self, **kwargs):
942 super(KeepClient.KeepWriterThread, self).__init__()
946 with self.args['thread_limiter'] as limiter:
947 if not limiter.shall_i_proceed():
948 # My turn arrived, but the job has been done without
951 logging.debug("KeepWriterThread %s proceeding %s %s" %
952 (str(threading.current_thread()),
953 self.args['data_hash'],
954 self.args['service_root']))
956 url = self.args['service_root'] + self.args['data_hash']
957 api_token = config['ARVADOS_API_TOKEN']
958 headers = {'Authorization': "OAuth2 %s" % api_token}
960 resp, content = h.request(url.encode('utf-8'), 'PUT',
962 body=self.args['data'])
963 if (resp['status'] == '401' and
964 re.match(r'Timestamp verification failed', content)):
965 body = KeepClient.sign_for_old_server(
966 self.args['data_hash'],
969 resp, content = h.request(url.encode('utf-8'), 'PUT',
972 if re.match(r'^2\d\d$', resp['status']):
973 logging.debug("KeepWriterThread %s succeeded %s %s" %
974 (str(threading.current_thread()),
975 self.args['data_hash'],
976 self.args['service_root']))
977 return limiter.increment_done()
978 logging.warning("Request fail: PUT %s => %s %s" %
979 (url, resp['status'], content))
980 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
981 logging.warning("Request fail: PUT %s => %s: %s" %
982 (url, type(e), str(e)))
985 self.lock = threading.Lock()
986 self.service_roots = None
988 def shuffled_service_roots(self, hash):
989 if self.service_roots == None:
991 keep_disks = api().keep_disks().list().execute()['items']
992 roots = (("http%s://%s:%d/" %
993 ('s' if f['service_ssl_flag'] else '',
997 self.service_roots = sorted(set(roots))
998 logging.debug(str(self.service_roots))
1001 pool = self.service_roots[:]
1003 while len(pool) > 0:
1005 if len(pseq) < len(hash) / 4: # first time around
1006 seed = hash[-4:] + hash
1009 probe = int(seed[0:8], 16) % len(pool)
1010 pseq += [pool[probe]]
1011 pool = pool[:probe] + pool[probe+1:]
1013 logging.debug(str(pseq))
1016 def get(self, locator):
1018 if re.search(r',', locator):
1019 return ''.join(self.get(x) for x in locator.split(','))
1020 if 'KEEP_LOCAL_STORE' in os.environ:
1021 return KeepClient.local_store_get(locator)
1022 expect_hash = re.sub(r'\+.*', '', locator)
1023 for service_root in self.shuffled_service_roots(expect_hash):
1025 url = service_root + expect_hash
1026 api_token = config['ARVADOS_API_TOKEN']
1027 headers = {'Authorization': "OAuth2 %s" % api_token,
1028 'Accept': 'application/octet-stream'}
1030 resp, content = h.request(url.encode('utf-8'), 'GET',
1032 if re.match(r'^2\d\d$', resp['status']):
1033 m = hashlib.new('md5')
1036 if md5 == expect_hash:
1038 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
1039 except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
1040 logging.info("Request fail: GET %s => %s: %s" %
1041 (url, type(e), str(e)))
1042 raise errors.NotFoundError("Block not found: %s" % expect_hash)
1044 def put(self, data, **kwargs):
1045 if 'KEEP_LOCAL_STORE' in os.environ:
1046 return KeepClient.local_store_put(data)
1047 m = hashlib.new('md5')
1049 data_hash = m.hexdigest()
1051 want_copies = kwargs.get('copies', 2)
1052 if not (want_copies > 0):
1055 thread_limiter = KeepClient.ThreadLimiter(want_copies)
1056 for service_root in self.shuffled_service_roots(data_hash):
1057 t = KeepClient.KeepWriterThread(data=data,
1058 data_hash=data_hash,
1059 service_root=service_root,
1060 thread_limiter=thread_limiter)
1065 have_copies = thread_limiter.done()
1066 if have_copies == want_copies:
1067 return (data_hash + '+' + str(len(data)))
1068 raise errors.KeepWriteError(
1069 "Write fail for %s: wanted %d but wrote %d" %
1070 (data_hash, want_copies, have_copies))
1073 def sign_for_old_server(data_hash, data):
1074 return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1078 def local_store_put(data):
1079 m = hashlib.new('md5')
1082 locator = '%s+%d' % (md5, len(data))
1083 with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1085 os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1086 os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1089 def local_store_get(locator):
1090 r = re.search('^([0-9a-f]{32,})', locator)
1092 raise errors.NotFoundError(
1093 "Invalid data locator: '%s'" % locator)
1094 if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1096 with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1099 # We really shouldn't do this but some clients still use
1100 # arvados.service.* directly instead of arvados.api().*