Merge branch 'master' into 3505-virtual-work-dir
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21 import datetime
22 import ssl
23
24 _logger = logging.getLogger('arvados.keep')
25 global_client_object = None
26
27 from api import *
28 import config
29 import arvados.errors
30 import arvados.util
31
32 class KeepLocator(object):
33     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
34
35     def __init__(self, locator_str):
36         self.size = None
37         self.loc_hint = None
38         self._perm_sig = None
39         self._perm_expiry = None
40         pieces = iter(locator_str.split('+'))
41         self.md5sum = next(pieces)
42         for hint in pieces:
43             if hint.startswith('A'):
44                 self.parse_permission_hint(hint)
45             elif hint.startswith('K'):
46                 self.loc_hint = hint  # FIXME
47             elif hint.isdigit():
48                 self.size = int(hint)
49             else:
50                 raise ValueError("unrecognized hint data {}".format(hint))
51
52     def __str__(self):
53         return '+'.join(
54             str(s) for s in [self.md5sum, self.size, self.loc_hint,
55                              self.permission_hint()]
56             if s is not None)
57
58     def _make_hex_prop(name, length):
59         # Build and return a new property with the given name that
60         # must be a hex string of the given length.
61         data_name = '_{}'.format(name)
62         def getter(self):
63             return getattr(self, data_name)
64         def setter(self, hex_str):
65             if not arvados.util.is_hex(hex_str, length):
66                 raise ValueError("{} must be a {}-digit hex string: {}".
67                                  format(name, length, hex_str))
68             setattr(self, data_name, hex_str)
69         return property(getter, setter)
70
71     md5sum = _make_hex_prop('md5sum', 32)
72     perm_sig = _make_hex_prop('perm_sig', 40)
73
74     @property
75     def perm_expiry(self):
76         return self._perm_expiry
77
78     @perm_expiry.setter
79     def perm_expiry(self, value):
80         if not arvados.util.is_hex(value, 1, 8):
81             raise ValueError(
82                 "permission timestamp must be a hex Unix timestamp: {}".
83                 format(value))
84         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
85
86     def permission_hint(self):
87         data = [self.perm_sig, self.perm_expiry]
88         if None in data:
89             return None
90         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
91         return "A{}@{:08x}".format(*data)
92
93     def parse_permission_hint(self, s):
94         try:
95             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
96         except IndexError:
97             raise ValueError("bad permission hint {}".format(s))
98
99     def permission_expired(self, as_of_dt=None):
100         if self.perm_expiry is None:
101             return False
102         elif as_of_dt is None:
103             as_of_dt = datetime.datetime.now()
104         return self.perm_expiry <= as_of_dt
105
106
107 class Keep:
108     @staticmethod
109     def global_client_object():
110         global global_client_object
111         if global_client_object == None:
112             global_client_object = KeepClient()
113         return global_client_object
114
115     @staticmethod
116     def get(locator, **kwargs):
117         return Keep.global_client_object().get(locator, **kwargs)
118
119     @staticmethod
120     def put(data, **kwargs):
121         return Keep.global_client_object().put(data, **kwargs)
122
123 class KeepClient(object):
124
125     class ThreadLimiter(object):
126         """
127         Limit the number of threads running at a given time to
128         {desired successes} minus {successes reported}. When successes
129         reported == desired, wake up the remaining threads and tell
130         them to quit.
131
132         Should be used in a "with" block.
133         """
134         def __init__(self, todo):
135             self._todo = todo
136             self._done = 0
137             self._response = None
138             self._todo_lock = threading.Semaphore(todo)
139             self._done_lock = threading.Lock()
140
141         def __enter__(self):
142             self._todo_lock.acquire()
143             return self
144
145         def __exit__(self, type, value, traceback):
146             self._todo_lock.release()
147
148         def shall_i_proceed(self):
149             """
150             Return true if the current thread should do stuff. Return
151             false if the current thread should just stop.
152             """
153             with self._done_lock:
154                 return (self._done < self._todo)
155
156         def save_response(self, response_body, replicas_stored):
157             """
158             Records a response body (a locator, possibly signed) returned by
159             the Keep server.  It is not necessary to save more than
160             one response, since we presume that any locator returned
161             in response to a successful request is valid.
162             """
163             with self._done_lock:
164                 self._done += replicas_stored
165                 self._response = response_body
166
167         def response(self):
168             """
169             Returns the body from the response to a PUT request.
170             """
171             with self._done_lock:
172                 return self._response
173
174         def done(self):
175             """
176             Return how many successes were reported.
177             """
178             with self._done_lock:
179                 return self._done
180
181     class KeepWriterThread(threading.Thread):
182         """
183         Write a blob of data to the given Keep server. On success, call
184         save_response() of the given ThreadLimiter to save the returned
185         locator.
186         """
187         def __init__(self, **kwargs):
188             super(KeepClient.KeepWriterThread, self).__init__()
189             self.args = kwargs
190             self._success = False
191
192         def success(self):
193             return self._success
194
195         def run(self):
196             with self.args['thread_limiter'] as limiter:
197                 if not limiter.shall_i_proceed():
198                     # My turn arrived, but the job has been done without
199                     # me.
200                     return
201                 self.run_with_limiter(limiter)
202
203         def run_with_limiter(self, limiter):
204             _logger.debug("KeepWriterThread %s proceeding %s %s",
205                           str(threading.current_thread()),
206                           self.args['data_hash'],
207                           self.args['service_root'])
208             h = httplib2.Http(timeout=self.args.get('timeout', None))
209             url = self.args['service_root'] + self.args['data_hash']
210             api_token = config.get('ARVADOS_API_TOKEN')
211             headers = {'Authorization': "OAuth2 %s" % api_token}
212
213             if self.args['using_proxy']:
214                 # We're using a proxy, so tell the proxy how many copies we
215                 # want it to store
216                 headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
217
218             try:
219                 _logger.debug("Uploading to {}".format(url))
220                 resp, content = h.request(url.encode('utf-8'), 'PUT',
221                                           headers=headers,
222                                           body=self.args['data'])
223                 if (resp['status'] == '401' and
224                     re.match(r'Timestamp verification failed', content)):
225                     body = KeepClient.sign_for_old_server(
226                         self.args['data_hash'],
227                         self.args['data'])
228                     h = httplib2.Http(timeout=self.args.get('timeout', None))
229                     resp, content = h.request(url.encode('utf-8'), 'PUT',
230                                               headers=headers,
231                                               body=body)
232                 if re.match(r'^2\d\d$', resp['status']):
233                     self._success = True
234                     _logger.debug("KeepWriterThread %s succeeded %s %s",
235                                   str(threading.current_thread()),
236                                   self.args['data_hash'],
237                                   self.args['service_root'])
238                     replicas_stored = 1
239                     if 'x-keep-replicas-stored' in resp:
240                         # Tick the 'done' counter for the number of replica
241                         # reported stored by the server, for the case that
242                         # we're talking to a proxy or other backend that
243                         # stores to multiple copies for us.
244                         try:
245                             replicas_stored = int(resp['x-keep-replicas-stored'])
246                         except ValueError:
247                             pass
248                     limiter.save_response(content.strip(), replicas_stored)
249                 else:
250                     _logger.warning("Request fail: PUT %s => %s %s",
251                                     url, resp['status'], content)
252             except (httplib2.HttpLib2Error,
253                     httplib.HTTPException,
254                     ssl.SSLError) as e:
255                 # When using https, timeouts look like ssl.SSLError from here.
256                 # "SSLError: The write operation timed out"
257                 _logger.warning("Request fail: PUT %s => %s: %s",
258                                 url, type(e), str(e))
259
260     def __init__(self, **kwargs):
261         self.lock = threading.Lock()
262         self.service_roots = None
263         self._cache_lock = threading.Lock()
264         self._cache = []
265         # default 256 megabyte cache
266         self.cache_max = 256 * 1024 * 1024
267         self.using_proxy = False
268         self.timeout = kwargs.get('timeout', 60)
269
270     def shuffled_service_roots(self, hash):
271         if self.service_roots == None:
272             self.lock.acquire()
273
274             # Override normal keep disk lookup with an explict proxy
275             # configuration.
276             keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
277             if keep_proxy_env != None and len(keep_proxy_env) > 0:
278
279                 if keep_proxy_env[-1:] != '/':
280                     keep_proxy_env += "/"
281                 self.service_roots = [keep_proxy_env]
282                 self.using_proxy = True
283             else:
284                 try:
285                     try:
286                         keep_services = arvados.api().keep_services().accessible().execute()['items']
287                     except Exception:
288                         keep_services = arvados.api().keep_disks().list().execute()['items']
289
290                     if len(keep_services) == 0:
291                         raise arvados.errors.NoKeepServersError()
292
293                     if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
294                         self.using_proxy = True
295
296                     roots = (("http%s://%s:%d/" %
297                               ('s' if f['service_ssl_flag'] else '',
298                                f['service_host'],
299                                f['service_port']))
300                              for f in keep_services)
301                     self.service_roots = sorted(set(roots))
302                     _logger.debug(str(self.service_roots))
303                 finally:
304                     self.lock.release()
305
306         # Build an ordering with which to query the Keep servers based on the
307         # contents of the hash.
308         # "hash" is a hex-encoded number at least 8 digits
309         # (32 bits) long
310
311         # seed used to calculate the next keep server from 'pool'
312         # to be added to 'pseq'
313         seed = hash
314
315         # Keep servers still to be added to the ordering
316         pool = self.service_roots[:]
317
318         # output probe sequence
319         pseq = []
320
321         # iterate while there are servers left to be assigned
322         while len(pool) > 0:
323             if len(seed) < 8:
324                 # ran out of digits in the seed
325                 if len(pseq) < len(hash) / 4:
326                     # the number of servers added to the probe sequence is less
327                     # than the number of 4-digit slices in 'hash' so refill the
328                     # seed with the last 4 digits and then append the contents
329                     # of 'hash'.
330                     seed = hash[-4:] + hash
331                 else:
332                     # refill the seed with the contents of 'hash'
333                     seed += hash
334
335             # Take the next 8 digits (32 bytes) and interpret as an integer,
336             # then modulus with the size of the remaining pool to get the next
337             # selected server.
338             probe = int(seed[0:8], 16) % len(pool)
339
340             # Append the selected server to the probe sequence and remove it
341             # from the pool.
342             pseq += [pool[probe]]
343             pool = pool[:probe] + pool[probe+1:]
344
345             # Remove the digits just used from the seed
346             seed = seed[8:]
347         _logger.debug(str(pseq))
348         return pseq
349
350     class CacheSlot(object):
351         def __init__(self, locator):
352             self.locator = locator
353             self.ready = threading.Event()
354             self.content = None
355
356         def get(self):
357             self.ready.wait()
358             return self.content
359
360         def set(self, value):
361             self.content = value
362             self.ready.set()
363
364         def size(self):
365             if self.content == None:
366                 return 0
367             else:
368                 return len(self.content)
369
370     def cap_cache(self):
371         '''Cap the cache size to self.cache_max'''
372         self._cache_lock.acquire()
373         try:
374             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
375             sm = sum([slot.size() for slot in self._cache])
376             while sm > self.cache_max:
377                 del self._cache[-1]
378                 sm = sum([slot.size() for a in self._cache])
379         finally:
380             self._cache_lock.release()
381
382     def reserve_cache(self, locator):
383         '''Reserve a cache slot for the specified locator,
384         or return the existing slot.'''
385         self._cache_lock.acquire()
386         try:
387             # Test if the locator is already in the cache
388             for i in xrange(0, len(self._cache)):
389                 if self._cache[i].locator == locator:
390                     n = self._cache[i]
391                     if i != 0:
392                         # move it to the front
393                         del self._cache[i]
394                         self._cache.insert(0, n)
395                     return n, False
396
397             # Add a new cache slot for the locator
398             n = KeepClient.CacheSlot(locator)
399             self._cache.insert(0, n)
400             return n, True
401         finally:
402             self._cache_lock.release()
403
404     def get(self, locator):
405         if re.search(r',', locator):
406             return ''.join(self.get(x) for x in locator.split(','))
407         if 'KEEP_LOCAL_STORE' in os.environ:
408             return KeepClient.local_store_get(locator)
409         expect_hash = re.sub(r'\+.*', '', locator)
410
411         slot, first = self.reserve_cache(expect_hash)
412
413         if not first:
414             v = slot.get()
415             return v
416
417         try:
418             for service_root in self.shuffled_service_roots(expect_hash):
419                 url = service_root + locator
420                 api_token = config.get('ARVADOS_API_TOKEN')
421                 headers = {'Authorization': "OAuth2 %s" % api_token,
422                            'Accept': 'application/octet-stream'}
423                 blob = self.get_url(url, headers, expect_hash)
424                 if blob:
425                     slot.set(blob)
426                     self.cap_cache()
427                     return blob
428
429             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
430                 instance = location_hint.group(1)
431                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
432                 blob = self.get_url(url, {}, expect_hash)
433                 if blob:
434                     slot.set(blob)
435                     self.cap_cache()
436                     return blob
437         except:
438             slot.set(None)
439             self.cap_cache()
440             raise
441
442         slot.set(None)
443         self.cap_cache()
444         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
445
446     def get_url(self, url, headers, expect_hash):
447         h = httplib2.Http()
448         try:
449             _logger.info("Request: GET %s", url)
450             with timer.Timer() as t:
451                 resp, content = h.request(url.encode('utf-8'), 'GET',
452                                           headers=headers)
453             _logger.info("Received %s bytes in %s msec (%s MiB/sec)",
454                          len(content), t.msecs,
455                          (len(content)/(1024*1024))/t.secs)
456             if re.match(r'^2\d\d$', resp['status']):
457                 m = hashlib.new('md5')
458                 m.update(content)
459                 md5 = m.hexdigest()
460                 if md5 == expect_hash:
461                     return content
462                 _logger.warning("Checksum fail: md5(%s) = %s", url, md5)
463         except Exception as e:
464             _logger.info("Request fail: GET %s => %s: %s",
465                          url, type(e), str(e))
466         return None
467
468     def put(self, data, **kwargs):
469         if 'KEEP_LOCAL_STORE' in os.environ:
470             return KeepClient.local_store_put(data)
471         m = hashlib.new('md5')
472         m.update(data)
473         data_hash = m.hexdigest()
474         have_copies = 0
475         want_copies = kwargs.get('copies', 2)
476         if not (want_copies > 0):
477             return data_hash
478         threads = []
479         thread_limiter = KeepClient.ThreadLimiter(want_copies)
480         for service_root in self.shuffled_service_roots(data_hash):
481             t = KeepClient.KeepWriterThread(
482                 data=data,
483                 data_hash=data_hash,
484                 service_root=service_root,
485                 thread_limiter=thread_limiter,
486                 timeout=self.timeout,
487                 using_proxy=self.using_proxy,
488                 want_copies=(want_copies if self.using_proxy else 1))
489             t.start()
490             threads += [t]
491         for t in threads:
492             t.join()
493         if thread_limiter.done() < want_copies:
494             # Retry the threads (i.e., services) that failed the first
495             # time around.
496             threads_retry = []
497             for t in threads:
498                 if not t.success():
499                     _logger.warning("Retrying: PUT %s %s",
500                                     t.args['service_root'],
501                                     t.args['data_hash'])
502                     retry_with_args = t.args.copy()
503                     t_retry = KeepClient.KeepWriterThread(**retry_with_args)
504                     t_retry.start()
505                     threads_retry += [t_retry]
506             for t in threads_retry:
507                 t.join()
508         have_copies = thread_limiter.done()
509         # If we're done, return the response from Keep
510         if have_copies >= want_copies:
511             return thread_limiter.response()
512         raise arvados.errors.KeepWriteError(
513             "Write fail for %s: wanted %d but wrote %d" %
514             (data_hash, want_copies, have_copies))
515
516     @staticmethod
517     def sign_for_old_server(data_hash, data):
518         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
519
520
521     @staticmethod
522     def local_store_put(data):
523         m = hashlib.new('md5')
524         m.update(data)
525         md5 = m.hexdigest()
526         locator = '%s+%d' % (md5, len(data))
527         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
528             f.write(data)
529         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
530                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
531         return locator
532
533     @staticmethod
534     def local_store_get(locator):
535         r = re.search('^([0-9a-f]{32,})', locator)
536         if not r:
537             raise arvados.errors.NotFoundError(
538                 "Invalid data locator: '%s'" % locator)
539         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
540             return ''
541         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
542             return f.read()