17465: Adds storage classes support to PySDK put().
[arvados.git] / sdk / python / arvados / keep.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from __future__ import absolute_import
6 from __future__ import division
7 from future import standard_library
8 from future.utils import native_str
9 standard_library.install_aliases()
10 from builtins import next
11 from builtins import str
12 from builtins import range
13 from builtins import object
14 import collections
15 import datetime
16 import hashlib
17 import io
18 import logging
19 import math
20 import os
21 import pycurl
22 import queue
23 import re
24 import socket
25 import ssl
26 import sys
27 import threading
28 from . import timer
29 import urllib.parse
30
31 if sys.version_info >= (3, 0):
32     from io import BytesIO
33 else:
34     from cStringIO import StringIO as BytesIO
35
36 import arvados
37 import arvados.config as config
38 import arvados.errors
39 import arvados.retry as retry
40 import arvados.util
41
42 _logger = logging.getLogger('arvados.keep')
43 global_client_object = None
44
45
46 # Monkey patch TCP constants when not available (apple). Values sourced from:
47 # http://www.opensource.apple.com/source/xnu/xnu-2422.115.4/bsd/netinet/tcp.h
48 if sys.platform == 'darwin':
49     if not hasattr(socket, 'TCP_KEEPALIVE'):
50         socket.TCP_KEEPALIVE = 0x010
51     if not hasattr(socket, 'TCP_KEEPINTVL'):
52         socket.TCP_KEEPINTVL = 0x101
53     if not hasattr(socket, 'TCP_KEEPCNT'):
54         socket.TCP_KEEPCNT = 0x102
55
56
57 class KeepLocator(object):
58     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
59     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
60
61     def __init__(self, locator_str):
62         self.hints = []
63         self._perm_sig = None
64         self._perm_expiry = None
65         pieces = iter(locator_str.split('+'))
66         self.md5sum = next(pieces)
67         try:
68             self.size = int(next(pieces))
69         except StopIteration:
70             self.size = None
71         for hint in pieces:
72             if self.HINT_RE.match(hint) is None:
73                 raise ValueError("invalid hint format: {}".format(hint))
74             elif hint.startswith('A'):
75                 self.parse_permission_hint(hint)
76             else:
77                 self.hints.append(hint)
78
79     def __str__(self):
80         return '+'.join(
81             native_str(s)
82             for s in [self.md5sum, self.size,
83                       self.permission_hint()] + self.hints
84             if s is not None)
85
86     def stripped(self):
87         if self.size is not None:
88             return "%s+%i" % (self.md5sum, self.size)
89         else:
90             return self.md5sum
91
92     def _make_hex_prop(name, length):
93         # Build and return a new property with the given name that
94         # must be a hex string of the given length.
95         data_name = '_{}'.format(name)
96         def getter(self):
97             return getattr(self, data_name)
98         def setter(self, hex_str):
99             if not arvados.util.is_hex(hex_str, length):
100                 raise ValueError("{} is not a {}-digit hex string: {!r}".
101                                  format(name, length, hex_str))
102             setattr(self, data_name, hex_str)
103         return property(getter, setter)
104
105     md5sum = _make_hex_prop('md5sum', 32)
106     perm_sig = _make_hex_prop('perm_sig', 40)
107
108     @property
109     def perm_expiry(self):
110         return self._perm_expiry
111
112     @perm_expiry.setter
113     def perm_expiry(self, value):
114         if not arvados.util.is_hex(value, 1, 8):
115             raise ValueError(
116                 "permission timestamp must be a hex Unix timestamp: {}".
117                 format(value))
118         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
119
120     def permission_hint(self):
121         data = [self.perm_sig, self.perm_expiry]
122         if None in data:
123             return None
124         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
125         return "A{}@{:08x}".format(*data)
126
127     def parse_permission_hint(self, s):
128         try:
129             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
130         except IndexError:
131             raise ValueError("bad permission hint {}".format(s))
132
133     def permission_expired(self, as_of_dt=None):
134         if self.perm_expiry is None:
135             return False
136         elif as_of_dt is None:
137             as_of_dt = datetime.datetime.now()
138         return self.perm_expiry <= as_of_dt
139
140
141 class Keep(object):
142     """Simple interface to a global KeepClient object.
143
144     THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
145     own API client.  The global KeepClient will build an API client from the
146     current Arvados configuration, which may not match the one you built.
147     """
148     _last_key = None
149
150     @classmethod
151     def global_client_object(cls):
152         global global_client_object
153         # Previously, KeepClient would change its behavior at runtime based
154         # on these configuration settings.  We simulate that behavior here
155         # by checking the values and returning a new KeepClient if any of
156         # them have changed.
157         key = (config.get('ARVADOS_API_HOST'),
158                config.get('ARVADOS_API_TOKEN'),
159                config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
160                config.get('ARVADOS_KEEP_PROXY'),
161                config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
162                os.environ.get('KEEP_LOCAL_STORE'))
163         if (global_client_object is None) or (cls._last_key != key):
164             global_client_object = KeepClient()
165             cls._last_key = key
166         return global_client_object
167
168     @staticmethod
169     def get(locator, **kwargs):
170         return Keep.global_client_object().get(locator, **kwargs)
171
172     @staticmethod
173     def put(data, **kwargs):
174         return Keep.global_client_object().put(data, **kwargs)
175
176 class KeepBlockCache(object):
177     # Default RAM cache is 256MiB
178     def __init__(self, cache_max=(256 * 1024 * 1024)):
179         self.cache_max = cache_max
180         self._cache = []
181         self._cache_lock = threading.Lock()
182
183     class CacheSlot(object):
184         __slots__ = ("locator", "ready", "content")
185
186         def __init__(self, locator):
187             self.locator = locator
188             self.ready = threading.Event()
189             self.content = None
190
191         def get(self):
192             self.ready.wait()
193             return self.content
194
195         def set(self, value):
196             self.content = value
197             self.ready.set()
198
199         def size(self):
200             if self.content is None:
201                 return 0
202             else:
203                 return len(self.content)
204
205     def cap_cache(self):
206         '''Cap the cache size to self.cache_max'''
207         with self._cache_lock:
208             # Select all slots except those where ready.is_set() and content is
209             # None (that means there was an error reading the block).
210             self._cache = [c for c in self._cache if not (c.ready.is_set() and c.content is None)]
211             sm = sum([slot.size() for slot in self._cache])
212             while len(self._cache) > 0 and sm > self.cache_max:
213                 for i in range(len(self._cache)-1, -1, -1):
214                     if self._cache[i].ready.is_set():
215                         del self._cache[i]
216                         break
217                 sm = sum([slot.size() for slot in self._cache])
218
219     def _get(self, locator):
220         # Test if the locator is already in the cache
221         for i in range(0, len(self._cache)):
222             if self._cache[i].locator == locator:
223                 n = self._cache[i]
224                 if i != 0:
225                     # move it to the front
226                     del self._cache[i]
227                     self._cache.insert(0, n)
228                 return n
229         return None
230
231     def get(self, locator):
232         with self._cache_lock:
233             return self._get(locator)
234
235     def reserve_cache(self, locator):
236         '''Reserve a cache slot for the specified locator,
237         or return the existing slot.'''
238         with self._cache_lock:
239             n = self._get(locator)
240             if n:
241                 return n, False
242             else:
243                 # Add a new cache slot for the locator
244                 n = KeepBlockCache.CacheSlot(locator)
245                 self._cache.insert(0, n)
246                 return n, True
247
248 class Counter(object):
249     def __init__(self, v=0):
250         self._lk = threading.Lock()
251         self._val = v
252
253     def add(self, v):
254         with self._lk:
255             self._val += v
256
257     def get(self):
258         with self._lk:
259             return self._val
260
261
262 class KeepClient(object):
263
264     # Default Keep server connection timeout:  2 seconds
265     # Default Keep server read timeout:       256 seconds
266     # Default Keep server bandwidth minimum:  32768 bytes per second
267     # Default Keep proxy connection timeout:  20 seconds
268     # Default Keep proxy read timeout:        256 seconds
269     # Default Keep proxy bandwidth minimum:   32768 bytes per second
270     DEFAULT_TIMEOUT = (2, 256, 32768)
271     DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
272
273
274     class KeepService(object):
275         """Make requests to a single Keep service, and track results.
276
277         A KeepService is intended to last long enough to perform one
278         transaction (GET or PUT) against one Keep service. This can
279         involve calling either get() or put() multiple times in order
280         to retry after transient failures. However, calling both get()
281         and put() on a single instance -- or using the same instance
282         to access two different Keep services -- will not produce
283         sensible behavior.
284         """
285
286         HTTP_ERRORS = (
287             socket.error,
288             ssl.SSLError,
289             arvados.errors.HttpError,
290         )
291
292         def __init__(self, root, user_agent_pool=queue.LifoQueue(),
293                      upload_counter=None,
294                      download_counter=None,
295                      headers={},
296                      insecure=False):
297             self.root = root
298             self._user_agent_pool = user_agent_pool
299             self._result = {'error': None}
300             self._usable = True
301             self._session = None
302             self._socket = None
303             self.get_headers = {'Accept': 'application/octet-stream'}
304             self.get_headers.update(headers)
305             self.put_headers = headers
306             self.upload_counter = upload_counter
307             self.download_counter = download_counter
308             self.insecure = insecure
309
310         def usable(self):
311             """Is it worth attempting a request?"""
312             return self._usable
313
314         def finished(self):
315             """Did the request succeed or encounter permanent failure?"""
316             return self._result['error'] == False or not self._usable
317
318         def last_result(self):
319             return self._result
320
321         def _get_user_agent(self):
322             try:
323                 return self._user_agent_pool.get(block=False)
324             except queue.Empty:
325                 return pycurl.Curl()
326
327         def _put_user_agent(self, ua):
328             try:
329                 ua.reset()
330                 self._user_agent_pool.put(ua, block=False)
331             except:
332                 ua.close()
333
334         def _socket_open(self, *args, **kwargs):
335             if len(args) + len(kwargs) == 2:
336                 return self._socket_open_pycurl_7_21_5(*args, **kwargs)
337             else:
338                 return self._socket_open_pycurl_7_19_3(*args, **kwargs)
339
340         def _socket_open_pycurl_7_19_3(self, family, socktype, protocol, address=None):
341             return self._socket_open_pycurl_7_21_5(
342                 purpose=None,
343                 address=collections.namedtuple(
344                     'Address', ['family', 'socktype', 'protocol', 'addr'],
345                 )(family, socktype, protocol, address))
346
347         def _socket_open_pycurl_7_21_5(self, purpose, address):
348             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
349             s = socket.socket(address.family, address.socktype, address.protocol)
350             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
351             # Will throw invalid protocol error on mac. This test prevents that.
352             if hasattr(socket, 'TCP_KEEPIDLE'):
353                 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
354             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
355             self._socket = s
356             return s
357
358         def get(self, locator, method="GET", timeout=None):
359             # locator is a KeepLocator object.
360             url = self.root + str(locator)
361             _logger.debug("Request: %s %s", method, url)
362             curl = self._get_user_agent()
363             ok = None
364             try:
365                 with timer.Timer() as t:
366                     self._headers = {}
367                     response_body = BytesIO()
368                     curl.setopt(pycurl.NOSIGNAL, 1)
369                     curl.setopt(pycurl.OPENSOCKETFUNCTION,
370                                 lambda *args, **kwargs: self._socket_open(*args, **kwargs))
371                     curl.setopt(pycurl.URL, url.encode('utf-8'))
372                     curl.setopt(pycurl.HTTPHEADER, [
373                         '{}: {}'.format(k,v) for k,v in self.get_headers.items()])
374                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
375                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
376                     if self.insecure:
377                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
378                     else:
379                         curl.setopt(pycurl.CAINFO, arvados.util.ca_certs_path())
380                     if method == "HEAD":
381                         curl.setopt(pycurl.NOBODY, True)
382                     self._setcurltimeouts(curl, timeout, method=="HEAD")
383
384                     try:
385                         curl.perform()
386                     except Exception as e:
387                         raise arvados.errors.HttpError(0, str(e))
388                     finally:
389                         if self._socket:
390                             self._socket.close()
391                             self._socket = None
392                     self._result = {
393                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
394                         'body': response_body.getvalue(),
395                         'headers': self._headers,
396                         'error': False,
397                     }
398
399                 ok = retry.check_http_response_success(self._result['status_code'])
400                 if not ok:
401                     self._result['error'] = arvados.errors.HttpError(
402                         self._result['status_code'],
403                         self._headers.get('x-status-line', 'Error'))
404             except self.HTTP_ERRORS as e:
405                 self._result = {
406                     'error': e,
407                 }
408             self._usable = ok != False
409             if self._result.get('status_code', None):
410                 # The client worked well enough to get an HTTP status
411                 # code, so presumably any problems are just on the
412                 # server side and it's OK to reuse the client.
413                 self._put_user_agent(curl)
414             else:
415                 # Don't return this client to the pool, in case it's
416                 # broken.
417                 curl.close()
418             if not ok:
419                 _logger.debug("Request fail: GET %s => %s: %s",
420                               url, type(self._result['error']), str(self._result['error']))
421                 return None
422             if method == "HEAD":
423                 _logger.info("HEAD %s: %s bytes",
424                          self._result['status_code'],
425                          self._result.get('content-length'))
426                 if self._result['headers'].get('x-keep-locator'):
427                     # This is a response to a remote block copy request, return
428                     # the local copy block locator.
429                     return self._result['headers'].get('x-keep-locator')
430                 return True
431
432             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
433                          self._result['status_code'],
434                          len(self._result['body']),
435                          t.msecs,
436                          1.0*len(self._result['body'])/2**20/t.secs if t.secs > 0 else 0)
437
438             if self.download_counter:
439                 self.download_counter.add(len(self._result['body']))
440             resp_md5 = hashlib.md5(self._result['body']).hexdigest()
441             if resp_md5 != locator.md5sum:
442                 _logger.warning("Checksum fail: md5(%s) = %s",
443                                 url, resp_md5)
444                 self._result['error'] = arvados.errors.HttpError(
445                     0, 'Checksum fail')
446                 return None
447             return self._result['body']
448
449         def put(self, hash_s, body, timeout=None):
450             url = self.root + hash_s
451             _logger.debug("Request: PUT %s", url)
452             curl = self._get_user_agent()
453             ok = None
454             try:
455                 with timer.Timer() as t:
456                     self._headers = {}
457                     body_reader = BytesIO(body)
458                     response_body = BytesIO()
459                     curl.setopt(pycurl.NOSIGNAL, 1)
460                     curl.setopt(pycurl.OPENSOCKETFUNCTION,
461                                 lambda *args, **kwargs: self._socket_open(*args, **kwargs))
462                     curl.setopt(pycurl.URL, url.encode('utf-8'))
463                     # Using UPLOAD tells cURL to wait for a "go ahead" from the
464                     # Keep server (in the form of a HTTP/1.1 "100 Continue"
465                     # response) instead of sending the request body immediately.
466                     # This allows the server to reject the request if the request
467                     # is invalid or the server is read-only, without waiting for
468                     # the client to send the entire block.
469                     curl.setopt(pycurl.UPLOAD, True)
470                     curl.setopt(pycurl.INFILESIZE, len(body))
471                     curl.setopt(pycurl.READFUNCTION, body_reader.read)
472                     curl.setopt(pycurl.HTTPHEADER, [
473                         '{}: {}'.format(k,v) for k,v in self.put_headers.items()])
474                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
475                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
476                     if self.insecure:
477                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
478                     else:
479                         curl.setopt(pycurl.CAINFO, arvados.util.ca_certs_path())
480                     self._setcurltimeouts(curl, timeout)
481                     try:
482                         curl.perform()
483                     except Exception as e:
484                         raise arvados.errors.HttpError(0, str(e))
485                     finally:
486                         if self._socket:
487                             self._socket.close()
488                             self._socket = None
489                     self._result = {
490                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
491                         'body': response_body.getvalue().decode('utf-8'),
492                         'headers': self._headers,
493                         'error': False,
494                     }
495                 ok = retry.check_http_response_success(self._result['status_code'])
496                 if not ok:
497                     self._result['error'] = arvados.errors.HttpError(
498                         self._result['status_code'],
499                         self._headers.get('x-status-line', 'Error'))
500             except self.HTTP_ERRORS as e:
501                 self._result = {
502                     'error': e,
503                 }
504             self._usable = ok != False # still usable if ok is True or None
505             if self._result.get('status_code', None):
506                 # Client is functional. See comment in get().
507                 self._put_user_agent(curl)
508             else:
509                 curl.close()
510             if not ok:
511                 _logger.debug("Request fail: PUT %s => %s: %s",
512                               url, type(self._result['error']), str(self._result['error']))
513                 return False
514             _logger.info("PUT %s: %s bytes in %s msec (%.3f MiB/sec)",
515                          self._result['status_code'],
516                          len(body),
517                          t.msecs,
518                          1.0*len(body)/2**20/t.secs if t.secs > 0 else 0)
519             if self.upload_counter:
520                 self.upload_counter.add(len(body))
521             return True
522
523         def _setcurltimeouts(self, curl, timeouts, ignore_bandwidth=False):
524             if not timeouts:
525                 return
526             elif isinstance(timeouts, tuple):
527                 if len(timeouts) == 2:
528                     conn_t, xfer_t = timeouts
529                     bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
530                 else:
531                     conn_t, xfer_t, bandwidth_bps = timeouts
532             else:
533                 conn_t, xfer_t = (timeouts, timeouts)
534                 bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
535             curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
536             if not ignore_bandwidth:
537                 curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
538                 curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
539
540         def _headerfunction(self, header_line):
541             if isinstance(header_line, bytes):
542                 header_line = header_line.decode('iso-8859-1')
543             if ':' in header_line:
544                 name, value = header_line.split(':', 1)
545                 name = name.strip().lower()
546                 value = value.strip()
547             elif self._headers:
548                 name = self._lastheadername
549                 value = self._headers[name] + ' ' + header_line.strip()
550             elif header_line.startswith('HTTP/'):
551                 name = 'x-status-line'
552                 value = header_line
553             else:
554                 _logger.error("Unexpected header line: %s", header_line)
555                 return
556             self._lastheadername = name
557             self._headers[name] = value
558             # Returning None implies all bytes were written
559
560
561     class KeepWriterQueue(queue.Queue):
562         def __init__(self, copies, classes=[]):
563             queue.Queue.__init__(self) # Old-style superclass
564             self.wanted_copies = copies
565             self.wanted_storage_classes = classes
566             self.successful_copies = 0
567             self.confirmed_storage_classes = {}
568             self.response = None
569             self.queue_data_lock = threading.Lock()
570             self.pending_tries = copies
571             self.pending_tries_notification = threading.Condition()
572
573         def write_success(self, response, replicas_nr, classes_confirmed):
574             with self.queue_data_lock:
575                 self.successful_copies += replicas_nr
576                 for st_class, st_copies in classes_confirmed.items():
577                     try:
578                         self.confirmed_storage_classes[st_class] += st_copies
579                     except KeyError:
580                         self.confirmed_storage_classes[st_class] = st_copies
581                 self.response = response
582             with self.pending_tries_notification:
583                 self.pending_tries_notification.notify_all()
584
585         def write_fail(self, ks):
586             with self.pending_tries_notification:
587                 self.pending_tries += 1
588                 self.pending_tries_notification.notify()
589
590         def pending_copies(self):
591             with self.queue_data_lock:
592                 return self.wanted_copies - self.successful_copies
593
594         def pending_classes(self):
595             with self.queue_data_lock:
596                 unsatisfied_classes = []
597                 for st_class, st_copies in self.confirmed_storage_classes.items():
598                     if st_class in self.wanted_storage_classes and st_copies < self.wanted_copies:
599                         unsatisfied_classes.append(st_class)
600                 return unsatisfied_classes
601
602         def get_next_task(self):
603             with self.pending_tries_notification:
604                 while True:
605                     if self.pending_copies() < 1 and len(self.pending_classes()) == 0:
606                         # This notify_all() is unnecessary --
607                         # write_success() already called notify_all()
608                         # when pending<1 became true, so it's not
609                         # possible for any other thread to be in
610                         # wait() now -- but it's cheap insurance
611                         # against deadlock so we do it anyway:
612                         self.pending_tries_notification.notify_all()
613                         # Drain the queue and then raise Queue.Empty
614                         while True:
615                             self.get_nowait()
616                             self.task_done()
617                     elif self.pending_tries > 0 or len(self.pending_classes()) > 0:
618                         service, service_root = self.get_nowait()
619                         if service.finished():
620                             self.task_done()
621                             continue
622                         self.pending_tries -= 1
623                         return service, service_root
624                     elif self.empty():
625                         self.pending_tries_notification.notify_all()
626                         raise queue.Empty
627                     else:
628                         self.pending_tries_notification.wait()
629
630
631     class KeepWriterThreadPool(object):
632         def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None, classes=[]):
633             self.total_task_nr = 0
634             if (not max_service_replicas) or (max_service_replicas >= copies):
635                 num_threads = 1
636             else:
637                 num_threads = int(math.ceil(1.0*copies/max_service_replicas))
638             _logger.debug("Pool max threads is %d", num_threads)
639             self.workers = []
640             self.queue = KeepClient.KeepWriterQueue(copies, classes)
641             # Create workers
642             for _ in range(num_threads):
643                 w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
644                 self.workers.append(w)
645
646         def add_task(self, ks, service_root):
647             self.queue.put((ks, service_root))
648             self.total_task_nr += 1
649
650         def done(self):
651             return self.queue.successful_copies
652
653         def join(self):
654             # Start workers
655             for worker in self.workers:
656                 worker.start()
657             # Wait for finished work
658             self.queue.join()
659
660         def response(self):
661             return self.queue.response
662
663
664     class KeepWriterThread(threading.Thread):
665         class TaskFailed(RuntimeError): pass
666
667         def __init__(self, queue, data, data_hash, timeout=None):
668             super(KeepClient.KeepWriterThread, self).__init__()
669             self.timeout = timeout
670             self.queue = queue
671             self.data = data
672             self.data_hash = data_hash
673             self.daemon = True
674
675         def run(self):
676             while True:
677                 try:
678                     service, service_root = self.queue.get_next_task()
679                 except queue.Empty:
680                     return
681                 try:
682                     locator, copies, classes = self.do_task(service, service_root)
683                 except Exception as e:
684                     if not isinstance(e, self.TaskFailed):
685                         _logger.exception("Exception in KeepWriterThread")
686                     self.queue.write_fail(service)
687                 else:
688                     self.queue.write_success(locator, copies, classes)
689                 finally:
690                     self.queue.task_done()
691
692         def do_task(self, service, service_root):
693             success = bool(service.put(self.data_hash,
694                                         self.data,
695                                         timeout=self.timeout))
696             result = service.last_result()
697
698             if not success:
699                 if result.get('status_code', None):
700                     _logger.debug("Request fail: PUT %s => %s %s",
701                                   self.data_hash,
702                                   result['status_code'],
703                                   result['body'])
704                 raise self.TaskFailed()
705
706             _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
707                           str(threading.current_thread()),
708                           self.data_hash,
709                           len(self.data),
710                           service_root)
711             try:
712                 replicas_stored = int(result['headers']['x-keep-replicas-stored'])
713             except (KeyError, ValueError):
714                 replicas_stored = 1
715
716             classes_confirmed = {}
717             try:
718                 scch = result['headers']['x-keep-storage-classes-confirmed']
719                 for confirmation in scch.replace(' ', '').split(','):
720                     if '=' in confirmation:
721                         stored_class, stored_copies = confirmation.split('=')[:2]
722                         classes_confirmed[stored_class] = int(stored_copies)
723             except (KeyError, ValueError):
724                 # Storage classes confirmed header missing or corrupt
725                 classes_confirmed = {}
726
727             return result['body'].strip(), replicas_stored, classes_confirmed
728
729
730     def __init__(self, api_client=None, proxy=None,
731                  timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
732                  api_token=None, local_store=None, block_cache=None,
733                  num_retries=0, session=None):
734         """Initialize a new KeepClient.
735
736         Arguments:
737         :api_client:
738           The API client to use to find Keep services.  If not
739           provided, KeepClient will build one from available Arvados
740           configuration.
741
742         :proxy:
743           If specified, this KeepClient will send requests to this Keep
744           proxy.  Otherwise, KeepClient will fall back to the setting of the
745           ARVADOS_KEEP_SERVICES or ARVADOS_KEEP_PROXY configuration settings.
746           If you want to KeepClient does not use a proxy, pass in an empty
747           string.
748
749         :timeout:
750           The initial timeout (in seconds) for HTTP requests to Keep
751           non-proxy servers.  A tuple of three floats is interpreted as
752           (connection_timeout, read_timeout, minimum_bandwidth). A connection
753           will be aborted if the average traffic rate falls below
754           minimum_bandwidth bytes per second over an interval of read_timeout
755           seconds. Because timeouts are often a result of transient server
756           load, the actual connection timeout will be increased by a factor
757           of two on each retry.
758           Default: (2, 256, 32768).
759
760         :proxy_timeout:
761           The initial timeout (in seconds) for HTTP requests to
762           Keep proxies. A tuple of three floats is interpreted as
763           (connection_timeout, read_timeout, minimum_bandwidth). The behavior
764           described above for adjusting connection timeouts on retry also
765           applies.
766           Default: (20, 256, 32768).
767
768         :api_token:
769           If you're not using an API client, but only talking
770           directly to a Keep proxy, this parameter specifies an API token
771           to authenticate Keep requests.  It is an error to specify both
772           api_client and api_token.  If you specify neither, KeepClient
773           will use one available from the Arvados configuration.
774
775         :local_store:
776           If specified, this KeepClient will bypass Keep
777           services, and save data to the named directory.  If unspecified,
778           KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
779           environment variable.  If you want to ensure KeepClient does not
780           use local storage, pass in an empty string.  This is primarily
781           intended to mock a server for testing.
782
783         :num_retries:
784           The default number of times to retry failed requests.
785           This will be used as the default num_retries value when get() and
786           put() are called.  Default 0.
787         """
788         self.lock = threading.Lock()
789         if proxy is None:
790             if config.get('ARVADOS_KEEP_SERVICES'):
791                 proxy = config.get('ARVADOS_KEEP_SERVICES')
792             else:
793                 proxy = config.get('ARVADOS_KEEP_PROXY')
794         if api_token is None:
795             if api_client is None:
796                 api_token = config.get('ARVADOS_API_TOKEN')
797             else:
798                 api_token = api_client.api_token
799         elif api_client is not None:
800             raise ValueError(
801                 "can't build KeepClient with both API client and token")
802         if local_store is None:
803             local_store = os.environ.get('KEEP_LOCAL_STORE')
804
805         if api_client is None:
806             self.insecure = config.flag_is_true('ARVADOS_API_HOST_INSECURE')
807         else:
808             self.insecure = api_client.insecure
809
810         self.block_cache = block_cache if block_cache else KeepBlockCache()
811         self.timeout = timeout
812         self.proxy_timeout = proxy_timeout
813         self._user_agent_pool = queue.LifoQueue()
814         self.upload_counter = Counter()
815         self.download_counter = Counter()
816         self.put_counter = Counter()
817         self.get_counter = Counter()
818         self.hits_counter = Counter()
819         self.misses_counter = Counter()
820
821         if local_store:
822             self.local_store = local_store
823             self.head = self.local_store_head
824             self.get = self.local_store_get
825             self.put = self.local_store_put
826         else:
827             self.num_retries = num_retries
828             self.max_replicas_per_service = None
829             if proxy:
830                 proxy_uris = proxy.split()
831                 for i in range(len(proxy_uris)):
832                     if not proxy_uris[i].endswith('/'):
833                         proxy_uris[i] += '/'
834                     # URL validation
835                     url = urllib.parse.urlparse(proxy_uris[i])
836                     if not (url.scheme and url.netloc):
837                         raise arvados.errors.ArgumentError("Invalid proxy URI: {}".format(proxy_uris[i]))
838                 self.api_token = api_token
839                 self._gateway_services = {}
840                 self._keep_services = [{
841                     'uuid': "00000-bi6l4-%015d" % idx,
842                     'service_type': 'proxy',
843                     '_service_root': uri,
844                     } for idx, uri in enumerate(proxy_uris)]
845                 self._writable_services = self._keep_services
846                 self.using_proxy = True
847                 self._static_services_list = True
848             else:
849                 # It's important to avoid instantiating an API client
850                 # unless we actually need one, for testing's sake.
851                 if api_client is None:
852                     api_client = arvados.api('v1')
853                 self.api_client = api_client
854                 self.api_token = api_client.api_token
855                 self._gateway_services = {}
856                 self._keep_services = None
857                 self._writable_services = None
858                 self.using_proxy = None
859                 self._static_services_list = False
860
861     def current_timeout(self, attempt_number):
862         """Return the appropriate timeout to use for this client.
863
864         The proxy timeout setting if the backend service is currently a proxy,
865         the regular timeout setting otherwise.  The `attempt_number` indicates
866         how many times the operation has been tried already (starting from 0
867         for the first try), and scales the connection timeout portion of the
868         return value accordingly.
869
870         """
871         # TODO(twp): the timeout should be a property of a
872         # KeepService, not a KeepClient. See #4488.
873         t = self.proxy_timeout if self.using_proxy else self.timeout
874         if len(t) == 2:
875             return (t[0] * (1 << attempt_number), t[1])
876         else:
877             return (t[0] * (1 << attempt_number), t[1], t[2])
878     def _any_nondisk_services(self, service_list):
879         return any(ks.get('service_type', 'disk') != 'disk'
880                    for ks in service_list)
881
882     def build_services_list(self, force_rebuild=False):
883         if (self._static_services_list or
884               (self._keep_services and not force_rebuild)):
885             return
886         with self.lock:
887             try:
888                 keep_services = self.api_client.keep_services().accessible()
889             except Exception:  # API server predates Keep services.
890                 keep_services = self.api_client.keep_disks().list()
891
892             # Gateway services are only used when specified by UUID,
893             # so there's nothing to gain by filtering them by
894             # service_type.
895             self._gateway_services = {ks['uuid']: ks for ks in
896                                       keep_services.execute()['items']}
897             if not self._gateway_services:
898                 raise arvados.errors.NoKeepServersError()
899
900             # Precompute the base URI for each service.
901             for r in self._gateway_services.values():
902                 host = r['service_host']
903                 if not host.startswith('[') and host.find(':') >= 0:
904                     # IPv6 URIs must be formatted like http://[::1]:80/...
905                     host = '[' + host + ']'
906                 r['_service_root'] = "{}://{}:{:d}/".format(
907                     'https' if r['service_ssl_flag'] else 'http',
908                     host,
909                     r['service_port'])
910
911             _logger.debug(str(self._gateway_services))
912             self._keep_services = [
913                 ks for ks in self._gateway_services.values()
914                 if not ks.get('service_type', '').startswith('gateway:')]
915             self._writable_services = [ks for ks in self._keep_services
916                                        if not ks.get('read_only')]
917
918             # For disk type services, max_replicas_per_service is 1
919             # It is unknown (unlimited) for other service types.
920             if self._any_nondisk_services(self._writable_services):
921                 self.max_replicas_per_service = None
922             else:
923                 self.max_replicas_per_service = 1
924
925     def _service_weight(self, data_hash, service_uuid):
926         """Compute the weight of a Keep service endpoint for a data
927         block with a known hash.
928
929         The weight is md5(h + u) where u is the last 15 characters of
930         the service endpoint's UUID.
931         """
932         return hashlib.md5((data_hash + service_uuid[-15:]).encode()).hexdigest()
933
934     def weighted_service_roots(self, locator, force_rebuild=False, need_writable=False):
935         """Return an array of Keep service endpoints, in the order in
936         which they should be probed when reading or writing data with
937         the given hash+hints.
938         """
939         self.build_services_list(force_rebuild)
940
941         sorted_roots = []
942         # Use the services indicated by the given +K@... remote
943         # service hints, if any are present and can be resolved to a
944         # URI.
945         for hint in locator.hints:
946             if hint.startswith('K@'):
947                 if len(hint) == 7:
948                     sorted_roots.append(
949                         "https://keep.{}.arvadosapi.com/".format(hint[2:]))
950                 elif len(hint) == 29:
951                     svc = self._gateway_services.get(hint[2:])
952                     if svc:
953                         sorted_roots.append(svc['_service_root'])
954
955         # Sort the available local services by weight (heaviest first)
956         # for this locator, and return their service_roots (base URIs)
957         # in that order.
958         use_services = self._keep_services
959         if need_writable:
960             use_services = self._writable_services
961         self.using_proxy = self._any_nondisk_services(use_services)
962         sorted_roots.extend([
963             svc['_service_root'] for svc in sorted(
964                 use_services,
965                 reverse=True,
966                 key=lambda svc: self._service_weight(locator.md5sum, svc['uuid']))])
967         _logger.debug("{}: {}".format(locator, sorted_roots))
968         return sorted_roots
969
970     def map_new_services(self, roots_map, locator, force_rebuild, need_writable, headers):
971         # roots_map is a dictionary, mapping Keep service root strings
972         # to KeepService objects.  Poll for Keep services, and add any
973         # new ones to roots_map.  Return the current list of local
974         # root strings.
975         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
976         local_roots = self.weighted_service_roots(locator, force_rebuild, need_writable)
977         for root in local_roots:
978             if root not in roots_map:
979                 roots_map[root] = self.KeepService(
980                     root, self._user_agent_pool,
981                     upload_counter=self.upload_counter,
982                     download_counter=self.download_counter,
983                     headers=headers,
984                     insecure=self.insecure)
985         return local_roots
986
987     @staticmethod
988     def _check_loop_result(result):
989         # KeepClient RetryLoops should save results as a 2-tuple: the
990         # actual result of the request, and the number of servers available
991         # to receive the request this round.
992         # This method returns True if there's a real result, False if
993         # there are no more servers available, otherwise None.
994         if isinstance(result, Exception):
995             return None
996         result, tried_server_count = result
997         if (result is not None) and (result is not False):
998             return True
999         elif tried_server_count < 1:
1000             _logger.info("No more Keep services to try; giving up")
1001             return False
1002         else:
1003             return None
1004
1005     def get_from_cache(self, loc):
1006         """Fetch a block only if is in the cache, otherwise return None."""
1007         slot = self.block_cache.get(loc)
1008         if slot is not None and slot.ready.is_set():
1009             return slot.get()
1010         else:
1011             return None
1012
1013     def refresh_signature(self, loc):
1014         """Ask Keep to get the remote block and return its local signature"""
1015         now = datetime.datetime.utcnow().isoformat("T") + 'Z'
1016         return self.head(loc, headers={'X-Keep-Signature': 'local, {}'.format(now)})
1017
1018     @retry.retry_method
1019     def head(self, loc_s, **kwargs):
1020         return self._get_or_head(loc_s, method="HEAD", **kwargs)
1021
1022     @retry.retry_method
1023     def get(self, loc_s, **kwargs):
1024         return self._get_or_head(loc_s, method="GET", **kwargs)
1025
1026     def _get_or_head(self, loc_s, method="GET", num_retries=None, request_id=None, headers=None):
1027         """Get data from Keep.
1028
1029         This method fetches one or more blocks of data from Keep.  It
1030         sends a request each Keep service registered with the API
1031         server (or the proxy provided when this client was
1032         instantiated), then each service named in location hints, in
1033         sequence.  As soon as one service provides the data, it's
1034         returned.
1035
1036         Arguments:
1037         * loc_s: A string of one or more comma-separated locators to fetch.
1038           This method returns the concatenation of these blocks.
1039         * num_retries: The number of times to retry GET requests to
1040           *each* Keep server if it returns temporary failures, with
1041           exponential backoff.  Note that, in each loop, the method may try
1042           to fetch data from every available Keep service, along with any
1043           that are named in location hints in the locator.  The default value
1044           is set when the KeepClient is initialized.
1045         """
1046         if ',' in loc_s:
1047             return ''.join(self.get(x) for x in loc_s.split(','))
1048
1049         self.get_counter.add(1)
1050
1051         slot = None
1052         blob = None
1053         try:
1054             locator = KeepLocator(loc_s)
1055             if method == "GET":
1056                 slot, first = self.block_cache.reserve_cache(locator.md5sum)
1057                 if not first:
1058                     self.hits_counter.add(1)
1059                     blob = slot.get()
1060                     if blob is None:
1061                         raise arvados.errors.KeepReadError(
1062                             "failed to read {}".format(loc_s))
1063                     return blob
1064
1065             self.misses_counter.add(1)
1066
1067             if headers is None:
1068                 headers = {}
1069             headers['X-Request-Id'] = (request_id or
1070                                         (hasattr(self, 'api_client') and self.api_client.request_id) or
1071                                         arvados.util.new_request_id())
1072
1073             # If the locator has hints specifying a prefix (indicating a
1074             # remote keepproxy) or the UUID of a local gateway service,
1075             # read data from the indicated service(s) instead of the usual
1076             # list of local disk services.
1077             hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
1078                           for hint in locator.hints if hint.startswith('K@') and len(hint) == 7]
1079             hint_roots.extend([self._gateway_services[hint[2:]]['_service_root']
1080                                for hint in locator.hints if (
1081                                        hint.startswith('K@') and
1082                                        len(hint) == 29 and
1083                                        self._gateway_services.get(hint[2:])
1084                                        )])
1085             # Map root URLs to their KeepService objects.
1086             roots_map = {
1087                 root: self.KeepService(root, self._user_agent_pool,
1088                                        upload_counter=self.upload_counter,
1089                                        download_counter=self.download_counter,
1090                                        headers=headers,
1091                                        insecure=self.insecure)
1092                 for root in hint_roots
1093             }
1094
1095             # See #3147 for a discussion of the loop implementation.  Highlights:
1096             # * Refresh the list of Keep services after each failure, in case
1097             #   it's being updated.
1098             # * Retry until we succeed, we're out of retries, or every available
1099             #   service has returned permanent failure.
1100             sorted_roots = []
1101             roots_map = {}
1102             loop = retry.RetryLoop(num_retries, self._check_loop_result,
1103                                    backoff_start=2)
1104             for tries_left in loop:
1105                 try:
1106                     sorted_roots = self.map_new_services(
1107                         roots_map, locator,
1108                         force_rebuild=(tries_left < num_retries),
1109                         need_writable=False,
1110                         headers=headers)
1111                 except Exception as error:
1112                     loop.save_result(error)
1113                     continue
1114
1115                 # Query KeepService objects that haven't returned
1116                 # permanent failure, in our specified shuffle order.
1117                 services_to_try = [roots_map[root]
1118                                    for root in sorted_roots
1119                                    if roots_map[root].usable()]
1120                 for keep_service in services_to_try:
1121                     blob = keep_service.get(locator, method=method, timeout=self.current_timeout(num_retries-tries_left))
1122                     if blob is not None:
1123                         break
1124                 loop.save_result((blob, len(services_to_try)))
1125
1126             # Always cache the result, then return it if we succeeded.
1127             if loop.success():
1128                 return blob
1129         finally:
1130             if slot is not None:
1131                 slot.set(blob)
1132                 self.block_cache.cap_cache()
1133
1134         # Q: Including 403 is necessary for the Keep tests to continue
1135         # passing, but maybe they should expect KeepReadError instead?
1136         not_founds = sum(1 for key in sorted_roots
1137                          if roots_map[key].last_result().get('status_code', None) in {403, 404, 410})
1138         service_errors = ((key, roots_map[key].last_result()['error'])
1139                           for key in sorted_roots)
1140         if not roots_map:
1141             raise arvados.errors.KeepReadError(
1142                 "failed to read {}: no Keep services available ({})".format(
1143                     loc_s, loop.last_result()))
1144         elif not_founds == len(sorted_roots):
1145             raise arvados.errors.NotFoundError(
1146                 "{} not found".format(loc_s), service_errors)
1147         else:
1148             raise arvados.errors.KeepReadError(
1149                 "failed to read {} after {}".format(loc_s, loop.attempts_str()), service_errors, label="service")
1150
1151     @retry.retry_method
1152     def put(self, data, copies=2, num_retries=None, request_id=None, classes=[]):
1153         """Save data in Keep.
1154
1155         This method will get a list of Keep services from the API server, and
1156         send the data to each one simultaneously in a new thread.  Once the
1157         uploads are finished, if enough copies are saved, this method returns
1158         the most recent HTTP response body.  If requests fail to upload
1159         enough copies, this method raises KeepWriteError.
1160
1161         Arguments:
1162         * data: The string of data to upload.
1163         * copies: The number of copies that the user requires be saved.
1164           Default 2.
1165         * num_retries: The number of times to retry PUT requests to
1166           *each* Keep server if it returns temporary failures, with
1167           exponential backoff.  The default value is set when the
1168           KeepClient is initialized.
1169         * classes: An optional list of storage class names where copies should
1170           be written.
1171         """
1172
1173         if not isinstance(data, bytes):
1174             data = data.encode()
1175
1176         self.put_counter.add(1)
1177
1178         data_hash = hashlib.md5(data).hexdigest()
1179         loc_s = data_hash + '+' + str(len(data))
1180         if copies < 1:
1181             return loc_s
1182         locator = KeepLocator(loc_s)
1183
1184         headers = {
1185             'X-Request-Id': (request_id or
1186                              (hasattr(self, 'api_client') and self.api_client.request_id) or
1187                              arvados.util.new_request_id()),
1188             'X-Keep-Desired-Replicas': str(copies),
1189         }
1190         if len(classes) > 0:
1191             headers['X-Keep-Storage-Classes'] = ', '.join(classes)
1192         roots_map = {}
1193         loop = retry.RetryLoop(num_retries, self._check_loop_result,
1194                                backoff_start=2)
1195         done = 0
1196         for tries_left in loop:
1197             try:
1198                 sorted_roots = self.map_new_services(
1199                     roots_map, locator,
1200                     force_rebuild=(tries_left < num_retries),
1201                     need_writable=True,
1202                     headers=headers)
1203             except Exception as error:
1204                 loop.save_result(error)
1205                 continue
1206
1207             writer_pool = KeepClient.KeepWriterThreadPool(data=data,
1208                                                         data_hash=data_hash,
1209                                                         copies=copies - done,
1210                                                         max_service_replicas=self.max_replicas_per_service,
1211                                                         timeout=self.current_timeout(num_retries - tries_left),
1212                                                         classes=classes)
1213             for service_root, ks in [(root, roots_map[root])
1214                                      for root in sorted_roots]:
1215                 if ks.finished():
1216                     continue
1217                 writer_pool.add_task(ks, service_root)
1218             writer_pool.join()
1219             done += writer_pool.done()
1220             loop.save_result((done >= copies, writer_pool.total_task_nr))
1221
1222         if loop.success():
1223             return writer_pool.response()
1224         if not roots_map:
1225             raise arvados.errors.KeepWriteError(
1226                 "failed to write {}: no Keep services available ({})".format(
1227                     data_hash, loop.last_result()))
1228         else:
1229             service_errors = ((key, roots_map[key].last_result()['error'])
1230                               for key in sorted_roots
1231                               if roots_map[key].last_result()['error'])
1232             raise arvados.errors.KeepWriteError(
1233                 "failed to write {} after {} (wanted {} copies but wrote {})".format(
1234                     data_hash, loop.attempts_str(), copies, writer_pool.done()), service_errors, label="service")
1235
1236     def local_store_put(self, data, copies=1, num_retries=None):
1237         """A stub for put().
1238
1239         This method is used in place of the real put() method when
1240         using local storage (see constructor's local_store argument).
1241
1242         copies and num_retries arguments are ignored: they are here
1243         only for the sake of offering the same call signature as
1244         put().
1245
1246         Data stored this way can be retrieved via local_store_get().
1247         """
1248         md5 = hashlib.md5(data).hexdigest()
1249         locator = '%s+%d' % (md5, len(data))
1250         with open(os.path.join(self.local_store, md5 + '.tmp'), 'wb') as f:
1251             f.write(data)
1252         os.rename(os.path.join(self.local_store, md5 + '.tmp'),
1253                   os.path.join(self.local_store, md5))
1254         return locator
1255
1256     def local_store_get(self, loc_s, num_retries=None):
1257         """Companion to local_store_put()."""
1258         try:
1259             locator = KeepLocator(loc_s)
1260         except ValueError:
1261             raise arvados.errors.NotFoundError(
1262                 "Invalid data locator: '%s'" % loc_s)
1263         if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
1264             return b''
1265         with open(os.path.join(self.local_store, locator.md5sum), 'rb') as f:
1266             return f.read()
1267
1268     def local_store_head(self, loc_s, num_retries=None):
1269         """Companion to local_store_put()."""
1270         try:
1271             locator = KeepLocator(loc_s)
1272         except ValueError:
1273             raise arvados.errors.NotFoundError(
1274                 "Invalid data locator: '%s'" % loc_s)
1275         if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
1276             return True
1277         if os.path.exists(os.path.join(self.local_store, locator.md5sum)):
1278             return True
1279
1280     def is_cached(self, locator):
1281         return self.block_cache.reserve_cache(expect_hash)