Merge branch '13645-fix-arvados-api-host-insecure'
[arvados.git] / sdk / python / arvados / keep.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from __future__ import absolute_import
6 from __future__ import division
7 from future import standard_library
8 from future.utils import native_str
9 standard_library.install_aliases()
10 from builtins import next
11 from builtins import str
12 from builtins import range
13 from builtins import object
14 import collections
15 import datetime
16 import hashlib
17 import io
18 import logging
19 import math
20 import os
21 import pycurl
22 import queue
23 import re
24 import socket
25 import ssl
26 import sys
27 import threading
28 from . import timer
29 import urllib.parse
30
31 if sys.version_info >= (3, 0):
32     from io import BytesIO
33 else:
34     from cStringIO import StringIO as BytesIO
35
36 import arvados
37 import arvados.config as config
38 import arvados.errors
39 import arvados.retry as retry
40 import arvados.util
41
42 _logger = logging.getLogger('arvados.keep')
43 global_client_object = None
44
45
46 # Monkey patch TCP constants when not available (apple). Values sourced from:
47 # http://www.opensource.apple.com/source/xnu/xnu-2422.115.4/bsd/netinet/tcp.h
48 if sys.platform == 'darwin':
49     if not hasattr(socket, 'TCP_KEEPALIVE'):
50         socket.TCP_KEEPALIVE = 0x010
51     if not hasattr(socket, 'TCP_KEEPINTVL'):
52         socket.TCP_KEEPINTVL = 0x101
53     if not hasattr(socket, 'TCP_KEEPCNT'):
54         socket.TCP_KEEPCNT = 0x102
55
56
57 class KeepLocator(object):
58     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
59     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
60
61     def __init__(self, locator_str):
62         self.hints = []
63         self._perm_sig = None
64         self._perm_expiry = None
65         pieces = iter(locator_str.split('+'))
66         self.md5sum = next(pieces)
67         try:
68             self.size = int(next(pieces))
69         except StopIteration:
70             self.size = None
71         for hint in pieces:
72             if self.HINT_RE.match(hint) is None:
73                 raise ValueError("invalid hint format: {}".format(hint))
74             elif hint.startswith('A'):
75                 self.parse_permission_hint(hint)
76             else:
77                 self.hints.append(hint)
78
79     def __str__(self):
80         return '+'.join(
81             native_str(s)
82             for s in [self.md5sum, self.size,
83                       self.permission_hint()] + self.hints
84             if s is not None)
85
86     def stripped(self):
87         if self.size is not None:
88             return "%s+%i" % (self.md5sum, self.size)
89         else:
90             return self.md5sum
91
92     def _make_hex_prop(name, length):
93         # Build and return a new property with the given name that
94         # must be a hex string of the given length.
95         data_name = '_{}'.format(name)
96         def getter(self):
97             return getattr(self, data_name)
98         def setter(self, hex_str):
99             if not arvados.util.is_hex(hex_str, length):
100                 raise ValueError("{} is not a {}-digit hex string: {!r}".
101                                  format(name, length, hex_str))
102             setattr(self, data_name, hex_str)
103         return property(getter, setter)
104
105     md5sum = _make_hex_prop('md5sum', 32)
106     perm_sig = _make_hex_prop('perm_sig', 40)
107
108     @property
109     def perm_expiry(self):
110         return self._perm_expiry
111
112     @perm_expiry.setter
113     def perm_expiry(self, value):
114         if not arvados.util.is_hex(value, 1, 8):
115             raise ValueError(
116                 "permission timestamp must be a hex Unix timestamp: {}".
117                 format(value))
118         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
119
120     def permission_hint(self):
121         data = [self.perm_sig, self.perm_expiry]
122         if None in data:
123             return None
124         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
125         return "A{}@{:08x}".format(*data)
126
127     def parse_permission_hint(self, s):
128         try:
129             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
130         except IndexError:
131             raise ValueError("bad permission hint {}".format(s))
132
133     def permission_expired(self, as_of_dt=None):
134         if self.perm_expiry is None:
135             return False
136         elif as_of_dt is None:
137             as_of_dt = datetime.datetime.now()
138         return self.perm_expiry <= as_of_dt
139
140
141 class Keep(object):
142     """Simple interface to a global KeepClient object.
143
144     THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
145     own API client.  The global KeepClient will build an API client from the
146     current Arvados configuration, which may not match the one you built.
147     """
148     _last_key = None
149
150     @classmethod
151     def global_client_object(cls):
152         global global_client_object
153         # Previously, KeepClient would change its behavior at runtime based
154         # on these configuration settings.  We simulate that behavior here
155         # by checking the values and returning a new KeepClient if any of
156         # them have changed.
157         key = (config.get('ARVADOS_API_HOST'),
158                config.get('ARVADOS_API_TOKEN'),
159                config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
160                config.get('ARVADOS_KEEP_PROXY'),
161                config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
162                os.environ.get('KEEP_LOCAL_STORE'))
163         if (global_client_object is None) or (cls._last_key != key):
164             global_client_object = KeepClient()
165             cls._last_key = key
166         return global_client_object
167
168     @staticmethod
169     def get(locator, **kwargs):
170         return Keep.global_client_object().get(locator, **kwargs)
171
172     @staticmethod
173     def put(data, **kwargs):
174         return Keep.global_client_object().put(data, **kwargs)
175
176 class KeepBlockCache(object):
177     # Default RAM cache is 256MiB
178     def __init__(self, cache_max=(256 * 1024 * 1024)):
179         self.cache_max = cache_max
180         self._cache = []
181         self._cache_lock = threading.Lock()
182
183     class CacheSlot(object):
184         __slots__ = ("locator", "ready", "content")
185
186         def __init__(self, locator):
187             self.locator = locator
188             self.ready = threading.Event()
189             self.content = None
190
191         def get(self):
192             self.ready.wait()
193             return self.content
194
195         def set(self, value):
196             self.content = value
197             self.ready.set()
198
199         def size(self):
200             if self.content is None:
201                 return 0
202             else:
203                 return len(self.content)
204
205     def cap_cache(self):
206         '''Cap the cache size to self.cache_max'''
207         with self._cache_lock:
208             # Select all slots except those where ready.is_set() and content is
209             # None (that means there was an error reading the block).
210             self._cache = [c for c in self._cache if not (c.ready.is_set() and c.content is None)]
211             sm = sum([slot.size() for slot in self._cache])
212             while len(self._cache) > 0 and sm > self.cache_max:
213                 for i in range(len(self._cache)-1, -1, -1):
214                     if self._cache[i].ready.is_set():
215                         del self._cache[i]
216                         break
217                 sm = sum([slot.size() for slot in self._cache])
218
219     def _get(self, locator):
220         # Test if the locator is already in the cache
221         for i in range(0, len(self._cache)):
222             if self._cache[i].locator == locator:
223                 n = self._cache[i]
224                 if i != 0:
225                     # move it to the front
226                     del self._cache[i]
227                     self._cache.insert(0, n)
228                 return n
229         return None
230
231     def get(self, locator):
232         with self._cache_lock:
233             return self._get(locator)
234
235     def reserve_cache(self, locator):
236         '''Reserve a cache slot for the specified locator,
237         or return the existing slot.'''
238         with self._cache_lock:
239             n = self._get(locator)
240             if n:
241                 return n, False
242             else:
243                 # Add a new cache slot for the locator
244                 n = KeepBlockCache.CacheSlot(locator)
245                 self._cache.insert(0, n)
246                 return n, True
247
248 class Counter(object):
249     def __init__(self, v=0):
250         self._lk = threading.Lock()
251         self._val = v
252
253     def add(self, v):
254         with self._lk:
255             self._val += v
256
257     def get(self):
258         with self._lk:
259             return self._val
260
261
262 class KeepClient(object):
263
264     # Default Keep server connection timeout:  2 seconds
265     # Default Keep server read timeout:       256 seconds
266     # Default Keep server bandwidth minimum:  32768 bytes per second
267     # Default Keep proxy connection timeout:  20 seconds
268     # Default Keep proxy read timeout:        256 seconds
269     # Default Keep proxy bandwidth minimum:   32768 bytes per second
270     DEFAULT_TIMEOUT = (2, 256, 32768)
271     DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
272
273
274     class KeepService(object):
275         """Make requests to a single Keep service, and track results.
276
277         A KeepService is intended to last long enough to perform one
278         transaction (GET or PUT) against one Keep service. This can
279         involve calling either get() or put() multiple times in order
280         to retry after transient failures. However, calling both get()
281         and put() on a single instance -- or using the same instance
282         to access two different Keep services -- will not produce
283         sensible behavior.
284         """
285
286         HTTP_ERRORS = (
287             socket.error,
288             ssl.SSLError,
289             arvados.errors.HttpError,
290         )
291
292         def __init__(self, root, user_agent_pool=queue.LifoQueue(),
293                      upload_counter=None,
294                      download_counter=None,
295                      headers={},
296                      insecure=False):
297             self.root = root
298             self._user_agent_pool = user_agent_pool
299             self._result = {'error': None}
300             self._usable = True
301             self._session = None
302             self._socket = None
303             self.get_headers = {'Accept': 'application/octet-stream'}
304             self.get_headers.update(headers)
305             self.put_headers = headers
306             self.upload_counter = upload_counter
307             self.download_counter = download_counter
308             self.insecure = insecure
309
310         def usable(self):
311             """Is it worth attempting a request?"""
312             return self._usable
313
314         def finished(self):
315             """Did the request succeed or encounter permanent failure?"""
316             return self._result['error'] == False or not self._usable
317
318         def last_result(self):
319             return self._result
320
321         def _get_user_agent(self):
322             try:
323                 return self._user_agent_pool.get(block=False)
324             except queue.Empty:
325                 return pycurl.Curl()
326
327         def _put_user_agent(self, ua):
328             try:
329                 ua.reset()
330                 self._user_agent_pool.put(ua, block=False)
331             except:
332                 ua.close()
333
334         def _socket_open(self, *args, **kwargs):
335             if len(args) + len(kwargs) == 2:
336                 return self._socket_open_pycurl_7_21_5(*args, **kwargs)
337             else:
338                 return self._socket_open_pycurl_7_19_3(*args, **kwargs)
339
340         def _socket_open_pycurl_7_19_3(self, family, socktype, protocol, address=None):
341             return self._socket_open_pycurl_7_21_5(
342                 purpose=None,
343                 address=collections.namedtuple(
344                     'Address', ['family', 'socktype', 'protocol', 'addr'],
345                 )(family, socktype, protocol, address))
346
347         def _socket_open_pycurl_7_21_5(self, purpose, address):
348             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
349             s = socket.socket(address.family, address.socktype, address.protocol)
350             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
351             # Will throw invalid protocol error on mac. This test prevents that.
352             if hasattr(socket, 'TCP_KEEPIDLE'):
353                 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
354             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
355             self._socket = s
356             return s
357
358         def get(self, locator, method="GET", timeout=None):
359             # locator is a KeepLocator object.
360             url = self.root + str(locator)
361             _logger.debug("Request: %s %s", method, url)
362             curl = self._get_user_agent()
363             ok = None
364             try:
365                 with timer.Timer() as t:
366                     self._headers = {}
367                     response_body = BytesIO()
368                     curl.setopt(pycurl.NOSIGNAL, 1)
369                     curl.setopt(pycurl.OPENSOCKETFUNCTION,
370                                 lambda *args, **kwargs: self._socket_open(*args, **kwargs))
371                     curl.setopt(pycurl.URL, url.encode('utf-8'))
372                     curl.setopt(pycurl.HTTPHEADER, [
373                         '{}: {}'.format(k,v) for k,v in self.get_headers.items()])
374                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
375                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
376                     if self.insecure:
377                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
378                     if method == "HEAD":
379                         curl.setopt(pycurl.NOBODY, True)
380                     self._setcurltimeouts(curl, timeout)
381
382                     try:
383                         curl.perform()
384                     except Exception as e:
385                         raise arvados.errors.HttpError(0, str(e))
386                     finally:
387                         if self._socket:
388                             self._socket.close()
389                             self._socket = None
390                     self._result = {
391                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
392                         'body': response_body.getvalue(),
393                         'headers': self._headers,
394                         'error': False,
395                     }
396
397                 ok = retry.check_http_response_success(self._result['status_code'])
398                 if not ok:
399                     self._result['error'] = arvados.errors.HttpError(
400                         self._result['status_code'],
401                         self._headers.get('x-status-line', 'Error'))
402             except self.HTTP_ERRORS as e:
403                 self._result = {
404                     'error': e,
405                 }
406             self._usable = ok != False
407             if self._result.get('status_code', None):
408                 # The client worked well enough to get an HTTP status
409                 # code, so presumably any problems are just on the
410                 # server side and it's OK to reuse the client.
411                 self._put_user_agent(curl)
412             else:
413                 # Don't return this client to the pool, in case it's
414                 # broken.
415                 curl.close()
416             if not ok:
417                 _logger.debug("Request fail: GET %s => %s: %s",
418                               url, type(self._result['error']), str(self._result['error']))
419                 return None
420             if method == "HEAD":
421                 _logger.info("HEAD %s: %s bytes",
422                          self._result['status_code'],
423                          self._result.get('content-length'))
424                 return True
425
426             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
427                          self._result['status_code'],
428                          len(self._result['body']),
429                          t.msecs,
430                          1.0*len(self._result['body'])/2**20/t.secs if t.secs > 0 else 0)
431
432             if self.download_counter:
433                 self.download_counter.add(len(self._result['body']))
434             resp_md5 = hashlib.md5(self._result['body']).hexdigest()
435             if resp_md5 != locator.md5sum:
436                 _logger.warning("Checksum fail: md5(%s) = %s",
437                                 url, resp_md5)
438                 self._result['error'] = arvados.errors.HttpError(
439                     0, 'Checksum fail')
440                 return None
441             return self._result['body']
442
443         def put(self, hash_s, body, timeout=None):
444             url = self.root + hash_s
445             _logger.debug("Request: PUT %s", url)
446             curl = self._get_user_agent()
447             ok = None
448             try:
449                 with timer.Timer() as t:
450                     self._headers = {}
451                     body_reader = BytesIO(body)
452                     response_body = BytesIO()
453                     curl.setopt(pycurl.NOSIGNAL, 1)
454                     curl.setopt(pycurl.OPENSOCKETFUNCTION,
455                                 lambda *args, **kwargs: self._socket_open(*args, **kwargs))
456                     curl.setopt(pycurl.URL, url.encode('utf-8'))
457                     # Using UPLOAD tells cURL to wait for a "go ahead" from the
458                     # Keep server (in the form of a HTTP/1.1 "100 Continue"
459                     # response) instead of sending the request body immediately.
460                     # This allows the server to reject the request if the request
461                     # is invalid or the server is read-only, without waiting for
462                     # the client to send the entire block.
463                     curl.setopt(pycurl.UPLOAD, True)
464                     curl.setopt(pycurl.INFILESIZE, len(body))
465                     curl.setopt(pycurl.READFUNCTION, body_reader.read)
466                     curl.setopt(pycurl.HTTPHEADER, [
467                         '{}: {}'.format(k,v) for k,v in self.put_headers.items()])
468                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
469                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
470                     if self.insecure:
471                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
472                     self._setcurltimeouts(curl, timeout)
473                     try:
474                         curl.perform()
475                     except Exception as e:
476                         raise arvados.errors.HttpError(0, str(e))
477                     finally:
478                         if self._socket:
479                             self._socket.close()
480                             self._socket = None
481                     self._result = {
482                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
483                         'body': response_body.getvalue().decode('utf-8'),
484                         'headers': self._headers,
485                         'error': False,
486                     }
487                 ok = retry.check_http_response_success(self._result['status_code'])
488                 if not ok:
489                     self._result['error'] = arvados.errors.HttpError(
490                         self._result['status_code'],
491                         self._headers.get('x-status-line', 'Error'))
492             except self.HTTP_ERRORS as e:
493                 self._result = {
494                     'error': e,
495                 }
496             self._usable = ok != False # still usable if ok is True or None
497             if self._result.get('status_code', None):
498                 # Client is functional. See comment in get().
499                 self._put_user_agent(curl)
500             else:
501                 curl.close()
502             if not ok:
503                 _logger.debug("Request fail: PUT %s => %s: %s",
504                               url, type(self._result['error']), str(self._result['error']))
505                 return False
506             _logger.info("PUT %s: %s bytes in %s msec (%.3f MiB/sec)",
507                          self._result['status_code'],
508                          len(body),
509                          t.msecs,
510                          1.0*len(body)/2**20/t.secs if t.secs > 0 else 0)
511             if self.upload_counter:
512                 self.upload_counter.add(len(body))
513             return True
514
515         def _setcurltimeouts(self, curl, timeouts):
516             if not timeouts:
517                 return
518             elif isinstance(timeouts, tuple):
519                 if len(timeouts) == 2:
520                     conn_t, xfer_t = timeouts
521                     bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
522                 else:
523                     conn_t, xfer_t, bandwidth_bps = timeouts
524             else:
525                 conn_t, xfer_t = (timeouts, timeouts)
526                 bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
527             curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
528             curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
529             curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
530
531         def _headerfunction(self, header_line):
532             if isinstance(header_line, bytes):
533                 header_line = header_line.decode('iso-8859-1')
534             if ':' in header_line:
535                 name, value = header_line.split(':', 1)
536                 name = name.strip().lower()
537                 value = value.strip()
538             elif self._headers:
539                 name = self._lastheadername
540                 value = self._headers[name] + ' ' + header_line.strip()
541             elif header_line.startswith('HTTP/'):
542                 name = 'x-status-line'
543                 value = header_line
544             else:
545                 _logger.error("Unexpected header line: %s", header_line)
546                 return
547             self._lastheadername = name
548             self._headers[name] = value
549             # Returning None implies all bytes were written
550
551
552     class KeepWriterQueue(queue.Queue):
553         def __init__(self, copies):
554             queue.Queue.__init__(self) # Old-style superclass
555             self.wanted_copies = copies
556             self.successful_copies = 0
557             self.response = None
558             self.successful_copies_lock = threading.Lock()
559             self.pending_tries = copies
560             self.pending_tries_notification = threading.Condition()
561
562         def write_success(self, response, replicas_nr):
563             with self.successful_copies_lock:
564                 self.successful_copies += replicas_nr
565                 self.response = response
566             with self.pending_tries_notification:
567                 self.pending_tries_notification.notify_all()
568
569         def write_fail(self, ks):
570             with self.pending_tries_notification:
571                 self.pending_tries += 1
572                 self.pending_tries_notification.notify()
573
574         def pending_copies(self):
575             with self.successful_copies_lock:
576                 return self.wanted_copies - self.successful_copies
577
578         def get_next_task(self):
579             with self.pending_tries_notification:
580                 while True:
581                     if self.pending_copies() < 1:
582                         # This notify_all() is unnecessary --
583                         # write_success() already called notify_all()
584                         # when pending<1 became true, so it's not
585                         # possible for any other thread to be in
586                         # wait() now -- but it's cheap insurance
587                         # against deadlock so we do it anyway:
588                         self.pending_tries_notification.notify_all()
589                         # Drain the queue and then raise Queue.Empty
590                         while True:
591                             self.get_nowait()
592                             self.task_done()
593                     elif self.pending_tries > 0:
594                         service, service_root = self.get_nowait()
595                         if service.finished():
596                             self.task_done()
597                             continue
598                         self.pending_tries -= 1
599                         return service, service_root
600                     elif self.empty():
601                         self.pending_tries_notification.notify_all()
602                         raise queue.Empty
603                     else:
604                         self.pending_tries_notification.wait()
605
606
607     class KeepWriterThreadPool(object):
608         def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None):
609             self.total_task_nr = 0
610             self.wanted_copies = copies
611             if (not max_service_replicas) or (max_service_replicas >= copies):
612                 num_threads = 1
613             else:
614                 num_threads = int(math.ceil(1.0*copies/max_service_replicas))
615             _logger.debug("Pool max threads is %d", num_threads)
616             self.workers = []
617             self.queue = KeepClient.KeepWriterQueue(copies)
618             # Create workers
619             for _ in range(num_threads):
620                 w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
621                 self.workers.append(w)
622
623         def add_task(self, ks, service_root):
624             self.queue.put((ks, service_root))
625             self.total_task_nr += 1
626
627         def done(self):
628             return self.queue.successful_copies
629
630         def join(self):
631             # Start workers
632             for worker in self.workers:
633                 worker.start()
634             # Wait for finished work
635             self.queue.join()
636
637         def response(self):
638             return self.queue.response
639
640
641     class KeepWriterThread(threading.Thread):
642         TaskFailed = RuntimeError()
643
644         def __init__(self, queue, data, data_hash, timeout=None):
645             super(KeepClient.KeepWriterThread, self).__init__()
646             self.timeout = timeout
647             self.queue = queue
648             self.data = data
649             self.data_hash = data_hash
650             self.daemon = True
651
652         def run(self):
653             while True:
654                 try:
655                     service, service_root = self.queue.get_next_task()
656                 except queue.Empty:
657                     return
658                 try:
659                     locator, copies = self.do_task(service, service_root)
660                 except Exception as e:
661                     if e is not self.TaskFailed:
662                         _logger.exception("Exception in KeepWriterThread")
663                     self.queue.write_fail(service)
664                 else:
665                     self.queue.write_success(locator, copies)
666                 finally:
667                     self.queue.task_done()
668
669         def do_task(self, service, service_root):
670             success = bool(service.put(self.data_hash,
671                                         self.data,
672                                         timeout=self.timeout))
673             result = service.last_result()
674
675             if not success:
676                 if result.get('status_code', None):
677                     _logger.debug("Request fail: PUT %s => %s %s",
678                                   self.data_hash,
679                                   result['status_code'],
680                                   result['body'])
681                 raise self.TaskFailed
682
683             _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
684                           str(threading.current_thread()),
685                           self.data_hash,
686                           len(self.data),
687                           service_root)
688             try:
689                 replicas_stored = int(result['headers']['x-keep-replicas-stored'])
690             except (KeyError, ValueError):
691                 replicas_stored = 1
692
693             return result['body'].strip(), replicas_stored
694
695
696     def __init__(self, api_client=None, proxy=None,
697                  timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
698                  api_token=None, local_store=None, block_cache=None,
699                  num_retries=0, session=None):
700         """Initialize a new KeepClient.
701
702         Arguments:
703         :api_client:
704           The API client to use to find Keep services.  If not
705           provided, KeepClient will build one from available Arvados
706           configuration.
707
708         :proxy:
709           If specified, this KeepClient will send requests to this Keep
710           proxy.  Otherwise, KeepClient will fall back to the setting of the
711           ARVADOS_KEEP_SERVICES or ARVADOS_KEEP_PROXY configuration settings.
712           If you want to KeepClient does not use a proxy, pass in an empty
713           string.
714
715         :timeout:
716           The initial timeout (in seconds) for HTTP requests to Keep
717           non-proxy servers.  A tuple of three floats is interpreted as
718           (connection_timeout, read_timeout, minimum_bandwidth). A connection
719           will be aborted if the average traffic rate falls below
720           minimum_bandwidth bytes per second over an interval of read_timeout
721           seconds. Because timeouts are often a result of transient server
722           load, the actual connection timeout will be increased by a factor
723           of two on each retry.
724           Default: (2, 256, 32768).
725
726         :proxy_timeout:
727           The initial timeout (in seconds) for HTTP requests to
728           Keep proxies. A tuple of three floats is interpreted as
729           (connection_timeout, read_timeout, minimum_bandwidth). The behavior
730           described above for adjusting connection timeouts on retry also
731           applies.
732           Default: (20, 256, 32768).
733
734         :api_token:
735           If you're not using an API client, but only talking
736           directly to a Keep proxy, this parameter specifies an API token
737           to authenticate Keep requests.  It is an error to specify both
738           api_client and api_token.  If you specify neither, KeepClient
739           will use one available from the Arvados configuration.
740
741         :local_store:
742           If specified, this KeepClient will bypass Keep
743           services, and save data to the named directory.  If unspecified,
744           KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
745           environment variable.  If you want to ensure KeepClient does not
746           use local storage, pass in an empty string.  This is primarily
747           intended to mock a server for testing.
748
749         :num_retries:
750           The default number of times to retry failed requests.
751           This will be used as the default num_retries value when get() and
752           put() are called.  Default 0.
753         """
754         self.lock = threading.Lock()
755         if proxy is None:
756             if config.get('ARVADOS_KEEP_SERVICES'):
757                 proxy = config.get('ARVADOS_KEEP_SERVICES')
758             else:
759                 proxy = config.get('ARVADOS_KEEP_PROXY')
760         if api_token is None:
761             if api_client is None:
762                 api_token = config.get('ARVADOS_API_TOKEN')
763             else:
764                 api_token = api_client.api_token
765         elif api_client is not None:
766             raise ValueError(
767                 "can't build KeepClient with both API client and token")
768         if local_store is None:
769             local_store = os.environ.get('KEEP_LOCAL_STORE')
770
771         if api_client is None:
772             self.insecure = config.flag_is_true('ARVADOS_API_HOST_INSECURE')
773         else:
774             self.insecure = api_client.insecure
775
776         self.block_cache = block_cache if block_cache else KeepBlockCache()
777         self.timeout = timeout
778         self.proxy_timeout = proxy_timeout
779         self._user_agent_pool = queue.LifoQueue()
780         self.upload_counter = Counter()
781         self.download_counter = Counter()
782         self.put_counter = Counter()
783         self.get_counter = Counter()
784         self.hits_counter = Counter()
785         self.misses_counter = Counter()
786
787         if local_store:
788             self.local_store = local_store
789             self.get = self.local_store_get
790             self.put = self.local_store_put
791         else:
792             self.num_retries = num_retries
793             self.max_replicas_per_service = None
794             if proxy:
795                 proxy_uris = proxy.split()
796                 for i in range(len(proxy_uris)):
797                     if not proxy_uris[i].endswith('/'):
798                         proxy_uris[i] += '/'
799                     # URL validation
800                     url = urllib.parse.urlparse(proxy_uris[i])
801                     if not (url.scheme and url.netloc):
802                         raise arvados.errors.ArgumentError("Invalid proxy URI: {}".format(proxy_uris[i]))
803                 self.api_token = api_token
804                 self._gateway_services = {}
805                 self._keep_services = [{
806                     'uuid': "00000-bi6l4-%015d" % idx,
807                     'service_type': 'proxy',
808                     '_service_root': uri,
809                     } for idx, uri in enumerate(proxy_uris)]
810                 self._writable_services = self._keep_services
811                 self.using_proxy = True
812                 self._static_services_list = True
813             else:
814                 # It's important to avoid instantiating an API client
815                 # unless we actually need one, for testing's sake.
816                 if api_client is None:
817                     api_client = arvados.api('v1')
818                 self.api_client = api_client
819                 self.api_token = api_client.api_token
820                 self._gateway_services = {}
821                 self._keep_services = None
822                 self._writable_services = None
823                 self.using_proxy = None
824                 self._static_services_list = False
825
826     def current_timeout(self, attempt_number):
827         """Return the appropriate timeout to use for this client.
828
829         The proxy timeout setting if the backend service is currently a proxy,
830         the regular timeout setting otherwise.  The `attempt_number` indicates
831         how many times the operation has been tried already (starting from 0
832         for the first try), and scales the connection timeout portion of the
833         return value accordingly.
834
835         """
836         # TODO(twp): the timeout should be a property of a
837         # KeepService, not a KeepClient. See #4488.
838         t = self.proxy_timeout if self.using_proxy else self.timeout
839         if len(t) == 2:
840             return (t[0] * (1 << attempt_number), t[1])
841         else:
842             return (t[0] * (1 << attempt_number), t[1], t[2])
843     def _any_nondisk_services(self, service_list):
844         return any(ks.get('service_type', 'disk') != 'disk'
845                    for ks in service_list)
846
847     def build_services_list(self, force_rebuild=False):
848         if (self._static_services_list or
849               (self._keep_services and not force_rebuild)):
850             return
851         with self.lock:
852             try:
853                 keep_services = self.api_client.keep_services().accessible()
854             except Exception:  # API server predates Keep services.
855                 keep_services = self.api_client.keep_disks().list()
856
857             # Gateway services are only used when specified by UUID,
858             # so there's nothing to gain by filtering them by
859             # service_type.
860             self._gateway_services = {ks['uuid']: ks for ks in
861                                       keep_services.execute()['items']}
862             if not self._gateway_services:
863                 raise arvados.errors.NoKeepServersError()
864
865             # Precompute the base URI for each service.
866             for r in self._gateway_services.values():
867                 host = r['service_host']
868                 if not host.startswith('[') and host.find(':') >= 0:
869                     # IPv6 URIs must be formatted like http://[::1]:80/...
870                     host = '[' + host + ']'
871                 r['_service_root'] = "{}://{}:{:d}/".format(
872                     'https' if r['service_ssl_flag'] else 'http',
873                     host,
874                     r['service_port'])
875
876             _logger.debug(str(self._gateway_services))
877             self._keep_services = [
878                 ks for ks in self._gateway_services.values()
879                 if not ks.get('service_type', '').startswith('gateway:')]
880             self._writable_services = [ks for ks in self._keep_services
881                                        if not ks.get('read_only')]
882
883             # For disk type services, max_replicas_per_service is 1
884             # It is unknown (unlimited) for other service types.
885             if self._any_nondisk_services(self._writable_services):
886                 self.max_replicas_per_service = None
887             else:
888                 self.max_replicas_per_service = 1
889
890     def _service_weight(self, data_hash, service_uuid):
891         """Compute the weight of a Keep service endpoint for a data
892         block with a known hash.
893
894         The weight is md5(h + u) where u is the last 15 characters of
895         the service endpoint's UUID.
896         """
897         return hashlib.md5((data_hash + service_uuid[-15:]).encode()).hexdigest()
898
899     def weighted_service_roots(self, locator, force_rebuild=False, need_writable=False):
900         """Return an array of Keep service endpoints, in the order in
901         which they should be probed when reading or writing data with
902         the given hash+hints.
903         """
904         self.build_services_list(force_rebuild)
905
906         sorted_roots = []
907         # Use the services indicated by the given +K@... remote
908         # service hints, if any are present and can be resolved to a
909         # URI.
910         for hint in locator.hints:
911             if hint.startswith('K@'):
912                 if len(hint) == 7:
913                     sorted_roots.append(
914                         "https://keep.{}.arvadosapi.com/".format(hint[2:]))
915                 elif len(hint) == 29:
916                     svc = self._gateway_services.get(hint[2:])
917                     if svc:
918                         sorted_roots.append(svc['_service_root'])
919
920         # Sort the available local services by weight (heaviest first)
921         # for this locator, and return their service_roots (base URIs)
922         # in that order.
923         use_services = self._keep_services
924         if need_writable:
925             use_services = self._writable_services
926         self.using_proxy = self._any_nondisk_services(use_services)
927         sorted_roots.extend([
928             svc['_service_root'] for svc in sorted(
929                 use_services,
930                 reverse=True,
931                 key=lambda svc: self._service_weight(locator.md5sum, svc['uuid']))])
932         _logger.debug("{}: {}".format(locator, sorted_roots))
933         return sorted_roots
934
935     def map_new_services(self, roots_map, locator, force_rebuild, need_writable, headers):
936         # roots_map is a dictionary, mapping Keep service root strings
937         # to KeepService objects.  Poll for Keep services, and add any
938         # new ones to roots_map.  Return the current list of local
939         # root strings.
940         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
941         local_roots = self.weighted_service_roots(locator, force_rebuild, need_writable)
942         for root in local_roots:
943             if root not in roots_map:
944                 roots_map[root] = self.KeepService(
945                     root, self._user_agent_pool,
946                     upload_counter=self.upload_counter,
947                     download_counter=self.download_counter,
948                     headers=headers,
949                     insecure=self.insecure)
950         return local_roots
951
952     @staticmethod
953     def _check_loop_result(result):
954         # KeepClient RetryLoops should save results as a 2-tuple: the
955         # actual result of the request, and the number of servers available
956         # to receive the request this round.
957         # This method returns True if there's a real result, False if
958         # there are no more servers available, otherwise None.
959         if isinstance(result, Exception):
960             return None
961         result, tried_server_count = result
962         if (result is not None) and (result is not False):
963             return True
964         elif tried_server_count < 1:
965             _logger.info("No more Keep services to try; giving up")
966             return False
967         else:
968             return None
969
970     def get_from_cache(self, loc):
971         """Fetch a block only if is in the cache, otherwise return None."""
972         slot = self.block_cache.get(loc)
973         if slot is not None and slot.ready.is_set():
974             return slot.get()
975         else:
976             return None
977
978     @retry.retry_method
979     def head(self, loc_s, **kwargs):
980         return self._get_or_head(loc_s, method="HEAD", **kwargs)
981
982     @retry.retry_method
983     def get(self, loc_s, **kwargs):
984         return self._get_or_head(loc_s, method="GET", **kwargs)
985
986     def _get_or_head(self, loc_s, method="GET", num_retries=None, request_id=None):
987         """Get data from Keep.
988
989         This method fetches one or more blocks of data from Keep.  It
990         sends a request each Keep service registered with the API
991         server (or the proxy provided when this client was
992         instantiated), then each service named in location hints, in
993         sequence.  As soon as one service provides the data, it's
994         returned.
995
996         Arguments:
997         * loc_s: A string of one or more comma-separated locators to fetch.
998           This method returns the concatenation of these blocks.
999         * num_retries: The number of times to retry GET requests to
1000           *each* Keep server if it returns temporary failures, with
1001           exponential backoff.  Note that, in each loop, the method may try
1002           to fetch data from every available Keep service, along with any
1003           that are named in location hints in the locator.  The default value
1004           is set when the KeepClient is initialized.
1005         """
1006         if ',' in loc_s:
1007             return ''.join(self.get(x) for x in loc_s.split(','))
1008
1009         self.get_counter.add(1)
1010
1011         slot = None
1012         blob = None
1013         try:
1014             locator = KeepLocator(loc_s)
1015             if method == "GET":
1016                 slot, first = self.block_cache.reserve_cache(locator.md5sum)
1017                 if not first:
1018                     self.hits_counter.add(1)
1019                     blob = slot.get()
1020                     if blob is None:
1021                         raise arvados.errors.KeepReadError(
1022                             "failed to read {}".format(loc_s))
1023                     return blob
1024
1025             self.misses_counter.add(1)
1026
1027             headers = {
1028                 'X-Request-Id': (request_id or
1029                                  (hasattr(self, 'api_client') and self.api_client.request_id) or
1030                                  arvados.util.new_request_id()),
1031             }
1032
1033             # If the locator has hints specifying a prefix (indicating a
1034             # remote keepproxy) or the UUID of a local gateway service,
1035             # read data from the indicated service(s) instead of the usual
1036             # list of local disk services.
1037             hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
1038                           for hint in locator.hints if hint.startswith('K@') and len(hint) == 7]
1039             hint_roots.extend([self._gateway_services[hint[2:]]['_service_root']
1040                                for hint in locator.hints if (
1041                                        hint.startswith('K@') and
1042                                        len(hint) == 29 and
1043                                        self._gateway_services.get(hint[2:])
1044                                        )])
1045             # Map root URLs to their KeepService objects.
1046             roots_map = {
1047                 root: self.KeepService(root, self._user_agent_pool,
1048                                        upload_counter=self.upload_counter,
1049                                        download_counter=self.download_counter,
1050                                        headers=headers,
1051                                        insecure=self.insecure)
1052                 for root in hint_roots
1053             }
1054
1055             # See #3147 for a discussion of the loop implementation.  Highlights:
1056             # * Refresh the list of Keep services after each failure, in case
1057             #   it's being updated.
1058             # * Retry until we succeed, we're out of retries, or every available
1059             #   service has returned permanent failure.
1060             sorted_roots = []
1061             roots_map = {}
1062             loop = retry.RetryLoop(num_retries, self._check_loop_result,
1063                                    backoff_start=2)
1064             for tries_left in loop:
1065                 try:
1066                     sorted_roots = self.map_new_services(
1067                         roots_map, locator,
1068                         force_rebuild=(tries_left < num_retries),
1069                         need_writable=False,
1070                         headers=headers)
1071                 except Exception as error:
1072                     loop.save_result(error)
1073                     continue
1074
1075                 # Query KeepService objects that haven't returned
1076                 # permanent failure, in our specified shuffle order.
1077                 services_to_try = [roots_map[root]
1078                                    for root in sorted_roots
1079                                    if roots_map[root].usable()]
1080                 for keep_service in services_to_try:
1081                     blob = keep_service.get(locator, method=method, timeout=self.current_timeout(num_retries-tries_left))
1082                     if blob is not None:
1083                         break
1084                 loop.save_result((blob, len(services_to_try)))
1085
1086             # Always cache the result, then return it if we succeeded.
1087             if loop.success():
1088                 if method == "HEAD":
1089                     return True
1090                 else:
1091                     return blob
1092         finally:
1093             if slot is not None:
1094                 slot.set(blob)
1095                 self.block_cache.cap_cache()
1096
1097         # Q: Including 403 is necessary for the Keep tests to continue
1098         # passing, but maybe they should expect KeepReadError instead?
1099         not_founds = sum(1 for key in sorted_roots
1100                          if roots_map[key].last_result().get('status_code', None) in {403, 404, 410})
1101         service_errors = ((key, roots_map[key].last_result()['error'])
1102                           for key in sorted_roots)
1103         if not roots_map:
1104             raise arvados.errors.KeepReadError(
1105                 "failed to read {}: no Keep services available ({})".format(
1106                     loc_s, loop.last_result()))
1107         elif not_founds == len(sorted_roots):
1108             raise arvados.errors.NotFoundError(
1109                 "{} not found".format(loc_s), service_errors)
1110         else:
1111             raise arvados.errors.KeepReadError(
1112                 "failed to read {}".format(loc_s), service_errors, label="service")
1113
1114     @retry.retry_method
1115     def put(self, data, copies=2, num_retries=None, request_id=None):
1116         """Save data in Keep.
1117
1118         This method will get a list of Keep services from the API server, and
1119         send the data to each one simultaneously in a new thread.  Once the
1120         uploads are finished, if enough copies are saved, this method returns
1121         the most recent HTTP response body.  If requests fail to upload
1122         enough copies, this method raises KeepWriteError.
1123
1124         Arguments:
1125         * data: The string of data to upload.
1126         * copies: The number of copies that the user requires be saved.
1127           Default 2.
1128         * num_retries: The number of times to retry PUT requests to
1129           *each* Keep server if it returns temporary failures, with
1130           exponential backoff.  The default value is set when the
1131           KeepClient is initialized.
1132         """
1133
1134         if not isinstance(data, bytes):
1135             data = data.encode()
1136
1137         self.put_counter.add(1)
1138
1139         data_hash = hashlib.md5(data).hexdigest()
1140         loc_s = data_hash + '+' + str(len(data))
1141         if copies < 1:
1142             return loc_s
1143         locator = KeepLocator(loc_s)
1144
1145         headers = {
1146             'X-Request-Id': (request_id or
1147                              (hasattr(self, 'api_client') and self.api_client.request_id) or
1148                              arvados.util.new_request_id()),
1149             'X-Keep-Desired-Replicas': str(copies),
1150         }
1151         roots_map = {}
1152         loop = retry.RetryLoop(num_retries, self._check_loop_result,
1153                                backoff_start=2)
1154         done = 0
1155         for tries_left in loop:
1156             try:
1157                 sorted_roots = self.map_new_services(
1158                     roots_map, locator,
1159                     force_rebuild=(tries_left < num_retries),
1160                     need_writable=True,
1161                     headers=headers)
1162             except Exception as error:
1163                 loop.save_result(error)
1164                 continue
1165
1166             writer_pool = KeepClient.KeepWriterThreadPool(data=data,
1167                                                         data_hash=data_hash,
1168                                                         copies=copies - done,
1169                                                         max_service_replicas=self.max_replicas_per_service,
1170                                                         timeout=self.current_timeout(num_retries - tries_left))
1171             for service_root, ks in [(root, roots_map[root])
1172                                      for root in sorted_roots]:
1173                 if ks.finished():
1174                     continue
1175                 writer_pool.add_task(ks, service_root)
1176             writer_pool.join()
1177             done += writer_pool.done()
1178             loop.save_result((done >= copies, writer_pool.total_task_nr))
1179
1180         if loop.success():
1181             return writer_pool.response()
1182         if not roots_map:
1183             raise arvados.errors.KeepWriteError(
1184                 "failed to write {}: no Keep services available ({})".format(
1185                     data_hash, loop.last_result()))
1186         else:
1187             service_errors = ((key, roots_map[key].last_result()['error'])
1188                               for key in sorted_roots
1189                               if roots_map[key].last_result()['error'])
1190             raise arvados.errors.KeepWriteError(
1191                 "failed to write {} (wanted {} copies but wrote {})".format(
1192                     data_hash, copies, writer_pool.done()), service_errors, label="service")
1193
1194     def local_store_put(self, data, copies=1, num_retries=None):
1195         """A stub for put().
1196
1197         This method is used in place of the real put() method when
1198         using local storage (see constructor's local_store argument).
1199
1200         copies and num_retries arguments are ignored: they are here
1201         only for the sake of offering the same call signature as
1202         put().
1203
1204         Data stored this way can be retrieved via local_store_get().
1205         """
1206         md5 = hashlib.md5(data).hexdigest()
1207         locator = '%s+%d' % (md5, len(data))
1208         with open(os.path.join(self.local_store, md5 + '.tmp'), 'wb') as f:
1209             f.write(data)
1210         os.rename(os.path.join(self.local_store, md5 + '.tmp'),
1211                   os.path.join(self.local_store, md5))
1212         return locator
1213
1214     def local_store_get(self, loc_s, num_retries=None):
1215         """Companion to local_store_put()."""
1216         try:
1217             locator = KeepLocator(loc_s)
1218         except ValueError:
1219             raise arvados.errors.NotFoundError(
1220                 "Invalid data locator: '%s'" % loc_s)
1221         if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
1222             return b''
1223         with open(os.path.join(self.local_store, locator.md5sum), 'rb') as f:
1224             return f.read()
1225
1226     def is_cached(self, locator):
1227         return self.block_cache.reserve_cache(expect_hash)