Merge remote-tracking branch 'refs/remotes/origin/3504-clients-compatible-with-3036...
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21 import datetime
22 import ssl
23
24 _logger = logging.getLogger('arvados.keep')
25 global_client_object = None
26
27 import arvados
28 import arvados.config as config
29 import arvados.errors
30 import arvados.util
31
32 class KeepLocator(object):
33     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
34     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
35
36     def __init__(self, locator_str):
37         self.hints = []
38         self._perm_sig = None
39         self._perm_expiry = None
40         pieces = iter(locator_str.split('+'))
41         self.md5sum = next(pieces)
42         try:
43             self.size = int(next(pieces))
44         except StopIteration:
45             self.size = None
46         for hint in pieces:
47             if self.HINT_RE.match(hint) is None:
48                 raise ValueError("unrecognized hint data {}".format(hint))
49             elif hint.startswith('A'):
50                 self.parse_permission_hint(hint)
51             else:
52                 self.hints.append(hint)
53
54     def __str__(self):
55         return '+'.join(
56             str(s) for s in [self.md5sum, self.size,
57                              self.permission_hint()] + self.hints
58             if s is not None)
59
60     def _make_hex_prop(name, length):
61         # Build and return a new property with the given name that
62         # must be a hex string of the given length.
63         data_name = '_{}'.format(name)
64         def getter(self):
65             return getattr(self, data_name)
66         def setter(self, hex_str):
67             if not arvados.util.is_hex(hex_str, length):
68                 raise ValueError("{} must be a {}-digit hex string: {}".
69                                  format(name, length, hex_str))
70             setattr(self, data_name, hex_str)
71         return property(getter, setter)
72
73     md5sum = _make_hex_prop('md5sum', 32)
74     perm_sig = _make_hex_prop('perm_sig', 40)
75
76     @property
77     def perm_expiry(self):
78         return self._perm_expiry
79
80     @perm_expiry.setter
81     def perm_expiry(self, value):
82         if not arvados.util.is_hex(value, 1, 8):
83             raise ValueError(
84                 "permission timestamp must be a hex Unix timestamp: {}".
85                 format(value))
86         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
87
88     def permission_hint(self):
89         data = [self.perm_sig, self.perm_expiry]
90         if None in data:
91             return None
92         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
93         return "A{}@{:08x}".format(*data)
94
95     def parse_permission_hint(self, s):
96         try:
97             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
98         except IndexError:
99             raise ValueError("bad permission hint {}".format(s))
100
101     def permission_expired(self, as_of_dt=None):
102         if self.perm_expiry is None:
103             return False
104         elif as_of_dt is None:
105             as_of_dt = datetime.datetime.now()
106         return self.perm_expiry <= as_of_dt
107
108
109 class Keep(object):
110     """Simple interface to a global KeepClient object.
111
112     THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
113     own API client.  The global KeepClient will build an API client from the
114     current Arvados configuration, which may not match the one you built.
115     """
116     _last_key = None
117
118     @classmethod
119     def global_client_object(cls):
120         global global_client_object
121         # Previously, KeepClient would change its behavior at runtime based
122         # on these configuration settings.  We simulate that behavior here
123         # by checking the values and returning a new KeepClient if any of
124         # them have changed.
125         key = (config.get('ARVADOS_API_HOST'),
126                config.get('ARVADOS_API_TOKEN'),
127                config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
128                config.get('ARVADOS_KEEP_PROXY'),
129                config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
130                os.environ.get('KEEP_LOCAL_STORE'))
131         if (global_client_object is None) or (cls._last_key != key):
132             global_client_object = KeepClient()
133             cls._last_key = key
134         return global_client_object
135
136     @staticmethod
137     def get(locator, **kwargs):
138         return Keep.global_client_object().get(locator, **kwargs)
139
140     @staticmethod
141     def put(data, **kwargs):
142         return Keep.global_client_object().put(data, **kwargs)
143
144
145 class KeepClient(object):
146     class ThreadLimiter(object):
147         """
148         Limit the number of threads running at a given time to
149         {desired successes} minus {successes reported}. When successes
150         reported == desired, wake up the remaining threads and tell
151         them to quit.
152
153         Should be used in a "with" block.
154         """
155         def __init__(self, todo):
156             self._todo = todo
157             self._done = 0
158             self._response = None
159             self._todo_lock = threading.Semaphore(todo)
160             self._done_lock = threading.Lock()
161
162         def __enter__(self):
163             self._todo_lock.acquire()
164             return self
165
166         def __exit__(self, type, value, traceback):
167             self._todo_lock.release()
168
169         def shall_i_proceed(self):
170             """
171             Return true if the current thread should do stuff. Return
172             false if the current thread should just stop.
173             """
174             with self._done_lock:
175                 return (self._done < self._todo)
176
177         def save_response(self, response_body, replicas_stored):
178             """
179             Records a response body (a locator, possibly signed) returned by
180             the Keep server.  It is not necessary to save more than
181             one response, since we presume that any locator returned
182             in response to a successful request is valid.
183             """
184             with self._done_lock:
185                 self._done += replicas_stored
186                 self._response = response_body
187
188         def response(self):
189             """
190             Returns the body from the response to a PUT request.
191             """
192             with self._done_lock:
193                 return self._response
194
195         def done(self):
196             """
197             Return how many successes were reported.
198             """
199             with self._done_lock:
200                 return self._done
201
202
203     class KeepWriterThread(threading.Thread):
204         """
205         Write a blob of data to the given Keep server. On success, call
206         save_response() of the given ThreadLimiter to save the returned
207         locator.
208         """
209         def __init__(self, api_token, **kwargs):
210             super(KeepClient.KeepWriterThread, self).__init__()
211             self._api_token = api_token
212             self.args = kwargs
213             self._success = False
214
215         def success(self):
216             return self._success
217
218         def run(self):
219             with self.args['thread_limiter'] as limiter:
220                 if not limiter.shall_i_proceed():
221                     # My turn arrived, but the job has been done without
222                     # me.
223                     return
224                 self.run_with_limiter(limiter)
225
226         def run_with_limiter(self, limiter):
227             _logger.debug("KeepWriterThread %s proceeding %s %s",
228                           str(threading.current_thread()),
229                           self.args['data_hash'],
230                           self.args['service_root'])
231             h = httplib2.Http(timeout=self.args.get('timeout', None))
232             url = self.args['service_root'] + self.args['data_hash']
233             headers = {'Authorization': "OAuth2 %s" % (self._api_token,)}
234
235             if self.args['using_proxy']:
236                 # We're using a proxy, so tell the proxy how many copies we
237                 # want it to store
238                 headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
239
240             try:
241                 _logger.debug("Uploading to {}".format(url))
242                 resp, content = h.request(url.encode('utf-8'), 'PUT',
243                                           headers=headers,
244                                           body=self.args['data'])
245                 if (resp['status'] == '401' and
246                     re.match(r'Timestamp verification failed', content)):
247                     body = KeepClient.sign_for_old_server(
248                         self.args['data_hash'],
249                         self.args['data'])
250                     h = httplib2.Http(timeout=self.args.get('timeout', None))
251                     resp, content = h.request(url.encode('utf-8'), 'PUT',
252                                               headers=headers,
253                                               body=body)
254                 if re.match(r'^2\d\d$', resp['status']):
255                     self._success = True
256                     _logger.debug("KeepWriterThread %s succeeded %s %s",
257                                   str(threading.current_thread()),
258                                   self.args['data_hash'],
259                                   self.args['service_root'])
260                     replicas_stored = 1
261                     if 'x-keep-replicas-stored' in resp:
262                         # Tick the 'done' counter for the number of replica
263                         # reported stored by the server, for the case that
264                         # we're talking to a proxy or other backend that
265                         # stores to multiple copies for us.
266                         try:
267                             replicas_stored = int(resp['x-keep-replicas-stored'])
268                         except ValueError:
269                             pass
270                     limiter.save_response(content.strip(), replicas_stored)
271                 else:
272                     _logger.debug("Request fail: PUT %s => %s %s",
273                                     url, resp['status'], content)
274             except (httplib2.HttpLib2Error,
275                     httplib.HTTPException,
276                     ssl.SSLError) as e:
277                 # When using https, timeouts look like ssl.SSLError from here.
278                 # "SSLError: The write operation timed out"
279                 _logger.debug("Request fail: PUT %s => %s: %s",
280                                 url, type(e), str(e))
281
282
283     def __init__(self, api_client=None, proxy=None, timeout=60,
284                  api_token=None, local_store=None):
285         """Initialize a new KeepClient.
286
287         Arguments:
288         * api_client: The API client to use to find Keep services.  If not
289           provided, KeepClient will build one from available Arvados
290           configuration.
291         * proxy: If specified, this KeepClient will send requests to this
292           Keep proxy.  Otherwise, KeepClient will fall back to the setting
293           of the ARVADOS_KEEP_PROXY configuration setting.  If you want to
294           ensure KeepClient does not use a proxy, pass in an empty string.
295         * timeout: The timeout for all HTTP requests, in seconds.  Default
296           60.
297         * api_token: If you're not using an API client, but only talking
298           directly to a Keep proxy, this parameter specifies an API token
299           to authenticate Keep requests.  It is an error to specify both
300           api_client and api_token.  If you specify neither, KeepClient
301           will use one available from the Arvados configuration.
302         * local_store: If specified, this KeepClient will bypass Keep
303           services, and save data to the named directory.  If unspecified,
304           KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
305           environment variable.  If you want to ensure KeepClient does not
306           use local storage, pass in an empty string.  This is primarily
307           intended to mock a server for testing.
308         """
309         self.lock = threading.Lock()
310         if proxy is None:
311             proxy = config.get('ARVADOS_KEEP_PROXY')
312         if api_token is None:
313             api_token = config.get('ARVADOS_API_TOKEN')
314         elif api_client is not None:
315             raise ValueError(
316                 "can't build KeepClient with both API client and token")
317         if local_store is None:
318             local_store = os.environ.get('KEEP_LOCAL_STORE')
319
320         if local_store:
321             self.local_store = local_store
322             self.get = self.local_store_get
323             self.put = self.local_store_put
324         else:
325             self.timeout = timeout
326             self.cache_max = 256 * 1024 * 1024  # Cache is 256MiB
327             self._cache = []
328             self._cache_lock = threading.Lock()
329             if proxy:
330                 if not proxy.endswith('/'):
331                     proxy += '/'
332                 self.api_token = api_token
333                 self.service_roots = [proxy]
334                 self.using_proxy = True
335             else:
336                 # It's important to avoid instantiating an API client
337                 # unless we actually need one, for testing's sake.
338                 if api_client is None:
339                     api_client = arvados.api('v1')
340                 self.api_client = api_client
341                 self.api_token = api_client.api_token
342                 self.service_roots = None
343                 self.using_proxy = None
344
345     def shuffled_service_roots(self, hash):
346         if self.service_roots is None:
347             with self.lock:
348                 try:
349                     keep_services = self.api_client.keep_services().accessible()
350                 except Exception:  # API server predates Keep services.
351                     keep_services = self.api_client.keep_disks().list()
352
353                 keep_services = keep_services.execute().get('items')
354                 if not keep_services:
355                     raise arvados.errors.NoKeepServersError()
356
357                 self.using_proxy = (keep_services[0].get('service_type') ==
358                                     'proxy')
359
360                 roots = (("http%s://%s:%d/" %
361                           ('s' if f['service_ssl_flag'] else '',
362                            f['service_host'],
363                            f['service_port']))
364                          for f in keep_services)
365                 self.service_roots = sorted(set(roots))
366                 _logger.debug(str(self.service_roots))
367
368         # Build an ordering with which to query the Keep servers based on the
369         # contents of the hash.
370         # "hash" is a hex-encoded number at least 8 digits
371         # (32 bits) long
372
373         # seed used to calculate the next keep server from 'pool'
374         # to be added to 'pseq'
375         seed = hash
376
377         # Keep servers still to be added to the ordering
378         pool = self.service_roots[:]
379
380         # output probe sequence
381         pseq = []
382
383         # iterate while there are servers left to be assigned
384         while len(pool) > 0:
385             if len(seed) < 8:
386                 # ran out of digits in the seed
387                 if len(pseq) < len(hash) / 4:
388                     # the number of servers added to the probe sequence is less
389                     # than the number of 4-digit slices in 'hash' so refill the
390                     # seed with the last 4 digits and then append the contents
391                     # of 'hash'.
392                     seed = hash[-4:] + hash
393                 else:
394                     # refill the seed with the contents of 'hash'
395                     seed += hash
396
397             # Take the next 8 digits (32 bytes) and interpret as an integer,
398             # then modulus with the size of the remaining pool to get the next
399             # selected server.
400             probe = int(seed[0:8], 16) % len(pool)
401
402             # Append the selected server to the probe sequence and remove it
403             # from the pool.
404             pseq += [pool[probe]]
405             pool = pool[:probe] + pool[probe+1:]
406
407             # Remove the digits just used from the seed
408             seed = seed[8:]
409         _logger.debug(str(pseq))
410         return pseq
411
412     class CacheSlot(object):
413         def __init__(self, locator):
414             self.locator = locator
415             self.ready = threading.Event()
416             self.content = None
417
418         def get(self):
419             self.ready.wait()
420             return self.content
421
422         def set(self, value):
423             self.content = value
424             self.ready.set()
425
426         def size(self):
427             if self.content == None:
428                 return 0
429             else:
430                 return len(self.content)
431
432     def cap_cache(self):
433         '''Cap the cache size to self.cache_max'''
434         self._cache_lock.acquire()
435         try:
436             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
437             sm = sum([slot.size() for slot in self._cache])
438             while sm > self.cache_max:
439                 del self._cache[-1]
440                 sm = sum([slot.size() for a in self._cache])
441         finally:
442             self._cache_lock.release()
443
444     def reserve_cache(self, locator):
445         '''Reserve a cache slot for the specified locator,
446         or return the existing slot.'''
447         self._cache_lock.acquire()
448         try:
449             # Test if the locator is already in the cache
450             for i in xrange(0, len(self._cache)):
451                 if self._cache[i].locator == locator:
452                     n = self._cache[i]
453                     if i != 0:
454                         # move it to the front
455                         del self._cache[i]
456                         self._cache.insert(0, n)
457                     return n, False
458
459             # Add a new cache slot for the locator
460             n = KeepClient.CacheSlot(locator)
461             self._cache.insert(0, n)
462             return n, True
463         finally:
464             self._cache_lock.release()
465
466     def get(self, loc_s):
467         if ',' in loc_s:
468             return ''.join(self.get(x) for x in loc_s.split(','))
469         locator = KeepLocator(loc_s)
470         expect_hash = locator.md5sum
471
472         slot, first = self.reserve_cache(expect_hash)
473
474         if not first:
475             v = slot.get()
476             return v
477
478         try:
479             for service_root in self.shuffled_service_roots(expect_hash):
480                 url = service_root + loc_s
481                 headers = {'Authorization': "OAuth2 %s" % (self.api_token,),
482                            'Accept': 'application/octet-stream'}
483                 blob = self.get_url(url, headers, expect_hash)
484                 if blob:
485                     slot.set(blob)
486                     self.cap_cache()
487                     return blob
488
489             for hint in locator.hints:
490                 if not hint.startswith('K@'):
491                     continue
492                 url = 'http://keep.' + hint[2:] + '.arvadosapi.com/' + loc_s
493                 blob = self.get_url(url, {}, expect_hash)
494                 if blob:
495                     slot.set(blob)
496                     self.cap_cache()
497                     return blob
498         except:
499             slot.set(None)
500             self.cap_cache()
501             raise
502
503         slot.set(None)
504         self.cap_cache()
505         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
506
507     def get_url(self, url, headers, expect_hash):
508         h = httplib2.Http()
509         try:
510             _logger.debug("Request: GET %s", url)
511             with timer.Timer() as t:
512                 resp, content = h.request(url.encode('utf-8'), 'GET',
513                                           headers=headers)
514             _logger.info("Received %s bytes in %s msec (%s MiB/sec)",
515                          len(content), t.msecs,
516                          (len(content)/(1024*1024))/t.secs)
517             if re.match(r'^2\d\d$', resp['status']):
518                 md5 = hashlib.md5(content).hexdigest()
519                 if md5 == expect_hash:
520                     return content
521                 _logger.warning("Checksum fail: md5(%s) = %s", url, md5)
522         except Exception as e:
523             _logger.debug("Request fail: GET %s => %s: %s",
524                          url, type(e), str(e))
525         return None
526
527     def put(self, data, copies=2):
528         data_hash = hashlib.md5(data).hexdigest()
529         have_copies = 0
530         want_copies = copies
531         if not (want_copies > 0):
532             return data_hash
533         threads = []
534         thread_limiter = KeepClient.ThreadLimiter(want_copies)
535         for service_root in self.shuffled_service_roots(data_hash):
536             t = KeepClient.KeepWriterThread(
537                 self.api_token,
538                 data=data,
539                 data_hash=data_hash,
540                 service_root=service_root,
541                 thread_limiter=thread_limiter,
542                 timeout=self.timeout,
543                 using_proxy=self.using_proxy,
544                 want_copies=(want_copies if self.using_proxy else 1))
545             t.start()
546             threads += [t]
547         for t in threads:
548             t.join()
549         if thread_limiter.done() < want_copies:
550             # Retry the threads (i.e., services) that failed the first
551             # time around.
552             threads_retry = []
553             for t in threads:
554                 if not t.success():
555                     _logger.debug("Retrying: PUT %s %s",
556                                     t.args['service_root'],
557                                     t.args['data_hash'])
558                     retry_with_args = t.args.copy()
559                     t_retry = KeepClient.KeepWriterThread(self.api_token,
560                                                           **retry_with_args)
561                     t_retry.start()
562                     threads_retry += [t_retry]
563             for t in threads_retry:
564                 t.join()
565         have_copies = thread_limiter.done()
566         # If we're done, return the response from Keep
567         if have_copies >= want_copies:
568             return thread_limiter.response()
569         raise arvados.errors.KeepWriteError(
570             "Write fail for %s: wanted %d but wrote %d" %
571             (data_hash, want_copies, have_copies))
572
573     @staticmethod
574     def sign_for_old_server(data_hash, data):
575         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
576
577     def local_store_put(self, data):
578         md5 = hashlib.md5(data).hexdigest()
579         locator = '%s+%d' % (md5, len(data))
580         with open(os.path.join(self.local_store, md5 + '.tmp'), 'w') as f:
581             f.write(data)
582         os.rename(os.path.join(self.local_store, md5 + '.tmp'),
583                   os.path.join(self.local_store, md5))
584         return locator
585
586     def local_store_get(self, loc_s):
587         try:
588             locator = KeepLocator(loc_s)
589         except ValueError:
590             raise arvados.errors.NotFoundError(
591                 "Invalid data locator: '%s'" % loc_s)
592         if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
593             return ''
594         with open(os.path.join(self.local_store, locator.md5sum), 'r') as f:
595             return f.read()