2800: Improve spec conformance of Python SDK KeepLocator.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21 import datetime
22 import ssl
23
24 _logger = logging.getLogger('arvados.keep')
25 global_client_object = None
26
27 from api import *
28 import config
29 import arvados.errors
30 import arvados.util
31
32 class KeepLocator(object):
33     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
34     HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
35
36     def __init__(self, locator_str):
37         self.hints = []
38         self._perm_sig = None
39         self._perm_expiry = None
40         pieces = iter(locator_str.split('+'))
41         self.md5sum = next(pieces)
42         try:
43             self.size = int(next(pieces))
44         except StopIteration:
45             self.size = None
46         for hint in pieces:
47             if self.HINT_RE.match(hint) is None:
48                 raise ValueError("unrecognized hint data {}".format(hint))
49             elif hint.startswith('A'):
50                 self.parse_permission_hint(hint)
51             else:
52                 self.hints.append(hint)
53
54     def __str__(self):
55         return '+'.join(
56             str(s) for s in [self.md5sum, self.size,
57                              self.permission_hint()] + self.hints
58             if s is not None)
59
60     def _make_hex_prop(name, length):
61         # Build and return a new property with the given name that
62         # must be a hex string of the given length.
63         data_name = '_{}'.format(name)
64         def getter(self):
65             return getattr(self, data_name)
66         def setter(self, hex_str):
67             if not arvados.util.is_hex(hex_str, length):
68                 raise ValueError("{} must be a {}-digit hex string: {}".
69                                  format(name, length, hex_str))
70             setattr(self, data_name, hex_str)
71         return property(getter, setter)
72
73     md5sum = _make_hex_prop('md5sum', 32)
74     perm_sig = _make_hex_prop('perm_sig', 40)
75
76     @property
77     def perm_expiry(self):
78         return self._perm_expiry
79
80     @perm_expiry.setter
81     def perm_expiry(self, value):
82         if not arvados.util.is_hex(value, 1, 8):
83             raise ValueError(
84                 "permission timestamp must be a hex Unix timestamp: {}".
85                 format(value))
86         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
87
88     def permission_hint(self):
89         data = [self.perm_sig, self.perm_expiry]
90         if None in data:
91             return None
92         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
93         return "A{}@{:08x}".format(*data)
94
95     def parse_permission_hint(self, s):
96         try:
97             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
98         except IndexError:
99             raise ValueError("bad permission hint {}".format(s))
100
101     def permission_expired(self, as_of_dt=None):
102         if self.perm_expiry is None:
103             return False
104         elif as_of_dt is None:
105             as_of_dt = datetime.datetime.now()
106         return self.perm_expiry <= as_of_dt
107
108
109 class Keep:
110     @staticmethod
111     def global_client_object():
112         global global_client_object
113         if global_client_object == None:
114             global_client_object = KeepClient()
115         return global_client_object
116
117     @staticmethod
118     def get(locator, **kwargs):
119         return Keep.global_client_object().get(locator, **kwargs)
120
121     @staticmethod
122     def put(data, **kwargs):
123         return Keep.global_client_object().put(data, **kwargs)
124
125 class KeepClient(object):
126
127     class ThreadLimiter(object):
128         """
129         Limit the number of threads running at a given time to
130         {desired successes} minus {successes reported}. When successes
131         reported == desired, wake up the remaining threads and tell
132         them to quit.
133
134         Should be used in a "with" block.
135         """
136         def __init__(self, todo):
137             self._todo = todo
138             self._done = 0
139             self._response = None
140             self._todo_lock = threading.Semaphore(todo)
141             self._done_lock = threading.Lock()
142
143         def __enter__(self):
144             self._todo_lock.acquire()
145             return self
146
147         def __exit__(self, type, value, traceback):
148             self._todo_lock.release()
149
150         def shall_i_proceed(self):
151             """
152             Return true if the current thread should do stuff. Return
153             false if the current thread should just stop.
154             """
155             with self._done_lock:
156                 return (self._done < self._todo)
157
158         def save_response(self, response_body, replicas_stored):
159             """
160             Records a response body (a locator, possibly signed) returned by
161             the Keep server.  It is not necessary to save more than
162             one response, since we presume that any locator returned
163             in response to a successful request is valid.
164             """
165             with self._done_lock:
166                 self._done += replicas_stored
167                 self._response = response_body
168
169         def response(self):
170             """
171             Returns the body from the response to a PUT request.
172             """
173             with self._done_lock:
174                 return self._response
175
176         def done(self):
177             """
178             Return how many successes were reported.
179             """
180             with self._done_lock:
181                 return self._done
182
183     class KeepWriterThread(threading.Thread):
184         """
185         Write a blob of data to the given Keep server. On success, call
186         save_response() of the given ThreadLimiter to save the returned
187         locator.
188         """
189         def __init__(self, **kwargs):
190             super(KeepClient.KeepWriterThread, self).__init__()
191             self.args = kwargs
192             self._success = False
193
194         def success(self):
195             return self._success
196
197         def run(self):
198             with self.args['thread_limiter'] as limiter:
199                 if not limiter.shall_i_proceed():
200                     # My turn arrived, but the job has been done without
201                     # me.
202                     return
203                 self.run_with_limiter(limiter)
204
205         def run_with_limiter(self, limiter):
206             _logger.debug("KeepWriterThread %s proceeding %s %s",
207                           str(threading.current_thread()),
208                           self.args['data_hash'],
209                           self.args['service_root'])
210             h = httplib2.Http(timeout=self.args.get('timeout', None))
211             url = self.args['service_root'] + self.args['data_hash']
212             api_token = config.get('ARVADOS_API_TOKEN')
213             headers = {'Authorization': "OAuth2 %s" % api_token}
214
215             if self.args['using_proxy']:
216                 # We're using a proxy, so tell the proxy how many copies we
217                 # want it to store
218                 headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
219
220             try:
221                 _logger.debug("Uploading to {}".format(url))
222                 resp, content = h.request(url.encode('utf-8'), 'PUT',
223                                           headers=headers,
224                                           body=self.args['data'])
225                 if (resp['status'] == '401' and
226                     re.match(r'Timestamp verification failed', content)):
227                     body = KeepClient.sign_for_old_server(
228                         self.args['data_hash'],
229                         self.args['data'])
230                     h = httplib2.Http(timeout=self.args.get('timeout', None))
231                     resp, content = h.request(url.encode('utf-8'), 'PUT',
232                                               headers=headers,
233                                               body=body)
234                 if re.match(r'^2\d\d$', resp['status']):
235                     self._success = True
236                     _logger.debug("KeepWriterThread %s succeeded %s %s",
237                                   str(threading.current_thread()),
238                                   self.args['data_hash'],
239                                   self.args['service_root'])
240                     replicas_stored = 1
241                     if 'x-keep-replicas-stored' in resp:
242                         # Tick the 'done' counter for the number of replica
243                         # reported stored by the server, for the case that
244                         # we're talking to a proxy or other backend that
245                         # stores to multiple copies for us.
246                         try:
247                             replicas_stored = int(resp['x-keep-replicas-stored'])
248                         except ValueError:
249                             pass
250                     limiter.save_response(content.strip(), replicas_stored)
251                 else:
252                     _logger.debug("Request fail: PUT %s => %s %s",
253                                     url, resp['status'], content)
254             except (httplib2.HttpLib2Error,
255                     httplib.HTTPException,
256                     ssl.SSLError) as e:
257                 # When using https, timeouts look like ssl.SSLError from here.
258                 # "SSLError: The write operation timed out"
259                 _logger.debug("Request fail: PUT %s => %s: %s",
260                                 url, type(e), str(e))
261
262     def __init__(self, **kwargs):
263         self.lock = threading.Lock()
264         self.service_roots = None
265         self._cache_lock = threading.Lock()
266         self._cache = []
267         # default 256 megabyte cache
268         self.cache_max = 256 * 1024 * 1024
269         self.using_proxy = False
270         self.timeout = kwargs.get('timeout', 60)
271
272     def shuffled_service_roots(self, hash):
273         if self.service_roots == None:
274             self.lock.acquire()
275
276             # Override normal keep disk lookup with an explict proxy
277             # configuration.
278             keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
279             if keep_proxy_env != None and len(keep_proxy_env) > 0:
280
281                 if keep_proxy_env[-1:] != '/':
282                     keep_proxy_env += "/"
283                 self.service_roots = [keep_proxy_env]
284                 self.using_proxy = True
285             else:
286                 try:
287                     try:
288                         keep_services = arvados.api().keep_services().accessible().execute()['items']
289                     except Exception:
290                         keep_services = arvados.api().keep_disks().list().execute()['items']
291
292                     if len(keep_services) == 0:
293                         raise arvados.errors.NoKeepServersError()
294
295                     if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
296                         self.using_proxy = True
297
298                     roots = (("http%s://%s:%d/" %
299                               ('s' if f['service_ssl_flag'] else '',
300                                f['service_host'],
301                                f['service_port']))
302                              for f in keep_services)
303                     self.service_roots = sorted(set(roots))
304                     _logger.debug(str(self.service_roots))
305                 finally:
306                     self.lock.release()
307
308         # Build an ordering with which to query the Keep servers based on the
309         # contents of the hash.
310         # "hash" is a hex-encoded number at least 8 digits
311         # (32 bits) long
312
313         # seed used to calculate the next keep server from 'pool'
314         # to be added to 'pseq'
315         seed = hash
316
317         # Keep servers still to be added to the ordering
318         pool = self.service_roots[:]
319
320         # output probe sequence
321         pseq = []
322
323         # iterate while there are servers left to be assigned
324         while len(pool) > 0:
325             if len(seed) < 8:
326                 # ran out of digits in the seed
327                 if len(pseq) < len(hash) / 4:
328                     # the number of servers added to the probe sequence is less
329                     # than the number of 4-digit slices in 'hash' so refill the
330                     # seed with the last 4 digits and then append the contents
331                     # of 'hash'.
332                     seed = hash[-4:] + hash
333                 else:
334                     # refill the seed with the contents of 'hash'
335                     seed += hash
336
337             # Take the next 8 digits (32 bytes) and interpret as an integer,
338             # then modulus with the size of the remaining pool to get the next
339             # selected server.
340             probe = int(seed[0:8], 16) % len(pool)
341
342             # Append the selected server to the probe sequence and remove it
343             # from the pool.
344             pseq += [pool[probe]]
345             pool = pool[:probe] + pool[probe+1:]
346
347             # Remove the digits just used from the seed
348             seed = seed[8:]
349         _logger.debug(str(pseq))
350         return pseq
351
352     class CacheSlot(object):
353         def __init__(self, locator):
354             self.locator = locator
355             self.ready = threading.Event()
356             self.content = None
357
358         def get(self):
359             self.ready.wait()
360             return self.content
361
362         def set(self, value):
363             self.content = value
364             self.ready.set()
365
366         def size(self):
367             if self.content == None:
368                 return 0
369             else:
370                 return len(self.content)
371
372     def cap_cache(self):
373         '''Cap the cache size to self.cache_max'''
374         self._cache_lock.acquire()
375         try:
376             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
377             sm = sum([slot.size() for slot in self._cache])
378             while sm > self.cache_max:
379                 del self._cache[-1]
380                 sm = sum([slot.size() for a in self._cache])
381         finally:
382             self._cache_lock.release()
383
384     def reserve_cache(self, locator):
385         '''Reserve a cache slot for the specified locator,
386         or return the existing slot.'''
387         self._cache_lock.acquire()
388         try:
389             # Test if the locator is already in the cache
390             for i in xrange(0, len(self._cache)):
391                 if self._cache[i].locator == locator:
392                     n = self._cache[i]
393                     if i != 0:
394                         # move it to the front
395                         del self._cache[i]
396                         self._cache.insert(0, n)
397                     return n, False
398
399             # Add a new cache slot for the locator
400             n = KeepClient.CacheSlot(locator)
401             self._cache.insert(0, n)
402             return n, True
403         finally:
404             self._cache_lock.release()
405
406     def get(self, locator):
407         if re.search(r',', locator):
408             return ''.join(self.get(x) for x in locator.split(','))
409         if 'KEEP_LOCAL_STORE' in os.environ:
410             return KeepClient.local_store_get(locator)
411         expect_hash = re.sub(r'\+.*', '', locator)
412
413         slot, first = self.reserve_cache(expect_hash)
414
415         if not first:
416             v = slot.get()
417             return v
418
419         try:
420             for service_root in self.shuffled_service_roots(expect_hash):
421                 url = service_root + locator
422                 api_token = config.get('ARVADOS_API_TOKEN')
423                 headers = {'Authorization': "OAuth2 %s" % api_token,
424                            'Accept': 'application/octet-stream'}
425                 blob = self.get_url(url, headers, expect_hash)
426                 if blob:
427                     slot.set(blob)
428                     self.cap_cache()
429                     return blob
430
431             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
432                 instance = location_hint.group(1)
433                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
434                 blob = self.get_url(url, {}, expect_hash)
435                 if blob:
436                     slot.set(blob)
437                     self.cap_cache()
438                     return blob
439         except:
440             slot.set(None)
441             self.cap_cache()
442             raise
443
444         slot.set(None)
445         self.cap_cache()
446         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
447
448     def get_url(self, url, headers, expect_hash):
449         h = httplib2.Http()
450         try:
451             _logger.debug("Request: GET %s", url)
452             with timer.Timer() as t:
453                 resp, content = h.request(url.encode('utf-8'), 'GET',
454                                           headers=headers)
455             _logger.info("Received %s bytes in %s msec (%s MiB/sec)",
456                          len(content), t.msecs,
457                          (len(content)/(1024*1024))/t.secs)
458             if re.match(r'^2\d\d$', resp['status']):
459                 m = hashlib.new('md5')
460                 m.update(content)
461                 md5 = m.hexdigest()
462                 if md5 == expect_hash:
463                     return content
464                 _logger.warning("Checksum fail: md5(%s) = %s", url, md5)
465         except Exception as e:
466             _logger.debug("Request fail: GET %s => %s: %s",
467                          url, type(e), str(e))
468         return None
469
470     def put(self, data, **kwargs):
471         if 'KEEP_LOCAL_STORE' in os.environ:
472             return KeepClient.local_store_put(data)
473         m = hashlib.new('md5')
474         m.update(data)
475         data_hash = m.hexdigest()
476         have_copies = 0
477         want_copies = kwargs.get('copies', 2)
478         if not (want_copies > 0):
479             return data_hash
480         threads = []
481         thread_limiter = KeepClient.ThreadLimiter(want_copies)
482         for service_root in self.shuffled_service_roots(data_hash):
483             t = KeepClient.KeepWriterThread(
484                 data=data,
485                 data_hash=data_hash,
486                 service_root=service_root,
487                 thread_limiter=thread_limiter,
488                 timeout=self.timeout,
489                 using_proxy=self.using_proxy,
490                 want_copies=(want_copies if self.using_proxy else 1))
491             t.start()
492             threads += [t]
493         for t in threads:
494             t.join()
495         if thread_limiter.done() < want_copies:
496             # Retry the threads (i.e., services) that failed the first
497             # time around.
498             threads_retry = []
499             for t in threads:
500                 if not t.success():
501                     _logger.debug("Retrying: PUT %s %s",
502                                     t.args['service_root'],
503                                     t.args['data_hash'])
504                     retry_with_args = t.args.copy()
505                     t_retry = KeepClient.KeepWriterThread(**retry_with_args)
506                     t_retry.start()
507                     threads_retry += [t_retry]
508             for t in threads_retry:
509                 t.join()
510         have_copies = thread_limiter.done()
511         # If we're done, return the response from Keep
512         if have_copies >= want_copies:
513             return thread_limiter.response()
514         raise arvados.errors.KeepWriteError(
515             "Write fail for %s: wanted %d but wrote %d" %
516             (data_hash, want_copies, have_copies))
517
518     @staticmethod
519     def sign_for_old_server(data_hash, data):
520         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
521
522
523     @staticmethod
524     def local_store_put(data):
525         m = hashlib.new('md5')
526         m.update(data)
527         md5 = m.hexdigest()
528         locator = '%s+%d' % (md5, len(data))
529         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
530             f.write(data)
531         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
532                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
533         return locator
534
535     @staticmethod
536     def local_store_get(locator):
537         r = re.search('^([0-9a-f]{32,})', locator)
538         if not r:
539             raise arvados.errors.NotFoundError(
540                 "Invalid data locator: '%s'" % locator)
541         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
542             return ''
543         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
544             return f.read()