Merge branch '3296-user-profile' of git.curoverse.com:arvados into 3296-user-profile
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21 import datetime
22 import ssl
23
24 global_client_object = None
25
26 from api import *
27 import config
28 import arvados.errors
29 import arvados.util
30
31 class KeepLocator(object):
32     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
33
34     def __init__(self, locator_str):
35         self.size = None
36         self.loc_hint = None
37         self._perm_sig = None
38         self._perm_expiry = None
39         pieces = iter(locator_str.split('+'))
40         self.md5sum = next(pieces)
41         for hint in pieces:
42             if hint.startswith('A'):
43                 self.parse_permission_hint(hint)
44             elif hint.startswith('K'):
45                 self.loc_hint = hint  # FIXME
46             elif hint.isdigit():
47                 self.size = int(hint)
48             else:
49                 raise ValueError("unrecognized hint data {}".format(hint))
50
51     def __str__(self):
52         return '+'.join(
53             str(s) for s in [self.md5sum, self.size, self.loc_hint,
54                              self.permission_hint()]
55             if s is not None)
56
57     def _make_hex_prop(name, length):
58         # Build and return a new property with the given name that
59         # must be a hex string of the given length.
60         data_name = '_{}'.format(name)
61         def getter(self):
62             return getattr(self, data_name)
63         def setter(self, hex_str):
64             if not arvados.util.is_hex(hex_str, length):
65                 raise ValueError("{} must be a {}-digit hex string: {}".
66                                  format(name, length, hex_str))
67             setattr(self, data_name, hex_str)
68         return property(getter, setter)
69
70     md5sum = _make_hex_prop('md5sum', 32)
71     perm_sig = _make_hex_prop('perm_sig', 40)
72
73     @property
74     def perm_expiry(self):
75         return self._perm_expiry
76
77     @perm_expiry.setter
78     def perm_expiry(self, value):
79         if not arvados.util.is_hex(value, 1, 8):
80             raise ValueError(
81                 "permission timestamp must be a hex Unix timestamp: {}".
82                 format(value))
83         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
84
85     def permission_hint(self):
86         data = [self.perm_sig, self.perm_expiry]
87         if None in data:
88             return None
89         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
90         return "A{}@{:08x}".format(*data)
91
92     def parse_permission_hint(self, s):
93         try:
94             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
95         except IndexError:
96             raise ValueError("bad permission hint {}".format(s))
97
98     def permission_expired(self, as_of_dt=None):
99         if self.perm_expiry is None:
100             return False
101         elif as_of_dt is None:
102             as_of_dt = datetime.datetime.now()
103         return self.perm_expiry <= as_of_dt
104
105
106 class Keep:
107     @staticmethod
108     def global_client_object():
109         global global_client_object
110         if global_client_object == None:
111             global_client_object = KeepClient()
112         return global_client_object
113
114     @staticmethod
115     def get(locator, **kwargs):
116         return Keep.global_client_object().get(locator, **kwargs)
117
118     @staticmethod
119     def put(data, **kwargs):
120         return Keep.global_client_object().put(data, **kwargs)
121
122 class KeepClient(object):
123
124     class ThreadLimiter(object):
125         """
126         Limit the number of threads running at a given time to
127         {desired successes} minus {successes reported}. When successes
128         reported == desired, wake up the remaining threads and tell
129         them to quit.
130
131         Should be used in a "with" block.
132         """
133         def __init__(self, todo):
134             self._todo = todo
135             self._done = 0
136             self._response = None
137             self._todo_lock = threading.Semaphore(todo)
138             self._done_lock = threading.Lock()
139
140         def __enter__(self):
141             self._todo_lock.acquire()
142             return self
143
144         def __exit__(self, type, value, traceback):
145             self._todo_lock.release()
146
147         def shall_i_proceed(self):
148             """
149             Return true if the current thread should do stuff. Return
150             false if the current thread should just stop.
151             """
152             with self._done_lock:
153                 return (self._done < self._todo)
154
155         def save_response(self, response_body, replicas_stored):
156             """
157             Records a response body (a locator, possibly signed) returned by
158             the Keep server.  It is not necessary to save more than
159             one response, since we presume that any locator returned
160             in response to a successful request is valid.
161             """
162             with self._done_lock:
163                 self._done += replicas_stored
164                 self._response = response_body
165
166         def response(self):
167             """
168             Returns the body from the response to a PUT request.
169             """
170             with self._done_lock:
171                 return self._response
172
173         def done(self):
174             """
175             Return how many successes were reported.
176             """
177             with self._done_lock:
178                 return self._done
179
180     class KeepWriterThread(threading.Thread):
181         """
182         Write a blob of data to the given Keep server. On success, call
183         save_response() of the given ThreadLimiter to save the returned
184         locator.
185         """
186         def __init__(self, **kwargs):
187             super(KeepClient.KeepWriterThread, self).__init__()
188             self.args = kwargs
189             self._success = False
190
191         def success(self):
192             return self._success
193
194         def run(self):
195             with self.args['thread_limiter'] as limiter:
196                 if not limiter.shall_i_proceed():
197                     # My turn arrived, but the job has been done without
198                     # me.
199                     return
200                 self.run_with_limiter(limiter)
201
202         def run_with_limiter(self, limiter):
203             logging.debug("KeepWriterThread %s proceeding %s %s" %
204                           (str(threading.current_thread()),
205                            self.args['data_hash'],
206                            self.args['service_root']))
207             h = httplib2.Http(timeout=self.args.get('timeout', None))
208             url = self.args['service_root'] + self.args['data_hash']
209             api_token = config.get('ARVADOS_API_TOKEN')
210             headers = {'Authorization': "OAuth2 %s" % api_token}
211
212             if self.args['using_proxy']:
213                 # We're using a proxy, so tell the proxy how many copies we
214                 # want it to store
215                 headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
216
217             try:
218                 logging.debug("Uploading to {}".format(url))
219                 resp, content = h.request(url.encode('utf-8'), 'PUT',
220                                           headers=headers,
221                                           body=self.args['data'])
222                 if (resp['status'] == '401' and
223                     re.match(r'Timestamp verification failed', content)):
224                     body = KeepClient.sign_for_old_server(
225                         self.args['data_hash'],
226                         self.args['data'])
227                     h = httplib2.Http(timeout=self.args.get('timeout', None))
228                     resp, content = h.request(url.encode('utf-8'), 'PUT',
229                                               headers=headers,
230                                               body=body)
231                 if re.match(r'^2\d\d$', resp['status']):
232                     self._success = True
233                     logging.debug("KeepWriterThread %s succeeded %s %s" %
234                                   (str(threading.current_thread()),
235                                    self.args['data_hash'],
236                                    self.args['service_root']))
237                     replicas_stored = 1
238                     if 'x-keep-replicas-stored' in resp:
239                         # Tick the 'done' counter for the number of replica
240                         # reported stored by the server, for the case that
241                         # we're talking to a proxy or other backend that
242                         # stores to multiple copies for us.
243                         try:
244                             replicas_stored = int(resp['x-keep-replicas-stored'])
245                         except ValueError:
246                             pass
247                     limiter.save_response(content.strip(), replicas_stored)
248                 else:
249                     logging.warning("Request fail: PUT %s => %s %s" %
250                                     (url, resp['status'], content))
251             except (httplib2.HttpLib2Error,
252                     httplib.HTTPException,
253                     ssl.SSLError) as e:
254                 # When using https, timeouts look like ssl.SSLError from here.
255                 # "SSLError: The write operation timed out"
256                 logging.warning("Request fail: PUT %s => %s: %s" %
257                                 (url, type(e), str(e)))
258
259     def __init__(self, **kwargs):
260         self.lock = threading.Lock()
261         self.service_roots = None
262         self._cache_lock = threading.Lock()
263         self._cache = []
264         # default 256 megabyte cache
265         self.cache_max = 256 * 1024 * 1024
266         self.using_proxy = False
267         self.timeout = kwargs.get('timeout', 60)
268
269     def shuffled_service_roots(self, hash):
270         if self.service_roots == None:
271             self.lock.acquire()
272
273             # Override normal keep disk lookup with an explict proxy
274             # configuration.
275             keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
276             if keep_proxy_env != None and len(keep_proxy_env) > 0:
277
278                 if keep_proxy_env[-1:] != '/':
279                     keep_proxy_env += "/"
280                 self.service_roots = [keep_proxy_env]
281                 self.using_proxy = True
282             else:
283                 try:
284                     try:
285                         keep_services = arvados.api().keep_services().accessible().execute()['items']
286                     except Exception:
287                         keep_services = arvados.api().keep_disks().list().execute()['items']
288
289                     if len(keep_services) == 0:
290                         raise arvados.errors.NoKeepServersError()
291
292                     if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
293                         self.using_proxy = True
294
295                     roots = (("http%s://%s:%d/" %
296                               ('s' if f['service_ssl_flag'] else '',
297                                f['service_host'],
298                                f['service_port']))
299                              for f in keep_services)
300                     self.service_roots = sorted(set(roots))
301                     logging.debug(str(self.service_roots))
302                 finally:
303                     self.lock.release()
304
305         # Build an ordering with which to query the Keep servers based on the
306         # contents of the hash.
307         # "hash" is a hex-encoded number at least 8 digits
308         # (32 bits) long
309
310         # seed used to calculate the next keep server from 'pool'
311         # to be added to 'pseq'
312         seed = hash
313
314         # Keep servers still to be added to the ordering
315         pool = self.service_roots[:]
316
317         # output probe sequence
318         pseq = []
319
320         # iterate while there are servers left to be assigned
321         while len(pool) > 0:
322             if len(seed) < 8:
323                 # ran out of digits in the seed
324                 if len(pseq) < len(hash) / 4:
325                     # the number of servers added to the probe sequence is less
326                     # than the number of 4-digit slices in 'hash' so refill the
327                     # seed with the last 4 digits and then append the contents
328                     # of 'hash'.
329                     seed = hash[-4:] + hash
330                 else:
331                     # refill the seed with the contents of 'hash'
332                     seed += hash
333
334             # Take the next 8 digits (32 bytes) and interpret as an integer,
335             # then modulus with the size of the remaining pool to get the next
336             # selected server.
337             probe = int(seed[0:8], 16) % len(pool)
338
339             # Append the selected server to the probe sequence and remove it
340             # from the pool.
341             pseq += [pool[probe]]
342             pool = pool[:probe] + pool[probe+1:]
343
344             # Remove the digits just used from the seed
345             seed = seed[8:]
346         logging.debug(str(pseq))
347         return pseq
348
349     class CacheSlot(object):
350         def __init__(self, locator):
351             self.locator = locator
352             self.ready = threading.Event()
353             self.content = None
354
355         def get(self):
356             self.ready.wait()
357             return self.content
358
359         def set(self, value):
360             self.content = value
361             self.ready.set()
362
363         def size(self):
364             if self.content == None:
365                 return 0
366             else:
367                 return len(self.content)
368
369     def cap_cache(self):
370         '''Cap the cache size to self.cache_max'''
371         self._cache_lock.acquire()
372         try:
373             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
374             sm = sum([slot.size() for slot in self._cache])
375             while sm > self.cache_max:
376                 del self._cache[-1]
377                 sm = sum([slot.size() for a in self._cache])
378         finally:
379             self._cache_lock.release()
380
381     def reserve_cache(self, locator):
382         '''Reserve a cache slot for the specified locator,
383         or return the existing slot.'''
384         self._cache_lock.acquire()
385         try:
386             # Test if the locator is already in the cache
387             for i in xrange(0, len(self._cache)):
388                 if self._cache[i].locator == locator:
389                     n = self._cache[i]
390                     if i != 0:
391                         # move it to the front
392                         del self._cache[i]
393                         self._cache.insert(0, n)
394                     return n, False
395
396             # Add a new cache slot for the locator
397             n = KeepClient.CacheSlot(locator)
398             self._cache.insert(0, n)
399             return n, True
400         finally:
401             self._cache_lock.release()
402
403     def get(self, locator):
404         #logging.debug("Keep.get %s" % (locator))
405
406         if re.search(r',', locator):
407             return ''.join(self.get(x) for x in locator.split(','))
408         if 'KEEP_LOCAL_STORE' in os.environ:
409             return KeepClient.local_store_get(locator)
410         expect_hash = re.sub(r'\+.*', '', locator)
411
412         slot, first = self.reserve_cache(expect_hash)
413         #logging.debug("%s %s %s" % (slot, first, expect_hash))
414
415         if not first:
416             v = slot.get()
417             return v
418
419         try:
420             for service_root in self.shuffled_service_roots(expect_hash):
421                 url = service_root + locator
422                 api_token = config.get('ARVADOS_API_TOKEN')
423                 headers = {'Authorization': "OAuth2 %s" % api_token,
424                            'Accept': 'application/octet-stream'}
425                 blob = self.get_url(url, headers, expect_hash)
426                 if blob:
427                     slot.set(blob)
428                     self.cap_cache()
429                     return blob
430
431             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
432                 instance = location_hint.group(1)
433                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
434                 blob = self.get_url(url, {}, expect_hash)
435                 if blob:
436                     slot.set(blob)
437                     self.cap_cache()
438                     return blob
439         except:
440             slot.set(None)
441             self.cap_cache()
442             raise
443
444         slot.set(None)
445         self.cap_cache()
446         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
447
448     def get_url(self, url, headers, expect_hash):
449         h = httplib2.Http()
450         try:
451             logging.info("Request: GET %s" % (url))
452             with timer.Timer() as t:
453                 resp, content = h.request(url.encode('utf-8'), 'GET',
454                                           headers=headers)
455             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
456                                                                         t.msecs,
457                                                                         (len(content)/(1024*1024))/t.secs))
458             if re.match(r'^2\d\d$', resp['status']):
459                 m = hashlib.new('md5')
460                 m.update(content)
461                 md5 = m.hexdigest()
462                 if md5 == expect_hash:
463                     return content
464                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
465         except Exception as e:
466             logging.info("Request fail: GET %s => %s: %s" %
467                          (url, type(e), str(e)))
468         return None
469
470     def put(self, data, **kwargs):
471         if 'KEEP_LOCAL_STORE' in os.environ:
472             return KeepClient.local_store_put(data)
473         m = hashlib.new('md5')
474         m.update(data)
475         data_hash = m.hexdigest()
476         have_copies = 0
477         want_copies = kwargs.get('copies', 2)
478         if not (want_copies > 0):
479             return data_hash
480         threads = []
481         thread_limiter = KeepClient.ThreadLimiter(want_copies)
482         for service_root in self.shuffled_service_roots(data_hash):
483             t = KeepClient.KeepWriterThread(
484                 data=data,
485                 data_hash=data_hash,
486                 service_root=service_root,
487                 thread_limiter=thread_limiter,
488                 timeout=self.timeout,
489                 using_proxy=self.using_proxy,
490                 want_copies=(want_copies if self.using_proxy else 1))
491             t.start()
492             threads += [t]
493         for t in threads:
494             t.join()
495         if thread_limiter.done() < want_copies:
496             # Retry the threads (i.e., services) that failed the first
497             # time around.
498             threads_retry = []
499             for t in threads:
500                 if not t.success():
501                     logging.warning("Retrying: PUT %s %s" % (
502                         t.args['service_root'],
503                         t.args['data_hash']))
504                     retry_with_args = t.args.copy()
505                     t_retry = KeepClient.KeepWriterThread(**retry_with_args)
506                     t_retry.start()
507                     threads_retry += [t_retry]
508             for t in threads_retry:
509                 t.join()
510         have_copies = thread_limiter.done()
511         # If we're done, return the response from Keep
512         if have_copies >= want_copies:
513             return thread_limiter.response()
514         raise arvados.errors.KeepWriteError(
515             "Write fail for %s: wanted %d but wrote %d" %
516             (data_hash, want_copies, have_copies))
517
518     @staticmethod
519     def sign_for_old_server(data_hash, data):
520         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
521
522
523     @staticmethod
524     def local_store_put(data):
525         m = hashlib.new('md5')
526         m.update(data)
527         md5 = m.hexdigest()
528         locator = '%s+%d' % (md5, len(data))
529         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
530             f.write(data)
531         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
532                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
533         return locator
534
535     @staticmethod
536     def local_store_get(locator):
537         r = re.search('^([0-9a-f]{32,})', locator)
538         if not r:
539             raise arvados.errors.NotFoundError(
540                 "Invalid data locator: '%s'" % locator)
541         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
542             return ''
543         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
544             return f.read()