9e9fb00833a20cd3ceb4fc8937a3c6b0bd1724f9
[arvados.git] / sdk / python / arvados / keep.py
1 import cStringIO
2 import datetime
3 import hashlib
4 import logging
5 import math
6 import os
7 import pycurl
8 import Queue
9 import re
10 import socket
11 import ssl
12 import threading
13 import timer
14
15 import arvados
16 import arvados.config as config
17 import arvados.errors
18 import arvados.retry as retry
19 import arvados.util
20
21 _logger = logging.getLogger('arvados.keep')
22 global_client_object = None
23
24
25 class KeepLocator(object):
26     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
27     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
28
29     def __init__(self, locator_str):
30         self.hints = []
31         self._perm_sig = None
32         self._perm_expiry = None
33         pieces = iter(locator_str.split('+'))
34         self.md5sum = next(pieces)
35         try:
36             self.size = int(next(pieces))
37         except StopIteration:
38             self.size = None
39         for hint in pieces:
40             if self.HINT_RE.match(hint) is None:
41                 raise ValueError("invalid hint format: {}".format(hint))
42             elif hint.startswith('A'):
43                 self.parse_permission_hint(hint)
44             else:
45                 self.hints.append(hint)
46
47     def __str__(self):
48         return '+'.join(
49             str(s) for s in [self.md5sum, self.size,
50                              self.permission_hint()] + self.hints
51             if s is not None)
52
53     def stripped(self):
54         if self.size is not None:
55             return "%s+%i" % (self.md5sum, self.size)
56         else:
57             return self.md5sum
58
59     def _make_hex_prop(name, length):
60         # Build and return a new property with the given name that
61         # must be a hex string of the given length.
62         data_name = '_{}'.format(name)
63         def getter(self):
64             return getattr(self, data_name)
65         def setter(self, hex_str):
66             if not arvados.util.is_hex(hex_str, length):
67                 raise ValueError("{} is not a {}-digit hex string: {}".
68                                  format(name, length, hex_str))
69             setattr(self, data_name, hex_str)
70         return property(getter, setter)
71
72     md5sum = _make_hex_prop('md5sum', 32)
73     perm_sig = _make_hex_prop('perm_sig', 40)
74
75     @property
76     def perm_expiry(self):
77         return self._perm_expiry
78
79     @perm_expiry.setter
80     def perm_expiry(self, value):
81         if not arvados.util.is_hex(value, 1, 8):
82             raise ValueError(
83                 "permission timestamp must be a hex Unix timestamp: {}".
84                 format(value))
85         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
86
87     def permission_hint(self):
88         data = [self.perm_sig, self.perm_expiry]
89         if None in data:
90             return None
91         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
92         return "A{}@{:08x}".format(*data)
93
94     def parse_permission_hint(self, s):
95         try:
96             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
97         except IndexError:
98             raise ValueError("bad permission hint {}".format(s))
99
100     def permission_expired(self, as_of_dt=None):
101         if self.perm_expiry is None:
102             return False
103         elif as_of_dt is None:
104             as_of_dt = datetime.datetime.now()
105         return self.perm_expiry <= as_of_dt
106
107
108 class Keep(object):
109     """Simple interface to a global KeepClient object.
110
111     THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
112     own API client.  The global KeepClient will build an API client from the
113     current Arvados configuration, which may not match the one you built.
114     """
115     _last_key = None
116
117     @classmethod
118     def global_client_object(cls):
119         global global_client_object
120         # Previously, KeepClient would change its behavior at runtime based
121         # on these configuration settings.  We simulate that behavior here
122         # by checking the values and returning a new KeepClient if any of
123         # them have changed.
124         key = (config.get('ARVADOS_API_HOST'),
125                config.get('ARVADOS_API_TOKEN'),
126                config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
127                config.get('ARVADOS_KEEP_PROXY'),
128                config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
129                os.environ.get('KEEP_LOCAL_STORE'))
130         if (global_client_object is None) or (cls._last_key != key):
131             global_client_object = KeepClient()
132             cls._last_key = key
133         return global_client_object
134
135     @staticmethod
136     def get(locator, **kwargs):
137         return Keep.global_client_object().get(locator, **kwargs)
138
139     @staticmethod
140     def put(data, **kwargs):
141         return Keep.global_client_object().put(data, **kwargs)
142
143 class KeepBlockCache(object):
144     # Default RAM cache is 256MiB
145     def __init__(self, cache_max=(256 * 1024 * 1024)):
146         self.cache_max = cache_max
147         self._cache = []
148         self._cache_lock = threading.Lock()
149
150     class CacheSlot(object):
151         __slots__ = ("locator", "ready", "content")
152
153         def __init__(self, locator):
154             self.locator = locator
155             self.ready = threading.Event()
156             self.content = None
157
158         def get(self):
159             self.ready.wait()
160             return self.content
161
162         def set(self, value):
163             self.content = value
164             self.ready.set()
165
166         def size(self):
167             if self.content is None:
168                 return 0
169             else:
170                 return len(self.content)
171
172     def cap_cache(self):
173         '''Cap the cache size to self.cache_max'''
174         with self._cache_lock:
175             # Select all slots except those where ready.is_set() and content is
176             # None (that means there was an error reading the block).
177             self._cache = [c for c in self._cache if not (c.ready.is_set() and c.content is None)]
178             sm = sum([slot.size() for slot in self._cache])
179             while len(self._cache) > 0 and sm > self.cache_max:
180                 for i in xrange(len(self._cache)-1, -1, -1):
181                     if self._cache[i].ready.is_set():
182                         del self._cache[i]
183                         break
184                 sm = sum([slot.size() for slot in self._cache])
185
186     def _get(self, locator):
187         # Test if the locator is already in the cache
188         for i in xrange(0, len(self._cache)):
189             if self._cache[i].locator == locator:
190                 n = self._cache[i]
191                 if i != 0:
192                     # move it to the front
193                     del self._cache[i]
194                     self._cache.insert(0, n)
195                 return n
196         return None
197
198     def get(self, locator):
199         with self._cache_lock:
200             return self._get(locator)
201
202     def reserve_cache(self, locator):
203         '''Reserve a cache slot for the specified locator,
204         or return the existing slot.'''
205         with self._cache_lock:
206             n = self._get(locator)
207             if n:
208                 return n, False
209             else:
210                 # Add a new cache slot for the locator
211                 n = KeepBlockCache.CacheSlot(locator)
212                 self._cache.insert(0, n)
213                 return n, True
214
215 class Counter(object):
216     def __init__(self, v=0):
217         self._lk = threading.Lock()
218         self._val = v
219
220     def add(self, v):
221         with self._lk:
222             self._val += v
223
224     def get(self):
225         with self._lk:
226             return self._val
227
228
229 class KeepClient(object):
230
231     # Default Keep server connection timeout:  2 seconds
232     # Default Keep server read timeout:       256 seconds
233     # Default Keep server bandwidth minimum:  32768 bytes per second
234     # Default Keep proxy connection timeout:  20 seconds
235     # Default Keep proxy read timeout:        256 seconds
236     # Default Keep proxy bandwidth minimum:   32768 bytes per second
237     DEFAULT_TIMEOUT = (2, 256, 32768)
238     DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
239
240     class ThreadLimiter(object):
241         """Limit the number of threads writing to Keep at once.
242
243         This ensures that only a number of writer threads that could
244         potentially achieve the desired replication level run at once.
245         Once the desired replication level is achieved, queued threads
246         are instructed not to run.
247
248         Should be used in a "with" block.
249         """
250         def __init__(self, want_copies, max_service_replicas):
251             self._started = 0
252             self._want_copies = want_copies
253             self._done = 0
254             self._thread_failures = 0
255             self._response = None
256             self._start_lock = threading.Condition()
257             if (not max_service_replicas) or (max_service_replicas >= want_copies):
258                 max_threads = 1
259             else:
260                 max_threads = math.ceil(float(want_copies) / max_service_replicas)
261             _logger.debug("Limiter max threads is %d", max_threads)
262             self._todo_lock = threading.Semaphore(max_threads)
263             self._done_lock = threading.Lock()
264             self._thread_failures_lock = threading.Lock()
265             self._local = threading.local()
266
267         def __enter__(self):
268             self._start_lock.acquire()
269             if getattr(self._local, 'sequence', None) is not None:
270                 # If the calling thread has used set_sequence(N), then
271                 # we wait here until N other threads have started.
272                 while self._started < self._local.sequence:
273                     self._start_lock.wait()
274             self._todo_lock.acquire()
275             self._started += 1
276             self._start_lock.notifyAll()
277             self._start_lock.release()
278             return self
279
280         def __exit__(self, type, value, traceback):
281             with self._thread_failures_lock:
282                 if self._thread_failures > 0:
283                     self._thread_failures -= 1
284                     self._todo_lock.release()
285
286             # If work is finished, release al pending threads
287             if not self.shall_i_proceed():
288                 self._todo_lock.release()
289
290         def set_sequence(self, sequence):
291             self._local.sequence = sequence
292
293         def shall_i_proceed(self):
294             """
295             Return true if the current thread should write to Keep.
296             Return false otherwise.
297             """
298             with self._done_lock:
299                 return (self._done < self._want_copies)
300
301         def save_response(self, response_body, replicas_stored):
302             """
303             Records a response body (a locator, possibly signed) returned by
304             the Keep server, and the number of replicas it stored.
305             """
306             if replicas_stored == 0:
307                 # Failure notification, should start a new thread to try to reach full replication
308                 with self._thread_failures_lock:
309                     self._thread_failures += 1
310             else:
311                 with self._done_lock:
312                     self._done += replicas_stored
313                     self._response = response_body
314
315         def response(self):
316             """Return the body from the response to a PUT request."""
317             with self._done_lock:
318                 return self._response
319
320         def done(self):
321             """Return the total number of replicas successfully stored."""
322             with self._done_lock:
323                 return self._done
324
325     class KeepService(object):
326         """Make requests to a single Keep service, and track results.
327
328         A KeepService is intended to last long enough to perform one
329         transaction (GET or PUT) against one Keep service. This can
330         involve calling either get() or put() multiple times in order
331         to retry after transient failures. However, calling both get()
332         and put() on a single instance -- or using the same instance
333         to access two different Keep services -- will not produce
334         sensible behavior.
335         """
336
337         HTTP_ERRORS = (
338             socket.error,
339             ssl.SSLError,
340             arvados.errors.HttpError,
341         )
342
343         def __init__(self, root, user_agent_pool=Queue.LifoQueue(),
344                      upload_counter=None,
345                      download_counter=None, **headers):
346             self.root = root
347             self._user_agent_pool = user_agent_pool
348             self._result = {'error': None}
349             self._usable = True
350             self._session = None
351             self.get_headers = {'Accept': 'application/octet-stream'}
352             self.get_headers.update(headers)
353             self.put_headers = headers
354             self.upload_counter = upload_counter
355             self.download_counter = download_counter
356
357         def usable(self):
358             """Is it worth attempting a request?"""
359             return self._usable
360
361         def finished(self):
362             """Did the request succeed or encounter permanent failure?"""
363             return self._result['error'] == False or not self._usable
364
365         def last_result(self):
366             return self._result
367
368         def _get_user_agent(self):
369             try:
370                 return self._user_agent_pool.get(False)
371             except Queue.Empty:
372                 return pycurl.Curl()
373
374         def _put_user_agent(self, ua):
375             try:
376                 ua.reset()
377                 self._user_agent_pool.put(ua, False)
378             except:
379                 ua.close()
380
381         @staticmethod
382         def _socket_open(family, socktype, protocol, address=None):
383             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
384             s = socket.socket(family, socktype, protocol)
385             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
386             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
387             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
388             return s
389
390         def get(self, locator, method="GET", timeout=None):
391             # locator is a KeepLocator object.
392             url = self.root + str(locator)
393             _logger.debug("Request: %s %s", method, url)
394             curl = self._get_user_agent()
395             ok = None
396             try:
397                 with timer.Timer() as t:
398                     self._headers = {}
399                     response_body = cStringIO.StringIO()
400                     curl.setopt(pycurl.NOSIGNAL, 1)
401                     curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
402                     curl.setopt(pycurl.URL, url.encode('utf-8'))
403                     curl.setopt(pycurl.HTTPHEADER, [
404                         '{}: {}'.format(k,v) for k,v in self.get_headers.iteritems()])
405                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
406                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
407                     if method == "HEAD":
408                         curl.setopt(pycurl.NOBODY, True)
409                     self._setcurltimeouts(curl, timeout)
410
411                     try:
412                         curl.perform()
413                     except Exception as e:
414                         raise arvados.errors.HttpError(0, str(e))
415                     self._result = {
416                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
417                         'body': response_body.getvalue(),
418                         'headers': self._headers,
419                         'error': False,
420                     }
421
422                 ok = retry.check_http_response_success(self._result['status_code'])
423                 if not ok:
424                     self._result['error'] = arvados.errors.HttpError(
425                         self._result['status_code'],
426                         self._headers.get('x-status-line', 'Error'))
427             except self.HTTP_ERRORS as e:
428                 self._result = {
429                     'error': e,
430                 }
431             self._usable = ok != False
432             if self._result.get('status_code', None):
433                 # The client worked well enough to get an HTTP status
434                 # code, so presumably any problems are just on the
435                 # server side and it's OK to reuse the client.
436                 self._put_user_agent(curl)
437             else:
438                 # Don't return this client to the pool, in case it's
439                 # broken.
440                 curl.close()
441             if not ok:
442                 _logger.debug("Request fail: GET %s => %s: %s",
443                               url, type(self._result['error']), str(self._result['error']))
444                 return None
445             if method == "HEAD":
446                 _logger.info("HEAD %s: %s bytes",
447                          self._result['status_code'],
448                          self._result.get('content-length'))
449                 return True
450
451             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
452                          self._result['status_code'],
453                          len(self._result['body']),
454                          t.msecs,
455                          (len(self._result['body'])/(1024.0*1024))/t.secs if t.secs > 0 else 0)
456
457             if self.download_counter:
458                 self.download_counter.add(len(self._result['body']))
459             resp_md5 = hashlib.md5(self._result['body']).hexdigest()
460             if resp_md5 != locator.md5sum:
461                 _logger.warning("Checksum fail: md5(%s) = %s",
462                                 url, resp_md5)
463                 self._result['error'] = arvados.errors.HttpError(
464                     0, 'Checksum fail')
465                 return None
466             return self._result['body']
467
468         def put(self, hash_s, body, timeout=None):
469             url = self.root + hash_s
470             _logger.debug("Request: PUT %s", url)
471             curl = self._get_user_agent()
472             ok = None
473             try:
474                 with timer.Timer() as t:
475                     self._headers = {}
476                     body_reader = cStringIO.StringIO(body)
477                     response_body = cStringIO.StringIO()
478                     curl.setopt(pycurl.NOSIGNAL, 1)
479                     curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
480                     curl.setopt(pycurl.URL, url.encode('utf-8'))
481                     # Using UPLOAD tells cURL to wait for a "go ahead" from the
482                     # Keep server (in the form of a HTTP/1.1 "100 Continue"
483                     # response) instead of sending the request body immediately.
484                     # This allows the server to reject the request if the request
485                     # is invalid or the server is read-only, without waiting for
486                     # the client to send the entire block.
487                     curl.setopt(pycurl.UPLOAD, True)
488                     curl.setopt(pycurl.INFILESIZE, len(body))
489                     curl.setopt(pycurl.READFUNCTION, body_reader.read)
490                     curl.setopt(pycurl.HTTPHEADER, [
491                         '{}: {}'.format(k,v) for k,v in self.put_headers.iteritems()])
492                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
493                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
494                     self._setcurltimeouts(curl, timeout)
495                     try:
496                         curl.perform()
497                     except Exception as e:
498                         raise arvados.errors.HttpError(0, str(e))
499                     self._result = {
500                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
501                         'body': response_body.getvalue(),
502                         'headers': self._headers,
503                         'error': False,
504                     }
505                 ok = retry.check_http_response_success(self._result['status_code'])
506                 if not ok:
507                     self._result['error'] = arvados.errors.HttpError(
508                         self._result['status_code'],
509                         self._headers.get('x-status-line', 'Error'))
510             except self.HTTP_ERRORS as e:
511                 self._result = {
512                     'error': e,
513                 }
514             self._usable = ok != False # still usable if ok is True or None
515             if self._result.get('status_code', None):
516                 # Client is functional. See comment in get().
517                 self._put_user_agent(curl)
518             else:
519                 curl.close()
520             if not ok:
521                 _logger.debug("Request fail: PUT %s => %s: %s",
522                               url, type(self._result['error']), str(self._result['error']))
523                 return False
524             _logger.info("PUT %s: %s bytes in %s msec (%.3f MiB/sec)",
525                          self._result['status_code'],
526                          len(body),
527                          t.msecs,
528                          (len(body)/(1024.0*1024))/t.secs if t.secs > 0 else 0)
529             if self.upload_counter:
530                 self.upload_counter.add(len(body))
531             return True
532
533         def _setcurltimeouts(self, curl, timeouts):
534             if not timeouts:
535                 return
536             elif isinstance(timeouts, tuple):
537                 if len(timeouts) == 2:
538                     conn_t, xfer_t = timeouts
539                     bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
540                 else:
541                     conn_t, xfer_t, bandwidth_bps = timeouts
542             else:
543                 conn_t, xfer_t = (timeouts, timeouts)
544                 bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
545             curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
546             curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
547             curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
548
549         def _headerfunction(self, header_line):
550             header_line = header_line.decode('iso-8859-1')
551             if ':' in header_line:
552                 name, value = header_line.split(':', 1)
553                 name = name.strip().lower()
554                 value = value.strip()
555             elif self._headers:
556                 name = self._lastheadername
557                 value = self._headers[name] + ' ' + header_line.strip()
558             elif header_line.startswith('HTTP/'):
559                 name = 'x-status-line'
560                 value = header_line
561             else:
562                 _logger.error("Unexpected header line: %s", header_line)
563                 return
564             self._lastheadername = name
565             self._headers[name] = value
566             # Returning None implies all bytes were written
567
568
569     class KeepWriterThread(threading.Thread):
570         """
571         Write a blob of data to the given Keep server. On success, call
572         save_response() of the given ThreadLimiter to save the returned
573         locator.
574         """
575         def __init__(self, keep_service, **kwargs):
576             super(KeepClient.KeepWriterThread, self).__init__()
577             self.service = keep_service
578             self.args = kwargs
579             self._success = False
580
581         def success(self):
582             return self._success
583
584         def run(self):
585             limiter = self.args['thread_limiter']
586             sequence = self.args['thread_sequence']
587             if sequence is not None:
588                 limiter.set_sequence(sequence)
589             with limiter:
590                 if not limiter.shall_i_proceed():
591                     # My turn arrived, but the job has been done without
592                     # me.
593                     return
594                 self.run_with_limiter(limiter)
595
596         def run_with_limiter(self, limiter):
597             if self.service.finished():
598                 return
599             _logger.debug("KeepWriterThread %s proceeding %s+%i %s",
600                           str(threading.current_thread()),
601                           self.args['data_hash'],
602                           len(self.args['data']),
603                           self.args['service_root'])
604             self._success = bool(self.service.put(
605                 self.args['data_hash'],
606                 self.args['data'],
607                 timeout=self.args.get('timeout', None)))
608             result = self.service.last_result()
609             if self._success:
610                 _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
611                               str(threading.current_thread()),
612                               self.args['data_hash'],
613                               len(self.args['data']),
614                               self.args['service_root'])
615                 # Tick the 'done' counter for the number of replica
616                 # reported stored by the server, for the case that
617                 # we're talking to a proxy or other backend that
618                 # stores to multiple copies for us.
619                 try:
620                     replicas_stored = int(result['headers']['x-keep-replicas-stored'])
621                 except (KeyError, ValueError):
622                     replicas_stored = 1
623                 limiter.save_response(result['body'].strip(), replicas_stored)
624             elif result.get('status_code', None):
625                 _logger.debug("Request fail: PUT %s => %s %s",
626                               self.args['data_hash'],
627                               result['status_code'],
628                               result['body'])
629             if not self._success:
630                 # Notify the failure so that the Thread limiter allows
631                 # a new one to run.
632                 limiter.save_response(None, 0)
633
634
635     def __init__(self, api_client=None, proxy=None,
636                  timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
637                  api_token=None, local_store=None, block_cache=None,
638                  num_retries=0, session=None):
639         """Initialize a new KeepClient.
640
641         Arguments:
642         :api_client:
643           The API client to use to find Keep services.  If not
644           provided, KeepClient will build one from available Arvados
645           configuration.
646
647         :proxy:
648           If specified, this KeepClient will send requests to this Keep
649           proxy.  Otherwise, KeepClient will fall back to the setting of the
650           ARVADOS_KEEP_PROXY configuration setting.  If you want to ensure
651           KeepClient does not use a proxy, pass in an empty string.
652
653         :timeout:
654           The initial timeout (in seconds) for HTTP requests to Keep
655           non-proxy servers.  A tuple of three floats is interpreted as
656           (connection_timeout, read_timeout, minimum_bandwidth). A connection
657           will be aborted if the average traffic rate falls below
658           minimum_bandwidth bytes per second over an interval of read_timeout
659           seconds. Because timeouts are often a result of transient server
660           load, the actual connection timeout will be increased by a factor
661           of two on each retry.
662           Default: (2, 256, 32768).
663
664         :proxy_timeout:
665           The initial timeout (in seconds) for HTTP requests to
666           Keep proxies. A tuple of three floats is interpreted as
667           (connection_timeout, read_timeout, minimum_bandwidth). The behavior
668           described above for adjusting connection timeouts on retry also
669           applies.
670           Default: (20, 256, 32768).
671
672         :api_token:
673           If you're not using an API client, but only talking
674           directly to a Keep proxy, this parameter specifies an API token
675           to authenticate Keep requests.  It is an error to specify both
676           api_client and api_token.  If you specify neither, KeepClient
677           will use one available from the Arvados configuration.
678
679         :local_store:
680           If specified, this KeepClient will bypass Keep
681           services, and save data to the named directory.  If unspecified,
682           KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
683           environment variable.  If you want to ensure KeepClient does not
684           use local storage, pass in an empty string.  This is primarily
685           intended to mock a server for testing.
686
687         :num_retries:
688           The default number of times to retry failed requests.
689           This will be used as the default num_retries value when get() and
690           put() are called.  Default 0.
691         """
692         self.lock = threading.Lock()
693         if proxy is None:
694             proxy = config.get('ARVADOS_KEEP_PROXY')
695         if api_token is None:
696             if api_client is None:
697                 api_token = config.get('ARVADOS_API_TOKEN')
698             else:
699                 api_token = api_client.api_token
700         elif api_client is not None:
701             raise ValueError(
702                 "can't build KeepClient with both API client and token")
703         if local_store is None:
704             local_store = os.environ.get('KEEP_LOCAL_STORE')
705
706         self.block_cache = block_cache if block_cache else KeepBlockCache()
707         self.timeout = timeout
708         self.proxy_timeout = proxy_timeout
709         self._user_agent_pool = Queue.LifoQueue()
710         self.upload_counter = Counter()
711         self.download_counter = Counter()
712         self.put_counter = Counter()
713         self.get_counter = Counter()
714         self.hits_counter = Counter()
715         self.misses_counter = Counter()
716
717         if local_store:
718             self.local_store = local_store
719             self.get = self.local_store_get
720             self.put = self.local_store_put
721         else:
722             self.num_retries = num_retries
723             self.max_replicas_per_service = None
724             if proxy:
725                 if not proxy.endswith('/'):
726                     proxy += '/'
727                 self.api_token = api_token
728                 self._gateway_services = {}
729                 self._keep_services = [{
730                     'uuid': 'proxy',
731                     'service_type': 'proxy',
732                     '_service_root': proxy,
733                     }]
734                 self._writable_services = self._keep_services
735                 self.using_proxy = True
736                 self._static_services_list = True
737             else:
738                 # It's important to avoid instantiating an API client
739                 # unless we actually need one, for testing's sake.
740                 if api_client is None:
741                     api_client = arvados.api('v1')
742                 self.api_client = api_client
743                 self.api_token = api_client.api_token
744                 self._gateway_services = {}
745                 self._keep_services = None
746                 self._writable_services = None
747                 self.using_proxy = None
748                 self._static_services_list = False
749
750     def current_timeout(self, attempt_number):
751         """Return the appropriate timeout to use for this client.
752
753         The proxy timeout setting if the backend service is currently a proxy,
754         the regular timeout setting otherwise.  The `attempt_number` indicates
755         how many times the operation has been tried already (starting from 0
756         for the first try), and scales the connection timeout portion of the
757         return value accordingly.
758
759         """
760         # TODO(twp): the timeout should be a property of a
761         # KeepService, not a KeepClient. See #4488.
762         t = self.proxy_timeout if self.using_proxy else self.timeout
763         if len(t) == 2:
764             return (t[0] * (1 << attempt_number), t[1])
765         else:
766             return (t[0] * (1 << attempt_number), t[1], t[2])
767     def _any_nondisk_services(self, service_list):
768         return any(ks.get('service_type', 'disk') != 'disk'
769                    for ks in service_list)
770
771     def build_services_list(self, force_rebuild=False):
772         if (self._static_services_list or
773               (self._keep_services and not force_rebuild)):
774             return
775         with self.lock:
776             try:
777                 keep_services = self.api_client.keep_services().accessible()
778             except Exception:  # API server predates Keep services.
779                 keep_services = self.api_client.keep_disks().list()
780
781             # Gateway services are only used when specified by UUID,
782             # so there's nothing to gain by filtering them by
783             # service_type.
784             self._gateway_services = {ks['uuid']: ks for ks in
785                                       keep_services.execute()['items']}
786             if not self._gateway_services:
787                 raise arvados.errors.NoKeepServersError()
788
789             # Precompute the base URI for each service.
790             for r in self._gateway_services.itervalues():
791                 host = r['service_host']
792                 if not host.startswith('[') and host.find(':') >= 0:
793                     # IPv6 URIs must be formatted like http://[::1]:80/...
794                     host = '[' + host + ']'
795                 r['_service_root'] = "{}://{}:{:d}/".format(
796                     'https' if r['service_ssl_flag'] else 'http',
797                     host,
798                     r['service_port'])
799
800             _logger.debug(str(self._gateway_services))
801             self._keep_services = [
802                 ks for ks in self._gateway_services.itervalues()
803                 if not ks.get('service_type', '').startswith('gateway:')]
804             self._writable_services = [ks for ks in self._keep_services
805                                        if not ks.get('read_only')]
806
807             # For disk type services, max_replicas_per_service is 1
808             # It is unknown (unlimited) for other service types.
809             if self._any_nondisk_services(self._writable_services):
810                 self.max_replicas_per_service = None
811             else:
812                 self.max_replicas_per_service = 1
813
814     def _service_weight(self, data_hash, service_uuid):
815         """Compute the weight of a Keep service endpoint for a data
816         block with a known hash.
817
818         The weight is md5(h + u) where u is the last 15 characters of
819         the service endpoint's UUID.
820         """
821         return hashlib.md5(data_hash + service_uuid[-15:]).hexdigest()
822
823     def weighted_service_roots(self, locator, force_rebuild=False, need_writable=False):
824         """Return an array of Keep service endpoints, in the order in
825         which they should be probed when reading or writing data with
826         the given hash+hints.
827         """
828         self.build_services_list(force_rebuild)
829
830         sorted_roots = []
831         # Use the services indicated by the given +K@... remote
832         # service hints, if any are present and can be resolved to a
833         # URI.
834         for hint in locator.hints:
835             if hint.startswith('K@'):
836                 if len(hint) == 7:
837                     sorted_roots.append(
838                         "https://keep.{}.arvadosapi.com/".format(hint[2:]))
839                 elif len(hint) == 29:
840                     svc = self._gateway_services.get(hint[2:])
841                     if svc:
842                         sorted_roots.append(svc['_service_root'])
843
844         # Sort the available local services by weight (heaviest first)
845         # for this locator, and return their service_roots (base URIs)
846         # in that order.
847         use_services = self._keep_services
848         if need_writable:
849             use_services = self._writable_services
850         self.using_proxy = self._any_nondisk_services(use_services)
851         sorted_roots.extend([
852             svc['_service_root'] for svc in sorted(
853                 use_services,
854                 reverse=True,
855                 key=lambda svc: self._service_weight(locator.md5sum, svc['uuid']))])
856         _logger.debug("{}: {}".format(locator, sorted_roots))
857         return sorted_roots
858
859     def map_new_services(self, roots_map, locator, force_rebuild, need_writable, **headers):
860         # roots_map is a dictionary, mapping Keep service root strings
861         # to KeepService objects.  Poll for Keep services, and add any
862         # new ones to roots_map.  Return the current list of local
863         # root strings.
864         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
865         local_roots = self.weighted_service_roots(locator, force_rebuild, need_writable)
866         for root in local_roots:
867             if root not in roots_map:
868                 roots_map[root] = self.KeepService(
869                     root, self._user_agent_pool,
870                     upload_counter=self.upload_counter,
871                     download_counter=self.download_counter,
872                     **headers)
873         return local_roots
874
875     @staticmethod
876     def _check_loop_result(result):
877         # KeepClient RetryLoops should save results as a 2-tuple: the
878         # actual result of the request, and the number of servers available
879         # to receive the request this round.
880         # This method returns True if there's a real result, False if
881         # there are no more servers available, otherwise None.
882         if isinstance(result, Exception):
883             return None
884         result, tried_server_count = result
885         if (result is not None) and (result is not False):
886             return True
887         elif tried_server_count < 1:
888             _logger.info("No more Keep services to try; giving up")
889             return False
890         else:
891             return None
892
893     def get_from_cache(self, loc):
894         """Fetch a block only if is in the cache, otherwise return None."""
895         slot = self.block_cache.get(loc)
896         if slot is not None and slot.ready.is_set():
897             return slot.get()
898         else:
899             return None
900
901     @retry.retry_method
902     def head(self, loc_s, num_retries=None):
903         return self._get_or_head(loc_s, method="HEAD", num_retries=num_retries)
904
905     @retry.retry_method
906     def get(self, loc_s, num_retries=None):
907         return self._get_or_head(loc_s, method="GET", num_retries=num_retries)
908
909     def _get_or_head(self, loc_s, method="GET", num_retries=None):
910         """Get data from Keep.
911
912         This method fetches one or more blocks of data from Keep.  It
913         sends a request each Keep service registered with the API
914         server (or the proxy provided when this client was
915         instantiated), then each service named in location hints, in
916         sequence.  As soon as one service provides the data, it's
917         returned.
918
919         Arguments:
920         * loc_s: A string of one or more comma-separated locators to fetch.
921           This method returns the concatenation of these blocks.
922         * num_retries: The number of times to retry GET requests to
923           *each* Keep server if it returns temporary failures, with
924           exponential backoff.  Note that, in each loop, the method may try
925           to fetch data from every available Keep service, along with any
926           that are named in location hints in the locator.  The default value
927           is set when the KeepClient is initialized.
928         """
929         if ',' in loc_s:
930             return ''.join(self.get(x) for x in loc_s.split(','))
931
932         self.get_counter.add(1)
933
934         locator = KeepLocator(loc_s)
935         if method == "GET":
936             slot, first = self.block_cache.reserve_cache(locator.md5sum)
937             if not first:
938                 self.hits_counter.add(1)
939                 v = slot.get()
940                 return v
941
942         self.misses_counter.add(1)
943
944         # If the locator has hints specifying a prefix (indicating a
945         # remote keepproxy) or the UUID of a local gateway service,
946         # read data from the indicated service(s) instead of the usual
947         # list of local disk services.
948         hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
949                       for hint in locator.hints if hint.startswith('K@') and len(hint) == 7]
950         hint_roots.extend([self._gateway_services[hint[2:]]['_service_root']
951                            for hint in locator.hints if (
952                                    hint.startswith('K@') and
953                                    len(hint) == 29 and
954                                    self._gateway_services.get(hint[2:])
955                                    )])
956         # Map root URLs to their KeepService objects.
957         roots_map = {
958             root: self.KeepService(root, self._user_agent_pool,
959                                    upload_counter=self.upload_counter,
960                                    download_counter=self.download_counter)
961             for root in hint_roots
962         }
963
964         # See #3147 for a discussion of the loop implementation.  Highlights:
965         # * Refresh the list of Keep services after each failure, in case
966         #   it's being updated.
967         # * Retry until we succeed, we're out of retries, or every available
968         #   service has returned permanent failure.
969         sorted_roots = []
970         roots_map = {}
971         blob = None
972         loop = retry.RetryLoop(num_retries, self._check_loop_result,
973                                backoff_start=2)
974         for tries_left in loop:
975             try:
976                 sorted_roots = self.map_new_services(
977                     roots_map, locator,
978                     force_rebuild=(tries_left < num_retries),
979                     need_writable=False)
980             except Exception as error:
981                 loop.save_result(error)
982                 continue
983
984             # Query KeepService objects that haven't returned
985             # permanent failure, in our specified shuffle order.
986             services_to_try = [roots_map[root]
987                                for root in sorted_roots
988                                if roots_map[root].usable()]
989             for keep_service in services_to_try:
990                 blob = keep_service.get(locator, method=method, timeout=self.current_timeout(num_retries-tries_left))
991                 if blob is not None:
992                     break
993             loop.save_result((blob, len(services_to_try)))
994
995         # Always cache the result, then return it if we succeeded.
996         if method == "GET":
997             slot.set(blob)
998             self.block_cache.cap_cache()
999         if loop.success():
1000             if method == "HEAD":
1001                 return True
1002             else:
1003                 return blob
1004
1005         # Q: Including 403 is necessary for the Keep tests to continue
1006         # passing, but maybe they should expect KeepReadError instead?
1007         not_founds = sum(1 for key in sorted_roots
1008                          if roots_map[key].last_result().get('status_code', None) in {403, 404, 410})
1009         service_errors = ((key, roots_map[key].last_result()['error'])
1010                           for key in sorted_roots)
1011         if not roots_map:
1012             raise arvados.errors.KeepReadError(
1013                 "failed to read {}: no Keep services available ({})".format(
1014                     loc_s, loop.last_result()))
1015         elif not_founds == len(sorted_roots):
1016             raise arvados.errors.NotFoundError(
1017                 "{} not found".format(loc_s), service_errors)
1018         else:
1019             raise arvados.errors.KeepReadError(
1020                 "failed to read {}".format(loc_s), service_errors, label="service")
1021
1022     @retry.retry_method
1023     def put(self, data, copies=2, num_retries=None):
1024         """Save data in Keep.
1025
1026         This method will get a list of Keep services from the API server, and
1027         send the data to each one simultaneously in a new thread.  Once the
1028         uploads are finished, if enough copies are saved, this method returns
1029         the most recent HTTP response body.  If requests fail to upload
1030         enough copies, this method raises KeepWriteError.
1031
1032         Arguments:
1033         * data: The string of data to upload.
1034         * copies: The number of copies that the user requires be saved.
1035           Default 2.
1036         * num_retries: The number of times to retry PUT requests to
1037           *each* Keep server if it returns temporary failures, with
1038           exponential backoff.  The default value is set when the
1039           KeepClient is initialized.
1040         """
1041
1042         if isinstance(data, unicode):
1043             data = data.encode("ascii")
1044         elif not isinstance(data, str):
1045             raise arvados.errors.ArgumentError("Argument 'data' to KeepClient.put is not type 'str'")
1046
1047         self.put_counter.add(1)
1048
1049         data_hash = hashlib.md5(data).hexdigest()
1050         loc_s = data_hash + '+' + str(len(data))
1051         if copies < 1:
1052             return loc_s
1053         locator = KeepLocator(loc_s)
1054
1055         headers = {}
1056         # Tell the proxy how many copies we want it to store
1057         headers['X-Keep-Desired-Replication'] = str(copies)
1058         roots_map = {}
1059         loop = retry.RetryLoop(num_retries, self._check_loop_result,
1060                                backoff_start=2)
1061         done = 0
1062         for tries_left in loop:
1063             try:
1064                 sorted_roots = self.map_new_services(
1065                     roots_map, locator,
1066                     force_rebuild=(tries_left < num_retries), need_writable=True, **headers)
1067             except Exception as error:
1068                 loop.save_result(error)
1069                 continue
1070
1071             thread_limiter = KeepClient.ThreadLimiter(
1072                 copies - done, self.max_replicas_per_service)
1073             threads = []
1074             for service_root, ks in [(root, roots_map[root])
1075                                      for root in sorted_roots]:
1076                 if ks.finished():
1077                     continue
1078                 t = KeepClient.KeepWriterThread(
1079                     ks,
1080                     data=data,
1081                     data_hash=data_hash,
1082                     service_root=service_root,
1083                     thread_limiter=thread_limiter,
1084                     timeout=self.current_timeout(num_retries-tries_left),
1085                     thread_sequence=len(threads))
1086                 t.start()
1087                 threads.append(t)
1088             for t in threads:
1089                 t.join()
1090             done += thread_limiter.done()
1091             loop.save_result((done >= copies, len(threads)))
1092
1093         if loop.success():
1094             return thread_limiter.response()
1095         if not roots_map:
1096             raise arvados.errors.KeepWriteError(
1097                 "failed to write {}: no Keep services available ({})".format(
1098                     data_hash, loop.last_result()))
1099         else:
1100             service_errors = ((key, roots_map[key].last_result()['error'])
1101                               for key in sorted_roots
1102                               if roots_map[key].last_result()['error'])
1103             raise arvados.errors.KeepWriteError(
1104                 "failed to write {} (wanted {} copies but wrote {})".format(
1105                     data_hash, copies, thread_limiter.done()), service_errors, label="service")
1106
1107     def local_store_put(self, data, copies=1, num_retries=None):
1108         """A stub for put().
1109
1110         This method is used in place of the real put() method when
1111         using local storage (see constructor's local_store argument).
1112
1113         copies and num_retries arguments are ignored: they are here
1114         only for the sake of offering the same call signature as
1115         put().
1116
1117         Data stored this way can be retrieved via local_store_get().
1118         """
1119         md5 = hashlib.md5(data).hexdigest()
1120         locator = '%s+%d' % (md5, len(data))
1121         with open(os.path.join(self.local_store, md5 + '.tmp'), 'w') as f:
1122             f.write(data)
1123         os.rename(os.path.join(self.local_store, md5 + '.tmp'),
1124                   os.path.join(self.local_store, md5))
1125         return locator
1126
1127     def local_store_get(self, loc_s, num_retries=None):
1128         """Companion to local_store_put()."""
1129         try:
1130             locator = KeepLocator(loc_s)
1131         except ValueError:
1132             raise arvados.errors.NotFoundError(
1133                 "Invalid data locator: '%s'" % loc_s)
1134         if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
1135             return ''
1136         with open(os.path.join(self.local_store, locator.md5sum), 'r') as f:
1137             return f.read()
1138
1139     def is_cached(self, locator):
1140         return self.block_cache.reserve_cache(expect_hash)