11308: Fix string/bytes confusion for Python 3
[arvados.git] / sdk / python / arvados / keep.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from future import standard_library
4 standard_library.install_aliases()
5 from builtins import next
6 from builtins import str
7 from builtins import range
8 from builtins import object
9 import io
10 import datetime
11 import hashlib
12 import logging
13 import math
14 import os
15 import pycurl
16 import queue
17 import re
18 import socket
19 import ssl
20 import sys
21 import threading
22 from . import timer
23 import urllib.parse
24
25 if sys.version_info >= (3, 0):
26     from io import BytesIO
27 else:
28     from cStringIO import StringIO as BytesIO
29
30 import arvados
31 import arvados.config as config
32 import arvados.errors
33 import arvados.retry as retry
34 import arvados.util
35
36 _logger = logging.getLogger('arvados.keep')
37 global_client_object = None
38
39
40 # Monkey patch TCP constants when not available (apple). Values sourced from:
41 # http://www.opensource.apple.com/source/xnu/xnu-2422.115.4/bsd/netinet/tcp.h
42 if sys.platform == 'darwin':
43     if not hasattr(socket, 'TCP_KEEPALIVE'):
44         socket.TCP_KEEPALIVE = 0x010
45     if not hasattr(socket, 'TCP_KEEPINTVL'):
46         socket.TCP_KEEPINTVL = 0x101
47     if not hasattr(socket, 'TCP_KEEPCNT'):
48         socket.TCP_KEEPCNT = 0x102
49
50
51 class KeepLocator(object):
52     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
53     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
54
55     def __init__(self, locator_str):
56         self.hints = []
57         self._perm_sig = None
58         self._perm_expiry = None
59         pieces = iter(locator_str.split('+'))
60         self.md5sum = next(pieces)
61         try:
62             self.size = int(next(pieces))
63         except StopIteration:
64             self.size = None
65         for hint in pieces:
66             if self.HINT_RE.match(hint) is None:
67                 raise ValueError("invalid hint format: {}".format(hint))
68             elif hint.startswith('A'):
69                 self.parse_permission_hint(hint)
70             else:
71                 self.hints.append(hint)
72
73     def __str__(self):
74         return '+'.join(
75             str(s) for s in [self.md5sum, self.size,
76                              self.permission_hint()] + self.hints
77             if s is not None)
78
79     def stripped(self):
80         if self.size is not None:
81             return "%s+%i" % (self.md5sum, self.size)
82         else:
83             return self.md5sum
84
85     def _make_hex_prop(name, length):
86         # Build and return a new property with the given name that
87         # must be a hex string of the given length.
88         data_name = '_{}'.format(name)
89         def getter(self):
90             return getattr(self, data_name)
91         def setter(self, hex_str):
92             if not arvados.util.is_hex(hex_str, length):
93                 raise ValueError("{} is not a {}-digit hex string: {!r}".
94                                  format(name, length, hex_str))
95             setattr(self, data_name, hex_str)
96         return property(getter, setter)
97
98     md5sum = _make_hex_prop('md5sum', 32)
99     perm_sig = _make_hex_prop('perm_sig', 40)
100
101     @property
102     def perm_expiry(self):
103         return self._perm_expiry
104
105     @perm_expiry.setter
106     def perm_expiry(self, value):
107         if not arvados.util.is_hex(value, 1, 8):
108             raise ValueError(
109                 "permission timestamp must be a hex Unix timestamp: {}".
110                 format(value))
111         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
112
113     def permission_hint(self):
114         data = [self.perm_sig, self.perm_expiry]
115         if None in data:
116             return None
117         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
118         return "A{}@{:08x}".format(*data)
119
120     def parse_permission_hint(self, s):
121         try:
122             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
123         except IndexError:
124             raise ValueError("bad permission hint {}".format(s))
125
126     def permission_expired(self, as_of_dt=None):
127         if self.perm_expiry is None:
128             return False
129         elif as_of_dt is None:
130             as_of_dt = datetime.datetime.now()
131         return self.perm_expiry <= as_of_dt
132
133
134 class Keep(object):
135     """Simple interface to a global KeepClient object.
136
137     THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
138     own API client.  The global KeepClient will build an API client from the
139     current Arvados configuration, which may not match the one you built.
140     """
141     _last_key = None
142
143     @classmethod
144     def global_client_object(cls):
145         global global_client_object
146         # Previously, KeepClient would change its behavior at runtime based
147         # on these configuration settings.  We simulate that behavior here
148         # by checking the values and returning a new KeepClient if any of
149         # them have changed.
150         key = (config.get('ARVADOS_API_HOST'),
151                config.get('ARVADOS_API_TOKEN'),
152                config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
153                config.get('ARVADOS_KEEP_PROXY'),
154                config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
155                os.environ.get('KEEP_LOCAL_STORE'))
156         if (global_client_object is None) or (cls._last_key != key):
157             global_client_object = KeepClient()
158             cls._last_key = key
159         return global_client_object
160
161     @staticmethod
162     def get(locator, **kwargs):
163         return Keep.global_client_object().get(locator, **kwargs)
164
165     @staticmethod
166     def put(data, **kwargs):
167         return Keep.global_client_object().put(data, **kwargs)
168
169 class KeepBlockCache(object):
170     # Default RAM cache is 256MiB
171     def __init__(self, cache_max=(256 * 1024 * 1024)):
172         self.cache_max = cache_max
173         self._cache = []
174         self._cache_lock = threading.Lock()
175
176     class CacheSlot(object):
177         __slots__ = ("locator", "ready", "content")
178
179         def __init__(self, locator):
180             self.locator = locator
181             self.ready = threading.Event()
182             self.content = None
183
184         def get(self):
185             self.ready.wait()
186             return self.content
187
188         def set(self, value):
189             self.content = value
190             self.ready.set()
191
192         def size(self):
193             if self.content is None:
194                 return 0
195             else:
196                 return len(self.content)
197
198     def cap_cache(self):
199         '''Cap the cache size to self.cache_max'''
200         with self._cache_lock:
201             # Select all slots except those where ready.is_set() and content is
202             # None (that means there was an error reading the block).
203             self._cache = [c for c in self._cache if not (c.ready.is_set() and c.content is None)]
204             sm = sum([slot.size() for slot in self._cache])
205             while len(self._cache) > 0 and sm > self.cache_max:
206                 for i in range(len(self._cache)-1, -1, -1):
207                     if self._cache[i].ready.is_set():
208                         del self._cache[i]
209                         break
210                 sm = sum([slot.size() for slot in self._cache])
211
212     def _get(self, locator):
213         # Test if the locator is already in the cache
214         for i in range(0, len(self._cache)):
215             if self._cache[i].locator == locator:
216                 n = self._cache[i]
217                 if i != 0:
218                     # move it to the front
219                     del self._cache[i]
220                     self._cache.insert(0, n)
221                 return n
222         return None
223
224     def get(self, locator):
225         with self._cache_lock:
226             return self._get(locator)
227
228     def reserve_cache(self, locator):
229         '''Reserve a cache slot for the specified locator,
230         or return the existing slot.'''
231         with self._cache_lock:
232             n = self._get(locator)
233             if n:
234                 return n, False
235             else:
236                 # Add a new cache slot for the locator
237                 n = KeepBlockCache.CacheSlot(locator)
238                 self._cache.insert(0, n)
239                 return n, True
240
241 class Counter(object):
242     def __init__(self, v=0):
243         self._lk = threading.Lock()
244         self._val = v
245
246     def add(self, v):
247         with self._lk:
248             self._val += v
249
250     def get(self):
251         with self._lk:
252             return self._val
253
254
255 class KeepClient(object):
256
257     # Default Keep server connection timeout:  2 seconds
258     # Default Keep server read timeout:       256 seconds
259     # Default Keep server bandwidth minimum:  32768 bytes per second
260     # Default Keep proxy connection timeout:  20 seconds
261     # Default Keep proxy read timeout:        256 seconds
262     # Default Keep proxy bandwidth minimum:   32768 bytes per second
263     DEFAULT_TIMEOUT = (2, 256, 32768)
264     DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
265
266
267     class KeepService(object):
268         """Make requests to a single Keep service, and track results.
269
270         A KeepService is intended to last long enough to perform one
271         transaction (GET or PUT) against one Keep service. This can
272         involve calling either get() or put() multiple times in order
273         to retry after transient failures. However, calling both get()
274         and put() on a single instance -- or using the same instance
275         to access two different Keep services -- will not produce
276         sensible behavior.
277         """
278
279         HTTP_ERRORS = (
280             socket.error,
281             ssl.SSLError,
282             arvados.errors.HttpError,
283         )
284
285         def __init__(self, root, user_agent_pool=queue.LifoQueue(),
286                      upload_counter=None,
287                      download_counter=None, **headers):
288             self.root = root
289             self._user_agent_pool = user_agent_pool
290             self._result = {'error': None}
291             self._usable = True
292             self._session = None
293             self.get_headers = {'Accept': 'application/octet-stream'}
294             self.get_headers.update(headers)
295             self.put_headers = headers
296             self.upload_counter = upload_counter
297             self.download_counter = download_counter
298
299         def usable(self):
300             """Is it worth attempting a request?"""
301             return self._usable
302
303         def finished(self):
304             """Did the request succeed or encounter permanent failure?"""
305             return self._result['error'] == False or not self._usable
306
307         def last_result(self):
308             return self._result
309
310         def _get_user_agent(self):
311             try:
312                 return self._user_agent_pool.get(block=False)
313             except queue.Empty:
314                 return pycurl.Curl()
315
316         def _put_user_agent(self, ua):
317             try:
318                 ua.reset()
319                 self._user_agent_pool.put(ua, block=False)
320             except:
321                 ua.close()
322
323         @staticmethod
324         def _socket_open(family, socktype, protocol, address=None):
325             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
326             s = socket.socket(family, socktype, protocol)
327             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
328             # Will throw invalid protocol error on mac. This test prevents that.
329             if hasattr(socket, 'TCP_KEEPIDLE'):
330                 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
331             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
332             return s
333
334         def get(self, locator, method="GET", timeout=None):
335             # locator is a KeepLocator object.
336             url = self.root + str(locator)
337             _logger.debug("Request: %s %s", method, url)
338             curl = self._get_user_agent()
339             ok = None
340             try:
341                 with timer.Timer() as t:
342                     self._headers = {}
343                     response_body = BytesIO()
344                     curl.setopt(pycurl.NOSIGNAL, 1)
345                     curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
346                     curl.setopt(pycurl.URL, url.encode('utf-8'))
347                     curl.setopt(pycurl.HTTPHEADER, [
348                         '{}: {}'.format(k,v) for k,v in self.get_headers.items()])
349                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
350                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
351                     if method == "HEAD":
352                         curl.setopt(pycurl.NOBODY, True)
353                     self._setcurltimeouts(curl, timeout)
354
355                     try:
356                         curl.perform()
357                     except Exception as e:
358                         raise arvados.errors.HttpError(0, str(e))
359                     self._result = {
360                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
361                         'body': response_body.getvalue(),
362                         'headers': self._headers,
363                         'error': False,
364                     }
365
366                 ok = retry.check_http_response_success(self._result['status_code'])
367                 if not ok:
368                     self._result['error'] = arvados.errors.HttpError(
369                         self._result['status_code'],
370                         self._headers.get('x-status-line', 'Error'))
371             except self.HTTP_ERRORS as e:
372                 self._result = {
373                     'error': e,
374                 }
375             self._usable = ok != False
376             if self._result.get('status_code', None):
377                 # The client worked well enough to get an HTTP status
378                 # code, so presumably any problems are just on the
379                 # server side and it's OK to reuse the client.
380                 self._put_user_agent(curl)
381             else:
382                 # Don't return this client to the pool, in case it's
383                 # broken.
384                 curl.close()
385             if not ok:
386                 _logger.debug("Request fail: GET %s => %s: %s",
387                               url, type(self._result['error']), str(self._result['error']))
388                 return None
389             if method == "HEAD":
390                 _logger.info("HEAD %s: %s bytes",
391                          self._result['status_code'],
392                          self._result.get('content-length'))
393                 return True
394
395             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
396                          self._result['status_code'],
397                          len(self._result['body']),
398                          t.msecs,
399                          1.0*len(self._result['body'])/2**20/t.secs if t.secs > 0 else 0)
400
401             if self.download_counter:
402                 self.download_counter.add(len(self._result['body']))
403             resp_md5 = hashlib.md5(self._result['body']).hexdigest()
404             if resp_md5 != locator.md5sum:
405                 _logger.warning("Checksum fail: md5(%s) = %s",
406                                 url, resp_md5)
407                 self._result['error'] = arvados.errors.HttpError(
408                     0, 'Checksum fail')
409                 return None
410             return self._result['body']
411
412         def put(self, hash_s, body, timeout=None):
413             url = self.root + hash_s
414             _logger.debug("Request: PUT %s", url)
415             curl = self._get_user_agent()
416             ok = None
417             try:
418                 with timer.Timer() as t:
419                     self._headers = {}
420                     body_reader = BytesIO(body)
421                     response_body = BytesIO()
422                     curl.setopt(pycurl.NOSIGNAL, 1)
423                     curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
424                     curl.setopt(pycurl.URL, url.encode('utf-8'))
425                     # Using UPLOAD tells cURL to wait for a "go ahead" from the
426                     # Keep server (in the form of a HTTP/1.1 "100 Continue"
427                     # response) instead of sending the request body immediately.
428                     # This allows the server to reject the request if the request
429                     # is invalid or the server is read-only, without waiting for
430                     # the client to send the entire block.
431                     curl.setopt(pycurl.UPLOAD, True)
432                     curl.setopt(pycurl.INFILESIZE, len(body))
433                     curl.setopt(pycurl.READFUNCTION, body_reader.read)
434                     curl.setopt(pycurl.HTTPHEADER, [
435                         '{}: {}'.format(k,v) for k,v in self.put_headers.items()])
436                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
437                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
438                     self._setcurltimeouts(curl, timeout)
439                     try:
440                         curl.perform()
441                     except Exception as e:
442                         raise arvados.errors.HttpError(0, str(e))
443                     self._result = {
444                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
445                         'body': response_body.getvalue().decode('utf-8'),
446                         'headers': self._headers,
447                         'error': False,
448                     }
449                 ok = retry.check_http_response_success(self._result['status_code'])
450                 if not ok:
451                     self._result['error'] = arvados.errors.HttpError(
452                         self._result['status_code'],
453                         self._headers.get('x-status-line', 'Error'))
454             except self.HTTP_ERRORS as e:
455                 self._result = {
456                     'error': e,
457                 }
458             self._usable = ok != False # still usable if ok is True or None
459             if self._result.get('status_code', None):
460                 # Client is functional. See comment in get().
461                 self._put_user_agent(curl)
462             else:
463                 curl.close()
464             if not ok:
465                 _logger.debug("Request fail: PUT %s => %s: %s",
466                               url, type(self._result['error']), str(self._result['error']))
467                 return False
468             _logger.info("PUT %s: %s bytes in %s msec (%.3f MiB/sec)",
469                          self._result['status_code'],
470                          len(body),
471                          t.msecs,
472                          1.0*len(body)/2**20/t.secs if t.secs > 0 else 0)
473             if self.upload_counter:
474                 self.upload_counter.add(len(body))
475             return True
476
477         def _setcurltimeouts(self, curl, timeouts):
478             if not timeouts:
479                 return
480             elif isinstance(timeouts, tuple):
481                 if len(timeouts) == 2:
482                     conn_t, xfer_t = timeouts
483                     bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
484                 else:
485                     conn_t, xfer_t, bandwidth_bps = timeouts
486             else:
487                 conn_t, xfer_t = (timeouts, timeouts)
488                 bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
489             curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
490             curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
491             curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
492
493         def _headerfunction(self, header_line):
494             if isinstance(header_line, bytes):
495                 header_line = header_line.decode('iso-8859-1')
496             if ':' in header_line:
497                 name, value = header_line.split(':', 1)
498                 name = name.strip().lower()
499                 value = value.strip()
500             elif self._headers:
501                 name = self._lastheadername
502                 value = self._headers[name] + ' ' + header_line.strip()
503             elif header_line.startswith('HTTP/'):
504                 name = 'x-status-line'
505                 value = header_line
506             else:
507                 _logger.error("Unexpected header line: %s", header_line)
508                 return
509             self._lastheadername = name
510             self._headers[name] = value
511             # Returning None implies all bytes were written
512     
513
514     class KeepWriterQueue(queue.Queue):
515         def __init__(self, copies):
516             queue.Queue.__init__(self) # Old-style superclass
517             self.wanted_copies = copies
518             self.successful_copies = 0
519             self.response = None
520             self.successful_copies_lock = threading.Lock()
521             self.pending_tries = copies
522             self.pending_tries_notification = threading.Condition()
523         
524         def write_success(self, response, replicas_nr):
525             with self.successful_copies_lock:
526                 self.successful_copies += replicas_nr
527                 self.response = response
528             with self.pending_tries_notification:
529                 self.pending_tries_notification.notify_all()
530         
531         def write_fail(self, ks):
532             with self.pending_tries_notification:
533                 self.pending_tries += 1
534                 self.pending_tries_notification.notify()
535         
536         def pending_copies(self):
537             with self.successful_copies_lock:
538                 return self.wanted_copies - self.successful_copies
539
540         def get_next_task(self):
541             with self.pending_tries_notification:
542                 while True:
543                     if self.pending_copies() < 1:
544                         # This notify_all() is unnecessary --
545                         # write_success() already called notify_all()
546                         # when pending<1 became true, so it's not
547                         # possible for any other thread to be in
548                         # wait() now -- but it's cheap insurance
549                         # against deadlock so we do it anyway:
550                         self.pending_tries_notification.notify_all()
551                         # Drain the queue and then raise Queue.Empty
552                         while True:
553                             self.get_nowait()
554                             self.task_done()
555                     elif self.pending_tries > 0:
556                         service, service_root = self.get_nowait()
557                         if service.finished():
558                             self.task_done()
559                             continue
560                         self.pending_tries -= 1
561                         return service, service_root
562                     elif self.empty():
563                         self.pending_tries_notification.notify_all()
564                         raise queue.Empty
565                     else:
566                         self.pending_tries_notification.wait()
567
568
569     class KeepWriterThreadPool(object):
570         def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None):
571             self.total_task_nr = 0
572             self.wanted_copies = copies
573             if (not max_service_replicas) or (max_service_replicas >= copies):
574                 num_threads = 1
575             else:
576                 num_threads = int(math.ceil(1.0*copies/max_service_replicas))
577             _logger.debug("Pool max threads is %d", num_threads)
578             self.workers = []
579             self.queue = KeepClient.KeepWriterQueue(copies)
580             # Create workers
581             for _ in range(num_threads):
582                 w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
583                 self.workers.append(w)
584         
585         def add_task(self, ks, service_root):
586             self.queue.put((ks, service_root))
587             self.total_task_nr += 1
588         
589         def done(self):
590             return self.queue.successful_copies
591         
592         def join(self):
593             # Start workers
594             for worker in self.workers:
595                 worker.start()
596             # Wait for finished work
597             self.queue.join()
598         
599         def response(self):
600             return self.queue.response
601     
602     
603     class KeepWriterThread(threading.Thread):
604         TaskFailed = RuntimeError()
605
606         def __init__(self, queue, data, data_hash, timeout=None):
607             super(KeepClient.KeepWriterThread, self).__init__()
608             self.timeout = timeout
609             self.queue = queue
610             self.data = data
611             self.data_hash = data_hash
612             self.daemon = True
613
614         def run(self):
615             while True:
616                 try:
617                     service, service_root = self.queue.get_next_task()
618                 except queue.Empty:
619                     return
620                 try:
621                     locator, copies = self.do_task(service, service_root)
622                 except Exception as e:
623                     if e is not self.TaskFailed:
624                         _logger.exception("Exception in KeepWriterThread")
625                     self.queue.write_fail(service)
626                 else:
627                     self.queue.write_success(locator, copies)
628                 finally:
629                     self.queue.task_done()
630
631         def do_task(self, service, service_root):
632             success = bool(service.put(self.data_hash,
633                                         self.data,
634                                         timeout=self.timeout))
635             result = service.last_result()
636
637             if not success:
638                 if result.get('status_code', None):
639                     _logger.debug("Request fail: PUT %s => %s %s",
640                                   self.data_hash,
641                                   result['status_code'],
642                                   result['body'])
643                 raise self.TaskFailed
644
645             _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
646                           str(threading.current_thread()),
647                           self.data_hash,
648                           len(self.data),
649                           service_root)
650             try:
651                 replicas_stored = int(result['headers']['x-keep-replicas-stored'])
652             except (KeyError, ValueError):
653                 replicas_stored = 1
654
655             return result['body'].strip(), replicas_stored
656
657
658     def __init__(self, api_client=None, proxy=None,
659                  timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
660                  api_token=None, local_store=None, block_cache=None,
661                  num_retries=0, session=None):
662         """Initialize a new KeepClient.
663
664         Arguments:
665         :api_client:
666           The API client to use to find Keep services.  If not
667           provided, KeepClient will build one from available Arvados
668           configuration.
669
670         :proxy:
671           If specified, this KeepClient will send requests to this Keep
672           proxy.  Otherwise, KeepClient will fall back to the setting of the
673           ARVADOS_KEEP_SERVICES or ARVADOS_KEEP_PROXY configuration settings.
674           If you want to KeepClient does not use a proxy, pass in an empty
675           string.
676
677         :timeout:
678           The initial timeout (in seconds) for HTTP requests to Keep
679           non-proxy servers.  A tuple of three floats is interpreted as
680           (connection_timeout, read_timeout, minimum_bandwidth). A connection
681           will be aborted if the average traffic rate falls below
682           minimum_bandwidth bytes per second over an interval of read_timeout
683           seconds. Because timeouts are often a result of transient server
684           load, the actual connection timeout will be increased by a factor
685           of two on each retry.
686           Default: (2, 256, 32768).
687
688         :proxy_timeout:
689           The initial timeout (in seconds) for HTTP requests to
690           Keep proxies. A tuple of three floats is interpreted as
691           (connection_timeout, read_timeout, minimum_bandwidth). The behavior
692           described above for adjusting connection timeouts on retry also
693           applies.
694           Default: (20, 256, 32768).
695
696         :api_token:
697           If you're not using an API client, but only talking
698           directly to a Keep proxy, this parameter specifies an API token
699           to authenticate Keep requests.  It is an error to specify both
700           api_client and api_token.  If you specify neither, KeepClient
701           will use one available from the Arvados configuration.
702
703         :local_store:
704           If specified, this KeepClient will bypass Keep
705           services, and save data to the named directory.  If unspecified,
706           KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
707           environment variable.  If you want to ensure KeepClient does not
708           use local storage, pass in an empty string.  This is primarily
709           intended to mock a server for testing.
710
711         :num_retries:
712           The default number of times to retry failed requests.
713           This will be used as the default num_retries value when get() and
714           put() are called.  Default 0.
715         """
716         self.lock = threading.Lock()
717         if proxy is None:
718             if config.get('ARVADOS_KEEP_SERVICES'):
719                 proxy = config.get('ARVADOS_KEEP_SERVICES')
720             else:
721                 proxy = config.get('ARVADOS_KEEP_PROXY')
722         if api_token is None:
723             if api_client is None:
724                 api_token = config.get('ARVADOS_API_TOKEN')
725             else:
726                 api_token = api_client.api_token
727         elif api_client is not None:
728             raise ValueError(
729                 "can't build KeepClient with both API client and token")
730         if local_store is None:
731             local_store = os.environ.get('KEEP_LOCAL_STORE')
732
733         self.block_cache = block_cache if block_cache else KeepBlockCache()
734         self.timeout = timeout
735         self.proxy_timeout = proxy_timeout
736         self._user_agent_pool = queue.LifoQueue()
737         self.upload_counter = Counter()
738         self.download_counter = Counter()
739         self.put_counter = Counter()
740         self.get_counter = Counter()
741         self.hits_counter = Counter()
742         self.misses_counter = Counter()
743
744         if local_store:
745             self.local_store = local_store
746             self.get = self.local_store_get
747             self.put = self.local_store_put
748         else:
749             self.num_retries = num_retries
750             self.max_replicas_per_service = None
751             if proxy:
752                 proxy_uris = proxy.split()
753                 for i in range(len(proxy_uris)):
754                     if not proxy_uris[i].endswith('/'):
755                         proxy_uris[i] += '/'
756                     # URL validation
757                     url = urllib.parse.urlparse(proxy_uris[i])
758                     if not (url.scheme and url.netloc):
759                         raise arvados.errors.ArgumentError("Invalid proxy URI: {}".format(proxy_uris[i]))
760                 self.api_token = api_token
761                 self._gateway_services = {}
762                 self._keep_services = [{
763                     'uuid': "00000-bi6l4-%015d" % idx,
764                     'service_type': 'proxy',
765                     '_service_root': uri,
766                     } for idx, uri in enumerate(proxy_uris)]
767                 self._writable_services = self._keep_services
768                 self.using_proxy = True
769                 self._static_services_list = True
770             else:
771                 # It's important to avoid instantiating an API client
772                 # unless we actually need one, for testing's sake.
773                 if api_client is None:
774                     api_client = arvados.api('v1')
775                 self.api_client = api_client
776                 self.api_token = api_client.api_token
777                 self._gateway_services = {}
778                 self._keep_services = None
779                 self._writable_services = None
780                 self.using_proxy = None
781                 self._static_services_list = False
782
783     def current_timeout(self, attempt_number):
784         """Return the appropriate timeout to use for this client.
785
786         The proxy timeout setting if the backend service is currently a proxy,
787         the regular timeout setting otherwise.  The `attempt_number` indicates
788         how many times the operation has been tried already (starting from 0
789         for the first try), and scales the connection timeout portion of the
790         return value accordingly.
791
792         """
793         # TODO(twp): the timeout should be a property of a
794         # KeepService, not a KeepClient. See #4488.
795         t = self.proxy_timeout if self.using_proxy else self.timeout
796         if len(t) == 2:
797             return (t[0] * (1 << attempt_number), t[1])
798         else:
799             return (t[0] * (1 << attempt_number), t[1], t[2])
800     def _any_nondisk_services(self, service_list):
801         return any(ks.get('service_type', 'disk') != 'disk'
802                    for ks in service_list)
803
804     def build_services_list(self, force_rebuild=False):
805         if (self._static_services_list or
806               (self._keep_services and not force_rebuild)):
807             return
808         with self.lock:
809             try:
810                 keep_services = self.api_client.keep_services().accessible()
811             except Exception:  # API server predates Keep services.
812                 keep_services = self.api_client.keep_disks().list()
813
814             # Gateway services are only used when specified by UUID,
815             # so there's nothing to gain by filtering them by
816             # service_type.
817             self._gateway_services = {ks['uuid']: ks for ks in
818                                       keep_services.execute()['items']}
819             if not self._gateway_services:
820                 raise arvados.errors.NoKeepServersError()
821
822             # Precompute the base URI for each service.
823             for r in self._gateway_services.values():
824                 host = r['service_host']
825                 if not host.startswith('[') and host.find(':') >= 0:
826                     # IPv6 URIs must be formatted like http://[::1]:80/...
827                     host = '[' + host + ']'
828                 r['_service_root'] = "{}://{}:{:d}/".format(
829                     'https' if r['service_ssl_flag'] else 'http',
830                     host,
831                     r['service_port'])
832
833             _logger.debug(str(self._gateway_services))
834             self._keep_services = [
835                 ks for ks in self._gateway_services.values()
836                 if not ks.get('service_type', '').startswith('gateway:')]
837             self._writable_services = [ks for ks in self._keep_services
838                                        if not ks.get('read_only')]
839
840             # For disk type services, max_replicas_per_service is 1
841             # It is unknown (unlimited) for other service types.
842             if self._any_nondisk_services(self._writable_services):
843                 self.max_replicas_per_service = None
844             else:
845                 self.max_replicas_per_service = 1
846
847     def _service_weight(self, data_hash, service_uuid):
848         """Compute the weight of a Keep service endpoint for a data
849         block with a known hash.
850
851         The weight is md5(h + u) where u is the last 15 characters of
852         the service endpoint's UUID.
853         """
854         return hashlib.md5((data_hash + service_uuid[-15:]).encode()).hexdigest()
855
856     def weighted_service_roots(self, locator, force_rebuild=False, need_writable=False):
857         """Return an array of Keep service endpoints, in the order in
858         which they should be probed when reading or writing data with
859         the given hash+hints.
860         """
861         self.build_services_list(force_rebuild)
862
863         sorted_roots = []
864         # Use the services indicated by the given +K@... remote
865         # service hints, if any are present and can be resolved to a
866         # URI.
867         for hint in locator.hints:
868             if hint.startswith('K@'):
869                 if len(hint) == 7:
870                     sorted_roots.append(
871                         "https://keep.{}.arvadosapi.com/".format(hint[2:]))
872                 elif len(hint) == 29:
873                     svc = self._gateway_services.get(hint[2:])
874                     if svc:
875                         sorted_roots.append(svc['_service_root'])
876
877         # Sort the available local services by weight (heaviest first)
878         # for this locator, and return their service_roots (base URIs)
879         # in that order.
880         use_services = self._keep_services
881         if need_writable:
882             use_services = self._writable_services
883         self.using_proxy = self._any_nondisk_services(use_services)
884         sorted_roots.extend([
885             svc['_service_root'] for svc in sorted(
886                 use_services,
887                 reverse=True,
888                 key=lambda svc: self._service_weight(locator.md5sum, svc['uuid']))])
889         _logger.debug("{}: {}".format(locator, sorted_roots))
890         return sorted_roots
891
892     def map_new_services(self, roots_map, locator, force_rebuild, need_writable, **headers):
893         # roots_map is a dictionary, mapping Keep service root strings
894         # to KeepService objects.  Poll for Keep services, and add any
895         # new ones to roots_map.  Return the current list of local
896         # root strings.
897         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
898         local_roots = self.weighted_service_roots(locator, force_rebuild, need_writable)
899         for root in local_roots:
900             if root not in roots_map:
901                 roots_map[root] = self.KeepService(
902                     root, self._user_agent_pool,
903                     upload_counter=self.upload_counter,
904                     download_counter=self.download_counter,
905                     **headers)
906         return local_roots
907
908     @staticmethod
909     def _check_loop_result(result):
910         # KeepClient RetryLoops should save results as a 2-tuple: the
911         # actual result of the request, and the number of servers available
912         # to receive the request this round.
913         # This method returns True if there's a real result, False if
914         # there are no more servers available, otherwise None.
915         if isinstance(result, Exception):
916             return None
917         result, tried_server_count = result
918         if (result is not None) and (result is not False):
919             return True
920         elif tried_server_count < 1:
921             _logger.info("No more Keep services to try; giving up")
922             return False
923         else:
924             return None
925
926     def get_from_cache(self, loc):
927         """Fetch a block only if is in the cache, otherwise return None."""
928         slot = self.block_cache.get(loc)
929         if slot is not None and slot.ready.is_set():
930             return slot.get()
931         else:
932             return None
933
934     @retry.retry_method
935     def head(self, loc_s, num_retries=None):
936         return self._get_or_head(loc_s, method="HEAD", num_retries=num_retries)
937
938     @retry.retry_method
939     def get(self, loc_s, num_retries=None):
940         return self._get_or_head(loc_s, method="GET", num_retries=num_retries)
941
942     def _get_or_head(self, loc_s, method="GET", num_retries=None):
943         """Get data from Keep.
944
945         This method fetches one or more blocks of data from Keep.  It
946         sends a request each Keep service registered with the API
947         server (or the proxy provided when this client was
948         instantiated), then each service named in location hints, in
949         sequence.  As soon as one service provides the data, it's
950         returned.
951
952         Arguments:
953         * loc_s: A string of one or more comma-separated locators to fetch.
954           This method returns the concatenation of these blocks.
955         * num_retries: The number of times to retry GET requests to
956           *each* Keep server if it returns temporary failures, with
957           exponential backoff.  Note that, in each loop, the method may try
958           to fetch data from every available Keep service, along with any
959           that are named in location hints in the locator.  The default value
960           is set when the KeepClient is initialized.
961         """
962         if ',' in loc_s:
963             return ''.join(self.get(x) for x in loc_s.split(','))
964
965         self.get_counter.add(1)
966
967         locator = KeepLocator(loc_s)
968         if method == "GET":
969             slot, first = self.block_cache.reserve_cache(locator.md5sum)
970             if not first:
971                 self.hits_counter.add(1)
972                 v = slot.get()
973                 return v
974
975         self.misses_counter.add(1)
976
977         # If the locator has hints specifying a prefix (indicating a
978         # remote keepproxy) or the UUID of a local gateway service,
979         # read data from the indicated service(s) instead of the usual
980         # list of local disk services.
981         hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
982                       for hint in locator.hints if hint.startswith('K@') and len(hint) == 7]
983         hint_roots.extend([self._gateway_services[hint[2:]]['_service_root']
984                            for hint in locator.hints if (
985                                    hint.startswith('K@') and
986                                    len(hint) == 29 and
987                                    self._gateway_services.get(hint[2:])
988                                    )])
989         # Map root URLs to their KeepService objects.
990         roots_map = {
991             root: self.KeepService(root, self._user_agent_pool,
992                                    upload_counter=self.upload_counter,
993                                    download_counter=self.download_counter)
994             for root in hint_roots
995         }
996
997         # See #3147 for a discussion of the loop implementation.  Highlights:
998         # * Refresh the list of Keep services after each failure, in case
999         #   it's being updated.
1000         # * Retry until we succeed, we're out of retries, or every available
1001         #   service has returned permanent failure.
1002         sorted_roots = []
1003         roots_map = {}
1004         blob = None
1005         loop = retry.RetryLoop(num_retries, self._check_loop_result,
1006                                backoff_start=2)
1007         for tries_left in loop:
1008             try:
1009                 sorted_roots = self.map_new_services(
1010                     roots_map, locator,
1011                     force_rebuild=(tries_left < num_retries),
1012                     need_writable=False)
1013             except Exception as error:
1014                 loop.save_result(error)
1015                 continue
1016
1017             # Query KeepService objects that haven't returned
1018             # permanent failure, in our specified shuffle order.
1019             services_to_try = [roots_map[root]
1020                                for root in sorted_roots
1021                                if roots_map[root].usable()]
1022             for keep_service in services_to_try:
1023                 blob = keep_service.get(locator, method=method, timeout=self.current_timeout(num_retries-tries_left))
1024                 if blob is not None:
1025                     break
1026             loop.save_result((blob, len(services_to_try)))
1027
1028         # Always cache the result, then return it if we succeeded.
1029         if method == "GET":
1030             slot.set(blob)
1031             self.block_cache.cap_cache()
1032         if loop.success():
1033             if method == "HEAD":
1034                 return True
1035             else:
1036                 return blob
1037
1038         # Q: Including 403 is necessary for the Keep tests to continue
1039         # passing, but maybe they should expect KeepReadError instead?
1040         not_founds = sum(1 for key in sorted_roots
1041                          if roots_map[key].last_result().get('status_code', None) in {403, 404, 410})
1042         service_errors = ((key, roots_map[key].last_result()['error'])
1043                           for key in sorted_roots)
1044         if not roots_map:
1045             raise arvados.errors.KeepReadError(
1046                 "failed to read {}: no Keep services available ({})".format(
1047                     loc_s, loop.last_result()))
1048         elif not_founds == len(sorted_roots):
1049             raise arvados.errors.NotFoundError(
1050                 "{} not found".format(loc_s), service_errors)
1051         else:
1052             raise arvados.errors.KeepReadError(
1053                 "failed to read {}".format(loc_s), service_errors, label="service")
1054
1055     @retry.retry_method
1056     def put(self, data, copies=2, num_retries=None):
1057         """Save data in Keep.
1058
1059         This method will get a list of Keep services from the API server, and
1060         send the data to each one simultaneously in a new thread.  Once the
1061         uploads are finished, if enough copies are saved, this method returns
1062         the most recent HTTP response body.  If requests fail to upload
1063         enough copies, this method raises KeepWriteError.
1064
1065         Arguments:
1066         * data: The string of data to upload.
1067         * copies: The number of copies that the user requires be saved.
1068           Default 2.
1069         * num_retries: The number of times to retry PUT requests to
1070           *each* Keep server if it returns temporary failures, with
1071           exponential backoff.  The default value is set when the
1072           KeepClient is initialized.
1073         """
1074
1075         if not isinstance(data, bytes):
1076             data = data.encode()
1077
1078         self.put_counter.add(1)
1079
1080         data_hash = hashlib.md5(data).hexdigest()
1081         loc_s = data_hash + '+' + str(len(data))
1082         if copies < 1:
1083             return loc_s
1084         locator = KeepLocator(loc_s)
1085
1086         headers = {}
1087         # Tell the proxy how many copies we want it to store
1088         headers['X-Keep-Desired-Replicas'] = str(copies)
1089         roots_map = {}
1090         loop = retry.RetryLoop(num_retries, self._check_loop_result,
1091                                backoff_start=2)
1092         done = 0
1093         for tries_left in loop:
1094             try:
1095                 sorted_roots = self.map_new_services(
1096                     roots_map, locator,
1097                     force_rebuild=(tries_left < num_retries), need_writable=True, **headers)
1098             except Exception as error:
1099                 loop.save_result(error)
1100                 continue
1101
1102             writer_pool = KeepClient.KeepWriterThreadPool(data=data, 
1103                                                         data_hash=data_hash,
1104                                                         copies=copies - done,
1105                                                         max_service_replicas=self.max_replicas_per_service,
1106                                                         timeout=self.current_timeout(num_retries - tries_left))
1107             for service_root, ks in [(root, roots_map[root])
1108                                      for root in sorted_roots]:
1109                 if ks.finished():
1110                     continue
1111                 writer_pool.add_task(ks, service_root)
1112             writer_pool.join()
1113             done += writer_pool.done()
1114             loop.save_result((done >= copies, writer_pool.total_task_nr))
1115
1116         if loop.success():
1117             return writer_pool.response()
1118         if not roots_map:
1119             raise arvados.errors.KeepWriteError(
1120                 "failed to write {}: no Keep services available ({})".format(
1121                     data_hash, loop.last_result()))
1122         else:
1123             service_errors = ((key, roots_map[key].last_result()['error'])
1124                               for key in sorted_roots
1125                               if roots_map[key].last_result()['error'])
1126             raise arvados.errors.KeepWriteError(
1127                 "failed to write {} (wanted {} copies but wrote {})".format(
1128                     data_hash, copies, writer_pool.done()), service_errors, label="service")
1129
1130     def local_store_put(self, data, copies=1, num_retries=None):
1131         """A stub for put().
1132
1133         This method is used in place of the real put() method when
1134         using local storage (see constructor's local_store argument).
1135
1136         copies and num_retries arguments are ignored: they are here
1137         only for the sake of offering the same call signature as
1138         put().
1139
1140         Data stored this way can be retrieved via local_store_get().
1141         """
1142         md5 = hashlib.md5(data).hexdigest()
1143         locator = '%s+%d' % (md5, len(data))
1144         with open(os.path.join(self.local_store, md5 + '.tmp'), 'wb') as f:
1145             f.write(data)
1146         os.rename(os.path.join(self.local_store, md5 + '.tmp'),
1147                   os.path.join(self.local_store, md5))
1148         return locator
1149
1150     def local_store_get(self, loc_s, num_retries=None):
1151         """Companion to local_store_put()."""
1152         try:
1153             locator = KeepLocator(loc_s)
1154         except ValueError:
1155             raise arvados.errors.NotFoundError(
1156                 "Invalid data locator: '%s'" % loc_s)
1157         if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
1158             return b''
1159         with open(os.path.join(self.local_store, locator.md5sum), 'rb') as f:
1160             return f.read()
1161
1162     def is_cached(self, locator):
1163         return self.block_cache.reserve_cache(expect_hash)