Merge branch '9570-cwl-runner-fixes' closes #9570 (again)
[arvados.git] / sdk / python / arvados / keep.py
1 import cStringIO
2 import datetime
3 import hashlib
4 import logging
5 import math
6 import os
7 import pycurl
8 import Queue
9 import re
10 import socket
11 import ssl
12 import sys
13 import threading
14 import timer
15
16 import arvados
17 import arvados.config as config
18 import arvados.errors
19 import arvados.retry as retry
20 import arvados.util
21
22 _logger = logging.getLogger('arvados.keep')
23 global_client_object = None
24
25
26 # Monkey patch TCP constants when not available (apple). Values sourced from:
27 # http://www.opensource.apple.com/source/xnu/xnu-2422.115.4/bsd/netinet/tcp.h
28 if sys.platform == 'darwin':
29     if not hasattr(socket, 'TCP_KEEPALIVE'):
30         socket.TCP_KEEPALIVE = 0x010
31     if not hasattr(socket, 'TCP_KEEPINTVL'):
32         socket.TCP_KEEPINTVL = 0x101
33     if not hasattr(socket, 'TCP_KEEPCNT'):
34         socket.TCP_KEEPCNT = 0x102
35
36
37 class KeepLocator(object):
38     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
39     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
40
41     def __init__(self, locator_str):
42         self.hints = []
43         self._perm_sig = None
44         self._perm_expiry = None
45         pieces = iter(locator_str.split('+'))
46         self.md5sum = next(pieces)
47         try:
48             self.size = int(next(pieces))
49         except StopIteration:
50             self.size = None
51         for hint in pieces:
52             if self.HINT_RE.match(hint) is None:
53                 raise ValueError("invalid hint format: {}".format(hint))
54             elif hint.startswith('A'):
55                 self.parse_permission_hint(hint)
56             else:
57                 self.hints.append(hint)
58
59     def __str__(self):
60         return '+'.join(
61             str(s) for s in [self.md5sum, self.size,
62                              self.permission_hint()] + self.hints
63             if s is not None)
64
65     def stripped(self):
66         if self.size is not None:
67             return "%s+%i" % (self.md5sum, self.size)
68         else:
69             return self.md5sum
70
71     def _make_hex_prop(name, length):
72         # Build and return a new property with the given name that
73         # must be a hex string of the given length.
74         data_name = '_{}'.format(name)
75         def getter(self):
76             return getattr(self, data_name)
77         def setter(self, hex_str):
78             if not arvados.util.is_hex(hex_str, length):
79                 raise ValueError("{} is not a {}-digit hex string: {}".
80                                  format(name, length, hex_str))
81             setattr(self, data_name, hex_str)
82         return property(getter, setter)
83
84     md5sum = _make_hex_prop('md5sum', 32)
85     perm_sig = _make_hex_prop('perm_sig', 40)
86
87     @property
88     def perm_expiry(self):
89         return self._perm_expiry
90
91     @perm_expiry.setter
92     def perm_expiry(self, value):
93         if not arvados.util.is_hex(value, 1, 8):
94             raise ValueError(
95                 "permission timestamp must be a hex Unix timestamp: {}".
96                 format(value))
97         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
98
99     def permission_hint(self):
100         data = [self.perm_sig, self.perm_expiry]
101         if None in data:
102             return None
103         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
104         return "A{}@{:08x}".format(*data)
105
106     def parse_permission_hint(self, s):
107         try:
108             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
109         except IndexError:
110             raise ValueError("bad permission hint {}".format(s))
111
112     def permission_expired(self, as_of_dt=None):
113         if self.perm_expiry is None:
114             return False
115         elif as_of_dt is None:
116             as_of_dt = datetime.datetime.now()
117         return self.perm_expiry <= as_of_dt
118
119
120 class Keep(object):
121     """Simple interface to a global KeepClient object.
122
123     THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
124     own API client.  The global KeepClient will build an API client from the
125     current Arvados configuration, which may not match the one you built.
126     """
127     _last_key = None
128
129     @classmethod
130     def global_client_object(cls):
131         global global_client_object
132         # Previously, KeepClient would change its behavior at runtime based
133         # on these configuration settings.  We simulate that behavior here
134         # by checking the values and returning a new KeepClient if any of
135         # them have changed.
136         key = (config.get('ARVADOS_API_HOST'),
137                config.get('ARVADOS_API_TOKEN'),
138                config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
139                config.get('ARVADOS_KEEP_PROXY'),
140                config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
141                os.environ.get('KEEP_LOCAL_STORE'))
142         if (global_client_object is None) or (cls._last_key != key):
143             global_client_object = KeepClient()
144             cls._last_key = key
145         return global_client_object
146
147     @staticmethod
148     def get(locator, **kwargs):
149         return Keep.global_client_object().get(locator, **kwargs)
150
151     @staticmethod
152     def put(data, **kwargs):
153         return Keep.global_client_object().put(data, **kwargs)
154
155 class KeepBlockCache(object):
156     # Default RAM cache is 256MiB
157     def __init__(self, cache_max=(256 * 1024 * 1024)):
158         self.cache_max = cache_max
159         self._cache = []
160         self._cache_lock = threading.Lock()
161
162     class CacheSlot(object):
163         __slots__ = ("locator", "ready", "content")
164
165         def __init__(self, locator):
166             self.locator = locator
167             self.ready = threading.Event()
168             self.content = None
169
170         def get(self):
171             self.ready.wait()
172             return self.content
173
174         def set(self, value):
175             self.content = value
176             self.ready.set()
177
178         def size(self):
179             if self.content is None:
180                 return 0
181             else:
182                 return len(self.content)
183
184     def cap_cache(self):
185         '''Cap the cache size to self.cache_max'''
186         with self._cache_lock:
187             # Select all slots except those where ready.is_set() and content is
188             # None (that means there was an error reading the block).
189             self._cache = [c for c in self._cache if not (c.ready.is_set() and c.content is None)]
190             sm = sum([slot.size() for slot in self._cache])
191             while len(self._cache) > 0 and sm > self.cache_max:
192                 for i in xrange(len(self._cache)-1, -1, -1):
193                     if self._cache[i].ready.is_set():
194                         del self._cache[i]
195                         break
196                 sm = sum([slot.size() for slot in self._cache])
197
198     def _get(self, locator):
199         # Test if the locator is already in the cache
200         for i in xrange(0, len(self._cache)):
201             if self._cache[i].locator == locator:
202                 n = self._cache[i]
203                 if i != 0:
204                     # move it to the front
205                     del self._cache[i]
206                     self._cache.insert(0, n)
207                 return n
208         return None
209
210     def get(self, locator):
211         with self._cache_lock:
212             return self._get(locator)
213
214     def reserve_cache(self, locator):
215         '''Reserve a cache slot for the specified locator,
216         or return the existing slot.'''
217         with self._cache_lock:
218             n = self._get(locator)
219             if n:
220                 return n, False
221             else:
222                 # Add a new cache slot for the locator
223                 n = KeepBlockCache.CacheSlot(locator)
224                 self._cache.insert(0, n)
225                 return n, True
226
227 class Counter(object):
228     def __init__(self, v=0):
229         self._lk = threading.Lock()
230         self._val = v
231
232     def add(self, v):
233         with self._lk:
234             self._val += v
235
236     def get(self):
237         with self._lk:
238             return self._val
239
240
241 class KeepClient(object):
242
243     # Default Keep server connection timeout:  2 seconds
244     # Default Keep server read timeout:       256 seconds
245     # Default Keep server bandwidth minimum:  32768 bytes per second
246     # Default Keep proxy connection timeout:  20 seconds
247     # Default Keep proxy read timeout:        256 seconds
248     # Default Keep proxy bandwidth minimum:   32768 bytes per second
249     DEFAULT_TIMEOUT = (2, 256, 32768)
250     DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
251
252
253     class KeepService(object):
254         """Make requests to a single Keep service, and track results.
255
256         A KeepService is intended to last long enough to perform one
257         transaction (GET or PUT) against one Keep service. This can
258         involve calling either get() or put() multiple times in order
259         to retry after transient failures. However, calling both get()
260         and put() on a single instance -- or using the same instance
261         to access two different Keep services -- will not produce
262         sensible behavior.
263         """
264
265         HTTP_ERRORS = (
266             socket.error,
267             ssl.SSLError,
268             arvados.errors.HttpError,
269         )
270
271         def __init__(self, root, user_agent_pool=Queue.LifoQueue(),
272                      upload_counter=None,
273                      download_counter=None, **headers):
274             self.root = root
275             self._user_agent_pool = user_agent_pool
276             self._result = {'error': None}
277             self._usable = True
278             self._session = None
279             self.get_headers = {'Accept': 'application/octet-stream'}
280             self.get_headers.update(headers)
281             self.put_headers = headers
282             self.upload_counter = upload_counter
283             self.download_counter = download_counter
284
285         def usable(self):
286             """Is it worth attempting a request?"""
287             return self._usable
288
289         def finished(self):
290             """Did the request succeed or encounter permanent failure?"""
291             return self._result['error'] == False or not self._usable
292
293         def last_result(self):
294             return self._result
295
296         def _get_user_agent(self):
297             try:
298                 return self._user_agent_pool.get(False)
299             except Queue.Empty:
300                 return pycurl.Curl()
301
302         def _put_user_agent(self, ua):
303             try:
304                 ua.reset()
305                 self._user_agent_pool.put(ua, False)
306             except:
307                 ua.close()
308
309         @staticmethod
310         def _socket_open(family, socktype, protocol, address=None):
311             """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
312             s = socket.socket(family, socktype, protocol)
313             s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
314             # Will throw invalid protocol error on mac. This test prevents that.
315             if hasattr(socket, 'TCP_KEEPIDLE'):
316                 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
317             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
318             return s
319
320         def get(self, locator, method="GET", timeout=None):
321             # locator is a KeepLocator object.
322             url = self.root + str(locator)
323             _logger.debug("Request: %s %s", method, url)
324             curl = self._get_user_agent()
325             ok = None
326             try:
327                 with timer.Timer() as t:
328                     self._headers = {}
329                     response_body = cStringIO.StringIO()
330                     curl.setopt(pycurl.NOSIGNAL, 1)
331                     curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
332                     curl.setopt(pycurl.URL, url.encode('utf-8'))
333                     curl.setopt(pycurl.HTTPHEADER, [
334                         '{}: {}'.format(k,v) for k,v in self.get_headers.iteritems()])
335                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
336                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
337                     if method == "HEAD":
338                         curl.setopt(pycurl.NOBODY, True)
339                     self._setcurltimeouts(curl, timeout)
340
341                     try:
342                         curl.perform()
343                     except Exception as e:
344                         raise arvados.errors.HttpError(0, str(e))
345                     self._result = {
346                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
347                         'body': response_body.getvalue(),
348                         'headers': self._headers,
349                         'error': False,
350                     }
351
352                 ok = retry.check_http_response_success(self._result['status_code'])
353                 if not ok:
354                     self._result['error'] = arvados.errors.HttpError(
355                         self._result['status_code'],
356                         self._headers.get('x-status-line', 'Error'))
357             except self.HTTP_ERRORS as e:
358                 self._result = {
359                     'error': e,
360                 }
361             self._usable = ok != False
362             if self._result.get('status_code', None):
363                 # The client worked well enough to get an HTTP status
364                 # code, so presumably any problems are just on the
365                 # server side and it's OK to reuse the client.
366                 self._put_user_agent(curl)
367             else:
368                 # Don't return this client to the pool, in case it's
369                 # broken.
370                 curl.close()
371             if not ok:
372                 _logger.debug("Request fail: GET %s => %s: %s",
373                               url, type(self._result['error']), str(self._result['error']))
374                 return None
375             if method == "HEAD":
376                 _logger.info("HEAD %s: %s bytes",
377                          self._result['status_code'],
378                          self._result.get('content-length'))
379                 return True
380
381             _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
382                          self._result['status_code'],
383                          len(self._result['body']),
384                          t.msecs,
385                          (len(self._result['body'])/(1024.0*1024))/t.secs if t.secs > 0 else 0)
386
387             if self.download_counter:
388                 self.download_counter.add(len(self._result['body']))
389             resp_md5 = hashlib.md5(self._result['body']).hexdigest()
390             if resp_md5 != locator.md5sum:
391                 _logger.warning("Checksum fail: md5(%s) = %s",
392                                 url, resp_md5)
393                 self._result['error'] = arvados.errors.HttpError(
394                     0, 'Checksum fail')
395                 return None
396             return self._result['body']
397
398         def put(self, hash_s, body, timeout=None):
399             url = self.root + hash_s
400             _logger.debug("Request: PUT %s", url)
401             curl = self._get_user_agent()
402             ok = None
403             try:
404                 with timer.Timer() as t:
405                     self._headers = {}
406                     body_reader = cStringIO.StringIO(body)
407                     response_body = cStringIO.StringIO()
408                     curl.setopt(pycurl.NOSIGNAL, 1)
409                     curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
410                     curl.setopt(pycurl.URL, url.encode('utf-8'))
411                     # Using UPLOAD tells cURL to wait for a "go ahead" from the
412                     # Keep server (in the form of a HTTP/1.1 "100 Continue"
413                     # response) instead of sending the request body immediately.
414                     # This allows the server to reject the request if the request
415                     # is invalid or the server is read-only, without waiting for
416                     # the client to send the entire block.
417                     curl.setopt(pycurl.UPLOAD, True)
418                     curl.setopt(pycurl.INFILESIZE, len(body))
419                     curl.setopt(pycurl.READFUNCTION, body_reader.read)
420                     curl.setopt(pycurl.HTTPHEADER, [
421                         '{}: {}'.format(k,v) for k,v in self.put_headers.iteritems()])
422                     curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
423                     curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
424                     self._setcurltimeouts(curl, timeout)
425                     try:
426                         curl.perform()
427                     except Exception as e:
428                         raise arvados.errors.HttpError(0, str(e))
429                     self._result = {
430                         'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
431                         'body': response_body.getvalue(),
432                         'headers': self._headers,
433                         'error': False,
434                     }
435                 ok = retry.check_http_response_success(self._result['status_code'])
436                 if not ok:
437                     self._result['error'] = arvados.errors.HttpError(
438                         self._result['status_code'],
439                         self._headers.get('x-status-line', 'Error'))
440             except self.HTTP_ERRORS as e:
441                 self._result = {
442                     'error': e,
443                 }
444             self._usable = ok != False # still usable if ok is True or None
445             if self._result.get('status_code', None):
446                 # Client is functional. See comment in get().
447                 self._put_user_agent(curl)
448             else:
449                 curl.close()
450             if not ok:
451                 _logger.debug("Request fail: PUT %s => %s: %s",
452                               url, type(self._result['error']), str(self._result['error']))
453                 return False
454             _logger.info("PUT %s: %s bytes in %s msec (%.3f MiB/sec)",
455                          self._result['status_code'],
456                          len(body),
457                          t.msecs,
458                          (len(body)/(1024.0*1024))/t.secs if t.secs > 0 else 0)
459             if self.upload_counter:
460                 self.upload_counter.add(len(body))
461             return True
462
463         def _setcurltimeouts(self, curl, timeouts):
464             if not timeouts:
465                 return
466             elif isinstance(timeouts, tuple):
467                 if len(timeouts) == 2:
468                     conn_t, xfer_t = timeouts
469                     bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
470                 else:
471                     conn_t, xfer_t, bandwidth_bps = timeouts
472             else:
473                 conn_t, xfer_t = (timeouts, timeouts)
474                 bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
475             curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
476             curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
477             curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
478
479         def _headerfunction(self, header_line):
480             header_line = header_line.decode('iso-8859-1')
481             if ':' in header_line:
482                 name, value = header_line.split(':', 1)
483                 name = name.strip().lower()
484                 value = value.strip()
485             elif self._headers:
486                 name = self._lastheadername
487                 value = self._headers[name] + ' ' + header_line.strip()
488             elif header_line.startswith('HTTP/'):
489                 name = 'x-status-line'
490                 value = header_line
491             else:
492                 _logger.error("Unexpected header line: %s", header_line)
493                 return
494             self._lastheadername = name
495             self._headers[name] = value
496             # Returning None implies all bytes were written
497     
498
499     class KeepWriterQueue(Queue.Queue):
500         def __init__(self, copies):
501             Queue.Queue.__init__(self) # Old-style superclass
502             self.wanted_copies = copies
503             self.successful_copies = 0
504             self.response = None
505             self.successful_copies_lock = threading.Lock()
506             self.pending_tries = copies
507             self.pending_tries_notification = threading.Condition()
508         
509         def write_success(self, response, replicas_nr):
510             with self.successful_copies_lock:
511                 self.successful_copies += replicas_nr
512                 self.response = response
513         
514         def write_fail(self, ks, status_code):
515             with self.pending_tries_notification:
516                 self.pending_tries += 1
517                 self.pending_tries_notification.notify()
518         
519         def pending_copies(self):
520             with self.successful_copies_lock:
521                 return self.wanted_copies - self.successful_copies
522     
523     
524     class KeepWriterThreadPool(object):
525         def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None):
526             self.total_task_nr = 0
527             self.wanted_copies = copies
528             if (not max_service_replicas) or (max_service_replicas >= copies):
529                 num_threads = 1
530             else:
531                 num_threads = int(math.ceil(float(copies) / max_service_replicas))
532             _logger.debug("Pool max threads is %d", num_threads)
533             self.workers = []
534             self.queue = KeepClient.KeepWriterQueue(copies)
535             # Create workers
536             for _ in range(num_threads):
537                 w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
538                 self.workers.append(w)
539         
540         def add_task(self, ks, service_root):
541             self.queue.put((ks, service_root))
542             self.total_task_nr += 1
543         
544         def done(self):
545             return self.queue.successful_copies
546         
547         def join(self):
548             # Start workers
549             for worker in self.workers:
550                 worker.start()
551             # Wait for finished work
552             self.queue.join()
553             with self.queue.pending_tries_notification:
554                 self.queue.pending_tries_notification.notify_all()
555             for worker in self.workers:
556                 worker.join()
557         
558         def response(self):
559             return self.queue.response
560     
561     
562     class KeepWriterThread(threading.Thread):
563         def __init__(self, queue, data, data_hash, timeout=None):
564             super(KeepClient.KeepWriterThread, self).__init__()
565             self.timeout = timeout
566             self.queue = queue
567             self.data = data
568             self.data_hash = data_hash
569         
570         def run(self):
571             while not self.queue.empty():
572                 if self.queue.pending_copies() > 0:
573                     # Avoid overreplication, wait for some needed re-attempt
574                     with self.queue.pending_tries_notification:
575                         if self.queue.pending_tries <= 0:
576                             self.queue.pending_tries_notification.wait()
577                             continue # try again when awake
578                         self.queue.pending_tries -= 1
579
580                     # Get to work
581                     try:
582                         service, service_root = self.queue.get_nowait()
583                     except Queue.Empty:
584                         continue
585                     if service.finished():
586                         self.queue.task_done()
587                         continue
588                     success = bool(service.put(self.data_hash,
589                                                 self.data,
590                                                 timeout=self.timeout))
591                     result = service.last_result()
592                     if success:
593                         _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
594                                       str(threading.current_thread()),
595                                       self.data_hash,
596                                       len(self.data),
597                                       service_root)
598                         try:
599                             replicas_stored = int(result['headers']['x-keep-replicas-stored'])
600                         except (KeyError, ValueError):
601                             replicas_stored = 1
602                         
603                         self.queue.write_success(result['body'].strip(), replicas_stored)
604                     else:
605                         if result.get('status_code', None):
606                             _logger.debug("Request fail: PUT %s => %s %s",
607                                           self.data_hash,
608                                           result['status_code'],
609                                           result['body'])
610                         self.queue.write_fail(service, result.get('status_code', None)) # Schedule a re-attempt with next service
611                     # Mark as done so the queue can be join()ed
612                     self.queue.task_done()
613                 else:
614                     # Remove the task from the queue anyways
615                     try:
616                         self.queue.get_nowait()
617                         # Mark as done so the queue can be join()ed
618                         self.queue.task_done()
619                     except Queue.Empty:
620                         continue
621
622
623     def __init__(self, api_client=None, proxy=None,
624                  timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
625                  api_token=None, local_store=None, block_cache=None,
626                  num_retries=0, session=None):
627         """Initialize a new KeepClient.
628
629         Arguments:
630         :api_client:
631           The API client to use to find Keep services.  If not
632           provided, KeepClient will build one from available Arvados
633           configuration.
634
635         :proxy:
636           If specified, this KeepClient will send requests to this Keep
637           proxy.  Otherwise, KeepClient will fall back to the setting of the
638           ARVADOS_KEEP_PROXY configuration setting.  If you want to ensure
639           KeepClient does not use a proxy, pass in an empty string.
640
641         :timeout:
642           The initial timeout (in seconds) for HTTP requests to Keep
643           non-proxy servers.  A tuple of three floats is interpreted as
644           (connection_timeout, read_timeout, minimum_bandwidth). A connection
645           will be aborted if the average traffic rate falls below
646           minimum_bandwidth bytes per second over an interval of read_timeout
647           seconds. Because timeouts are often a result of transient server
648           load, the actual connection timeout will be increased by a factor
649           of two on each retry.
650           Default: (2, 256, 32768).
651
652         :proxy_timeout:
653           The initial timeout (in seconds) for HTTP requests to
654           Keep proxies. A tuple of three floats is interpreted as
655           (connection_timeout, read_timeout, minimum_bandwidth). The behavior
656           described above for adjusting connection timeouts on retry also
657           applies.
658           Default: (20, 256, 32768).
659
660         :api_token:
661           If you're not using an API client, but only talking
662           directly to a Keep proxy, this parameter specifies an API token
663           to authenticate Keep requests.  It is an error to specify both
664           api_client and api_token.  If you specify neither, KeepClient
665           will use one available from the Arvados configuration.
666
667         :local_store:
668           If specified, this KeepClient will bypass Keep
669           services, and save data to the named directory.  If unspecified,
670           KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
671           environment variable.  If you want to ensure KeepClient does not
672           use local storage, pass in an empty string.  This is primarily
673           intended to mock a server for testing.
674
675         :num_retries:
676           The default number of times to retry failed requests.
677           This will be used as the default num_retries value when get() and
678           put() are called.  Default 0.
679         """
680         self.lock = threading.Lock()
681         if proxy is None:
682             proxy = config.get('ARVADOS_KEEP_PROXY')
683         if api_token is None:
684             if api_client is None:
685                 api_token = config.get('ARVADOS_API_TOKEN')
686             else:
687                 api_token = api_client.api_token
688         elif api_client is not None:
689             raise ValueError(
690                 "can't build KeepClient with both API client and token")
691         if local_store is None:
692             local_store = os.environ.get('KEEP_LOCAL_STORE')
693
694         self.block_cache = block_cache if block_cache else KeepBlockCache()
695         self.timeout = timeout
696         self.proxy_timeout = proxy_timeout
697         self._user_agent_pool = Queue.LifoQueue()
698         self.upload_counter = Counter()
699         self.download_counter = Counter()
700         self.put_counter = Counter()
701         self.get_counter = Counter()
702         self.hits_counter = Counter()
703         self.misses_counter = Counter()
704
705         if local_store:
706             self.local_store = local_store
707             self.get = self.local_store_get
708             self.put = self.local_store_put
709         else:
710             self.num_retries = num_retries
711             self.max_replicas_per_service = None
712             if proxy:
713                 if not proxy.endswith('/'):
714                     proxy += '/'
715                 self.api_token = api_token
716                 self._gateway_services = {}
717                 self._keep_services = [{
718                     'uuid': 'proxy',
719                     'service_type': 'proxy',
720                     '_service_root': proxy,
721                     }]
722                 self._writable_services = self._keep_services
723                 self.using_proxy = True
724                 self._static_services_list = True
725             else:
726                 # It's important to avoid instantiating an API client
727                 # unless we actually need one, for testing's sake.
728                 if api_client is None:
729                     api_client = arvados.api('v1')
730                 self.api_client = api_client
731                 self.api_token = api_client.api_token
732                 self._gateway_services = {}
733                 self._keep_services = None
734                 self._writable_services = None
735                 self.using_proxy = None
736                 self._static_services_list = False
737
738     def current_timeout(self, attempt_number):
739         """Return the appropriate timeout to use for this client.
740
741         The proxy timeout setting if the backend service is currently a proxy,
742         the regular timeout setting otherwise.  The `attempt_number` indicates
743         how many times the operation has been tried already (starting from 0
744         for the first try), and scales the connection timeout portion of the
745         return value accordingly.
746
747         """
748         # TODO(twp): the timeout should be a property of a
749         # KeepService, not a KeepClient. See #4488.
750         t = self.proxy_timeout if self.using_proxy else self.timeout
751         if len(t) == 2:
752             return (t[0] * (1 << attempt_number), t[1])
753         else:
754             return (t[0] * (1 << attempt_number), t[1], t[2])
755     def _any_nondisk_services(self, service_list):
756         return any(ks.get('service_type', 'disk') != 'disk'
757                    for ks in service_list)
758
759     def build_services_list(self, force_rebuild=False):
760         if (self._static_services_list or
761               (self._keep_services and not force_rebuild)):
762             return
763         with self.lock:
764             try:
765                 keep_services = self.api_client.keep_services().accessible()
766             except Exception:  # API server predates Keep services.
767                 keep_services = self.api_client.keep_disks().list()
768
769             # Gateway services are only used when specified by UUID,
770             # so there's nothing to gain by filtering them by
771             # service_type.
772             self._gateway_services = {ks['uuid']: ks for ks in
773                                       keep_services.execute()['items']}
774             if not self._gateway_services:
775                 raise arvados.errors.NoKeepServersError()
776
777             # Precompute the base URI for each service.
778             for r in self._gateway_services.itervalues():
779                 host = r['service_host']
780                 if not host.startswith('[') and host.find(':') >= 0:
781                     # IPv6 URIs must be formatted like http://[::1]:80/...
782                     host = '[' + host + ']'
783                 r['_service_root'] = "{}://{}:{:d}/".format(
784                     'https' if r['service_ssl_flag'] else 'http',
785                     host,
786                     r['service_port'])
787
788             _logger.debug(str(self._gateway_services))
789             self._keep_services = [
790                 ks for ks in self._gateway_services.itervalues()
791                 if not ks.get('service_type', '').startswith('gateway:')]
792             self._writable_services = [ks for ks in self._keep_services
793                                        if not ks.get('read_only')]
794
795             # For disk type services, max_replicas_per_service is 1
796             # It is unknown (unlimited) for other service types.
797             if self._any_nondisk_services(self._writable_services):
798                 self.max_replicas_per_service = None
799             else:
800                 self.max_replicas_per_service = 1
801
802     def _service_weight(self, data_hash, service_uuid):
803         """Compute the weight of a Keep service endpoint for a data
804         block with a known hash.
805
806         The weight is md5(h + u) where u is the last 15 characters of
807         the service endpoint's UUID.
808         """
809         return hashlib.md5(data_hash + service_uuid[-15:]).hexdigest()
810
811     def weighted_service_roots(self, locator, force_rebuild=False, need_writable=False):
812         """Return an array of Keep service endpoints, in the order in
813         which they should be probed when reading or writing data with
814         the given hash+hints.
815         """
816         self.build_services_list(force_rebuild)
817
818         sorted_roots = []
819         # Use the services indicated by the given +K@... remote
820         # service hints, if any are present and can be resolved to a
821         # URI.
822         for hint in locator.hints:
823             if hint.startswith('K@'):
824                 if len(hint) == 7:
825                     sorted_roots.append(
826                         "https://keep.{}.arvadosapi.com/".format(hint[2:]))
827                 elif len(hint) == 29:
828                     svc = self._gateway_services.get(hint[2:])
829                     if svc:
830                         sorted_roots.append(svc['_service_root'])
831
832         # Sort the available local services by weight (heaviest first)
833         # for this locator, and return their service_roots (base URIs)
834         # in that order.
835         use_services = self._keep_services
836         if need_writable:
837             use_services = self._writable_services
838         self.using_proxy = self._any_nondisk_services(use_services)
839         sorted_roots.extend([
840             svc['_service_root'] for svc in sorted(
841                 use_services,
842                 reverse=True,
843                 key=lambda svc: self._service_weight(locator.md5sum, svc['uuid']))])
844         _logger.debug("{}: {}".format(locator, sorted_roots))
845         return sorted_roots
846
847     def map_new_services(self, roots_map, locator, force_rebuild, need_writable, **headers):
848         # roots_map is a dictionary, mapping Keep service root strings
849         # to KeepService objects.  Poll for Keep services, and add any
850         # new ones to roots_map.  Return the current list of local
851         # root strings.
852         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
853         local_roots = self.weighted_service_roots(locator, force_rebuild, need_writable)
854         for root in local_roots:
855             if root not in roots_map:
856                 roots_map[root] = self.KeepService(
857                     root, self._user_agent_pool,
858                     upload_counter=self.upload_counter,
859                     download_counter=self.download_counter,
860                     **headers)
861         return local_roots
862
863     @staticmethod
864     def _check_loop_result(result):
865         # KeepClient RetryLoops should save results as a 2-tuple: the
866         # actual result of the request, and the number of servers available
867         # to receive the request this round.
868         # This method returns True if there's a real result, False if
869         # there are no more servers available, otherwise None.
870         if isinstance(result, Exception):
871             return None
872         result, tried_server_count = result
873         if (result is not None) and (result is not False):
874             return True
875         elif tried_server_count < 1:
876             _logger.info("No more Keep services to try; giving up")
877             return False
878         else:
879             return None
880
881     def get_from_cache(self, loc):
882         """Fetch a block only if is in the cache, otherwise return None."""
883         slot = self.block_cache.get(loc)
884         if slot is not None and slot.ready.is_set():
885             return slot.get()
886         else:
887             return None
888
889     @retry.retry_method
890     def head(self, loc_s, num_retries=None):
891         return self._get_or_head(loc_s, method="HEAD", num_retries=num_retries)
892
893     @retry.retry_method
894     def get(self, loc_s, num_retries=None):
895         return self._get_or_head(loc_s, method="GET", num_retries=num_retries)
896
897     def _get_or_head(self, loc_s, method="GET", num_retries=None):
898         """Get data from Keep.
899
900         This method fetches one or more blocks of data from Keep.  It
901         sends a request each Keep service registered with the API
902         server (or the proxy provided when this client was
903         instantiated), then each service named in location hints, in
904         sequence.  As soon as one service provides the data, it's
905         returned.
906
907         Arguments:
908         * loc_s: A string of one or more comma-separated locators to fetch.
909           This method returns the concatenation of these blocks.
910         * num_retries: The number of times to retry GET requests to
911           *each* Keep server if it returns temporary failures, with
912           exponential backoff.  Note that, in each loop, the method may try
913           to fetch data from every available Keep service, along with any
914           that are named in location hints in the locator.  The default value
915           is set when the KeepClient is initialized.
916         """
917         if ',' in loc_s:
918             return ''.join(self.get(x) for x in loc_s.split(','))
919
920         self.get_counter.add(1)
921
922         locator = KeepLocator(loc_s)
923         if method == "GET":
924             slot, first = self.block_cache.reserve_cache(locator.md5sum)
925             if not first:
926                 self.hits_counter.add(1)
927                 v = slot.get()
928                 return v
929
930         self.misses_counter.add(1)
931
932         # If the locator has hints specifying a prefix (indicating a
933         # remote keepproxy) or the UUID of a local gateway service,
934         # read data from the indicated service(s) instead of the usual
935         # list of local disk services.
936         hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
937                       for hint in locator.hints if hint.startswith('K@') and len(hint) == 7]
938         hint_roots.extend([self._gateway_services[hint[2:]]['_service_root']
939                            for hint in locator.hints if (
940                                    hint.startswith('K@') and
941                                    len(hint) == 29 and
942                                    self._gateway_services.get(hint[2:])
943                                    )])
944         # Map root URLs to their KeepService objects.
945         roots_map = {
946             root: self.KeepService(root, self._user_agent_pool,
947                                    upload_counter=self.upload_counter,
948                                    download_counter=self.download_counter)
949             for root in hint_roots
950         }
951
952         # See #3147 for a discussion of the loop implementation.  Highlights:
953         # * Refresh the list of Keep services after each failure, in case
954         #   it's being updated.
955         # * Retry until we succeed, we're out of retries, or every available
956         #   service has returned permanent failure.
957         sorted_roots = []
958         roots_map = {}
959         blob = None
960         loop = retry.RetryLoop(num_retries, self._check_loop_result,
961                                backoff_start=2)
962         for tries_left in loop:
963             try:
964                 sorted_roots = self.map_new_services(
965                     roots_map, locator,
966                     force_rebuild=(tries_left < num_retries),
967                     need_writable=False)
968             except Exception as error:
969                 loop.save_result(error)
970                 continue
971
972             # Query KeepService objects that haven't returned
973             # permanent failure, in our specified shuffle order.
974             services_to_try = [roots_map[root]
975                                for root in sorted_roots
976                                if roots_map[root].usable()]
977             for keep_service in services_to_try:
978                 blob = keep_service.get(locator, method=method, timeout=self.current_timeout(num_retries-tries_left))
979                 if blob is not None:
980                     break
981             loop.save_result((blob, len(services_to_try)))
982
983         # Always cache the result, then return it if we succeeded.
984         if method == "GET":
985             slot.set(blob)
986             self.block_cache.cap_cache()
987         if loop.success():
988             if method == "HEAD":
989                 return True
990             else:
991                 return blob
992
993         # Q: Including 403 is necessary for the Keep tests to continue
994         # passing, but maybe they should expect KeepReadError instead?
995         not_founds = sum(1 for key in sorted_roots
996                          if roots_map[key].last_result().get('status_code', None) in {403, 404, 410})
997         service_errors = ((key, roots_map[key].last_result()['error'])
998                           for key in sorted_roots)
999         if not roots_map:
1000             raise arvados.errors.KeepReadError(
1001                 "failed to read {}: no Keep services available ({})".format(
1002                     loc_s, loop.last_result()))
1003         elif not_founds == len(sorted_roots):
1004             raise arvados.errors.NotFoundError(
1005                 "{} not found".format(loc_s), service_errors)
1006         else:
1007             raise arvados.errors.KeepReadError(
1008                 "failed to read {}".format(loc_s), service_errors, label="service")
1009
1010     @retry.retry_method
1011     def put(self, data, copies=2, num_retries=None):
1012         """Save data in Keep.
1013
1014         This method will get a list of Keep services from the API server, and
1015         send the data to each one simultaneously in a new thread.  Once the
1016         uploads are finished, if enough copies are saved, this method returns
1017         the most recent HTTP response body.  If requests fail to upload
1018         enough copies, this method raises KeepWriteError.
1019
1020         Arguments:
1021         * data: The string of data to upload.
1022         * copies: The number of copies that the user requires be saved.
1023           Default 2.
1024         * num_retries: The number of times to retry PUT requests to
1025           *each* Keep server if it returns temporary failures, with
1026           exponential backoff.  The default value is set when the
1027           KeepClient is initialized.
1028         """
1029
1030         if isinstance(data, unicode):
1031             data = data.encode("ascii")
1032         elif not isinstance(data, str):
1033             raise arvados.errors.ArgumentError("Argument 'data' to KeepClient.put is not type 'str'")
1034
1035         self.put_counter.add(1)
1036
1037         data_hash = hashlib.md5(data).hexdigest()
1038         loc_s = data_hash + '+' + str(len(data))
1039         if copies < 1:
1040             return loc_s
1041         locator = KeepLocator(loc_s)
1042
1043         headers = {}
1044         # Tell the proxy how many copies we want it to store
1045         headers['X-Keep-Desired-Replicas'] = str(copies)
1046         roots_map = {}
1047         loop = retry.RetryLoop(num_retries, self._check_loop_result,
1048                                backoff_start=2)
1049         done = 0
1050         for tries_left in loop:
1051             try:
1052                 sorted_roots = self.map_new_services(
1053                     roots_map, locator,
1054                     force_rebuild=(tries_left < num_retries), need_writable=True, **headers)
1055             except Exception as error:
1056                 loop.save_result(error)
1057                 continue
1058
1059             writer_pool = KeepClient.KeepWriterThreadPool(data=data, 
1060                                                         data_hash=data_hash,
1061                                                         copies=copies - done,
1062                                                         max_service_replicas=self.max_replicas_per_service,
1063                                                         timeout=self.current_timeout(num_retries - tries_left))
1064             for service_root, ks in [(root, roots_map[root])
1065                                      for root in sorted_roots]:
1066                 if ks.finished():
1067                     continue
1068                 writer_pool.add_task(ks, service_root)
1069             writer_pool.join()
1070             done += writer_pool.done()
1071             loop.save_result((done >= copies, writer_pool.total_task_nr))
1072
1073         if loop.success():
1074             return writer_pool.response()
1075         if not roots_map:
1076             raise arvados.errors.KeepWriteError(
1077                 "failed to write {}: no Keep services available ({})".format(
1078                     data_hash, loop.last_result()))
1079         else:
1080             service_errors = ((key, roots_map[key].last_result()['error'])
1081                               for key in sorted_roots
1082                               if roots_map[key].last_result()['error'])
1083             raise arvados.errors.KeepWriteError(
1084                 "failed to write {} (wanted {} copies but wrote {})".format(
1085                     data_hash, copies, writer_pool.done()), service_errors, label="service")
1086
1087     def local_store_put(self, data, copies=1, num_retries=None):
1088         """A stub for put().
1089
1090         This method is used in place of the real put() method when
1091         using local storage (see constructor's local_store argument).
1092
1093         copies and num_retries arguments are ignored: they are here
1094         only for the sake of offering the same call signature as
1095         put().
1096
1097         Data stored this way can be retrieved via local_store_get().
1098         """
1099         md5 = hashlib.md5(data).hexdigest()
1100         locator = '%s+%d' % (md5, len(data))
1101         with open(os.path.join(self.local_store, md5 + '.tmp'), 'w') as f:
1102             f.write(data)
1103         os.rename(os.path.join(self.local_store, md5 + '.tmp'),
1104                   os.path.join(self.local_store, md5))
1105         return locator
1106
1107     def local_store_get(self, loc_s, num_retries=None):
1108         """Companion to local_store_put()."""
1109         try:
1110             locator = KeepLocator(loc_s)
1111         except ValueError:
1112             raise arvados.errors.NotFoundError(
1113                 "Invalid data locator: '%s'" % loc_s)
1114         if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
1115             return ''
1116         with open(os.path.join(self.local_store, locator.md5sum), 'r') as f:
1117             return f.read()
1118
1119     def is_cached(self, locator):
1120         return self.block_cache.reserve_cache(expect_hash)