2752: Add KeepLocator class to Python SDK.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21 import datetime
22
23 global_client_object = None
24
25 from api import *
26 import config
27 import arvados.errors
28
29 class KeepLocator(object):
30     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
31     HEX_RE = re.compile(r'^[0-9a-fA-F]+$')
32
33     def __init__(self, locator_str):
34         self.size = None
35         self.loc_hint = None
36         self._perm_sig = None
37         self._perm_expiry = None
38         pieces = iter(locator_str.split('+'))
39         self.md5sum = next(pieces)
40         for hint in pieces:
41             if hint.startswith('A'):
42                 self.parse_permission_hint(hint)
43             elif hint.startswith('K'):
44                 self.loc_hint = hint  # FIXME
45             elif hint.isdigit():
46                 self.size = int(hint)
47             else:
48                 raise ValueError("unrecognized hint data {}".format(hint))
49
50     def __str__(self):
51         return '+'.join(
52             str(s) for s in [self.md5sum, self.size, self.loc_hint,
53                              self.permission_hint()]
54             if s is not None)
55
56     def _is_hex_length(self, s, *size_spec):
57         if len(size_spec) == 1:
58             good_len = (len(s) == size_spec[0])
59         else:
60             good_len = (size_spec[0] <= len(s) <= size_spec[1])
61         return good_len and self.HEX_RE.match(s)
62
63     def _make_hex_prop(name, length):
64         # Build and return a new property with the given name that
65         # must be a hex string of the given length.
66         data_name = '_{}'.format(name)
67         def getter(self):
68             return getattr(self, data_name)
69         def setter(self, hex_str):
70             if not self._is_hex_length(hex_str, length):
71                 raise ValueError("{} must be a {}-digit hex string: {}".
72                                  format(name, length, hex_str))
73             setattr(self, data_name, hex_str)
74         return property(getter, setter)
75
76     md5sum = _make_hex_prop('md5sum', 32)
77     perm_sig = _make_hex_prop('perm_sig', 40)
78
79     @property
80     def perm_expiry(self):
81         return self._perm_expiry
82
83     @perm_expiry.setter
84     def perm_expiry(self, value):
85         if not self._is_hex_length(value, 1, 8):
86             raise ValueError(
87                 "permission timestamp must be a hex Unix timestamp: {}".
88                 format(value))
89         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
90
91     def permission_hint(self):
92         data = [self.perm_sig, self.perm_expiry]
93         if None in data:
94             return None
95         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
96         return "A{}@{:08x}".format(*data)
97
98     def parse_permission_hint(self, s):
99         try:
100             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
101         except IndexError:
102             raise ValueError("bad permission hint {}".format(s))
103
104     def permission_expired(self, as_of_dt=None):
105         if self.perm_expiry is None:
106             return False
107         elif as_of_dt is None:
108             as_of_dt = datetime.datetime.now()
109         return self.perm_expiry <= as_of_dt
110
111
112 class Keep:
113     @staticmethod
114     def global_client_object():
115         global global_client_object
116         if global_client_object == None:
117             global_client_object = KeepClient()
118         return global_client_object
119
120     @staticmethod
121     def get(locator, **kwargs):
122         return Keep.global_client_object().get(locator, **kwargs)
123
124     @staticmethod
125     def put(data, **kwargs):
126         return Keep.global_client_object().put(data, **kwargs)
127
128 class KeepClient(object):
129
130     class ThreadLimiter(object):
131         """
132         Limit the number of threads running at a given time to
133         {desired successes} minus {successes reported}. When successes
134         reported == desired, wake up the remaining threads and tell
135         them to quit.
136
137         Should be used in a "with" block.
138         """
139         def __init__(self, todo):
140             self._todo = todo
141             self._done = 0
142             self._response = None
143             self._todo_lock = threading.Semaphore(todo)
144             self._done_lock = threading.Lock()
145
146         def __enter__(self):
147             self._todo_lock.acquire()
148             return self
149
150         def __exit__(self, type, value, traceback):
151             self._todo_lock.release()
152
153         def shall_i_proceed(self):
154             """
155             Return true if the current thread should do stuff. Return
156             false if the current thread should just stop.
157             """
158             with self._done_lock:
159                 return (self._done < self._todo)
160
161         def save_response(self, response_body, replicas_stored):
162             """
163             Records a response body (a locator, possibly signed) returned by
164             the Keep server.  It is not necessary to save more than
165             one response, since we presume that any locator returned
166             in response to a successful request is valid.
167             """
168             with self._done_lock:
169                 self._done += replicas_stored
170                 self._response = response_body
171
172         def response(self):
173             """
174             Returns the body from the response to a PUT request.
175             """
176             with self._done_lock:
177                 return self._response
178
179         def done(self):
180             """
181             Return how many successes were reported.
182             """
183             with self._done_lock:
184                 return self._done
185
186     class KeepWriterThread(threading.Thread):
187         """
188         Write a blob of data to the given Keep server. On success, call
189         save_response() of the given ThreadLimiter to save the returned
190         locator.
191         """
192         def __init__(self, **kwargs):
193             super(KeepClient.KeepWriterThread, self).__init__()
194             self.args = kwargs
195
196         def run(self):
197             with self.args['thread_limiter'] as limiter:
198                 if not limiter.shall_i_proceed():
199                     # My turn arrived, but the job has been done without
200                     # me.
201                     return
202                 logging.debug("KeepWriterThread %s proceeding %s %s" %
203                               (str(threading.current_thread()),
204                                self.args['data_hash'],
205                                self.args['service_root']))
206                 h = httplib2.Http()
207                 url = self.args['service_root'] + self.args['data_hash']
208                 api_token = config.get('ARVADOS_API_TOKEN')
209                 headers = {'Authorization': "OAuth2 %s" % api_token}
210
211                 if self.args['using_proxy']:
212                     # We're using a proxy, so tell the proxy how many copies we
213                     # want it to store
214                     headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
215
216                 try:
217                     logging.debug("Uploading to {}".format(url))
218                     resp, content = h.request(url.encode('utf-8'), 'PUT',
219                                               headers=headers,
220                                               body=self.args['data'])
221                     if (resp['status'] == '401' and
222                         re.match(r'Timestamp verification failed', content)):
223                         body = KeepClient.sign_for_old_server(
224                             self.args['data_hash'],
225                             self.args['data'])
226                         h = httplib2.Http()
227                         resp, content = h.request(url.encode('utf-8'), 'PUT',
228                                                   headers=headers,
229                                                   body=body)
230                     if re.match(r'^2\d\d$', resp['status']):
231                         logging.debug("KeepWriterThread %s succeeded %s %s" %
232                                       (str(threading.current_thread()),
233                                        self.args['data_hash'],
234                                        self.args['service_root']))
235                         replicas_stored = 1
236                         if 'x-keep-replicas-stored' in resp:
237                             # Tick the 'done' counter for the number of replica
238                             # reported stored by the server, for the case that
239                             # we're talking to a proxy or other backend that
240                             # stores to multiple copies for us.
241                             try:
242                                 replicas_stored = int(resp['x-keep-replicas-stored'])
243                             except ValueError:
244                                 pass
245                         return limiter.save_response(content.strip(), replicas_stored)
246
247                     logging.warning("Request fail: PUT %s => %s %s" %
248                                     (url, resp['status'], content))
249                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
250                     logging.warning("Request fail: PUT %s => %s: %s" %
251                                     (url, type(e), str(e)))
252
253     def __init__(self):
254         self.lock = threading.Lock()
255         self.service_roots = None
256         self._cache_lock = threading.Lock()
257         self._cache = []
258         # default 256 megabyte cache
259         self.cache_max = 256 * 1024 * 1024
260         self.using_proxy = False
261
262     def shuffled_service_roots(self, hash):
263         if self.service_roots == None:
264             self.lock.acquire()
265
266             # Override normal keep disk lookup with an explict proxy
267             # configuration.
268             keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
269             if keep_proxy_env != None and len(keep_proxy_env) > 0:
270
271                 if keep_proxy_env[-1:] != '/':
272                     keep_proxy_env += "/"
273                 self.service_roots = [keep_proxy_env]
274                 self.using_proxy = True
275             else:
276                 try:
277                     try:
278                         keep_services = arvados.api().keep_services().accessible().execute()['items']
279                     except Exception:
280                         keep_services = arvados.api().keep_disks().list().execute()['items']
281
282                     if len(keep_services) == 0:
283                         raise arvados.errors.NoKeepServersError()
284
285                     if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
286                         self.using_proxy = True
287
288                     roots = (("http%s://%s:%d/" %
289                               ('s' if f['service_ssl_flag'] else '',
290                                f['service_host'],
291                                f['service_port']))
292                              for f in keep_services)
293                     self.service_roots = sorted(set(roots))
294                     logging.debug(str(self.service_roots))
295                 finally:
296                     self.lock.release()
297
298         # Build an ordering with which to query the Keep servers based on the
299         # contents of the hash.
300         # "hash" is a hex-encoded number at least 8 digits
301         # (32 bits) long
302
303         # seed used to calculate the next keep server from 'pool'
304         # to be added to 'pseq'
305         seed = hash
306
307         # Keep servers still to be added to the ordering
308         pool = self.service_roots[:]
309
310         # output probe sequence
311         pseq = []
312
313         # iterate while there are servers left to be assigned
314         while len(pool) > 0:
315             if len(seed) < 8:
316                 # ran out of digits in the seed
317                 if len(pseq) < len(hash) / 4:
318                     # the number of servers added to the probe sequence is less
319                     # than the number of 4-digit slices in 'hash' so refill the
320                     # seed with the last 4 digits and then append the contents
321                     # of 'hash'.
322                     seed = hash[-4:] + hash
323                 else:
324                     # refill the seed with the contents of 'hash'
325                     seed += hash
326
327             # Take the next 8 digits (32 bytes) and interpret as an integer,
328             # then modulus with the size of the remaining pool to get the next
329             # selected server.
330             probe = int(seed[0:8], 16) % len(pool)
331
332             # Append the selected server to the probe sequence and remove it
333             # from the pool.
334             pseq += [pool[probe]]
335             pool = pool[:probe] + pool[probe+1:]
336
337             # Remove the digits just used from the seed
338             seed = seed[8:]
339         logging.debug(str(pseq))
340         return pseq
341
342     class CacheSlot(object):
343         def __init__(self, locator):
344             self.locator = locator
345             self.ready = threading.Event()
346             self.content = None
347
348         def get(self):
349             self.ready.wait()
350             return self.content
351
352         def set(self, value):
353             self.content = value
354             self.ready.set()
355
356         def size(self):
357             if self.content == None:
358                 return 0
359             else:
360                 return len(self.content)
361
362     def cap_cache(self):
363         '''Cap the cache size to self.cache_max'''
364         self._cache_lock.acquire()
365         try:
366             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
367             sm = sum([slot.size() for slot in self._cache])
368             while sm > self.cache_max:
369                 del self._cache[-1]
370                 sm = sum([slot.size() for a in self._cache])
371         finally:
372             self._cache_lock.release()
373
374     def reserve_cache(self, locator):
375         '''Reserve a cache slot for the specified locator,
376         or return the existing slot.'''
377         self._cache_lock.acquire()
378         try:
379             # Test if the locator is already in the cache
380             for i in xrange(0, len(self._cache)):
381                 if self._cache[i].locator == locator:
382                     n = self._cache[i]
383                     if i != 0:
384                         # move it to the front
385                         del self._cache[i]
386                         self._cache.insert(0, n)
387                     return n, False
388
389             # Add a new cache slot for the locator
390             n = KeepClient.CacheSlot(locator)
391             self._cache.insert(0, n)
392             return n, True
393         finally:
394             self._cache_lock.release()
395
396     def get(self, locator):
397         #logging.debug("Keep.get %s" % (locator))
398
399         if re.search(r',', locator):
400             return ''.join(self.get(x) for x in locator.split(','))
401         if 'KEEP_LOCAL_STORE' in os.environ:
402             return KeepClient.local_store_get(locator)
403         expect_hash = re.sub(r'\+.*', '', locator)
404
405         slot, first = self.reserve_cache(expect_hash)
406         #logging.debug("%s %s %s" % (slot, first, expect_hash))
407
408         if not first:
409             v = slot.get()
410             return v
411
412         try:
413             for service_root in self.shuffled_service_roots(expect_hash):
414                 url = service_root + locator
415                 api_token = config.get('ARVADOS_API_TOKEN')
416                 headers = {'Authorization': "OAuth2 %s" % api_token,
417                            'Accept': 'application/octet-stream'}
418                 blob = self.get_url(url, headers, expect_hash)
419                 if blob:
420                     slot.set(blob)
421                     self.cap_cache()
422                     return blob
423
424             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
425                 instance = location_hint.group(1)
426                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
427                 blob = self.get_url(url, {}, expect_hash)
428                 if blob:
429                     slot.set(blob)
430                     self.cap_cache()
431                     return blob
432         except:
433             slot.set(None)
434             self.cap_cache()
435             raise
436
437         slot.set(None)
438         self.cap_cache()
439         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
440
441     def get_url(self, url, headers, expect_hash):
442         h = httplib2.Http()
443         try:
444             logging.info("Request: GET %s" % (url))
445             with timer.Timer() as t:
446                 resp, content = h.request(url.encode('utf-8'), 'GET',
447                                           headers=headers)
448             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
449                                                                         t.msecs,
450                                                                         (len(content)/(1024*1024))/t.secs))
451             if re.match(r'^2\d\d$', resp['status']):
452                 m = hashlib.new('md5')
453                 m.update(content)
454                 md5 = m.hexdigest()
455                 if md5 == expect_hash:
456                     return content
457                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
458         except Exception as e:
459             logging.info("Request fail: GET %s => %s: %s" %
460                          (url, type(e), str(e)))
461         return None
462
463     def put(self, data, **kwargs):
464         if 'KEEP_LOCAL_STORE' in os.environ:
465             return KeepClient.local_store_put(data)
466         m = hashlib.new('md5')
467         m.update(data)
468         data_hash = m.hexdigest()
469         have_copies = 0
470         want_copies = kwargs.get('copies', 2)
471         if not (want_copies > 0):
472             return data_hash
473         threads = []
474         thread_limiter = KeepClient.ThreadLimiter(want_copies)
475         for service_root in self.shuffled_service_roots(data_hash):
476             t = KeepClient.KeepWriterThread(data=data,
477                                             data_hash=data_hash,
478                                             service_root=service_root,
479                                             thread_limiter=thread_limiter,
480                                             using_proxy=self.using_proxy,
481                                             want_copies=(want_copies if self.using_proxy else 1))
482             t.start()
483             threads += [t]
484         for t in threads:
485             t.join()
486         have_copies = thread_limiter.done()
487         # If we're done, return the response from Keep
488         if have_copies >= want_copies:
489             return thread_limiter.response()
490         raise arvados.errors.KeepWriteError(
491             "Write fail for %s: wanted %d but wrote %d" %
492             (data_hash, want_copies, have_copies))
493
494     @staticmethod
495     def sign_for_old_server(data_hash, data):
496         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
497
498
499     @staticmethod
500     def local_store_put(data):
501         m = hashlib.new('md5')
502         m.update(data)
503         md5 = m.hexdigest()
504         locator = '%s+%d' % (md5, len(data))
505         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
506             f.write(data)
507         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
508                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
509         return locator
510
511     @staticmethod
512     def local_store_get(locator):
513         r = re.search('^([0-9a-f]{32,})', locator)
514         if not r:
515             raise arvados.errors.NotFoundError(
516                 "Invalid data locator: '%s'" % locator)
517         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
518             return ''
519         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
520             return f.read()