2872: Merge branch 'master' into 2872-folder-nav
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21 import datetime
22
23 global_client_object = None
24
25 from api import *
26 import config
27 import arvados.errors
28 import arvados.util
29
30 class KeepLocator(object):
31     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
32
33     def __init__(self, locator_str):
34         self.size = None
35         self.loc_hint = None
36         self._perm_sig = None
37         self._perm_expiry = None
38         pieces = iter(locator_str.split('+'))
39         self.md5sum = next(pieces)
40         for hint in pieces:
41             if hint.startswith('A'):
42                 self.parse_permission_hint(hint)
43             elif hint.startswith('K'):
44                 self.loc_hint = hint  # FIXME
45             elif hint.isdigit():
46                 self.size = int(hint)
47             else:
48                 raise ValueError("unrecognized hint data {}".format(hint))
49
50     def __str__(self):
51         return '+'.join(
52             str(s) for s in [self.md5sum, self.size, self.loc_hint,
53                              self.permission_hint()]
54             if s is not None)
55
56     def _make_hex_prop(name, length):
57         # Build and return a new property with the given name that
58         # must be a hex string of the given length.
59         data_name = '_{}'.format(name)
60         def getter(self):
61             return getattr(self, data_name)
62         def setter(self, hex_str):
63             if not arvados.util.is_hex(hex_str, length):
64                 raise ValueError("{} must be a {}-digit hex string: {}".
65                                  format(name, length, hex_str))
66             setattr(self, data_name, hex_str)
67         return property(getter, setter)
68
69     md5sum = _make_hex_prop('md5sum', 32)
70     perm_sig = _make_hex_prop('perm_sig', 40)
71
72     @property
73     def perm_expiry(self):
74         return self._perm_expiry
75
76     @perm_expiry.setter
77     def perm_expiry(self, value):
78         if not arvados.util.is_hex(value, 1, 8):
79             raise ValueError(
80                 "permission timestamp must be a hex Unix timestamp: {}".
81                 format(value))
82         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
83
84     def permission_hint(self):
85         data = [self.perm_sig, self.perm_expiry]
86         if None in data:
87             return None
88         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
89         return "A{}@{:08x}".format(*data)
90
91     def parse_permission_hint(self, s):
92         try:
93             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
94         except IndexError:
95             raise ValueError("bad permission hint {}".format(s))
96
97     def permission_expired(self, as_of_dt=None):
98         if self.perm_expiry is None:
99             return False
100         elif as_of_dt is None:
101             as_of_dt = datetime.datetime.now()
102         return self.perm_expiry <= as_of_dt
103
104
105 class Keep:
106     @staticmethod
107     def global_client_object():
108         global global_client_object
109         if global_client_object == None:
110             global_client_object = KeepClient()
111         return global_client_object
112
113     @staticmethod
114     def get(locator, **kwargs):
115         return Keep.global_client_object().get(locator, **kwargs)
116
117     @staticmethod
118     def put(data, **kwargs):
119         return Keep.global_client_object().put(data, **kwargs)
120
121 class KeepClient(object):
122
123     class ThreadLimiter(object):
124         """
125         Limit the number of threads running at a given time to
126         {desired successes} minus {successes reported}. When successes
127         reported == desired, wake up the remaining threads and tell
128         them to quit.
129
130         Should be used in a "with" block.
131         """
132         def __init__(self, todo):
133             self._todo = todo
134             self._done = 0
135             self._response = None
136             self._todo_lock = threading.Semaphore(todo)
137             self._done_lock = threading.Lock()
138
139         def __enter__(self):
140             self._todo_lock.acquire()
141             return self
142
143         def __exit__(self, type, value, traceback):
144             self._todo_lock.release()
145
146         def shall_i_proceed(self):
147             """
148             Return true if the current thread should do stuff. Return
149             false if the current thread should just stop.
150             """
151             with self._done_lock:
152                 return (self._done < self._todo)
153
154         def save_response(self, response_body, replicas_stored):
155             """
156             Records a response body (a locator, possibly signed) returned by
157             the Keep server.  It is not necessary to save more than
158             one response, since we presume that any locator returned
159             in response to a successful request is valid.
160             """
161             with self._done_lock:
162                 self._done += replicas_stored
163                 self._response = response_body
164
165         def response(self):
166             """
167             Returns the body from the response to a PUT request.
168             """
169             with self._done_lock:
170                 return self._response
171
172         def done(self):
173             """
174             Return how many successes were reported.
175             """
176             with self._done_lock:
177                 return self._done
178
179     class KeepWriterThread(threading.Thread):
180         """
181         Write a blob of data to the given Keep server. On success, call
182         save_response() of the given ThreadLimiter to save the returned
183         locator.
184         """
185         def __init__(self, **kwargs):
186             super(KeepClient.KeepWriterThread, self).__init__()
187             self.args = kwargs
188
189         def run(self):
190             with self.args['thread_limiter'] as limiter:
191                 if not limiter.shall_i_proceed():
192                     # My turn arrived, but the job has been done without
193                     # me.
194                     return
195                 logging.debug("KeepWriterThread %s proceeding %s %s" %
196                               (str(threading.current_thread()),
197                                self.args['data_hash'],
198                                self.args['service_root']))
199                 h = httplib2.Http()
200                 url = self.args['service_root'] + self.args['data_hash']
201                 api_token = config.get('ARVADOS_API_TOKEN')
202                 headers = {'Authorization': "OAuth2 %s" % api_token}
203
204                 if self.args['using_proxy']:
205                     # We're using a proxy, so tell the proxy how many copies we
206                     # want it to store
207                     headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
208
209                 try:
210                     logging.debug("Uploading to {}".format(url))
211                     resp, content = h.request(url.encode('utf-8'), 'PUT',
212                                               headers=headers,
213                                               body=self.args['data'])
214                     if (resp['status'] == '401' and
215                         re.match(r'Timestamp verification failed', content)):
216                         body = KeepClient.sign_for_old_server(
217                             self.args['data_hash'],
218                             self.args['data'])
219                         h = httplib2.Http()
220                         resp, content = h.request(url.encode('utf-8'), 'PUT',
221                                                   headers=headers,
222                                                   body=body)
223                     if re.match(r'^2\d\d$', resp['status']):
224                         logging.debug("KeepWriterThread %s succeeded %s %s" %
225                                       (str(threading.current_thread()),
226                                        self.args['data_hash'],
227                                        self.args['service_root']))
228                         replicas_stored = 1
229                         if 'x-keep-replicas-stored' in resp:
230                             # Tick the 'done' counter for the number of replica
231                             # reported stored by the server, for the case that
232                             # we're talking to a proxy or other backend that
233                             # stores to multiple copies for us.
234                             try:
235                                 replicas_stored = int(resp['x-keep-replicas-stored'])
236                             except ValueError:
237                                 pass
238                         return limiter.save_response(content.strip(), replicas_stored)
239
240                     logging.warning("Request fail: PUT %s => %s %s" %
241                                     (url, resp['status'], content))
242                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
243                     logging.warning("Request fail: PUT %s => %s: %s" %
244                                     (url, type(e), str(e)))
245
246     def __init__(self):
247         self.lock = threading.Lock()
248         self.service_roots = None
249         self._cache_lock = threading.Lock()
250         self._cache = []
251         # default 256 megabyte cache
252         self.cache_max = 256 * 1024 * 1024
253         self.using_proxy = False
254
255     def shuffled_service_roots(self, hash):
256         if self.service_roots == None:
257             self.lock.acquire()
258
259             # Override normal keep disk lookup with an explict proxy
260             # configuration.
261             keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
262             if keep_proxy_env != None and len(keep_proxy_env) > 0:
263
264                 if keep_proxy_env[-1:] != '/':
265                     keep_proxy_env += "/"
266                 self.service_roots = [keep_proxy_env]
267                 self.using_proxy = True
268             else:
269                 try:
270                     try:
271                         keep_services = arvados.api().keep_services().accessible().execute()['items']
272                     except Exception:
273                         keep_services = arvados.api().keep_disks().list().execute()['items']
274
275                     if len(keep_services) == 0:
276                         raise arvados.errors.NoKeepServersError()
277
278                     if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
279                         self.using_proxy = True
280
281                     roots = (("http%s://%s:%d/" %
282                               ('s' if f['service_ssl_flag'] else '',
283                                f['service_host'],
284                                f['service_port']))
285                              for f in keep_services)
286                     self.service_roots = sorted(set(roots))
287                     logging.debug(str(self.service_roots))
288                 finally:
289                     self.lock.release()
290
291         # Build an ordering with which to query the Keep servers based on the
292         # contents of the hash.
293         # "hash" is a hex-encoded number at least 8 digits
294         # (32 bits) long
295
296         # seed used to calculate the next keep server from 'pool'
297         # to be added to 'pseq'
298         seed = hash
299
300         # Keep servers still to be added to the ordering
301         pool = self.service_roots[:]
302
303         # output probe sequence
304         pseq = []
305
306         # iterate while there are servers left to be assigned
307         while len(pool) > 0:
308             if len(seed) < 8:
309                 # ran out of digits in the seed
310                 if len(pseq) < len(hash) / 4:
311                     # the number of servers added to the probe sequence is less
312                     # than the number of 4-digit slices in 'hash' so refill the
313                     # seed with the last 4 digits and then append the contents
314                     # of 'hash'.
315                     seed = hash[-4:] + hash
316                 else:
317                     # refill the seed with the contents of 'hash'
318                     seed += hash
319
320             # Take the next 8 digits (32 bytes) and interpret as an integer,
321             # then modulus with the size of the remaining pool to get the next
322             # selected server.
323             probe = int(seed[0:8], 16) % len(pool)
324
325             # Append the selected server to the probe sequence and remove it
326             # from the pool.
327             pseq += [pool[probe]]
328             pool = pool[:probe] + pool[probe+1:]
329
330             # Remove the digits just used from the seed
331             seed = seed[8:]
332         logging.debug(str(pseq))
333         return pseq
334
335     class CacheSlot(object):
336         def __init__(self, locator):
337             self.locator = locator
338             self.ready = threading.Event()
339             self.content = None
340
341         def get(self):
342             self.ready.wait()
343             return self.content
344
345         def set(self, value):
346             self.content = value
347             self.ready.set()
348
349         def size(self):
350             if self.content == None:
351                 return 0
352             else:
353                 return len(self.content)
354
355     def cap_cache(self):
356         '''Cap the cache size to self.cache_max'''
357         self._cache_lock.acquire()
358         try:
359             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
360             sm = sum([slot.size() for slot in self._cache])
361             while sm > self.cache_max:
362                 del self._cache[-1]
363                 sm = sum([slot.size() for a in self._cache])
364         finally:
365             self._cache_lock.release()
366
367     def reserve_cache(self, locator):
368         '''Reserve a cache slot for the specified locator,
369         or return the existing slot.'''
370         self._cache_lock.acquire()
371         try:
372             # Test if the locator is already in the cache
373             for i in xrange(0, len(self._cache)):
374                 if self._cache[i].locator == locator:
375                     n = self._cache[i]
376                     if i != 0:
377                         # move it to the front
378                         del self._cache[i]
379                         self._cache.insert(0, n)
380                     return n, False
381
382             # Add a new cache slot for the locator
383             n = KeepClient.CacheSlot(locator)
384             self._cache.insert(0, n)
385             return n, True
386         finally:
387             self._cache_lock.release()
388
389     def get(self, locator):
390         #logging.debug("Keep.get %s" % (locator))
391
392         if re.search(r',', locator):
393             return ''.join(self.get(x) for x in locator.split(','))
394         if 'KEEP_LOCAL_STORE' in os.environ:
395             return KeepClient.local_store_get(locator)
396         expect_hash = re.sub(r'\+.*', '', locator)
397
398         slot, first = self.reserve_cache(expect_hash)
399         #logging.debug("%s %s %s" % (slot, first, expect_hash))
400
401         if not first:
402             v = slot.get()
403             return v
404
405         try:
406             for service_root in self.shuffled_service_roots(expect_hash):
407                 url = service_root + locator
408                 api_token = config.get('ARVADOS_API_TOKEN')
409                 headers = {'Authorization': "OAuth2 %s" % api_token,
410                            'Accept': 'application/octet-stream'}
411                 blob = self.get_url(url, headers, expect_hash)
412                 if blob:
413                     slot.set(blob)
414                     self.cap_cache()
415                     return blob
416
417             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
418                 instance = location_hint.group(1)
419                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
420                 blob = self.get_url(url, {}, expect_hash)
421                 if blob:
422                     slot.set(blob)
423                     self.cap_cache()
424                     return blob
425         except:
426             slot.set(None)
427             self.cap_cache()
428             raise
429
430         slot.set(None)
431         self.cap_cache()
432         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
433
434     def get_url(self, url, headers, expect_hash):
435         h = httplib2.Http()
436         try:
437             logging.info("Request: GET %s" % (url))
438             with timer.Timer() as t:
439                 resp, content = h.request(url.encode('utf-8'), 'GET',
440                                           headers=headers)
441             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
442                                                                         t.msecs,
443                                                                         (len(content)/(1024*1024))/t.secs))
444             if re.match(r'^2\d\d$', resp['status']):
445                 m = hashlib.new('md5')
446                 m.update(content)
447                 md5 = m.hexdigest()
448                 if md5 == expect_hash:
449                     return content
450                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
451         except Exception as e:
452             logging.info("Request fail: GET %s => %s: %s" %
453                          (url, type(e), str(e)))
454         return None
455
456     def put(self, data, **kwargs):
457         if 'KEEP_LOCAL_STORE' in os.environ:
458             return KeepClient.local_store_put(data)
459         m = hashlib.new('md5')
460         m.update(data)
461         data_hash = m.hexdigest()
462         have_copies = 0
463         want_copies = kwargs.get('copies', 2)
464         if not (want_copies > 0):
465             return data_hash
466         threads = []
467         thread_limiter = KeepClient.ThreadLimiter(want_copies)
468         for service_root in self.shuffled_service_roots(data_hash):
469             t = KeepClient.KeepWriterThread(data=data,
470                                             data_hash=data_hash,
471                                             service_root=service_root,
472                                             thread_limiter=thread_limiter,
473                                             using_proxy=self.using_proxy,
474                                             want_copies=(want_copies if self.using_proxy else 1))
475             t.start()
476             threads += [t]
477         for t in threads:
478             t.join()
479         have_copies = thread_limiter.done()
480         # If we're done, return the response from Keep
481         if have_copies >= want_copies:
482             return thread_limiter.response()
483         raise arvados.errors.KeepWriteError(
484             "Write fail for %s: wanted %d but wrote %d" %
485             (data_hash, want_copies, have_copies))
486
487     @staticmethod
488     def sign_for_old_server(data_hash, data):
489         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
490
491
492     @staticmethod
493     def local_store_put(data):
494         m = hashlib.new('md5')
495         m.update(data)
496         md5 = m.hexdigest()
497         locator = '%s+%d' % (md5, len(data))
498         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
499             f.write(data)
500         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
501                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
502         return locator
503
504     @staticmethod
505     def local_store_get(locator):
506         r = re.search('^([0-9a-f]{32,})', locator)
507         if not r:
508             raise arvados.errors.NotFoundError(
509                 "Invalid data locator: '%s'" % locator)
510         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
511             return ''
512         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
513             return f.read()