3351: Retry failed threads (servers) if replication is too low after one pass.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21 import datetime
22
23 global_client_object = None
24
25 from api import *
26 import config
27 import arvados.errors
28 import arvados.util
29
30 class KeepLocator(object):
31     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
32
33     def __init__(self, locator_str):
34         self.size = None
35         self.loc_hint = None
36         self._perm_sig = None
37         self._perm_expiry = None
38         pieces = iter(locator_str.split('+'))
39         self.md5sum = next(pieces)
40         for hint in pieces:
41             if hint.startswith('A'):
42                 self.parse_permission_hint(hint)
43             elif hint.startswith('K'):
44                 self.loc_hint = hint  # FIXME
45             elif hint.isdigit():
46                 self.size = int(hint)
47             else:
48                 raise ValueError("unrecognized hint data {}".format(hint))
49
50     def __str__(self):
51         return '+'.join(
52             str(s) for s in [self.md5sum, self.size, self.loc_hint,
53                              self.permission_hint()]
54             if s is not None)
55
56     def _make_hex_prop(name, length):
57         # Build and return a new property with the given name that
58         # must be a hex string of the given length.
59         data_name = '_{}'.format(name)
60         def getter(self):
61             return getattr(self, data_name)
62         def setter(self, hex_str):
63             if not arvados.util.is_hex(hex_str, length):
64                 raise ValueError("{} must be a {}-digit hex string: {}".
65                                  format(name, length, hex_str))
66             setattr(self, data_name, hex_str)
67         return property(getter, setter)
68
69     md5sum = _make_hex_prop('md5sum', 32)
70     perm_sig = _make_hex_prop('perm_sig', 40)
71
72     @property
73     def perm_expiry(self):
74         return self._perm_expiry
75
76     @perm_expiry.setter
77     def perm_expiry(self, value):
78         if not arvados.util.is_hex(value, 1, 8):
79             raise ValueError(
80                 "permission timestamp must be a hex Unix timestamp: {}".
81                 format(value))
82         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
83
84     def permission_hint(self):
85         data = [self.perm_sig, self.perm_expiry]
86         if None in data:
87             return None
88         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
89         return "A{}@{:08x}".format(*data)
90
91     def parse_permission_hint(self, s):
92         try:
93             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
94         except IndexError:
95             raise ValueError("bad permission hint {}".format(s))
96
97     def permission_expired(self, as_of_dt=None):
98         if self.perm_expiry is None:
99             return False
100         elif as_of_dt is None:
101             as_of_dt = datetime.datetime.now()
102         return self.perm_expiry <= as_of_dt
103
104
105 class Keep:
106     @staticmethod
107     def global_client_object():
108         global global_client_object
109         if global_client_object == None:
110             global_client_object = KeepClient()
111         return global_client_object
112
113     @staticmethod
114     def get(locator, **kwargs):
115         return Keep.global_client_object().get(locator, **kwargs)
116
117     @staticmethod
118     def put(data, **kwargs):
119         return Keep.global_client_object().put(data, **kwargs)
120
121 class KeepClient(object):
122
123     class ThreadLimiter(object):
124         """
125         Limit the number of threads running at a given time to
126         {desired successes} minus {successes reported}. When successes
127         reported == desired, wake up the remaining threads and tell
128         them to quit.
129
130         Should be used in a "with" block.
131         """
132         def __init__(self, todo):
133             self._todo = todo
134             self._done = 0
135             self._response = None
136             self._todo_lock = threading.Semaphore(todo)
137             self._done_lock = threading.Lock()
138
139         def __enter__(self):
140             self._todo_lock.acquire()
141             return self
142
143         def __exit__(self, type, value, traceback):
144             self._todo_lock.release()
145
146         def shall_i_proceed(self):
147             """
148             Return true if the current thread should do stuff. Return
149             false if the current thread should just stop.
150             """
151             with self._done_lock:
152                 return (self._done < self._todo)
153
154         def save_response(self, response_body, replicas_stored):
155             """
156             Records a response body (a locator, possibly signed) returned by
157             the Keep server.  It is not necessary to save more than
158             one response, since we presume that any locator returned
159             in response to a successful request is valid.
160             """
161             with self._done_lock:
162                 self._done += replicas_stored
163                 self._response = response_body
164
165         def response(self):
166             """
167             Returns the body from the response to a PUT request.
168             """
169             with self._done_lock:
170                 return self._response
171
172         def done(self):
173             """
174             Return how many successes were reported.
175             """
176             with self._done_lock:
177                 return self._done
178
179     class KeepWriterThread(threading.Thread):
180         """
181         Write a blob of data to the given Keep server. On success, call
182         save_response() of the given ThreadLimiter to save the returned
183         locator.
184         """
185         def __init__(self, **kwargs):
186             super(KeepClient.KeepWriterThread, self).__init__()
187             self.args = kwargs
188             self._success = False
189
190         def success(self):
191             return self._success
192
193         def run(self):
194             with self.args['thread_limiter'] as limiter:
195                 if not limiter.shall_i_proceed():
196                     # My turn arrived, but the job has been done without
197                     # me.
198                     return
199                 logging.debug("KeepWriterThread %s proceeding %s %s" %
200                               (str(threading.current_thread()),
201                                self.args['data_hash'],
202                                self.args['service_root']))
203                 h = httplib2.Http(timeout=60)
204                 url = self.args['service_root'] + self.args['data_hash']
205                 api_token = config.get('ARVADOS_API_TOKEN')
206                 headers = {'Authorization': "OAuth2 %s" % api_token}
207
208                 if self.args['using_proxy']:
209                     # We're using a proxy, so tell the proxy how many copies we
210                     # want it to store
211                     headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
212
213                 try:
214                     logging.debug("Uploading to {}".format(url))
215                     resp, content = h.request(url.encode('utf-8'), 'PUT',
216                                               headers=headers,
217                                               body=self.args['data'])
218                     if (resp['status'] == '401' and
219                         re.match(r'Timestamp verification failed', content)):
220                         body = KeepClient.sign_for_old_server(
221                             self.args['data_hash'],
222                             self.args['data'])
223                         h = httplib2.Http()
224                         resp, content = h.request(url.encode('utf-8'), 'PUT',
225                                                   headers=headers,
226                                                   body=body)
227                     if re.match(r'^2\d\d$', resp['status']):
228                         self._success = True
229                         logging.debug("KeepWriterThread %s succeeded %s %s" %
230                                       (str(threading.current_thread()),
231                                        self.args['data_hash'],
232                                        self.args['service_root']))
233                         replicas_stored = 1
234                         if 'x-keep-replicas-stored' in resp:
235                             # Tick the 'done' counter for the number of replica
236                             # reported stored by the server, for the case that
237                             # we're talking to a proxy or other backend that
238                             # stores to multiple copies for us.
239                             try:
240                                 replicas_stored = int(resp['x-keep-replicas-stored'])
241                             except ValueError:
242                                 pass
243                         return limiter.save_response(content.strip(), replicas_stored)
244
245                     logging.warning("Request fail: PUT %s => %s %s" %
246                                     (url, resp['status'], content))
247                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
248                     logging.warning("Request fail: PUT %s => %s: %s" %
249                                     (url, type(e), str(e)))
250
251     def __init__(self, **kwargs):
252         self.lock = threading.Lock()
253         self.service_roots = None
254         self._cache_lock = threading.Lock()
255         self._cache = []
256         # default 256 megabyte cache
257         self.cache_max = 256 * 1024 * 1024
258         self.using_proxy = False
259         self.timeout = kwargs.get('timeout', 60)
260
261     def shuffled_service_roots(self, hash):
262         if self.service_roots == None:
263             self.lock.acquire()
264
265             # Override normal keep disk lookup with an explict proxy
266             # configuration.
267             keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
268             if keep_proxy_env != None and len(keep_proxy_env) > 0:
269
270                 if keep_proxy_env[-1:] != '/':
271                     keep_proxy_env += "/"
272                 self.service_roots = [keep_proxy_env]
273                 self.using_proxy = True
274             else:
275                 try:
276                     try:
277                         keep_services = arvados.api().keep_services().accessible().execute()['items']
278                     except Exception:
279                         keep_services = arvados.api().keep_disks().list().execute()['items']
280
281                     if len(keep_services) == 0:
282                         raise arvados.errors.NoKeepServersError()
283
284                     if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
285                         self.using_proxy = True
286
287                     roots = (("http%s://%s:%d/" %
288                               ('s' if f['service_ssl_flag'] else '',
289                                f['service_host'],
290                                f['service_port']))
291                              for f in keep_services)
292                     self.service_roots = sorted(set(roots))
293                     logging.debug(str(self.service_roots))
294                 finally:
295                     self.lock.release()
296
297         # Build an ordering with which to query the Keep servers based on the
298         # contents of the hash.
299         # "hash" is a hex-encoded number at least 8 digits
300         # (32 bits) long
301
302         # seed used to calculate the next keep server from 'pool'
303         # to be added to 'pseq'
304         seed = hash
305
306         # Keep servers still to be added to the ordering
307         pool = self.service_roots[:]
308
309         # output probe sequence
310         pseq = []
311
312         # iterate while there are servers left to be assigned
313         while len(pool) > 0:
314             if len(seed) < 8:
315                 # ran out of digits in the seed
316                 if len(pseq) < len(hash) / 4:
317                     # the number of servers added to the probe sequence is less
318                     # than the number of 4-digit slices in 'hash' so refill the
319                     # seed with the last 4 digits and then append the contents
320                     # of 'hash'.
321                     seed = hash[-4:] + hash
322                 else:
323                     # refill the seed with the contents of 'hash'
324                     seed += hash
325
326             # Take the next 8 digits (32 bytes) and interpret as an integer,
327             # then modulus with the size of the remaining pool to get the next
328             # selected server.
329             probe = int(seed[0:8], 16) % len(pool)
330
331             # Append the selected server to the probe sequence and remove it
332             # from the pool.
333             pseq += [pool[probe]]
334             pool = pool[:probe] + pool[probe+1:]
335
336             # Remove the digits just used from the seed
337             seed = seed[8:]
338         logging.debug(str(pseq))
339         return pseq
340
341     class CacheSlot(object):
342         def __init__(self, locator):
343             self.locator = locator
344             self.ready = threading.Event()
345             self.content = None
346
347         def get(self):
348             self.ready.wait()
349             return self.content
350
351         def set(self, value):
352             self.content = value
353             self.ready.set()
354
355         def size(self):
356             if self.content == None:
357                 return 0
358             else:
359                 return len(self.content)
360
361     def cap_cache(self):
362         '''Cap the cache size to self.cache_max'''
363         self._cache_lock.acquire()
364         try:
365             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
366             sm = sum([slot.size() for slot in self._cache])
367             while sm > self.cache_max:
368                 del self._cache[-1]
369                 sm = sum([slot.size() for a in self._cache])
370         finally:
371             self._cache_lock.release()
372
373     def reserve_cache(self, locator):
374         '''Reserve a cache slot for the specified locator,
375         or return the existing slot.'''
376         self._cache_lock.acquire()
377         try:
378             # Test if the locator is already in the cache
379             for i in xrange(0, len(self._cache)):
380                 if self._cache[i].locator == locator:
381                     n = self._cache[i]
382                     if i != 0:
383                         # move it to the front
384                         del self._cache[i]
385                         self._cache.insert(0, n)
386                     return n, False
387
388             # Add a new cache slot for the locator
389             n = KeepClient.CacheSlot(locator)
390             self._cache.insert(0, n)
391             return n, True
392         finally:
393             self._cache_lock.release()
394
395     def get(self, locator):
396         #logging.debug("Keep.get %s" % (locator))
397
398         if re.search(r',', locator):
399             return ''.join(self.get(x) for x in locator.split(','))
400         if 'KEEP_LOCAL_STORE' in os.environ:
401             return KeepClient.local_store_get(locator)
402         expect_hash = re.sub(r'\+.*', '', locator)
403
404         slot, first = self.reserve_cache(expect_hash)
405         #logging.debug("%s %s %s" % (slot, first, expect_hash))
406
407         if not first:
408             v = slot.get()
409             return v
410
411         try:
412             for service_root in self.shuffled_service_roots(expect_hash):
413                 url = service_root + locator
414                 api_token = config.get('ARVADOS_API_TOKEN')
415                 headers = {'Authorization': "OAuth2 %s" % api_token,
416                            'Accept': 'application/octet-stream'}
417                 blob = self.get_url(url, headers, expect_hash)
418                 if blob:
419                     slot.set(blob)
420                     self.cap_cache()
421                     return blob
422
423             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
424                 instance = location_hint.group(1)
425                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
426                 blob = self.get_url(url, {}, expect_hash)
427                 if blob:
428                     slot.set(blob)
429                     self.cap_cache()
430                     return blob
431         except:
432             slot.set(None)
433             self.cap_cache()
434             raise
435
436         slot.set(None)
437         self.cap_cache()
438         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
439
440     def get_url(self, url, headers, expect_hash):
441         h = httplib2.Http()
442         try:
443             logging.info("Request: GET %s" % (url))
444             with timer.Timer() as t:
445                 resp, content = h.request(url.encode('utf-8'), 'GET',
446                                           headers=headers)
447             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
448                                                                         t.msecs,
449                                                                         (len(content)/(1024*1024))/t.secs))
450             if re.match(r'^2\d\d$', resp['status']):
451                 m = hashlib.new('md5')
452                 m.update(content)
453                 md5 = m.hexdigest()
454                 if md5 == expect_hash:
455                     return content
456                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
457         except Exception as e:
458             logging.info("Request fail: GET %s => %s: %s" %
459                          (url, type(e), str(e)))
460         return None
461
462     def put(self, data, **kwargs):
463         if 'KEEP_LOCAL_STORE' in os.environ:
464             return KeepClient.local_store_put(data)
465         m = hashlib.new('md5')
466         m.update(data)
467         data_hash = m.hexdigest()
468         have_copies = 0
469         want_copies = kwargs.get('copies', 2)
470         if not (want_copies > 0):
471             return data_hash
472         threads = []
473         thread_limiter = KeepClient.ThreadLimiter(want_copies)
474         for service_root in self.shuffled_service_roots(data_hash):
475             t = KeepClient.KeepWriterThread(
476                 data=data,
477                 data_hash=data_hash,
478                 service_root=service_root,
479                 thread_limiter=thread_limiter,
480                 timeout=self.timeout,
481                 using_proxy=self.using_proxy,
482                 want_copies=(want_copies if self.using_proxy else 1))
483             t.start()
484             threads += [t]
485         for t in threads:
486             t.join()
487         if thread_limiter.done() < want_copies:
488             # Retry the threads (i.e., services) that failed the first
489             # time around.
490             threads_retry = []
491             for t in threads:
492                 if not t.success():
493                     logging.warning("Retrying: PUT %s %s" % (
494                         t.args['service_root'],
495                         t.args['data_hash']))
496                     retry_with_args = t.args.copy()
497                     t_retry = KeepClient.KeepWriterThread(**retry_with_args)
498                     t_retry.start()
499                     threads_retry += [t_retry]
500             for t in threads_retry:
501                 t.join()
502         have_copies = thread_limiter.done()
503         # If we're done, return the response from Keep
504         if have_copies >= want_copies:
505             return thread_limiter.response()
506         raise arvados.errors.KeepWriteError(
507             "Write fail for %s: wanted %d but wrote %d" %
508             (data_hash, want_copies, have_copies))
509
510     @staticmethod
511     def sign_for_old_server(data_hash, data):
512         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
513
514
515     @staticmethod
516     def local_store_put(data):
517         m = hashlib.new('md5')
518         m.update(data)
519         md5 = m.hexdigest()
520         locator = '%s+%d' % (md5, len(data))
521         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
522             f.write(data)
523         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
524                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
525         return locator
526
527     @staticmethod
528     def local_store_get(locator):
529         r = re.search('^([0-9a-f]{32,})', locator)
530         if not r:
531             raise arvados.errors.NotFoundError(
532                 "Invalid data locator: '%s'" % locator)
533         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
534             return ''
535         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
536             return f.read()