2752: Add KeepLocator class to Python SDK.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21 import datetime
22
23 global_client_object = None
24
25 from api import *
26 import config
27 import arvados.errors
28
29 class KeepLocator(object):
30     EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
31     HEX_RE = re.compile(r'^[0-9a-fA-F]+$')
32
33     def __init__(self, locator_str):
34         self.size = None
35         self.loc_hint = None
36         self._perm_sig = None
37         self._perm_expiry = None
38         pieces = iter(locator_str.split('+'))
39         self.md5sum = next(pieces)
40         for hint in pieces:
41             if hint.startswith('A'):
42                 self.parse_permission_hint(hint)
43             elif hint.startswith('K'):
44                 self.loc_hint = hint  # FIXME
45             elif hint.isdigit():
46                 self.size = int(hint)
47             else:
48                 raise ValueError("unrecognized hint data {}".format(hint))
49
50     def __str__(self):
51         return '+'.join(
52             str(s) for s in [self.md5sum, self.size, self.loc_hint,
53                              self.permission_hint()]
54             if s is not None)
55
56     def _is_hex_length(self, s, *size_spec):
57         if len(size_spec) == 1:
58             good_len = (len(s) == size_spec[0])
59         else:
60             good_len = (size_spec[0] <= len(s) <= size_spec[1])
61         return good_len and self.HEX_RE.match(s)
62
63     def _make_hex_prop(name, length):
64         # Build and return a new property with the given name that
65         # must be a hex string of the given length.
66         data_name = '_{}'.format(name)
67         def getter(self):
68             return getattr(self, data_name)
69         def setter(self, hex_str):
70             if not self._is_hex_length(hex_str, length):
71                 raise ValueError("{} must be a {}-digit hex string: {}".
72                                  format(name, length, hex_str))
73             setattr(self, data_name, hex_str)
74         return property(getter, setter)
75
76     md5sum = _make_hex_prop('md5sum', 32)
77     perm_sig = _make_hex_prop('perm_sig', 40)
78
79     @property
80     def perm_expiry(self):
81         return self._perm_expiry
82
83     @perm_expiry.setter
84     def perm_expiry(self, value):
85         if not self._is_hex_length(value, 1, 8):
86             raise ValueError(
87                 "permission timestamp must be a hex Unix timestamp: {}".
88                 format(value))
89         self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
90
91     def permission_hint(self):
92         data = [self.perm_sig, self.perm_expiry]
93         if None in data:
94             return None
95         data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
96         return "A{}@{:08x}".format(*data)
97
98     def parse_permission_hint(self, s):
99         try:
100             self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
101         except IndexError:
102             raise ValueError("bad permission hint {}".format(s))
103
104     def permission_expired(self, as_of_dt=None):
105         if self.perm_expiry is None:
106             return False
107         elif as_of_dt is None:
108             as_of_dt = datetime.datetime.now()
109         return self.perm_expiry <= as_of_dt
110
111
112 class Keep:
113     @staticmethod
114     def global_client_object():
115         global global_client_object
116         if global_client_object == None:
117             global_client_object = KeepClient()
118         return global_client_object
119
120     @staticmethod
121     def get(locator, **kwargs):
122         return Keep.global_client_object().get(locator, **kwargs)
123
124     @staticmethod
125     def put(data, **kwargs):
126         return Keep.global_client_object().put(data, **kwargs)
127
128 class KeepClient(object):
129
130     class ThreadLimiter(object):
131         """
132         Limit the number of threads running at a given time to
133         {desired successes} minus {successes reported}. When successes
134         reported == desired, wake up the remaining threads and tell
135         them to quit.
136
137         Should be used in a "with" block.
138         """
139         def __init__(self, todo):
140             self._todo = todo
141             self._done = 0
142             self._todo_lock = threading.Semaphore(todo)
143             self._done_lock = threading.Lock()
144
145         def __enter__(self):
146             self._todo_lock.acquire()
147             return self
148
149         def __exit__(self, type, value, traceback):
150             self._todo_lock.release()
151
152         def shall_i_proceed(self):
153             """
154             Return true if the current thread should do stuff. Return
155             false if the current thread should just stop.
156             """
157             with self._done_lock:
158                 return (self._done < self._todo)
159
160         def increment_done(self):
161             """
162             Report that the current thread was successful.
163             """
164             with self._done_lock:
165                 self._done += 1
166
167         def done(self):
168             """
169             Return how many successes were reported.
170             """
171             with self._done_lock:
172                 return self._done
173
174     class KeepWriterThread(threading.Thread):
175         """
176         Write a blob of data to the given Keep server. Call
177         increment_done() of the given ThreadLimiter if the write
178         succeeds.
179         """
180         def __init__(self, **kwargs):
181             super(KeepClient.KeepWriterThread, self).__init__()
182             self.args = kwargs
183
184         def run(self):
185             with self.args['thread_limiter'] as limiter:
186                 if not limiter.shall_i_proceed():
187                     # My turn arrived, but the job has been done without
188                     # me.
189                     return
190                 logging.debug("KeepWriterThread %s proceeding %s %s" %
191                               (str(threading.current_thread()),
192                                self.args['data_hash'],
193                                self.args['service_root']))
194                 h = httplib2.Http()
195                 url = self.args['service_root'] + self.args['data_hash']
196                 api_token = config.get('ARVADOS_API_TOKEN')
197                 headers = {'Authorization': "OAuth2 %s" % api_token}
198                 try:
199                     resp, content = h.request(url.encode('utf-8'), 'PUT',
200                                               headers=headers,
201                                               body=self.args['data'])
202                     if (resp['status'] == '401' and
203                         re.match(r'Timestamp verification failed', content)):
204                         body = KeepClient.sign_for_old_server(
205                             self.args['data_hash'],
206                             self.args['data'])
207                         h = httplib2.Http()
208                         resp, content = h.request(url.encode('utf-8'), 'PUT',
209                                                   headers=headers,
210                                                   body=body)
211                     if re.match(r'^2\d\d$', resp['status']):
212                         logging.debug("KeepWriterThread %s succeeded %s %s" %
213                                       (str(threading.current_thread()),
214                                        self.args['data_hash'],
215                                        self.args['service_root']))
216                         return limiter.increment_done()
217                     logging.warning("Request fail: PUT %s => %s %s" %
218                                     (url, resp['status'], content))
219                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
220                     logging.warning("Request fail: PUT %s => %s: %s" %
221                                     (url, type(e), str(e)))
222
223     def __init__(self):
224         self.lock = threading.Lock()
225         self.service_roots = None
226         self._cache_lock = threading.Lock()
227         self._cache = []
228         # default 256 megabyte cache
229         self.cache_max = 256 * 1024 * 1024
230
231     def shuffled_service_roots(self, hash):
232         if self.service_roots == None:
233             self.lock.acquire()
234             try:
235                 keep_disks = arvados.api().keep_disks().list().execute()['items']
236                 roots = (("http%s://%s:%d/" %
237                           ('s' if f['service_ssl_flag'] else '',
238                            f['service_host'],
239                            f['service_port']))
240                          for f in keep_disks)
241                 self.service_roots = sorted(set(roots))
242                 logging.debug(str(self.service_roots))
243             finally:
244                 self.lock.release()
245
246         # Build an ordering with which to query the Keep servers based on the
247         # contents of the hash.
248         # "hash" is a hex-encoded number at least 8 digits
249         # (32 bits) long
250
251         # seed used to calculate the next keep server from 'pool'
252         # to be added to 'pseq'
253         seed = hash
254
255         # Keep servers still to be added to the ordering
256         pool = self.service_roots[:]
257
258         # output probe sequence
259         pseq = []
260
261         # iterate while there are servers left to be assigned
262         while len(pool) > 0:
263             if len(seed) < 8:
264                 # ran out of digits in the seed
265                 if len(pseq) < len(hash) / 4:
266                     # the number of servers added to the probe sequence is less
267                     # than the number of 4-digit slices in 'hash' so refill the
268                     # seed with the last 4 digits and then append the contents
269                     # of 'hash'.
270                     seed = hash[-4:] + hash
271                 else:
272                     # refill the seed with the contents of 'hash'
273                     seed += hash
274
275             # Take the next 8 digits (32 bytes) and interpret as an integer,
276             # then modulus with the size of the remaining pool to get the next
277             # selected server.
278             probe = int(seed[0:8], 16) % len(pool)
279
280             print seed[0:8], int(seed[0:8], 16), len(pool), probe
281
282             # Append the selected server to the probe sequence and remove it
283             # from the pool.
284             pseq += [pool[probe]]
285             pool = pool[:probe] + pool[probe+1:]
286
287             # Remove the digits just used from the seed
288             seed = seed[8:]
289         logging.debug(str(pseq))
290         return pseq
291
292     class CacheSlot(object):
293         def __init__(self, locator):
294             self.locator = locator
295             self.ready = threading.Event()
296             self.content = None
297
298         def get(self):
299             self.ready.wait()
300             return self.content
301
302         def set(self, value):
303             self.content = value
304             self.ready.set()
305
306         def size(self):
307             if self.content == None:
308                 return 0
309             else:
310                 return len(self.content)
311
312     def cap_cache(self):
313         '''Cap the cache size to self.cache_max'''
314         self._cache_lock.acquire()
315         try:
316             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
317             sm = sum([slot.size() for slot in self._cache])
318             while sm > self.cache_max:
319                 del self._cache[-1]
320                 sm = sum([slot.size() for a in self._cache])
321         finally:
322             self._cache_lock.release()
323
324     def reserve_cache(self, locator):
325         '''Reserve a cache slot for the specified locator,
326         or return the existing slot.'''
327         self._cache_lock.acquire()
328         try:
329             # Test if the locator is already in the cache
330             for i in xrange(0, len(self._cache)):
331                 if self._cache[i].locator == locator:
332                     n = self._cache[i]
333                     if i != 0:
334                         # move it to the front
335                         del self._cache[i]
336                         self._cache.insert(0, n)
337                     return n, False
338
339             # Add a new cache slot for the locator
340             n = KeepClient.CacheSlot(locator)
341             self._cache.insert(0, n)
342             return n, True
343         finally:
344             self._cache_lock.release()
345
346     def get(self, locator):
347         #logging.debug("Keep.get %s" % (locator))
348
349         if re.search(r',', locator):
350             return ''.join(self.get(x) for x in locator.split(','))
351         if 'KEEP_LOCAL_STORE' in os.environ:
352             return KeepClient.local_store_get(locator)
353         expect_hash = re.sub(r'\+.*', '', locator)
354
355         slot, first = self.reserve_cache(expect_hash)
356         #logging.debug("%s %s %s" % (slot, first, expect_hash))
357
358         if not first:
359             v = slot.get()
360             return v
361
362         try:
363             for service_root in self.shuffled_service_roots(expect_hash):
364                 url = service_root + expect_hash
365                 api_token = config.get('ARVADOS_API_TOKEN')
366                 headers = {'Authorization': "OAuth2 %s" % api_token,
367                            'Accept': 'application/octet-stream'}
368                 blob = self.get_url(url, headers, expect_hash)
369                 if blob:
370                     slot.set(blob)
371                     self.cap_cache()
372                     return blob
373
374             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
375                 instance = location_hint.group(1)
376                 url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
377                 blob = self.get_url(url, {}, expect_hash)
378                 if blob:
379                     slot.set(blob)
380                     self.cap_cache()
381                     return blob
382         except:
383             slot.set(None)
384             self.cap_cache()
385             raise
386
387         slot.set(None)
388         self.cap_cache()
389         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
390
391     def get_url(self, url, headers, expect_hash):
392         h = httplib2.Http()
393         try:
394             logging.info("Request: GET %s" % (url))
395             with timer.Timer() as t:
396                 resp, content = h.request(url.encode('utf-8'), 'GET',
397                                           headers=headers)
398             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
399                                                                         t.msecs,
400                                                                         (len(content)/(1024*1024))/t.secs))
401             if re.match(r'^2\d\d$', resp['status']):
402                 m = hashlib.new('md5')
403                 m.update(content)
404                 md5 = m.hexdigest()
405                 if md5 == expect_hash:
406                     return content
407                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
408         except Exception as e:
409             logging.info("Request fail: GET %s => %s: %s" %
410                          (url, type(e), str(e)))
411         return None
412
413     def put(self, data, **kwargs):
414         if 'KEEP_LOCAL_STORE' in os.environ:
415             return KeepClient.local_store_put(data)
416         m = hashlib.new('md5')
417         m.update(data)
418         data_hash = m.hexdigest()
419         have_copies = 0
420         want_copies = kwargs.get('copies', 2)
421         if not (want_copies > 0):
422             return data_hash
423         threads = []
424         thread_limiter = KeepClient.ThreadLimiter(want_copies)
425         for service_root in self.shuffled_service_roots(data_hash):
426             t = KeepClient.KeepWriterThread(data=data,
427                                             data_hash=data_hash,
428                                             service_root=service_root,
429                                             thread_limiter=thread_limiter)
430             t.start()
431             threads += [t]
432         for t in threads:
433             t.join()
434         have_copies = thread_limiter.done()
435         if have_copies == want_copies:
436             return (data_hash + '+' + str(len(data)))
437         raise arvados.errors.KeepWriteError(
438             "Write fail for %s: wanted %d but wrote %d" %
439             (data_hash, want_copies, have_copies))
440
441     @staticmethod
442     def sign_for_old_server(data_hash, data):
443         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
444
445
446     @staticmethod
447     def local_store_put(data):
448         m = hashlib.new('md5')
449         m.update(data)
450         md5 = m.hexdigest()
451         locator = '%s+%d' % (md5, len(data))
452         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
453             f.write(data)
454         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
455                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
456         return locator
457
458     @staticmethod
459     def local_store_get(locator):
460         r = re.search('^([0-9a-f]{32,})', locator)
461         if not r:
462             raise arvados.errors.NotFoundError(
463                 "Invalid data locator: '%s'" % locator)
464         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
465             return ''
466         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
467             return f.read()