Merge branch 'master' into 2756-eventbus-in-workbench
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._response = None
59             self._todo_lock = threading.Semaphore(todo)
60             self._done_lock = threading.Lock()
61
62         def __enter__(self):
63             self._todo_lock.acquire()
64             return self
65
66         def __exit__(self, type, value, traceback):
67             self._todo_lock.release()
68
69         def shall_i_proceed(self):
70             """
71             Return true if the current thread should do stuff. Return
72             false if the current thread should just stop.
73             """
74             with self._done_lock:
75                 return (self._done < self._todo)
76
77         def save_response(self, response_body):
78             """
79             Records a response body (a locator, possibly signed) returned by
80             the Keep server.  It is not necessary to save more than
81             one response, since we presume that any locator returned
82             in response to a successful request is valid.
83             """
84             with self._done_lock:
85                 self._done += 1
86                 self._response = response_body
87
88         def response(self):
89             """
90             Returns the body from the response to a PUT request.
91             """
92             with self._done_lock:
93                 return self._response
94
95         def done(self):
96             """
97             Return how many successes were reported.
98             """
99             with self._done_lock:
100                 return self._done
101
102     class KeepWriterThread(threading.Thread):
103         """
104         Write a blob of data to the given Keep server. On success, call
105         save_response() of the given ThreadLimiter to save the returned
106         locator.
107         """
108         def __init__(self, **kwargs):
109             super(KeepClient.KeepWriterThread, self).__init__()
110             self.args = kwargs
111
112         def run(self):
113             with self.args['thread_limiter'] as limiter:
114                 if not limiter.shall_i_proceed():
115                     # My turn arrived, but the job has been done without
116                     # me.
117                     return
118                 logging.debug("KeepWriterThread %s proceeding %s %s" %
119                               (str(threading.current_thread()),
120                                self.args['data_hash'],
121                                self.args['service_root']))
122                 h = httplib2.Http()
123                 url = self.args['service_root'] + self.args['data_hash']
124                 api_token = config.get('ARVADOS_API_TOKEN')
125                 headers = {'Authorization': "OAuth2 %s" % api_token}
126                 try:
127                     resp, content = h.request(url.encode('utf-8'), 'PUT',
128                                               headers=headers,
129                                               body=self.args['data'])
130                     if (resp['status'] == '401' and
131                         re.match(r'Timestamp verification failed', content)):
132                         body = KeepClient.sign_for_old_server(
133                             self.args['data_hash'],
134                             self.args['data'])
135                         h = httplib2.Http()
136                         resp, content = h.request(url.encode('utf-8'), 'PUT',
137                                                   headers=headers,
138                                                   body=body)
139                     if re.match(r'^2\d\d$', resp['status']):
140                         logging.debug("KeepWriterThread %s succeeded %s %s" %
141                                       (str(threading.current_thread()),
142                                        self.args['data_hash'],
143                                        self.args['service_root']))
144                         return limiter.save_response(content.strip())
145                     logging.warning("Request fail: PUT %s => %s %s" %
146                                     (url, resp['status'], content))
147                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
148                     logging.warning("Request fail: PUT %s => %s: %s" %
149                                     (url, type(e), str(e)))
150
151     def __init__(self):
152         self.lock = threading.Lock()
153         self.service_roots = None
154         self._cache_lock = threading.Lock()
155         self._cache = []
156         # default 256 megabyte cache
157         self.cache_max = 256 * 1024 * 1024
158
159     def shuffled_service_roots(self, hash):
160         if self.service_roots == None:
161             self.lock.acquire()
162             try:
163                 keep_disks = arvados.api().keep_disks().list().execute()['items']
164                 roots = (("http%s://%s:%d/" %
165                           ('s' if f['service_ssl_flag'] else '',
166                            f['service_host'],
167                            f['service_port']))
168                          for f in keep_disks)
169                 self.service_roots = sorted(set(roots))
170                 logging.debug(str(self.service_roots))
171             finally:
172                 self.lock.release()
173
174         # Build an ordering with which to query the Keep servers based on the
175         # contents of the hash.
176         # "hash" is a hex-encoded number at least 8 digits
177         # (32 bits) long
178
179         # seed used to calculate the next keep server from 'pool'
180         # to be added to 'pseq'
181         seed = hash
182
183         # Keep servers still to be added to the ordering
184         pool = self.service_roots[:]
185
186         # output probe sequence
187         pseq = []
188
189         # iterate while there are servers left to be assigned
190         while len(pool) > 0:
191             if len(seed) < 8:
192                 # ran out of digits in the seed
193                 if len(pseq) < len(hash) / 4:
194                     # the number of servers added to the probe sequence is less
195                     # than the number of 4-digit slices in 'hash' so refill the
196                     # seed with the last 4 digits and then append the contents
197                     # of 'hash'.
198                     seed = hash[-4:] + hash
199                 else:
200                     # refill the seed with the contents of 'hash'
201                     seed += hash
202
203             # Take the next 8 digits (32 bytes) and interpret as an integer,
204             # then modulus with the size of the remaining pool to get the next
205             # selected server.
206             probe = int(seed[0:8], 16) % len(pool)
207
208             # Append the selected server to the probe sequence and remove it
209             # from the pool.
210             pseq += [pool[probe]]
211             pool = pool[:probe] + pool[probe+1:]
212
213             # Remove the digits just used from the seed
214             seed = seed[8:]
215         logging.debug(str(pseq))
216         return pseq
217
218     class CacheSlot(object):
219         def __init__(self, locator):
220             self.locator = locator
221             self.ready = threading.Event()
222             self.content = None
223
224         def get(self):
225             self.ready.wait()
226             return self.content
227
228         def set(self, value):
229             self.content = value
230             self.ready.set()
231
232         def size(self):
233             if self.content == None:
234                 return 0
235             else:
236                 return len(self.content)
237
238     def cap_cache(self):
239         '''Cap the cache size to self.cache_max'''
240         self._cache_lock.acquire()
241         try:
242             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
243             sm = sum([slot.size() for slot in self._cache])
244             while sm > self.cache_max:
245                 del self._cache[-1]
246                 sm = sum([slot.size() for a in self._cache])
247         finally:
248             self._cache_lock.release()
249
250     def reserve_cache(self, locator):
251         '''Reserve a cache slot for the specified locator,
252         or return the existing slot.'''
253         self._cache_lock.acquire()
254         try:
255             # Test if the locator is already in the cache
256             for i in xrange(0, len(self._cache)):
257                 if self._cache[i].locator == locator:
258                     n = self._cache[i]
259                     if i != 0:
260                         # move it to the front
261                         del self._cache[i]
262                         self._cache.insert(0, n)
263                     return n, False
264
265             # Add a new cache slot for the locator
266             n = KeepClient.CacheSlot(locator)
267             self._cache.insert(0, n)
268             return n, True
269         finally:
270             self._cache_lock.release()
271
272     def get(self, locator):
273         #logging.debug("Keep.get %s" % (locator))
274
275         if re.search(r',', locator):
276             return ''.join(self.get(x) for x in locator.split(','))
277         if 'KEEP_LOCAL_STORE' in os.environ:
278             return KeepClient.local_store_get(locator)
279         expect_hash = re.sub(r'\+.*', '', locator)
280
281         slot, first = self.reserve_cache(expect_hash)
282         #logging.debug("%s %s %s" % (slot, first, expect_hash))
283
284         if not first:
285             v = slot.get()
286             return v
287
288         try:
289             for service_root in self.shuffled_service_roots(expect_hash):
290                 url = service_root + locator
291                 api_token = config.get('ARVADOS_API_TOKEN')
292                 headers = {'Authorization': "OAuth2 %s" % api_token,
293                            'Accept': 'application/octet-stream'}
294                 blob = self.get_url(url, headers, expect_hash)
295                 if blob:
296                     slot.set(blob)
297                     self.cap_cache()
298                     return blob
299
300             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
301                 instance = location_hint.group(1)
302                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
303                 blob = self.get_url(url, {}, expect_hash)
304                 if blob:
305                     slot.set(blob)
306                     self.cap_cache()
307                     return blob
308         except:
309             slot.set(None)
310             self.cap_cache()
311             raise
312
313         slot.set(None)
314         self.cap_cache()
315         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
316
317     def get_url(self, url, headers, expect_hash):
318         h = httplib2.Http()
319         try:
320             logging.info("Request: GET %s" % (url))
321             with timer.Timer() as t:
322                 resp, content = h.request(url.encode('utf-8'), 'GET',
323                                           headers=headers)
324             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
325                                                                         t.msecs,
326                                                                         (len(content)/(1024*1024))/t.secs))
327             if re.match(r'^2\d\d$', resp['status']):
328                 m = hashlib.new('md5')
329                 m.update(content)
330                 md5 = m.hexdigest()
331                 if md5 == expect_hash:
332                     return content
333                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
334         except Exception as e:
335             logging.info("Request fail: GET %s => %s: %s" %
336                          (url, type(e), str(e)))
337         return None
338
339     def put(self, data, **kwargs):
340         if 'KEEP_LOCAL_STORE' in os.environ:
341             return KeepClient.local_store_put(data)
342         m = hashlib.new('md5')
343         m.update(data)
344         data_hash = m.hexdigest()
345         have_copies = 0
346         want_copies = kwargs.get('copies', 2)
347         if not (want_copies > 0):
348             return data_hash
349         threads = []
350         thread_limiter = KeepClient.ThreadLimiter(want_copies)
351         for service_root in self.shuffled_service_roots(data_hash):
352             t = KeepClient.KeepWriterThread(data=data,
353                                             data_hash=data_hash,
354                                             service_root=service_root,
355                                             thread_limiter=thread_limiter)
356             t.start()
357             threads += [t]
358         for t in threads:
359             t.join()
360         have_copies = thread_limiter.done()
361         # If we're done, return the response from Keep
362         if have_copies == want_copies:
363             return thread_limiter.response()
364         raise arvados.errors.KeepWriteError(
365             "Write fail for %s: wanted %d but wrote %d" %
366             (data_hash, want_copies, have_copies))
367
368     @staticmethod
369     def sign_for_old_server(data_hash, data):
370         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
371
372
373     @staticmethod
374     def local_store_put(data):
375         m = hashlib.new('md5')
376         m.update(data)
377         md5 = m.hexdigest()
378         locator = '%s+%d' % (md5, len(data))
379         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
380             f.write(data)
381         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
382                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
383         return locator
384
385     @staticmethod
386     def local_store_get(locator):
387         r = re.search('^([0-9a-f]{32,})', locator)
388         if not r:
389             raise arvados.errors.NotFoundError(
390                 "Invalid data locator: '%s'" % locator)
391         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
392             return ''
393         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
394             return f.read()