Added tests and documentation for save and save_new methods in
[arvados.git] / sdk / python / arvados / keep.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from __future__ import absolute_import
6 from __future__ import division
7 from future import standard_library
8 from future.utils import native_str
9 standard_library.install_aliases()
10 from builtins import next
11 from builtins import str
12 from builtins import range
13 from builtins import object
14 import collections
15 import datetime
16 import hashlib
17 import io
18 import logging
19 import math
20 import os
21 import pycurl
22 import queue
23 import re
24 import socket
25 import ssl
26 import sys
27 import threading
28 from . import timer
29 import urllib.parse
30
31 if sys.version_info >= (3, 0):
32     from io import BytesIO
33 else:
34     from cStringIO import StringIO as BytesIO
35
36 import arvados
37 import arvados.config as config
38 import arvados.errors
39 import arvados.retry as retry
40 import arvados.util
41
42 _logger = logging.getLogger('arvados.keep')
43 global_client_object = None
44
45
46 # Monkey patch TCP constants when not available (apple). Values sourced from:
47 # http://www.opensource.apple.com/source/xnu/xnu-2422.115.4/bsd/netinet/tcp.h
48 if sys.platform == 'darwin':
49     if not hasattr(socket, 'TCP_KEEPALIVE'):
50         socket.TCP_KEEPALIVE = 0x010
51     if not hasattr(socket, 'TCP_KEEPINTVL'):
52         socket.TCP_KEEPINTVL = 0x101
53     if not hasattr(socket, 'TCP_KEEPCNT'):
54         socket.TCP_KEEPCNT = 0x102
55
56
57 class KeepLocator(object):
58     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
59     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
60
61     def __init__(self, locator_str):
62         self.hints = []
63         self._perm_sig = None
64         self._perm_expiry = None
65         pieces = iter(locator_str.split('+'))
66         self.md5sum = next(pieces)
67         try:
68             self.size = int(next(pieces))
69         except StopIteration:
70             self.size = None
71         for hint in pieces:
72             if self.HINT_RE.match(hint) is None:
73                 raise ValueError("invalid hint format: {}".format(hint))
74             elif hint.startswith('A'):
75                 self.parse_permission_hint(hint)
76             else:
77                 self.hints.append(hint)
78
79     def __str__(self):
80         return '+'.join(
81             native_str(s)
82             for s in [self.md5sum, self.size,
83                       self.permission_hint()] + self.hints
84             if s is not None)
85
86     def stripped(self):
87         if self.size is not None:
88             return "%s+%i" % (self.md5sum, self.size)
89         else:
90             return self.md5sum
91
92     def _make_hex_prop(name, length):
93         # Build and return a new property with the given name that
94         # must be a hex string of the given length.
95         data_name = '_{}'.format(name)
96         def getter(self):
97             return getattr(self, data_name)
98         def setter(self, hex_str):
99             if not arvados.util.is_hex(hex_str, length):
100                 raise ValueError("{} is not a {}-digit hex string: {!r}".
101                                  format(name, length, hex_str))
102             setattr(self, data_name, hex_str)
103         return property(getter, setter)
104
105     md5sum = _make_hex_prop('md5sum', 32)
106     perm_sig = _make_hex_prop('perm_sig', 40)
107
108     @property
109     def perm_expiry(self):
110         return self._perm_expiry
111
112     @perm_expiry.setter
113     def perm_expiry(self, value):
114         if not arvados.util.is_hex(value, 1, 8):
115             raise ValueError(
116                 "permission timestamp must be a hex Unix timestamp: {}".
117                 format(value))
118         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
119
120     def permission_hint(self):
121         data = [self.perm_sig, self.perm_expiry]
122         if None in data:
123             return None
124         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
125         return "A{}@{:08x}".format(*data)
126
127     def parse_permission_hint(self, s):
128         try:
129             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
130         except IndexError:
131             raise ValueError("bad permission hint {}".format(s))
132
133     def permission_expired(self, as_of_dt=None):
134         if self.perm_expiry is None:
135             return False
136         elif as_of_dt is None:
137             as_of_dt = datetime.datetime.now()
138         return self.perm_expiry <= as_of_dt
139
140
141 class Keep(object):
142     """Simple interface to a global KeepClient object.
143
144     THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
145     own API client.  The global KeepClient will build an API client from the
146     current Arvados configuration, which may not match the one you built.
147     """
148     _last_key = None
149
150     @classmethod
151     def global_client_object(cls):
152         global global_client_object
153         # Previously, KeepClient would change its behavior at runtime based
154         # on these configuration settings.  We simulate that behavior here
155         # by checking the values and returning a new KeepClient if any of
156         # them have changed.
157         key = (config.get('ARVADOS_API_HOST'),
158                config.get('ARVADOS_API_TOKEN'),
159                config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
160                config.get('ARVADOS_KEEP_PROXY'),
161                config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
162                os.environ.get('KEEP_LOCAL_STORE'))
163         if (global_client_object is None) or (cls._last_key != key):
164             global_client_object = KeepClient()
165             cls._last_key = key
166         return global_client_object
167
168     @staticmethod
169     def get(locator, **kwargs):
170         return Keep.global_client_object().get(locator, **kwargs)
171
172     @staticmethod
173     def put(data, **kwargs):
174         return Keep.global_client_object().put(data, **kwargs)
175
176 class KeepBlockCache(object):
177     # Default RAM cache is 256MiB
178     def __init__(self, cache_max=(256 * 1024 * 1024)):
179         self.cache_max = cache_max
180         self._cache = []
181         self._cache_lock = threading.Lock()
182
183     class CacheSlot(object):
184         __slots__ = ("locator", "ready", "content")
185
186         def __init__(self, locator):
187             self.locator = locator
188             self.ready = threading.Event()
189             self.content = None
190
191         def get(self):
192             self.ready.wait()
193             return self.content
194
195         def set(self, value):
196             self.content = value
197             self.ready.set()
198
199         def size(self):
200             if self.content is None:
201                 return 0
202             else:
203                 return len(self.content)
204
205     def cap_cache(self):
206         '''Cap the cache size to self.cache_max'''
207         with self._cache_lock:
208             # Select all slots except those where ready.is_set() and content is
209             # None (that means there was an error reading the block).
210             self._cache = [c for c in self._cache if not (c.ready.is_set() and c.content is None)]
211             sm = sum([slot.size() for slot in self._cache])
212             while len(self._cache) > 0 and sm > self.cache_max:
213                 for i in range(len(self._cache)-1, -1, -1):
214                     if self._cache[i].ready.is_set():
215                         del self._cache[i]
216                         break
217                 sm = sum([slot.size() for slot in self._cache])
218
219     def _get(self, locator):
220         # Test if the locator is already in the cache
221         for i in range(0, len(self._cache)):
222             if self._cache[i].locator == locator:
223                 n = self._cache[i]
224                 if i != 0:
225                     # move it to the front
226                     del self._cache[i]
227                     self._cache.insert(0, n)
228                 return n
229         return None
230
231     def get(self, locator):
232         with self._cache_lock:
233             return self._get(locator)
234
235     def reserve_cache(self, locator):
236         '''Reserve a cache slot for the specified locator,
237         or return the existing slot.'''
238         with self._cache_lock:
239             n = self._get(locator)
240             if n:
241                 return n, False
242             else:
243                 # Add a new cache slot for the locator
244                 n = KeepBlockCache.CacheSlot(locator)
245                 self._cache.insert(0, n)
246                 return n, True
247
248 class Counter(object):
249     def __init__(self, v=0):
250         self._lk = threading.Lock()
251         self._val = v
252
253     def add(self, v):
254         with self._lk:
255             self._val += v
256
257     def get(self):
258         with self._lk:
259             return self._val
260
261
262 class KeepClient(object):
263
264     # Default Keep server connection timeout:  2 seconds
265     # Default Keep server read timeout:       256 seconds
266     # Default Keep server bandwidth minimum:  32768 bytes per second
267     # Default Keep proxy connection timeout:  20 seconds
268     # Default Keep proxy read timeout:        256 seconds
269     # Default Keep proxy bandwidth minimum:   32768 bytes per second
270     DEFAULT_TIMEOUT = (2, 256, 32768)
271     DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
272
273
274     class KeepService(object):
275         """Make requests to a single Keep service, and track results.
276
277         A KeepService is intended to last long enough to perform one
278         transaction (GET or PUT) against one Keep service. This can
279         involve calling either get() or put() multiple times in order
280         to retry after transient failures. However, calling both get()
281         and put() on a single instance -- or using the same instance
282         to access two different Keep services -- will not produce
283         sensible behavior.
284         """
285
286         HTTP_ERRORS = (
287             socket.error,
288             ssl.SSLError,
289             arvados.errors.HttpError,
290         )
291
292         def __init__(self, root, user_agent_pool=queue.LifoQueue(),
293                      upload_counter=None,
294                      download_counter=None,
295                      headers={}):
296             self.root = root
297             self._user_agent_pool = user_agent_pool
298             self._result = {'error': None}
299             self._usable = True
300             self._session = None
301             self._socket = None
302             self.get_headers = {'Accept': 'application/octet-stream'}
303             self.get_headers.update(headers)
304             self.put_headers = headers
305             self.upload_counter = upload_counter
306             self.download_counter = download_counter
307
308         def usable(self):
309             """Is it worth attempting a request?"""
310             return self._usable
311
312         def finished(self):
313             """Did the request succeed or encounter permanent failure?"""
314             return self._result['error'] == False or not self._usable
315
316         def last_result(self):
317             return self._result
318
319         def _get_user_agent(self):
320             try:
321                 return self._user_agent_pool.get(block=False)
322             except queue.Empty:
323                 return pycurl.Curl()
324
325         def _put_user_agent(self, ua):
326             try:
327                 ua.reset()
328                 self._user_agent_pool.put(ua, block=False)
329             except:
330                 ua.close()
331
332         def _socket_open(self, *args, **kwargs):
333             if len(args) + len(kwargs) == 2:
334                 return self._socket_open_pycurl_7_21_5(*args, **kwargs)
335             else:
336                 return self._socket_open_pycurl_7_19_3(*args, **kwargs)
337
338         def _socket_open_pycurl_7_19_3(self, family, socktype, protocol, address=None):
339             return self._socket_open_pycurl_7_21_5(
340                 purpose=None,
341                 address=collections.namedtuple(
342                     'Address', ['family', 'socktype', 'protocol', 'addr'],
343                 )(family, socktype, protocol, address))
344
345         def _socket_open_pycurl_7_21_5(self, purpose, address):
346             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
347             s = socket.socket(address.family, address.socktype, address.protocol)
348             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
349             # Will throw invalid protocol error on mac. This test prevents that.
350             if hasattr(socket, 'TCP_KEEPIDLE'):
351                 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
352             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
353             self._socket = s
354             return s
355
356         def get(self, locator, method="GET", timeout=None):
357             # locator is a KeepLocator object.
358             url = self.root + str(locator)
359             _logger.debug("Request: %s %s", method, url)
360             curl = self._get_user_agent()
361             ok = None
362             try:
363                 with timer.Timer() as t:
364                     self._headers = {}
365                     response_body = BytesIO()
366                     curl.setopt(pycurl.NOSIGNAL, 1)
367                     curl.setopt(pycurl.OPENSOCKETFUNCTION,
368                                 lambda *args, **kwargs: self._socket_open(*args, **kwargs))
369                     curl.setopt(pycurl.URL, url.encode('utf-8'))
370                     curl.setopt(pycurl.HTTPHEADER, [
371                         '{}: {}'.format(k,v) for k,v in self.get_headers.items()])
372                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
373                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
374                     if method == "HEAD":
375                         curl.setopt(pycurl.NOBODY, True)
376                     self._setcurltimeouts(curl, timeout)
377
378                     try:
379                         curl.perform()
380                     except Exception as e:
381                         raise arvados.errors.HttpError(0, str(e))
382                     finally:
383                         if self._socket:
384                             self._socket.close()
385                             self._socket = None
386                     self._result = {
387                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
388                         'body': response_body.getvalue(),
389                         'headers': self._headers,
390                         'error': False,
391                     }
392
393                 ok = retry.check_http_response_success(self._result['status_code'])
394                 if not ok:
395                     self._result['error'] = arvados.errors.HttpError(
396                         self._result['status_code'],
397                         self._headers.get('x-status-line', 'Error'))
398             except self.HTTP_ERRORS as e:
399                 self._result = {
400                     'error': e,
401                 }
402             self._usable = ok != False
403             if self._result.get('status_code', None):
404                 # The client worked well enough to get an HTTP status
405                 # code, so presumably any problems are just on the
406                 # server side and it's OK to reuse the client.
407                 self._put_user_agent(curl)
408             else:
409                 # Don't return this client to the pool, in case it's
410                 # broken.
411                 curl.close()
412             if not ok:
413                 _logger.debug("Request fail: GET %s => %s: %s",
414                               url, type(self._result['error']), str(self._result['error']))
415                 return None
416             if method == "HEAD":
417                 _logger.info("HEAD %s: %s bytes",
418                          self._result['status_code'],
419                          self._result.get('content-length'))
420                 return True
421
422             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
423                          self._result['status_code'],
424                          len(self._result['body']),
425                          t.msecs,
426                          1.0*len(self._result['body'])/2**20/t.secs if t.secs > 0 else 0)
427
428             if self.download_counter:
429                 self.download_counter.add(len(self._result['body']))
430             resp_md5 = hashlib.md5(self._result['body']).hexdigest()
431             if resp_md5 != locator.md5sum:
432                 _logger.warning("Checksum fail: md5(%s) = %s",
433                                 url, resp_md5)
434                 self._result['error'] = arvados.errors.HttpError(
435                     0, 'Checksum fail')
436                 return None
437             return self._result['body']
438
439         def put(self, hash_s, body, timeout=None):
440             url = self.root + hash_s
441             _logger.debug("Request: PUT %s", url)
442             curl = self._get_user_agent()
443             ok = None
444             try:
445                 with timer.Timer() as t:
446                     self._headers = {}
447                     body_reader = BytesIO(body)
448                     response_body = BytesIO()
449                     curl.setopt(pycurl.NOSIGNAL, 1)
450                     curl.setopt(pycurl.OPENSOCKETFUNCTION,
451                                 lambda *args, **kwargs: self._socket_open(*args, **kwargs))
452                     curl.setopt(pycurl.URL, url.encode('utf-8'))
453                     # Using UPLOAD tells cURL to wait for a "go ahead" from the
454                     # Keep server (in the form of a HTTP/1.1 "100 Continue"
455                     # response) instead of sending the request body immediately.
456                     # This allows the server to reject the request if the request
457                     # is invalid or the server is read-only, without waiting for
458                     # the client to send the entire block.
459                     curl.setopt(pycurl.UPLOAD, True)
460                     curl.setopt(pycurl.INFILESIZE, len(body))
461                     curl.setopt(pycurl.READFUNCTION, body_reader.read)
462                     curl.setopt(pycurl.HTTPHEADER, [
463                         '{}: {}'.format(k,v) for k,v in self.put_headers.items()])
464                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
465                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
466                     self._setcurltimeouts(curl, timeout)
467                     try:
468                         curl.perform()
469                     except Exception as e:
470                         raise arvados.errors.HttpError(0, str(e))
471                     finally:
472                         if self._socket:
473                             self._socket.close()
474                             self._socket = None
475                     self._result = {
476                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
477                         'body': response_body.getvalue().decode('utf-8'),
478                         'headers': self._headers,
479                         'error': False,
480                     }
481                 ok = retry.check_http_response_success(self._result['status_code'])
482                 if not ok:
483                     self._result['error'] = arvados.errors.HttpError(
484                         self._result['status_code'],
485                         self._headers.get('x-status-line', 'Error'))
486             except self.HTTP_ERRORS as e:
487                 self._result = {
488                     'error': e,
489                 }
490             self._usable = ok != False # still usable if ok is True or None
491             if self._result.get('status_code', None):
492                 # Client is functional. See comment in get().
493                 self._put_user_agent(curl)
494             else:
495                 curl.close()
496             if not ok:
497                 _logger.debug("Request fail: PUT %s => %s: %s",
498                               url, type(self._result['error']), str(self._result['error']))
499                 return False
500             _logger.info("PUT %s: %s bytes in %s msec (%.3f MiB/sec)",
501                          self._result['status_code'],
502                          len(body),
503                          t.msecs,
504                          1.0*len(body)/2**20/t.secs if t.secs > 0 else 0)
505             if self.upload_counter:
506                 self.upload_counter.add(len(body))
507             return True
508
509         def _setcurltimeouts(self, curl, timeouts):
510             if not timeouts:
511                 return
512             elif isinstance(timeouts, tuple):
513                 if len(timeouts) == 2:
514                     conn_t, xfer_t = timeouts
515                     bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
516                 else:
517                     conn_t, xfer_t, bandwidth_bps = timeouts
518             else:
519                 conn_t, xfer_t = (timeouts, timeouts)
520                 bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
521             curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
522             curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
523             curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
524
525         def _headerfunction(self, header_line):
526             if isinstance(header_line, bytes):
527                 header_line = header_line.decode('iso-8859-1')
528             if ':' in header_line:
529                 name, value = header_line.split(':', 1)
530                 name = name.strip().lower()
531                 value = value.strip()
532             elif self._headers:
533                 name = self._lastheadername
534                 value = self._headers[name] + ' ' + header_line.strip()
535             elif header_line.startswith('HTTP/'):
536                 name = 'x-status-line'
537                 value = header_line
538             else:
539                 _logger.error("Unexpected header line: %s", header_line)
540                 return
541             self._lastheadername = name
542             self._headers[name] = value
543             # Returning None implies all bytes were written
544
545
546     class KeepWriterQueue(queue.Queue):
547         def __init__(self, copies):
548             queue.Queue.__init__(self) # Old-style superclass
549             self.wanted_copies = copies
550             self.successful_copies = 0
551             self.response = None
552             self.successful_copies_lock = threading.Lock()
553             self.pending_tries = copies
554             self.pending_tries_notification = threading.Condition()
555
556         def write_success(self, response, replicas_nr):
557             with self.successful_copies_lock:
558                 self.successful_copies += replicas_nr
559                 self.response = response
560             with self.pending_tries_notification:
561                 self.pending_tries_notification.notify_all()
562
563         def write_fail(self, ks):
564             with self.pending_tries_notification:
565                 self.pending_tries += 1
566                 self.pending_tries_notification.notify()
567
568         def pending_copies(self):
569             with self.successful_copies_lock:
570                 return self.wanted_copies - self.successful_copies
571
572         def get_next_task(self):
573             with self.pending_tries_notification:
574                 while True:
575                     if self.pending_copies() < 1:
576                         # This notify_all() is unnecessary --
577                         # write_success() already called notify_all()
578                         # when pending<1 became true, so it's not
579                         # possible for any other thread to be in
580                         # wait() now -- but it's cheap insurance
581                         # against deadlock so we do it anyway:
582                         self.pending_tries_notification.notify_all()
583                         # Drain the queue and then raise Queue.Empty
584                         while True:
585                             self.get_nowait()
586                             self.task_done()
587                     elif self.pending_tries > 0:
588                         service, service_root = self.get_nowait()
589                         if service.finished():
590                             self.task_done()
591                             continue
592                         self.pending_tries -= 1
593                         return service, service_root
594                     elif self.empty():
595                         self.pending_tries_notification.notify_all()
596                         raise queue.Empty
597                     else:
598                         self.pending_tries_notification.wait()
599
600
601     class KeepWriterThreadPool(object):
602         def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None):
603             self.total_task_nr = 0
604             self.wanted_copies = copies
605             if (not max_service_replicas) or (max_service_replicas >= copies):
606                 num_threads = 1
607             else:
608                 num_threads = int(math.ceil(1.0*copies/max_service_replicas))
609             _logger.debug("Pool max threads is %d", num_threads)
610             self.workers = []
611             self.queue = KeepClient.KeepWriterQueue(copies)
612             # Create workers
613             for _ in range(num_threads):
614                 w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
615                 self.workers.append(w)
616
617         def add_task(self, ks, service_root):
618             self.queue.put((ks, service_root))
619             self.total_task_nr += 1
620
621         def done(self):
622             return self.queue.successful_copies
623
624         def join(self):
625             # Start workers
626             for worker in self.workers:
627                 worker.start()
628             # Wait for finished work
629             self.queue.join()
630
631         def response(self):
632             return self.queue.response
633
634
635     class KeepWriterThread(threading.Thread):
636         TaskFailed = RuntimeError()
637
638         def __init__(self, queue, data, data_hash, timeout=None):
639             super(KeepClient.KeepWriterThread, self).__init__()
640             self.timeout = timeout
641             self.queue = queue
642             self.data = data
643             self.data_hash = data_hash
644             self.daemon = True
645
646         def run(self):
647             while True:
648                 try:
649                     service, service_root = self.queue.get_next_task()
650                 except queue.Empty:
651                     return
652                 try:
653                     locator, copies = self.do_task(service, service_root)
654                 except Exception as e:
655                     if e is not self.TaskFailed:
656                         _logger.exception("Exception in KeepWriterThread")
657                     self.queue.write_fail(service)
658                 else:
659                     self.queue.write_success(locator, copies)
660                 finally:
661                     self.queue.task_done()
662
663         def do_task(self, service, service_root):
664             success = bool(service.put(self.data_hash,
665                                         self.data,
666                                         timeout=self.timeout))
667             result = service.last_result()
668
669             if not success:
670                 if result.get('status_code', None):
671                     _logger.debug("Request fail: PUT %s => %s %s",
672                                   self.data_hash,
673                                   result['status_code'],
674                                   result['body'])
675                 raise self.TaskFailed
676
677             _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
678                           str(threading.current_thread()),
679                           self.data_hash,
680                           len(self.data),
681                           service_root)
682             try:
683                 replicas_stored = int(result['headers']['x-keep-replicas-stored'])
684             except (KeyError, ValueError):
685                 replicas_stored = 1
686
687             return result['body'].strip(), replicas_stored
688
689
690     def __init__(self, api_client=None, proxy=None,
691                  timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
692                  api_token=None, local_store=None, block_cache=None,
693                  num_retries=0, session=None):
694         """Initialize a new KeepClient.
695
696         Arguments:
697         :api_client:
698           The API client to use to find Keep services.  If not
699           provided, KeepClient will build one from available Arvados
700           configuration.
701
702         :proxy:
703           If specified, this KeepClient will send requests to this Keep
704           proxy.  Otherwise, KeepClient will fall back to the setting of the
705           ARVADOS_KEEP_SERVICES or ARVADOS_KEEP_PROXY configuration settings.
706           If you want to KeepClient does not use a proxy, pass in an empty
707           string.
708
709         :timeout:
710           The initial timeout (in seconds) for HTTP requests to Keep
711           non-proxy servers.  A tuple of three floats is interpreted as
712           (connection_timeout, read_timeout, minimum_bandwidth). A connection
713           will be aborted if the average traffic rate falls below
714           minimum_bandwidth bytes per second over an interval of read_timeout
715           seconds. Because timeouts are often a result of transient server
716           load, the actual connection timeout will be increased by a factor
717           of two on each retry.
718           Default: (2, 256, 32768).
719
720         :proxy_timeout:
721           The initial timeout (in seconds) for HTTP requests to
722           Keep proxies. A tuple of three floats is interpreted as
723           (connection_timeout, read_timeout, minimum_bandwidth). The behavior
724           described above for adjusting connection timeouts on retry also
725           applies.
726           Default: (20, 256, 32768).
727
728         :api_token:
729           If you're not using an API client, but only talking
730           directly to a Keep proxy, this parameter specifies an API token
731           to authenticate Keep requests.  It is an error to specify both
732           api_client and api_token.  If you specify neither, KeepClient
733           will use one available from the Arvados configuration.
734
735         :local_store:
736           If specified, this KeepClient will bypass Keep
737           services, and save data to the named directory.  If unspecified,
738           KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
739           environment variable.  If you want to ensure KeepClient does not
740           use local storage, pass in an empty string.  This is primarily
741           intended to mock a server for testing.
742
743         :num_retries:
744           The default number of times to retry failed requests.
745           This will be used as the default num_retries value when get() and
746           put() are called.  Default 0.
747         """
748         self.lock = threading.Lock()
749         if proxy is None:
750             if config.get('ARVADOS_KEEP_SERVICES'):
751                 proxy = config.get('ARVADOS_KEEP_SERVICES')
752             else:
753                 proxy = config.get('ARVADOS_KEEP_PROXY')
754         if api_token is None:
755             if api_client is None:
756                 api_token = config.get('ARVADOS_API_TOKEN')
757             else:
758                 api_token = api_client.api_token
759         elif api_client is not None:
760             raise ValueError(
761                 "can't build KeepClient with both API client and token")
762         if local_store is None:
763             local_store = os.environ.get('KEEP_LOCAL_STORE')
764
765         self.block_cache = block_cache if block_cache else KeepBlockCache()
766         self.timeout = timeout
767         self.proxy_timeout = proxy_timeout
768         self._user_agent_pool = queue.LifoQueue()
769         self.upload_counter = Counter()
770         self.download_counter = Counter()
771         self.put_counter = Counter()
772         self.get_counter = Counter()
773         self.hits_counter = Counter()
774         self.misses_counter = Counter()
775
776         if local_store:
777             self.local_store = local_store
778             self.get = self.local_store_get
779             self.put = self.local_store_put
780         else:
781             self.num_retries = num_retries
782             self.max_replicas_per_service = None
783             if proxy:
784                 proxy_uris = proxy.split()
785                 for i in range(len(proxy_uris)):
786                     if not proxy_uris[i].endswith('/'):
787                         proxy_uris[i] += '/'
788                     # URL validation
789                     url = urllib.parse.urlparse(proxy_uris[i])
790                     if not (url.scheme and url.netloc):
791                         raise arvados.errors.ArgumentError("Invalid proxy URI: {}".format(proxy_uris[i]))
792                 self.api_token = api_token
793                 self._gateway_services = {}
794                 self._keep_services = [{
795                     'uuid': "00000-bi6l4-%015d" % idx,
796                     'service_type': 'proxy',
797                     '_service_root': uri,
798                     } for idx, uri in enumerate(proxy_uris)]
799                 self._writable_services = self._keep_services
800                 self.using_proxy = True
801                 self._static_services_list = True
802             else:
803                 # It's important to avoid instantiating an API client
804                 # unless we actually need one, for testing's sake.
805                 if api_client is None:
806                     api_client = arvados.api('v1')
807                 self.api_client = api_client
808                 self.api_token = api_client.api_token
809                 self._gateway_services = {}
810                 self._keep_services = None
811                 self._writable_services = None
812                 self.using_proxy = None
813                 self._static_services_list = False
814
815     def current_timeout(self, attempt_number):
816         """Return the appropriate timeout to use for this client.
817
818         The proxy timeout setting if the backend service is currently a proxy,
819         the regular timeout setting otherwise.  The `attempt_number` indicates
820         how many times the operation has been tried already (starting from 0
821         for the first try), and scales the connection timeout portion of the
822         return value accordingly.
823
824         """
825         # TODO(twp): the timeout should be a property of a
826         # KeepService, not a KeepClient. See #4488.
827         t = self.proxy_timeout if self.using_proxy else self.timeout
828         if len(t) == 2:
829             return (t[0] * (1 << attempt_number), t[1])
830         else:
831             return (t[0] * (1 << attempt_number), t[1], t[2])
832     def _any_nondisk_services(self, service_list):
833         return any(ks.get('service_type', 'disk') != 'disk'
834                    for ks in service_list)
835
836     def build_services_list(self, force_rebuild=False):
837         if (self._static_services_list or
838               (self._keep_services and not force_rebuild)):
839             return
840         with self.lock:
841             try:
842                 keep_services = self.api_client.keep_services().accessible()
843             except Exception:  # API server predates Keep services.
844                 keep_services = self.api_client.keep_disks().list()
845
846             # Gateway services are only used when specified by UUID,
847             # so there's nothing to gain by filtering them by
848             # service_type.
849             self._gateway_services = {ks['uuid']: ks for ks in
850                                       keep_services.execute()['items']}
851             if not self._gateway_services:
852                 raise arvados.errors.NoKeepServersError()
853
854             # Precompute the base URI for each service.
855             for r in self._gateway_services.values():
856                 host = r['service_host']
857                 if not host.startswith('[') and host.find(':') >= 0:
858                     # IPv6 URIs must be formatted like http://[::1]:80/...
859                     host = '[' + host + ']'
860                 r['_service_root'] = "{}://{}:{:d}/".format(
861                     'https' if r['service_ssl_flag'] else 'http',
862                     host,
863                     r['service_port'])
864
865             _logger.debug(str(self._gateway_services))
866             self._keep_services = [
867                 ks for ks in self._gateway_services.values()
868                 if not ks.get('service_type', '').startswith('gateway:')]
869             self._writable_services = [ks for ks in self._keep_services
870                                        if not ks.get('read_only')]
871
872             # For disk type services, max_replicas_per_service is 1
873             # It is unknown (unlimited) for other service types.
874             if self._any_nondisk_services(self._writable_services):
875                 self.max_replicas_per_service = None
876             else:
877                 self.max_replicas_per_service = 1
878
879     def _service_weight(self, data_hash, service_uuid):
880         """Compute the weight of a Keep service endpoint for a data
881         block with a known hash.
882
883         The weight is md5(h + u) where u is the last 15 characters of
884         the service endpoint's UUID.
885         """
886         return hashlib.md5((data_hash + service_uuid[-15:]).encode()).hexdigest()
887
888     def weighted_service_roots(self, locator, force_rebuild=False, need_writable=False):
889         """Return an array of Keep service endpoints, in the order in
890         which they should be probed when reading or writing data with
891         the given hash+hints.
892         """
893         self.build_services_list(force_rebuild)
894
895         sorted_roots = []
896         # Use the services indicated by the given +K@... remote
897         # service hints, if any are present and can be resolved to a
898         # URI.
899         for hint in locator.hints:
900             if hint.startswith('K@'):
901                 if len(hint) == 7:
902                     sorted_roots.append(
903                         "https://keep.{}.arvadosapi.com/".format(hint[2:]))
904                 elif len(hint) == 29:
905                     svc = self._gateway_services.get(hint[2:])
906                     if svc:
907                         sorted_roots.append(svc['_service_root'])
908
909         # Sort the available local services by weight (heaviest first)
910         # for this locator, and return their service_roots (base URIs)
911         # in that order.
912         use_services = self._keep_services
913         if need_writable:
914             use_services = self._writable_services
915         self.using_proxy = self._any_nondisk_services(use_services)
916         sorted_roots.extend([
917             svc['_service_root'] for svc in sorted(
918                 use_services,
919                 reverse=True,
920                 key=lambda svc: self._service_weight(locator.md5sum, svc['uuid']))])
921         _logger.debug("{}: {}".format(locator, sorted_roots))
922         return sorted_roots
923
924     def map_new_services(self, roots_map, locator, force_rebuild, need_writable, headers):
925         # roots_map is a dictionary, mapping Keep service root strings
926         # to KeepService objects.  Poll for Keep services, and add any
927         # new ones to roots_map.  Return the current list of local
928         # root strings.
929         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
930         local_roots = self.weighted_service_roots(locator, force_rebuild, need_writable)
931         for root in local_roots:
932             if root not in roots_map:
933                 roots_map[root] = self.KeepService(
934                     root, self._user_agent_pool,
935                     upload_counter=self.upload_counter,
936                     download_counter=self.download_counter,
937                     headers=headers)
938         return local_roots
939
940     @staticmethod
941     def _check_loop_result(result):
942         # KeepClient RetryLoops should save results as a 2-tuple: the
943         # actual result of the request, and the number of servers available
944         # to receive the request this round.
945         # This method returns True if there's a real result, False if
946         # there are no more servers available, otherwise None.
947         if isinstance(result, Exception):
948             return None
949         result, tried_server_count = result
950         if (result is not None) and (result is not False):
951             return True
952         elif tried_server_count < 1:
953             _logger.info("No more Keep services to try; giving up")
954             return False
955         else:
956             return None
957
958     def get_from_cache(self, loc):
959         """Fetch a block only if is in the cache, otherwise return None."""
960         slot = self.block_cache.get(loc)
961         if slot is not None and slot.ready.is_set():
962             return slot.get()
963         else:
964             return None
965
966     @retry.retry_method
967     def head(self, loc_s, **kwargs):
968         return self._get_or_head(loc_s, method="HEAD", **kwargs)
969
970     @retry.retry_method
971     def get(self, loc_s, **kwargs):
972         return self._get_or_head(loc_s, method="GET", **kwargs)
973
974     def _get_or_head(self, loc_s, method="GET", num_retries=None, request_id=None):
975         """Get data from Keep.
976
977         This method fetches one or more blocks of data from Keep.  It
978         sends a request each Keep service registered with the API
979         server (or the proxy provided when this client was
980         instantiated), then each service named in location hints, in
981         sequence.  As soon as one service provides the data, it's
982         returned.
983
984         Arguments:
985         * loc_s: A string of one or more comma-separated locators to fetch.
986           This method returns the concatenation of these blocks.
987         * num_retries: The number of times to retry GET requests to
988           *each* Keep server if it returns temporary failures, with
989           exponential backoff.  Note that, in each loop, the method may try
990           to fetch data from every available Keep service, along with any
991           that are named in location hints in the locator.  The default value
992           is set when the KeepClient is initialized.
993         """
994         if ',' in loc_s:
995             return ''.join(self.get(x) for x in loc_s.split(','))
996
997         self.get_counter.add(1)
998
999         slot = None
1000         blob = None
1001         try:
1002             locator = KeepLocator(loc_s)
1003             if method == "GET":
1004                 slot, first = self.block_cache.reserve_cache(locator.md5sum)
1005                 if not first:
1006                     self.hits_counter.add(1)
1007                     blob = slot.get()
1008                     if blob is None:
1009                         raise arvados.errors.KeepReadError(
1010                             "failed to read {}".format(loc_s))
1011                     return blob
1012
1013             self.misses_counter.add(1)
1014
1015             headers = {
1016                 'X-Request-Id': (request_id or
1017                                  (hasattr(self, 'api_client') and self.api_client.request_id) or
1018                                  arvados.util.new_request_id()),
1019             }
1020
1021             # If the locator has hints specifying a prefix (indicating a
1022             # remote keepproxy) or the UUID of a local gateway service,
1023             # read data from the indicated service(s) instead of the usual
1024             # list of local disk services.
1025             hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
1026                           for hint in locator.hints if hint.startswith('K@') and len(hint) == 7]
1027             hint_roots.extend([self._gateway_services[hint[2:]]['_service_root']
1028                                for hint in locator.hints if (
1029                                        hint.startswith('K@') and
1030                                        len(hint) == 29 and
1031                                        self._gateway_services.get(hint[2:])
1032                                        )])
1033             # Map root URLs to their KeepService objects.
1034             roots_map = {
1035                 root: self.KeepService(root, self._user_agent_pool,
1036                                        upload_counter=self.upload_counter,
1037                                        download_counter=self.download_counter,
1038                                        headers=headers)
1039                 for root in hint_roots
1040             }
1041
1042             # See #3147 for a discussion of the loop implementation.  Highlights:
1043             # * Refresh the list of Keep services after each failure, in case
1044             #   it's being updated.
1045             # * Retry until we succeed, we're out of retries, or every available
1046             #   service has returned permanent failure.
1047             sorted_roots = []
1048             roots_map = {}
1049             loop = retry.RetryLoop(num_retries, self._check_loop_result,
1050                                    backoff_start=2)
1051             for tries_left in loop:
1052                 try:
1053                     sorted_roots = self.map_new_services(
1054                         roots_map, locator,
1055                         force_rebuild=(tries_left < num_retries),
1056                         need_writable=False,
1057                         headers=headers)
1058                 except Exception as error:
1059                     loop.save_result(error)
1060                     continue
1061
1062                 # Query KeepService objects that haven't returned
1063                 # permanent failure, in our specified shuffle order.
1064                 services_to_try = [roots_map[root]
1065                                    for root in sorted_roots
1066                                    if roots_map[root].usable()]
1067                 for keep_service in services_to_try:
1068                     blob = keep_service.get(locator, method=method, timeout=self.current_timeout(num_retries-tries_left))
1069                     if blob is not None:
1070                         break
1071                 loop.save_result((blob, len(services_to_try)))
1072
1073             # Always cache the result, then return it if we succeeded.
1074             if loop.success():
1075                 if method == "HEAD":
1076                     return True
1077                 else:
1078                     return blob
1079         finally:
1080             if slot is not None:
1081                 slot.set(blob)
1082                 self.block_cache.cap_cache()
1083
1084         # Q: Including 403 is necessary for the Keep tests to continue
1085         # passing, but maybe they should expect KeepReadError instead?
1086         not_founds = sum(1 for key in sorted_roots
1087                          if roots_map[key].last_result().get('status_code', None) in {403, 404, 410})
1088         service_errors = ((key, roots_map[key].last_result()['error'])
1089                           for key in sorted_roots)
1090         if not roots_map:
1091             raise arvados.errors.KeepReadError(
1092                 "failed to read {}: no Keep services available ({})".format(
1093                     loc_s, loop.last_result()))
1094         elif not_founds == len(sorted_roots):
1095             raise arvados.errors.NotFoundError(
1096                 "{} not found".format(loc_s), service_errors)
1097         else:
1098             raise arvados.errors.KeepReadError(
1099                 "failed to read {}".format(loc_s), service_errors, label="service")
1100
1101     @retry.retry_method
1102     def put(self, data, copies=2, num_retries=None, request_id=None):
1103         """Save data in Keep.
1104
1105         This method will get a list of Keep services from the API server, and
1106         send the data to each one simultaneously in a new thread.  Once the
1107         uploads are finished, if enough copies are saved, this method returns
1108         the most recent HTTP response body.  If requests fail to upload
1109         enough copies, this method raises KeepWriteError.
1110
1111         Arguments:
1112         * data: The string of data to upload.
1113         * copies: The number of copies that the user requires be saved.
1114           Default 2.
1115         * num_retries: The number of times to retry PUT requests to
1116           *each* Keep server if it returns temporary failures, with
1117           exponential backoff.  The default value is set when the
1118           KeepClient is initialized.
1119         """
1120
1121         if not isinstance(data, bytes):
1122             data = data.encode()
1123
1124         self.put_counter.add(1)
1125
1126         data_hash = hashlib.md5(data).hexdigest()
1127         loc_s = data_hash + '+' + str(len(data))
1128         if copies < 1:
1129             return loc_s
1130         locator = KeepLocator(loc_s)
1131
1132         headers = {
1133             'X-Request-Id': (request_id or
1134                              (hasattr(self, 'api_client') and self.api_client.request_id) or
1135                              arvados.util.new_request_id()),
1136             'X-Keep-Desired-Replicas': str(copies),
1137         }
1138         roots_map = {}
1139         loop = retry.RetryLoop(num_retries, self._check_loop_result,
1140                                backoff_start=2)
1141         done = 0
1142         for tries_left in loop:
1143             try:
1144                 sorted_roots = self.map_new_services(
1145                     roots_map, locator,
1146                     force_rebuild=(tries_left < num_retries),
1147                     need_writable=True,
1148                     headers=headers)
1149             except Exception as error:
1150                 loop.save_result(error)
1151                 continue
1152
1153             writer_pool = KeepClient.KeepWriterThreadPool(data=data,
1154                                                         data_hash=data_hash,
1155                                                         copies=copies - done,
1156                                                         max_service_replicas=self.max_replicas_per_service,
1157                                                         timeout=self.current_timeout(num_retries - tries_left))
1158             for service_root, ks in [(root, roots_map[root])
1159                                      for root in sorted_roots]:
1160                 if ks.finished():
1161                     continue
1162                 writer_pool.add_task(ks, service_root)
1163             writer_pool.join()
1164             done += writer_pool.done()
1165             loop.save_result((done >= copies, writer_pool.total_task_nr))
1166
1167         if loop.success():
1168             return writer_pool.response()
1169         if not roots_map:
1170             raise arvados.errors.KeepWriteError(
1171                 "failed to write {}: no Keep services available ({})".format(
1172                     data_hash, loop.last_result()))
1173         else:
1174             service_errors = ((key, roots_map[key].last_result()['error'])
1175                               for key in sorted_roots
1176                               if roots_map[key].last_result()['error'])
1177             raise arvados.errors.KeepWriteError(
1178                 "failed to write {} (wanted {} copies but wrote {})".format(
1179                     data_hash, copies, writer_pool.done()), service_errors, label="service")
1180
1181     def local_store_put(self, data, copies=1, num_retries=None):
1182         """A stub for put().
1183
1184         This method is used in place of the real put() method when
1185         using local storage (see constructor's local_store argument).
1186
1187         copies and num_retries arguments are ignored: they are here
1188         only for the sake of offering the same call signature as
1189         put().
1190
1191         Data stored this way can be retrieved via local_store_get().
1192         """
1193         md5 = hashlib.md5(data).hexdigest()
1194         locator = '%s+%d' % (md5, len(data))
1195         with open(os.path.join(self.local_store, md5 + '.tmp'), 'wb') as f:
1196             f.write(data)
1197         os.rename(os.path.join(self.local_store, md5 + '.tmp'),
1198                   os.path.join(self.local_store, md5))
1199         return locator
1200
1201     def local_store_get(self, loc_s, num_retries=None):
1202         """Companion to local_store_put()."""
1203         try:
1204             locator = KeepLocator(loc_s)
1205         except ValueError:
1206             raise arvados.errors.NotFoundError(
1207                 "Invalid data locator: '%s'" % loc_s)
1208         if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
1209             return b''
1210         with open(os.path.join(self.local_store, locator.md5sum), 'rb') as f:
1211             return f.read()
1212
1213     def is_cached(self, locator):
1214         return self.block_cache.reserve_cache(expect_hash)