22 import apiclient.discovery
24 # Arvados configuration settings are taken from $HOME/.config/arvados.
25 # Environment variables override settings in the config file.
27 class ArvadosConfig(dict):
28 def __init__(self, config_file):
30 with open(config_file, "r") as f:
32 var, val = config_line.rstrip().split('=', 2)
34 for var in os.environ:
35 if var.startswith('ARVADOS_'):
36 self[var] = os.environ[var]
39 config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
41 if 'ARVADOS_DEBUG' in config:
42 logging.basicConfig(level=logging.DEBUG)
44 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
49 class SyntaxError(Exception):
51 class AssertionError(Exception):
53 class NotFoundError(Exception):
55 class CommandFailedError(Exception):
57 class KeepWriteError(Exception):
59 class NotImplementedError(Exception):
62 class CredentialsFromEnv(object):
64 def http_request(self, uri, **kwargs):
66 from httplib import BadStatusLine
67 if 'headers' not in kwargs:
68 kwargs['headers'] = {}
69 kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
71 return self.orig_http_request(uri, **kwargs)
73 # This is how httplib tells us that it tried to reuse an
74 # existing connection but it was already closed by the
75 # server. In that case, yes, we would like to retry.
76 # Unfortunately, we are not absolutely certain that the
77 # previous call did not succeed, so this is slightly
79 return self.orig_http_request(uri, **kwargs)
80 def authorize(self, http):
81 http.orig_http_request = http.request
82 http.request = types.MethodType(self.http_request, http)
85 def task_set_output(self,s):
86 api('v1').job_tasks().update(uuid=self['uuid'],
98 t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
99 t = UserDict.UserDict(t)
100 t.set_output = types.MethodType(task_set_output, t)
101 t.tmpdir = os.environ['TASK_WORK']
110 t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
111 t = UserDict.UserDict(t)
112 t.tmpdir = os.environ['JOB_WORK']
116 def getjobparam(*args):
117 return current_job()['script_parameters'].get(*args)
119 # Monkey patch discovery._cast() so objects and arrays get serialized
120 # with json.dumps() instead of str().
121 _cast_orig = apiclient.discovery._cast
122 def _cast_objects_too(value, schema_type):
124 if (type(value) != type('') and
125 (schema_type == 'object' or schema_type == 'array')):
126 return json.dumps(value)
128 return _cast_orig(value, schema_type)
129 apiclient.discovery._cast = _cast_objects_too
131 def api(version=None):
132 global services, config
133 if not services.get(version):
137 logging.info("Using default API version. " +
138 "Call arvados.api('%s') instead." %
140 if 'ARVADOS_API_HOST' not in config:
141 raise Exception("ARVADOS_API_HOST is not set. Aborting.")
142 url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
143 config['ARVADOS_API_HOST'])
144 credentials = CredentialsFromEnv()
146 # Use system's CA certificates (if we find them) instead of httplib2's
147 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
148 if not os.path.exists(ca_certs):
149 ca_certs = None # use httplib2 default
151 http = httplib2.Http(ca_certs=ca_certs)
152 http = credentials.authorize(http)
153 if re.match(r'(?i)^(true|1|yes)$',
154 config.get('ARVADOS_API_HOST_INSECURE', 'no')):
155 http.disable_ssl_certificate_validation=True
156 services[version] = apiclient.discovery.build(
157 'arvados', apiVersion, http=http, discoveryServiceUrl=url)
158 return services[version]
160 class JobTask(object):
161 def __init__(self, parameters=dict(), runtime_constraints=dict()):
162 print "init jobtask %s %s" % (parameters, runtime_constraints)
166 def one_task_per_input_file(if_sequence=0, and_end_task=True):
167 if if_sequence != current_task()['sequence']:
169 job_input = current_job()['script_parameters']['input']
170 cr = CollectionReader(job_input)
171 for s in cr.all_streams():
172 for f in s.all_files():
173 task_input = f.as_manifest()
175 'job_uuid': current_job()['uuid'],
176 'created_by_job_task_uuid': current_task()['uuid'],
177 'sequence': if_sequence + 1,
182 api('v1').job_tasks().create(body=new_task_attrs).execute()
184 api('v1').job_tasks().update(uuid=current_task()['uuid'],
185 body={'success':True}
190 def one_task_per_input_stream(if_sequence=0, and_end_task=True):
191 if if_sequence != current_task()['sequence']:
193 job_input = current_job()['script_parameters']['input']
194 cr = CollectionReader(job_input)
195 for s in cr.all_streams():
196 task_input = s.tokens()
198 'job_uuid': current_job()['uuid'],
199 'created_by_job_task_uuid': current_task()['uuid'],
200 'sequence': if_sequence + 1,
205 api('v1').job_tasks().create(body=new_task_attrs).execute()
207 api('v1').job_tasks().update(uuid=current_task()['uuid'],
208 body={'success':True}
214 def clear_tmpdir(path=None):
216 Ensure the given directory (or TASK_TMPDIR if none given)
220 path = current_task().tmpdir
221 if os.path.exists(path):
222 p = subprocess.Popen(['rm', '-rf', path])
223 stdout, stderr = p.communicate(None)
224 if p.returncode != 0:
225 raise Exception('rm -rf %s: %s' % (path, stderr))
229 def run_command(execargs, **kwargs):
230 kwargs.setdefault('stdin', subprocess.PIPE)
231 kwargs.setdefault('stdout', subprocess.PIPE)
232 kwargs.setdefault('stderr', sys.stderr)
233 kwargs.setdefault('close_fds', True)
234 kwargs.setdefault('shell', False)
235 p = subprocess.Popen(execargs, **kwargs)
236 stdoutdata, stderrdata = p.communicate(None)
237 if p.returncode != 0:
238 raise errors.CommandFailedError(
239 "run_command %s exit %d:\n%s" %
240 (execargs, p.returncode, stderrdata))
241 return stdoutdata, stderrdata
244 def git_checkout(url, version, path):
245 if not re.search('^/', path):
246 path = os.path.join(current_job().tmpdir, path)
247 if not os.path.exists(path):
248 util.run_command(["git", "clone", url, path],
249 cwd=os.path.dirname(path))
250 util.run_command(["git", "checkout", version],
255 def tar_extractor(path, decompress_flag):
256 return subprocess.Popen(["tar",
258 ("-x%sf" % decompress_flag),
261 stdin=subprocess.PIPE, stderr=sys.stderr,
262 shell=False, close_fds=True)
265 def tarball_extract(tarball, path):
266 """Retrieve a tarball from Keep and extract it to a local
267 directory. Return the absolute path where the tarball was
268 extracted. If the top level of the tarball contained just one
269 file or directory, return the absolute path of that single
272 tarball -- collection locator
273 path -- where to extract the tarball: absolute, or relative to job tmp
275 if not re.search('^/', path):
276 path = os.path.join(current_job().tmpdir, path)
277 lockfile = open(path + '.lock', 'w')
278 fcntl.flock(lockfile, fcntl.LOCK_EX)
283 already_have_it = False
285 if os.readlink(os.path.join(path, '.locator')) == tarball:
286 already_have_it = True
289 if not already_have_it:
291 # emulate "rm -f" (i.e., if the file does not exist, we win)
293 os.unlink(os.path.join(path, '.locator'))
295 if os.path.exists(os.path.join(path, '.locator')):
296 os.unlink(os.path.join(path, '.locator'))
298 for f in CollectionReader(tarball).all_files():
299 if re.search('\.(tbz|tar.bz2)$', f.name()):
300 p = util.tar_extractor(path, 'j')
301 elif re.search('\.(tgz|tar.gz)$', f.name()):
302 p = util.tar_extractor(path, 'z')
303 elif re.search('\.tar$', f.name()):
304 p = util.tar_extractor(path, '')
306 raise errors.AssertionError(
307 "tarball_extract cannot handle filename %s" % f.name())
315 if p.returncode != 0:
317 raise errors.CommandFailedError(
318 "tar exited %d" % p.returncode)
319 os.symlink(tarball, os.path.join(path, '.locator'))
320 tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
322 if len(tld_extracts) == 1:
323 return os.path.join(path, tld_extracts[0])
327 def zipball_extract(zipball, path):
328 """Retrieve a zip archive from Keep and extract it to a local
329 directory. Return the absolute path where the archive was
330 extracted. If the top level of the archive contained just one
331 file or directory, return the absolute path of that single
334 zipball -- collection locator
335 path -- where to extract the archive: absolute, or relative to job tmp
337 if not re.search('^/', path):
338 path = os.path.join(current_job().tmpdir, path)
339 lockfile = open(path + '.lock', 'w')
340 fcntl.flock(lockfile, fcntl.LOCK_EX)
345 already_have_it = False
347 if os.readlink(os.path.join(path, '.locator')) == zipball:
348 already_have_it = True
351 if not already_have_it:
353 # emulate "rm -f" (i.e., if the file does not exist, we win)
355 os.unlink(os.path.join(path, '.locator'))
357 if os.path.exists(os.path.join(path, '.locator')):
358 os.unlink(os.path.join(path, '.locator'))
360 for f in CollectionReader(zipball).all_files():
361 if not re.search('\.zip$', f.name()):
362 raise errors.NotImplementedError(
363 "zipball_extract cannot handle filename %s" % f.name())
364 zip_filename = os.path.join(path, os.path.basename(f.name()))
365 zip_file = open(zip_filename, 'wb')
373 p = subprocess.Popen(["unzip",
378 stdin=None, stderr=sys.stderr,
379 shell=False, close_fds=True)
381 if p.returncode != 0:
383 raise errors.CommandFailedError(
384 "unzip exited %d" % p.returncode)
385 os.unlink(zip_filename)
386 os.symlink(zipball, os.path.join(path, '.locator'))
387 tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
389 if len(tld_extracts) == 1:
390 return os.path.join(path, tld_extracts[0])
394 def collection_extract(collection, path, files=[], decompress=True):
395 """Retrieve a collection from Keep and extract it to a local
396 directory. Return the absolute path where the collection was
399 collection -- collection locator
400 path -- where to extract: absolute, or relative to job tmp
402 matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
404 collection_hash = matches.group(1)
406 collection_hash = hashlib.md5(collection).hexdigest()
407 if not re.search('^/', path):
408 path = os.path.join(current_job().tmpdir, path)
409 lockfile = open(path + '.lock', 'w')
410 fcntl.flock(lockfile, fcntl.LOCK_EX)
415 already_have_it = False
417 if os.readlink(os.path.join(path, '.locator')) == collection_hash:
418 already_have_it = True
422 # emulate "rm -f" (i.e., if the file does not exist, we win)
424 os.unlink(os.path.join(path, '.locator'))
426 if os.path.exists(os.path.join(path, '.locator')):
427 os.unlink(os.path.join(path, '.locator'))
430 for s in CollectionReader(collection).all_streams():
431 stream_name = s.name()
432 for f in s.all_files():
434 ((f.name() not in files_got) and
435 (f.name() in files or
436 (decompress and f.decompressed_name() in files)))):
437 outname = f.decompressed_name() if decompress else f.name()
438 files_got += [outname]
439 if os.path.exists(os.path.join(path, stream_name, outname)):
441 util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
442 outfile = open(os.path.join(path, stream_name, outname), 'wb')
443 for buf in (f.readall_decompressed() if decompress
447 if len(files_got) < len(files):
448 raise errors.AssertionError(
449 "Wanted files %s but only got %s from %s" %
451 [z.name() for z in CollectionReader(collection).all_files()]))
452 os.symlink(collection_hash, os.path.join(path, '.locator'))
458 def mkdir_dash_p(path):
459 if not os.path.exists(path):
460 util.mkdir_dash_p(os.path.dirname(path))
464 if not os.path.exists(path):
468 def stream_extract(stream, path, files=[], decompress=True):
469 """Retrieve a stream from Keep and extract it to a local
470 directory. Return the absolute path where the stream was
473 stream -- StreamReader object
474 path -- where to extract: absolute, or relative to job tmp
476 if not re.search('^/', path):
477 path = os.path.join(current_job().tmpdir, path)
478 lockfile = open(path + '.lock', 'w')
479 fcntl.flock(lockfile, fcntl.LOCK_EX)
486 for f in stream.all_files():
488 ((f.name() not in files_got) and
489 (f.name() in files or
490 (decompress and f.decompressed_name() in files)))):
491 outname = f.decompressed_name() if decompress else f.name()
492 files_got += [outname]
493 if os.path.exists(os.path.join(path, outname)):
494 os.unlink(os.path.join(path, outname))
495 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
496 outfile = open(os.path.join(path, outname), 'wb')
497 for buf in (f.readall_decompressed() if decompress
501 if len(files_got) < len(files):
502 raise errors.AssertionError(
503 "Wanted files %s but only got %s from %s" %
504 (files, files_got, [z.name() for z in stream.all_files()]))
509 def listdir_recursive(dirname, base=None):
511 for ent in sorted(os.listdir(dirname)):
512 ent_path = os.path.join(dirname, ent)
513 ent_base = os.path.join(base, ent) if base else ent
514 if os.path.isdir(ent_path):
515 allfiles += util.listdir_recursive(ent_path, ent_base)
517 allfiles += [ent_base]
520 class StreamFileReader(object):
521 def __init__(self, stream, pos, size, name):
522 self._stream = stream
529 def decompressed_name(self):
530 return re.sub('\.(bz2|gz)$', '', self._name)
533 def stream_name(self):
534 return self._stream.name()
535 def read(self, size, **kwargs):
536 self._stream.seek(self._pos + self._filepos)
537 data = self._stream.read(min(size, self._size - self._filepos))
538 self._filepos += len(data)
540 def readall(self, size=2**20, **kwargs):
542 data = self.read(size, **kwargs)
546 def bunzip2(self, size):
547 decompressor = bz2.BZ2Decompressor()
548 for chunk in self.readall(size):
549 data = decompressor.decompress(chunk)
550 if data and data != '':
552 def gunzip(self, size):
553 decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
554 for chunk in self.readall(size):
555 data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
556 if data and data != '':
558 def readall_decompressed(self, size=2**20):
559 self._stream.seek(self._pos + self._filepos)
560 if re.search('\.bz2$', self._name):
561 return self.bunzip2(size)
562 elif re.search('\.gz$', self._name):
563 return self.gunzip(size)
565 return self.readall(size)
566 def readlines(self, decompress=True):
568 datasource = self.readall_decompressed()
570 self._stream.seek(self._pos + self._filepos)
571 datasource = self.readall()
573 for newdata in datasource:
577 eol = string.find(data, "\n", sol)
580 yield data[sol:eol+1]
585 def as_manifest(self):
587 return ("%s %s 0:0:%s\n"
588 % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
589 return string.join(self._stream.tokens_for_range(self._pos, self._size),
592 class StreamReader(object):
593 def __init__(self, tokens):
594 self._tokens = tokens
595 self._current_datablock_data = None
596 self._current_datablock_pos = 0
597 self._current_datablock_index = -1
600 self._stream_name = None
601 self.data_locators = []
604 for tok in self._tokens:
605 if self._stream_name == None:
606 self._stream_name = tok.replace('\\040', ' ')
607 elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
608 self.data_locators += [tok]
609 elif re.search(r'^\d+:\d+:\S+', tok):
610 pos, size, name = tok.split(':',2)
611 self.files += [[int(pos), int(size), name.replace('\\040', ' ')]]
613 raise errors.SyntaxError("Invalid manifest format")
617 def tokens_for_range(self, range_start, range_size):
618 resp = [self._stream_name]
619 return_all_tokens = False
621 token_bytes_skipped = 0
622 for locator in self.data_locators:
623 sizehint = re.search(r'\+(\d+)', locator)
625 return_all_tokens = True
626 if return_all_tokens:
629 blocksize = int(sizehint.group(0))
630 if range_start + range_size <= block_start:
632 if range_start < block_start + blocksize:
635 token_bytes_skipped += blocksize
636 block_start += blocksize
638 if ((f[0] < range_start + range_size)
640 (f[0] + f[1] > range_start)
643 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
646 return self._stream_name
650 yield StreamFileReader(self, pos, size, name)
651 def nextdatablock(self):
652 if self._current_datablock_index < 0:
653 self._current_datablock_pos = 0
654 self._current_datablock_index = 0
656 self._current_datablock_pos += self.current_datablock_size()
657 self._current_datablock_index += 1
658 self._current_datablock_data = None
659 def current_datablock_data(self):
660 if self._current_datablock_data == None:
661 self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
662 return self._current_datablock_data
663 def current_datablock_size(self):
664 if self._current_datablock_index < 0:
666 sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
668 return int(sizehint.group(0))
669 return len(self.current_datablock_data())
671 """Set the position of the next read operation."""
673 def really_seek(self):
674 """Find and load the appropriate data block, so the byte at
677 if self._pos == self._current_datablock_pos:
679 if (self._current_datablock_pos != None and
680 self._pos >= self._current_datablock_pos and
681 self._pos <= self._current_datablock_pos + self.current_datablock_size()):
683 if self._pos < self._current_datablock_pos:
684 self._current_datablock_index = -1
686 while (self._pos > self._current_datablock_pos and
687 self._pos > self._current_datablock_pos + self.current_datablock_size()):
689 def read(self, size):
690 """Read no more than size bytes -- but at least one byte,
691 unless _pos is already at the end of the stream.
696 while self._pos >= self._current_datablock_pos + self.current_datablock_size():
698 if self._current_datablock_index >= len(self.data_locators):
700 data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
701 self._pos += len(data)
704 class CollectionReader(object):
705 def __init__(self, manifest_locator_or_text):
706 if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
707 self._manifest_text = manifest_locator_or_text
708 self._manifest_locator = None
710 self._manifest_locator = manifest_locator_or_text
711 self._manifest_text = None
718 if self._streams != None:
720 if not self._manifest_text:
721 self._manifest_text = Keep.get(self._manifest_locator)
723 for stream_line in self._manifest_text.split("\n"):
724 if stream_line != '':
725 stream_tokens = stream_line.split()
726 self._streams += [stream_tokens]
727 def all_streams(self):
730 for s in self._streams:
731 resp += [StreamReader(s)]
734 for s in self.all_streams():
735 for f in s.all_files():
737 def manifest_text(self):
739 return self._manifest_text
741 class CollectionWriter(object):
742 KEEP_BLOCK_SIZE = 2**26
744 self._data_buffer = []
745 self._data_buffer_len = 0
746 self._current_stream_files = []
747 self._current_stream_length = 0
748 self._current_stream_locators = []
749 self._current_stream_name = '.'
750 self._current_file_name = None
751 self._current_file_pos = 0
752 self._finished_streams = []
757 def write_directory_tree(self,
758 path, stream_name='.', max_manifest_depth=-1):
759 self.start_new_stream(stream_name)
761 if max_manifest_depth == 0:
762 dirents = sorted(util.listdir_recursive(path))
764 dirents = sorted(os.listdir(path))
765 for dirent in dirents:
766 target = os.path.join(path, dirent)
767 if os.path.isdir(target):
769 os.path.join(stream_name, dirent),
770 max_manifest_depth-1]]
772 self.start_new_file(dirent)
773 with open(target, 'rb') as f:
779 self.finish_current_stream()
780 map(lambda x: self.write_directory_tree(*x), todo)
782 def write(self, newdata):
783 if hasattr(newdata, '__iter__'):
787 self._data_buffer += [newdata]
788 self._data_buffer_len += len(newdata)
789 self._current_stream_length += len(newdata)
790 while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
792 def flush_data(self):
793 data_buffer = ''.join(self._data_buffer)
794 if data_buffer != '':
795 self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
796 self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
797 self._data_buffer_len = len(self._data_buffer[0])
798 def start_new_file(self, newfilename=None):
799 self.finish_current_file()
800 self.set_current_file_name(newfilename)
801 def set_current_file_name(self, newfilename):
802 if re.search(r'[\t\n]', newfilename):
803 raise errors.AssertionError(
804 "Manifest filenames cannot contain whitespace: %s" %
806 self._current_file_name = newfilename
807 def current_file_name(self):
808 return self._current_file_name
809 def finish_current_file(self):
810 if self._current_file_name == None:
811 if self._current_file_pos == self._current_stream_length:
813 raise errors.AssertionError(
814 "Cannot finish an unnamed file " +
815 "(%d bytes at offset %d in '%s' stream)" %
816 (self._current_stream_length - self._current_file_pos,
817 self._current_file_pos,
818 self._current_stream_name))
819 self._current_stream_files += [[self._current_file_pos,
820 self._current_stream_length - self._current_file_pos,
821 self._current_file_name]]
822 self._current_file_pos = self._current_stream_length
823 def start_new_stream(self, newstreamname='.'):
824 self.finish_current_stream()
825 self.set_current_stream_name(newstreamname)
826 def set_current_stream_name(self, newstreamname):
827 if re.search(r'[\t\n]', newstreamname):
828 raise errors.AssertionError(
829 "Manifest stream names cannot contain whitespace")
830 self._current_stream_name = '.' if newstreamname=='' else newstreamname
831 def current_stream_name(self):
832 return self._current_stream_name
833 def finish_current_stream(self):
834 self.finish_current_file()
836 if len(self._current_stream_files) == 0:
838 elif self._current_stream_name == None:
839 raise errors.AssertionError(
840 "Cannot finish an unnamed stream (%d bytes in %d files)" %
841 (self._current_stream_length, len(self._current_stream_files)))
843 if len(self._current_stream_locators) == 0:
844 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
845 self._finished_streams += [[self._current_stream_name,
846 self._current_stream_locators,
847 self._current_stream_files]]
848 self._current_stream_files = []
849 self._current_stream_length = 0
850 self._current_stream_locators = []
851 self._current_stream_name = None
852 self._current_file_pos = 0
853 self._current_file_name = None
855 return Keep.put(self.manifest_text())
856 def manifest_text(self):
857 self.finish_current_stream()
859 for stream in self._finished_streams:
860 if not re.search(r'^\.(/.*)?$', stream[0]):
862 manifest += stream[0].replace(' ', '\\040')
863 for locator in stream[1]:
864 manifest += " %s" % locator
865 for sfile in stream[2]:
866 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2].replace(' ', '\\040'))
869 def data_locators(self):
871 for name, locators, files in self._finished_streams:
875 global_client_object = None
879 def global_client_object():
880 global global_client_object
881 if global_client_object == None:
882 global_client_object = KeepClient()
883 return global_client_object
886 def get(locator, **kwargs):
887 return Keep.global_client_object().get(locator, **kwargs)
890 def put(data, **kwargs):
891 return Keep.global_client_object().put(data, **kwargs)
893 class KeepClient(object):
895 class ThreadLimiter(object):
897 Limit the number of threads running at a given time to
898 {desired successes} minus {successes reported}. When successes
899 reported == desired, wake up the remaining threads and tell
902 Should be used in a "with" block.
904 def __init__(self, todo):
907 self._todo_lock = threading.Semaphore(todo)
908 self._done_lock = threading.Lock()
910 self._todo_lock.acquire()
912 def __exit__(self, type, value, traceback):
913 self._todo_lock.release()
914 def shall_i_proceed(self):
916 Return true if the current thread should do stuff. Return
917 false if the current thread should just stop.
919 with self._done_lock:
920 return (self._done < self._todo)
921 def increment_done(self):
923 Report that the current thread was successful.
925 with self._done_lock:
929 Return how many successes were reported.
931 with self._done_lock:
934 class KeepWriterThread(threading.Thread):
936 Write a blob of data to the given Keep server. Call
937 increment_done() of the given ThreadLimiter if the write
940 def __init__(self, **kwargs):
941 super(KeepClient.KeepWriterThread, self).__init__()
945 with self.args['thread_limiter'] as limiter:
946 if not limiter.shall_i_proceed():
947 # My turn arrived, but the job has been done without
950 logging.debug("KeepWriterThread %s proceeding %s %s" %
951 (str(threading.current_thread()),
952 self.args['data_hash'],
953 self.args['service_root']))
955 url = self.args['service_root'] + self.args['data_hash']
956 api_token = config['ARVADOS_API_TOKEN']
957 headers = {'Authorization': "OAuth2 %s" % api_token}
959 resp, content = h.request(url.encode('utf-8'), 'PUT',
961 body=self.args['data'])
962 if (resp['status'] == '401' and
963 re.match(r'Timestamp verification failed', content)):
964 body = KeepClient.sign_for_old_server(
965 self.args['data_hash'],
968 resp, content = h.request(url.encode('utf-8'), 'PUT',
971 if re.match(r'^2\d\d$', resp['status']):
972 logging.debug("KeepWriterThread %s succeeded %s %s" %
973 (str(threading.current_thread()),
974 self.args['data_hash'],
975 self.args['service_root']))
976 return limiter.increment_done()
977 logging.warning("Request fail: PUT %s => %s %s" %
978 (url, resp['status'], content))
979 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
980 logging.warning("Request fail: PUT %s => %s: %s" %
981 (url, type(e), str(e)))
984 self.lock = threading.Lock()
985 self.service_roots = None
987 def shuffled_service_roots(self, hash):
988 if self.service_roots == None:
990 keep_disks = api().keep_disks().list().execute()['items']
991 roots = (("http%s://%s:%d/" %
992 ('s' if f['service_ssl_flag'] else '',
996 self.service_roots = sorted(set(roots))
997 logging.debug(str(self.service_roots))
1000 pool = self.service_roots[:]
1002 while len(pool) > 0:
1004 if len(pseq) < len(hash) / 4: # first time around
1005 seed = hash[-4:] + hash
1008 probe = int(seed[0:8], 16) % len(pool)
1009 pseq += [pool[probe]]
1010 pool = pool[:probe] + pool[probe+1:]
1012 logging.debug(str(pseq))
1015 def get(self, locator):
1017 if re.search(r',', locator):
1018 return ''.join(self.get(x) for x in locator.split(','))
1019 if 'KEEP_LOCAL_STORE' in os.environ:
1020 return KeepClient.local_store_get(locator)
1021 expect_hash = re.sub(r'\+.*', '', locator)
1022 for service_root in self.shuffled_service_roots(expect_hash):
1024 url = service_root + expect_hash
1025 api_token = config['ARVADOS_API_TOKEN']
1026 headers = {'Authorization': "OAuth2 %s" % api_token,
1027 'Accept': 'application/octet-stream'}
1029 resp, content = h.request(url.encode('utf-8'), 'GET',
1031 if re.match(r'^2\d\d$', resp['status']):
1032 m = hashlib.new('md5')
1035 if md5 == expect_hash:
1037 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
1038 except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
1039 logging.info("Request fail: GET %s => %s: %s" %
1040 (url, type(e), str(e)))
1041 raise errors.NotFoundError("Block not found: %s" % expect_hash)
1043 def put(self, data, **kwargs):
1044 if 'KEEP_LOCAL_STORE' in os.environ:
1045 return KeepClient.local_store_put(data)
1046 m = hashlib.new('md5')
1048 data_hash = m.hexdigest()
1050 want_copies = kwargs.get('copies', 2)
1051 if not (want_copies > 0):
1054 thread_limiter = KeepClient.ThreadLimiter(want_copies)
1055 for service_root in self.shuffled_service_roots(data_hash):
1056 t = KeepClient.KeepWriterThread(data=data,
1057 data_hash=data_hash,
1058 service_root=service_root,
1059 thread_limiter=thread_limiter)
1064 have_copies = thread_limiter.done()
1065 if have_copies == want_copies:
1066 return (data_hash + '+' + str(len(data)))
1067 raise errors.KeepWriteError(
1068 "Write fail for %s: wanted %d but wrote %d" %
1069 (data_hash, want_copies, have_copies))
1072 def sign_for_old_server(data_hash, data):
1073 return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1077 def local_store_put(data):
1078 m = hashlib.new('md5')
1081 locator = '%s+%d' % (md5, len(data))
1082 with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1084 os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1085 os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1088 def local_store_get(locator):
1089 r = re.search('^([0-9a-f]{32,})', locator)
1091 raise errors.NotFoundError(
1092 "Invalid data locator: '%s'" % locator)
1093 if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1095 with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1098 # We really shouldn't do this but some clients still use
1099 # arvados.service.* directly instead of arvados.api().*