22 import apiclient.discovery
24 # Arvados configuration settings are taken from $HOME/.config/arvados.
25 # Environment variables override settings in the config file.
27 class ArvadosConfig(dict):
28 def __init__(self, config_file):
30 if os.path.exists(config_file):
31 with open(config_file, "r") as f:
33 var, val = config_line.rstrip().split('=', 2)
35 for var in os.environ:
36 if var.startswith('ARVADOS_'):
37 self[var] = os.environ[var]
40 config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
42 if 'ARVADOS_DEBUG' in config:
43 logging.basicConfig(level=logging.DEBUG)
45 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
50 class SyntaxError(Exception):
52 class AssertionError(Exception):
54 class NotFoundError(Exception):
56 class CommandFailedError(Exception):
58 class KeepWriteError(Exception):
60 class NotImplementedError(Exception):
63 class CredentialsFromEnv(object):
65 def http_request(self, uri, **kwargs):
67 from httplib import BadStatusLine
68 if 'headers' not in kwargs:
69 kwargs['headers'] = {}
70 kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
72 return self.orig_http_request(uri, **kwargs)
74 # This is how httplib tells us that it tried to reuse an
75 # existing connection but it was already closed by the
76 # server. In that case, yes, we would like to retry.
77 # Unfortunately, we are not absolutely certain that the
78 # previous call did not succeed, so this is slightly
80 return self.orig_http_request(uri, **kwargs)
81 def authorize(self, http):
82 http.orig_http_request = http.request
83 http.request = types.MethodType(self.http_request, http)
86 def task_set_output(self,s):
87 api('v1').job_tasks().update(uuid=self['uuid'],
99 t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
100 t = UserDict.UserDict(t)
101 t.set_output = types.MethodType(task_set_output, t)
102 t.tmpdir = os.environ['TASK_WORK']
111 t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
112 t = UserDict.UserDict(t)
113 t.tmpdir = os.environ['JOB_WORK']
117 def getjobparam(*args):
118 return current_job()['script_parameters'].get(*args)
120 # Monkey patch discovery._cast() so objects and arrays get serialized
121 # with json.dumps() instead of str().
122 _cast_orig = apiclient.discovery._cast
123 def _cast_objects_too(value, schema_type):
125 if (type(value) != type('') and
126 (schema_type == 'object' or schema_type == 'array')):
127 return json.dumps(value)
129 return _cast_orig(value, schema_type)
130 apiclient.discovery._cast = _cast_objects_too
132 def api(version=None):
133 global services, config
134 if not services.get(version):
138 logging.info("Using default API version. " +
139 "Call arvados.api('%s') instead." %
141 if 'ARVADOS_API_HOST' not in config:
142 raise Exception("ARVADOS_API_HOST is not set. Aborting.")
143 url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
144 config['ARVADOS_API_HOST'])
145 credentials = CredentialsFromEnv()
147 # Use system's CA certificates (if we find them) instead of httplib2's
148 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
149 if not os.path.exists(ca_certs):
150 ca_certs = None # use httplib2 default
152 http = httplib2.Http(ca_certs=ca_certs)
153 http = credentials.authorize(http)
154 if re.match(r'(?i)^(true|1|yes)$',
155 config.get('ARVADOS_API_HOST_INSECURE', 'no')):
156 http.disable_ssl_certificate_validation=True
157 services[version] = apiclient.discovery.build(
158 'arvados', apiVersion, http=http, discoveryServiceUrl=url)
159 return services[version]
161 class JobTask(object):
162 def __init__(self, parameters=dict(), runtime_constraints=dict()):
163 print "init jobtask %s %s" % (parameters, runtime_constraints)
167 def one_task_per_input_file(if_sequence=0, and_end_task=True):
168 if if_sequence != current_task()['sequence']:
170 job_input = current_job()['script_parameters']['input']
171 cr = CollectionReader(job_input)
172 for s in cr.all_streams():
173 for f in s.all_files():
174 task_input = f.as_manifest()
176 'job_uuid': current_job()['uuid'],
177 'created_by_job_task_uuid': current_task()['uuid'],
178 'sequence': if_sequence + 1,
183 api('v1').job_tasks().create(body=new_task_attrs).execute()
185 api('v1').job_tasks().update(uuid=current_task()['uuid'],
186 body={'success':True}
191 def one_task_per_input_stream(if_sequence=0, and_end_task=True):
192 if if_sequence != current_task()['sequence']:
194 job_input = current_job()['script_parameters']['input']
195 cr = CollectionReader(job_input)
196 for s in cr.all_streams():
197 task_input = s.tokens()
199 'job_uuid': current_job()['uuid'],
200 'created_by_job_task_uuid': current_task()['uuid'],
201 'sequence': if_sequence + 1,
206 api('v1').job_tasks().create(body=new_task_attrs).execute()
208 api('v1').job_tasks().update(uuid=current_task()['uuid'],
209 body={'success':True}
215 def clear_tmpdir(path=None):
217 Ensure the given directory (or TASK_TMPDIR if none given)
221 path = current_task().tmpdir
222 if os.path.exists(path):
223 p = subprocess.Popen(['rm', '-rf', path])
224 stdout, stderr = p.communicate(None)
225 if p.returncode != 0:
226 raise Exception('rm -rf %s: %s' % (path, stderr))
230 def run_command(execargs, **kwargs):
231 kwargs.setdefault('stdin', subprocess.PIPE)
232 kwargs.setdefault('stdout', subprocess.PIPE)
233 kwargs.setdefault('stderr', sys.stderr)
234 kwargs.setdefault('close_fds', True)
235 kwargs.setdefault('shell', False)
236 p = subprocess.Popen(execargs, **kwargs)
237 stdoutdata, stderrdata = p.communicate(None)
238 if p.returncode != 0:
239 raise errors.CommandFailedError(
240 "run_command %s exit %d:\n%s" %
241 (execargs, p.returncode, stderrdata))
242 return stdoutdata, stderrdata
245 def git_checkout(url, version, path):
246 if not re.search('^/', path):
247 path = os.path.join(current_job().tmpdir, path)
248 if not os.path.exists(path):
249 util.run_command(["git", "clone", url, path],
250 cwd=os.path.dirname(path))
251 util.run_command(["git", "checkout", version],
256 def tar_extractor(path, decompress_flag):
257 return subprocess.Popen(["tar",
259 ("-x%sf" % decompress_flag),
262 stdin=subprocess.PIPE, stderr=sys.stderr,
263 shell=False, close_fds=True)
266 def tarball_extract(tarball, path):
267 """Retrieve a tarball from Keep and extract it to a local
268 directory. Return the absolute path where the tarball was
269 extracted. If the top level of the tarball contained just one
270 file or directory, return the absolute path of that single
273 tarball -- collection locator
274 path -- where to extract the tarball: absolute, or relative to job tmp
276 if not re.search('^/', path):
277 path = os.path.join(current_job().tmpdir, path)
278 lockfile = open(path + '.lock', 'w')
279 fcntl.flock(lockfile, fcntl.LOCK_EX)
284 already_have_it = False
286 if os.readlink(os.path.join(path, '.locator')) == tarball:
287 already_have_it = True
290 if not already_have_it:
292 # emulate "rm -f" (i.e., if the file does not exist, we win)
294 os.unlink(os.path.join(path, '.locator'))
296 if os.path.exists(os.path.join(path, '.locator')):
297 os.unlink(os.path.join(path, '.locator'))
299 for f in CollectionReader(tarball).all_files():
300 if re.search('\.(tbz|tar.bz2)$', f.name()):
301 p = util.tar_extractor(path, 'j')
302 elif re.search('\.(tgz|tar.gz)$', f.name()):
303 p = util.tar_extractor(path, 'z')
304 elif re.search('\.tar$', f.name()):
305 p = util.tar_extractor(path, '')
307 raise errors.AssertionError(
308 "tarball_extract cannot handle filename %s" % f.name())
316 if p.returncode != 0:
318 raise errors.CommandFailedError(
319 "tar exited %d" % p.returncode)
320 os.symlink(tarball, os.path.join(path, '.locator'))
321 tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
323 if len(tld_extracts) == 1:
324 return os.path.join(path, tld_extracts[0])
328 def zipball_extract(zipball, path):
329 """Retrieve a zip archive from Keep and extract it to a local
330 directory. Return the absolute path where the archive was
331 extracted. If the top level of the archive contained just one
332 file or directory, return the absolute path of that single
335 zipball -- collection locator
336 path -- where to extract the archive: absolute, or relative to job tmp
338 if not re.search('^/', path):
339 path = os.path.join(current_job().tmpdir, path)
340 lockfile = open(path + '.lock', 'w')
341 fcntl.flock(lockfile, fcntl.LOCK_EX)
346 already_have_it = False
348 if os.readlink(os.path.join(path, '.locator')) == zipball:
349 already_have_it = True
352 if not already_have_it:
354 # emulate "rm -f" (i.e., if the file does not exist, we win)
356 os.unlink(os.path.join(path, '.locator'))
358 if os.path.exists(os.path.join(path, '.locator')):
359 os.unlink(os.path.join(path, '.locator'))
361 for f in CollectionReader(zipball).all_files():
362 if not re.search('\.zip$', f.name()):
363 raise errors.NotImplementedError(
364 "zipball_extract cannot handle filename %s" % f.name())
365 zip_filename = os.path.join(path, os.path.basename(f.name()))
366 zip_file = open(zip_filename, 'wb')
374 p = subprocess.Popen(["unzip",
379 stdin=None, stderr=sys.stderr,
380 shell=False, close_fds=True)
382 if p.returncode != 0:
384 raise errors.CommandFailedError(
385 "unzip exited %d" % p.returncode)
386 os.unlink(zip_filename)
387 os.symlink(zipball, os.path.join(path, '.locator'))
388 tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
390 if len(tld_extracts) == 1:
391 return os.path.join(path, tld_extracts[0])
395 def collection_extract(collection, path, files=[], decompress=True):
396 """Retrieve a collection from Keep and extract it to a local
397 directory. Return the absolute path where the collection was
400 collection -- collection locator
401 path -- where to extract: absolute, or relative to job tmp
403 matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
405 collection_hash = matches.group(1)
407 collection_hash = hashlib.md5(collection).hexdigest()
408 if not re.search('^/', path):
409 path = os.path.join(current_job().tmpdir, path)
410 lockfile = open(path + '.lock', 'w')
411 fcntl.flock(lockfile, fcntl.LOCK_EX)
416 already_have_it = False
418 if os.readlink(os.path.join(path, '.locator')) == collection_hash:
419 already_have_it = True
423 # emulate "rm -f" (i.e., if the file does not exist, we win)
425 os.unlink(os.path.join(path, '.locator'))
427 if os.path.exists(os.path.join(path, '.locator')):
428 os.unlink(os.path.join(path, '.locator'))
431 for s in CollectionReader(collection).all_streams():
432 stream_name = s.name()
433 for f in s.all_files():
435 ((f.name() not in files_got) and
436 (f.name() in files or
437 (decompress and f.decompressed_name() in files)))):
438 outname = f.decompressed_name() if decompress else f.name()
439 files_got += [outname]
440 if os.path.exists(os.path.join(path, stream_name, outname)):
442 util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
443 outfile = open(os.path.join(path, stream_name, outname), 'wb')
444 for buf in (f.readall_decompressed() if decompress
448 if len(files_got) < len(files):
449 raise errors.AssertionError(
450 "Wanted files %s but only got %s from %s" %
452 [z.name() for z in CollectionReader(collection).all_files()]))
453 os.symlink(collection_hash, os.path.join(path, '.locator'))
459 def mkdir_dash_p(path):
460 if not os.path.exists(path):
461 util.mkdir_dash_p(os.path.dirname(path))
465 if not os.path.exists(path):
469 def stream_extract(stream, path, files=[], decompress=True):
470 """Retrieve a stream from Keep and extract it to a local
471 directory. Return the absolute path where the stream was
474 stream -- StreamReader object
475 path -- where to extract: absolute, or relative to job tmp
477 if not re.search('^/', path):
478 path = os.path.join(current_job().tmpdir, path)
479 lockfile = open(path + '.lock', 'w')
480 fcntl.flock(lockfile, fcntl.LOCK_EX)
487 for f in stream.all_files():
489 ((f.name() not in files_got) and
490 (f.name() in files or
491 (decompress and f.decompressed_name() in files)))):
492 outname = f.decompressed_name() if decompress else f.name()
493 files_got += [outname]
494 if os.path.exists(os.path.join(path, outname)):
495 os.unlink(os.path.join(path, outname))
496 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
497 outfile = open(os.path.join(path, outname), 'wb')
498 for buf in (f.readall_decompressed() if decompress
502 if len(files_got) < len(files):
503 raise errors.AssertionError(
504 "Wanted files %s but only got %s from %s" %
505 (files, files_got, [z.name() for z in stream.all_files()]))
510 def listdir_recursive(dirname, base=None):
512 for ent in sorted(os.listdir(dirname)):
513 ent_path = os.path.join(dirname, ent)
514 ent_base = os.path.join(base, ent) if base else ent
515 if os.path.isdir(ent_path):
516 allfiles += util.listdir_recursive(ent_path, ent_base)
518 allfiles += [ent_base]
521 class StreamFileReader(object):
522 def __init__(self, stream, pos, size, name):
523 self._stream = stream
530 def decompressed_name(self):
531 return re.sub('\.(bz2|gz)$', '', self._name)
534 def stream_name(self):
535 return self._stream.name()
536 def read(self, size, **kwargs):
537 self._stream.seek(self._pos + self._filepos)
538 data = self._stream.read(min(size, self._size - self._filepos))
539 self._filepos += len(data)
541 def readall(self, size=2**20, **kwargs):
543 data = self.read(size, **kwargs)
547 def bunzip2(self, size):
548 decompressor = bz2.BZ2Decompressor()
549 for chunk in self.readall(size):
550 data = decompressor.decompress(chunk)
551 if data and data != '':
553 def gunzip(self, size):
554 decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
555 for chunk in self.readall(size):
556 data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
557 if data and data != '':
559 def readall_decompressed(self, size=2**20):
560 self._stream.seek(self._pos + self._filepos)
561 if re.search('\.bz2$', self._name):
562 return self.bunzip2(size)
563 elif re.search('\.gz$', self._name):
564 return self.gunzip(size)
566 return self.readall(size)
567 def readlines(self, decompress=True):
569 datasource = self.readall_decompressed()
571 self._stream.seek(self._pos + self._filepos)
572 datasource = self.readall()
574 for newdata in datasource:
578 eol = string.find(data, "\n", sol)
581 yield data[sol:eol+1]
586 def as_manifest(self):
588 return ("%s %s 0:0:%s\n"
589 % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
590 return string.join(self._stream.tokens_for_range(self._pos, self._size),
593 class StreamReader(object):
594 def __init__(self, tokens):
595 self._tokens = tokens
596 self._current_datablock_data = None
597 self._current_datablock_pos = 0
598 self._current_datablock_index = -1
601 self._stream_name = None
602 self.data_locators = []
605 for tok in self._tokens:
606 if self._stream_name == None:
607 self._stream_name = tok.replace('\\040', ' ')
608 elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
609 self.data_locators += [tok]
610 elif re.search(r'^\d+:\d+:\S+', tok):
611 pos, size, name = tok.split(':',2)
612 self.files += [[int(pos), int(size), name.replace('\\040', ' ')]]
614 raise errors.SyntaxError("Invalid manifest format")
618 def tokens_for_range(self, range_start, range_size):
619 resp = [self._stream_name]
620 return_all_tokens = False
622 token_bytes_skipped = 0
623 for locator in self.data_locators:
624 sizehint = re.search(r'\+(\d+)', locator)
626 return_all_tokens = True
627 if return_all_tokens:
630 blocksize = int(sizehint.group(0))
631 if range_start + range_size <= block_start:
633 if range_start < block_start + blocksize:
636 token_bytes_skipped += blocksize
637 block_start += blocksize
639 if ((f[0] < range_start + range_size)
641 (f[0] + f[1] > range_start)
644 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
647 return self._stream_name
651 yield StreamFileReader(self, pos, size, name)
652 def nextdatablock(self):
653 if self._current_datablock_index < 0:
654 self._current_datablock_pos = 0
655 self._current_datablock_index = 0
657 self._current_datablock_pos += self.current_datablock_size()
658 self._current_datablock_index += 1
659 self._current_datablock_data = None
660 def current_datablock_data(self):
661 if self._current_datablock_data == None:
662 self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
663 return self._current_datablock_data
664 def current_datablock_size(self):
665 if self._current_datablock_index < 0:
667 sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
669 return int(sizehint.group(0))
670 return len(self.current_datablock_data())
672 """Set the position of the next read operation."""
674 def really_seek(self):
675 """Find and load the appropriate data block, so the byte at
678 if self._pos == self._current_datablock_pos:
680 if (self._current_datablock_pos != None and
681 self._pos >= self._current_datablock_pos and
682 self._pos <= self._current_datablock_pos + self.current_datablock_size()):
684 if self._pos < self._current_datablock_pos:
685 self._current_datablock_index = -1
687 while (self._pos > self._current_datablock_pos and
688 self._pos > self._current_datablock_pos + self.current_datablock_size()):
690 def read(self, size):
691 """Read no more than size bytes -- but at least one byte,
692 unless _pos is already at the end of the stream.
697 while self._pos >= self._current_datablock_pos + self.current_datablock_size():
699 if self._current_datablock_index >= len(self.data_locators):
701 data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
702 self._pos += len(data)
705 class CollectionReader(object):
706 def __init__(self, manifest_locator_or_text):
707 if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
708 self._manifest_text = manifest_locator_or_text
709 self._manifest_locator = None
711 self._manifest_locator = manifest_locator_or_text
712 self._manifest_text = None
719 if self._streams != None:
721 if not self._manifest_text:
722 self._manifest_text = Keep.get(self._manifest_locator)
724 for stream_line in self._manifest_text.split("\n"):
725 if stream_line != '':
726 stream_tokens = stream_line.split()
727 self._streams += [stream_tokens]
728 def all_streams(self):
731 for s in self._streams:
732 resp += [StreamReader(s)]
735 for s in self.all_streams():
736 for f in s.all_files():
738 def manifest_text(self):
740 return self._manifest_text
742 class CollectionWriter(object):
743 KEEP_BLOCK_SIZE = 2**26
745 self._data_buffer = []
746 self._data_buffer_len = 0
747 self._current_stream_files = []
748 self._current_stream_length = 0
749 self._current_stream_locators = []
750 self._current_stream_name = '.'
751 self._current_file_name = None
752 self._current_file_pos = 0
753 self._finished_streams = []
758 def write_directory_tree(self,
759 path, stream_name='.', max_manifest_depth=-1):
760 self.start_new_stream(stream_name)
762 if max_manifest_depth == 0:
763 dirents = sorted(util.listdir_recursive(path))
765 dirents = sorted(os.listdir(path))
766 for dirent in dirents:
767 target = os.path.join(path, dirent)
768 if os.path.isdir(target):
770 os.path.join(stream_name, dirent),
771 max_manifest_depth-1]]
773 self.start_new_file(dirent)
774 with open(target, 'rb') as f:
780 self.finish_current_stream()
781 map(lambda x: self.write_directory_tree(*x), todo)
783 def write(self, newdata):
784 if hasattr(newdata, '__iter__'):
788 self._data_buffer += [newdata]
789 self._data_buffer_len += len(newdata)
790 self._current_stream_length += len(newdata)
791 while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
793 def flush_data(self):
794 data_buffer = ''.join(self._data_buffer)
795 if data_buffer != '':
796 self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
797 self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
798 self._data_buffer_len = len(self._data_buffer[0])
799 def start_new_file(self, newfilename=None):
800 self.finish_current_file()
801 self.set_current_file_name(newfilename)
802 def set_current_file_name(self, newfilename):
803 if re.search(r'[\t\n]', newfilename):
804 raise errors.AssertionError(
805 "Manifest filenames cannot contain whitespace: %s" %
807 self._current_file_name = newfilename
808 def current_file_name(self):
809 return self._current_file_name
810 def finish_current_file(self):
811 if self._current_file_name == None:
812 if self._current_file_pos == self._current_stream_length:
814 raise errors.AssertionError(
815 "Cannot finish an unnamed file " +
816 "(%d bytes at offset %d in '%s' stream)" %
817 (self._current_stream_length - self._current_file_pos,
818 self._current_file_pos,
819 self._current_stream_name))
820 self._current_stream_files += [[self._current_file_pos,
821 self._current_stream_length - self._current_file_pos,
822 self._current_file_name]]
823 self._current_file_pos = self._current_stream_length
824 def start_new_stream(self, newstreamname='.'):
825 self.finish_current_stream()
826 self.set_current_stream_name(newstreamname)
827 def set_current_stream_name(self, newstreamname):
828 if re.search(r'[\t\n]', newstreamname):
829 raise errors.AssertionError(
830 "Manifest stream names cannot contain whitespace")
831 self._current_stream_name = '.' if newstreamname=='' else newstreamname
832 def current_stream_name(self):
833 return self._current_stream_name
834 def finish_current_stream(self):
835 self.finish_current_file()
837 if len(self._current_stream_files) == 0:
839 elif self._current_stream_name == None:
840 raise errors.AssertionError(
841 "Cannot finish an unnamed stream (%d bytes in %d files)" %
842 (self._current_stream_length, len(self._current_stream_files)))
844 if len(self._current_stream_locators) == 0:
845 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
846 self._finished_streams += [[self._current_stream_name,
847 self._current_stream_locators,
848 self._current_stream_files]]
849 self._current_stream_files = []
850 self._current_stream_length = 0
851 self._current_stream_locators = []
852 self._current_stream_name = None
853 self._current_file_pos = 0
854 self._current_file_name = None
856 return Keep.put(self.manifest_text())
857 def manifest_text(self):
858 self.finish_current_stream()
860 for stream in self._finished_streams:
861 if not re.search(r'^\.(/.*)?$', stream[0]):
863 manifest += stream[0].replace(' ', '\\040')
864 for locator in stream[1]:
865 manifest += " %s" % locator
866 for sfile in stream[2]:
867 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2].replace(' ', '\\040'))
870 def data_locators(self):
872 for name, locators, files in self._finished_streams:
876 global_client_object = None
880 def global_client_object():
881 global global_client_object
882 if global_client_object == None:
883 global_client_object = KeepClient()
884 return global_client_object
887 def get(locator, **kwargs):
888 return Keep.global_client_object().get(locator, **kwargs)
891 def put(data, **kwargs):
892 return Keep.global_client_object().put(data, **kwargs)
894 class KeepClient(object):
896 class ThreadLimiter(object):
898 Limit the number of threads running at a given time to
899 {desired successes} minus {successes reported}. When successes
900 reported == desired, wake up the remaining threads and tell
903 Should be used in a "with" block.
905 def __init__(self, todo):
908 self._todo_lock = threading.Semaphore(todo)
909 self._done_lock = threading.Lock()
911 self._todo_lock.acquire()
913 def __exit__(self, type, value, traceback):
914 self._todo_lock.release()
915 def shall_i_proceed(self):
917 Return true if the current thread should do stuff. Return
918 false if the current thread should just stop.
920 with self._done_lock:
921 return (self._done < self._todo)
922 def increment_done(self):
924 Report that the current thread was successful.
926 with self._done_lock:
930 Return how many successes were reported.
932 with self._done_lock:
935 class KeepWriterThread(threading.Thread):
937 Write a blob of data to the given Keep server. Call
938 increment_done() of the given ThreadLimiter if the write
941 def __init__(self, **kwargs):
942 super(KeepClient.KeepWriterThread, self).__init__()
946 with self.args['thread_limiter'] as limiter:
947 if not limiter.shall_i_proceed():
948 # My turn arrived, but the job has been done without
951 logging.debug("KeepWriterThread %s proceeding %s %s" %
952 (str(threading.current_thread()),
953 self.args['data_hash'],
954 self.args['service_root']))
956 url = self.args['service_root'] + self.args['data_hash']
957 api_token = config['ARVADOS_API_TOKEN']
958 headers = {'Authorization': "OAuth2 %s" % api_token}
960 resp, content = h.request(url.encode('utf-8'), 'PUT',
962 body=self.args['data'])
963 if (resp['status'] == '401' and
964 re.match(r'Timestamp verification failed', content)):
965 body = KeepClient.sign_for_old_server(
966 self.args['data_hash'],
969 resp, content = h.request(url.encode('utf-8'), 'PUT',
972 if re.match(r'^2\d\d$', resp['status']):
973 logging.debug("KeepWriterThread %s succeeded %s %s" %
974 (str(threading.current_thread()),
975 self.args['data_hash'],
976 self.args['service_root']))
977 return limiter.increment_done()
978 logging.warning("Request fail: PUT %s => %s %s" %
979 (url, resp['status'], content))
980 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
981 logging.warning("Request fail: PUT %s => %s: %s" %
982 (url, type(e), str(e)))
985 self.lock = threading.Lock()
986 self.service_roots = None
988 def shuffled_service_roots(self, hash):
989 if self.service_roots == None:
991 keep_disks = api().keep_disks().list().execute()['items']
992 roots = (("http%s://%s:%d/" %
993 ('s' if f['service_ssl_flag'] else '',
997 self.service_roots = sorted(set(roots))
998 logging.debug(str(self.service_roots))
1001 pool = self.service_roots[:]
1003 while len(pool) > 0:
1005 if len(pseq) < len(hash) / 4: # first time around
1006 seed = hash[-4:] + hash
1009 probe = int(seed[0:8], 16) % len(pool)
1010 pseq += [pool[probe]]
1011 pool = pool[:probe] + pool[probe+1:]
1013 logging.debug(str(pseq))
1016 def get(self, locator):
1018 if re.search(r',', locator):
1019 return ''.join(self.get(x) for x in locator.split(','))
1020 if 'KEEP_LOCAL_STORE' in os.environ:
1021 return KeepClient.local_store_get(locator)
1022 expect_hash = re.sub(r'\+.*', '', locator)
1023 for service_root in self.shuffled_service_roots(expect_hash):
1025 url = service_root + expect_hash
1026 api_token = config['ARVADOS_API_TOKEN']
1027 headers = {'Authorization': "OAuth2 %s" % api_token,
1028 'Accept': 'application/octet-stream'}
1030 resp, content = h.request(url.encode('utf-8'), 'GET',
1032 if re.match(r'^2\d\d$', resp['status']):
1033 m = hashlib.new('md5')
1036 if md5 == expect_hash:
1038 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
1039 except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
1040 logging.info("Request fail: GET %s => %s: %s" %
1041 (url, type(e), str(e)))
1042 raise errors.NotFoundError("Block not found: %s" % expect_hash)
1044 def put(self, data, **kwargs):
1045 if 'KEEP_LOCAL_STORE' in os.environ:
1046 return KeepClient.local_store_put(data)
1047 m = hashlib.new('md5')
1049 data_hash = m.hexdigest()
1051 want_copies = kwargs.get('copies', 2)
1052 if not (want_copies > 0):
1055 thread_limiter = KeepClient.ThreadLimiter(want_copies)
1056 for service_root in self.shuffled_service_roots(data_hash):
1057 t = KeepClient.KeepWriterThread(data=data,
1058 data_hash=data_hash,
1059 service_root=service_root,
1060 thread_limiter=thread_limiter)
1065 have_copies = thread_limiter.done()
1066 if have_copies == want_copies:
1067 return (data_hash + '+' + str(len(data)))
1068 raise errors.KeepWriteError(
1069 "Write fail for %s: wanted %d but wrote %d" %
1070 (data_hash, want_copies, have_copies))
1073 def sign_for_old_server(data_hash, data):
1074 return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1078 def local_store_put(data):
1079 m = hashlib.new('md5')
1082 locator = '%s+%d' % (md5, len(data))
1083 with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1085 os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1086 os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1089 def local_store_get(locator):
1090 r = re.search('^([0-9a-f]{32,})', locator)
1092 raise errors.NotFoundError(
1093 "Invalid data locator: '%s'" % locator)
1094 if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1096 with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1099 # We really shouldn't do this but some clients still use
1100 # arvados.service.* directly instead of arvados.api().*