2755: add support for signed locators in the Python SDK.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._response = None
59             self._todo_lock = threading.Semaphore(todo)
60             self._done_lock = threading.Lock()
61
62         def __enter__(self):
63             self._todo_lock.acquire()
64             return self
65
66         def __exit__(self, type, value, traceback):
67             self._todo_lock.release()
68
69         def shall_i_proceed(self):
70             """
71             Return true if the current thread should do stuff. Return
72             false if the current thread should just stop.
73             """
74             with self._done_lock:
75                 return (self._done < self._todo)
76
77         def save_response(self, response_body):
78             """
79             Records a response body (a locator, possibly signed) returned by
80             the Keep server.  It is not necessary to save more than
81             one response, since we presume that any locator returned
82             in response to a successful request is valid.
83             """
84             with self._done_lock:
85                 self._done += 1
86                 self._response = response_body
87
88         def response(self):
89             """
90             Returns the body from the response to a PUT request.
91             """
92             with self._done_lock:
93                 return self._response
94
95         def done(self):
96             """
97             Return how many successes were reported.
98             """
99             with self._done_lock:
100                 return self._done
101
102     class KeepWriterThread(threading.Thread):
103         """
104         Write a blob of data to the given Keep server. On success, call
105         save_response() of the given ThreadLimiter to save the returned
106         locator.
107         """
108         def __init__(self, **kwargs):
109             super(KeepClient.KeepWriterThread, self).__init__()
110             self.args = kwargs
111
112         def run(self):
113             with self.args['thread_limiter'] as limiter:
114                 if not limiter.shall_i_proceed():
115                     # My turn arrived, but the job has been done without
116                     # me.
117                     return
118                 logging.debug("KeepWriterThread %s proceeding %s %s" %
119                               (str(threading.current_thread()),
120                                self.args['data_hash'],
121                                self.args['service_root']))
122                 h = httplib2.Http()
123                 url = self.args['service_root'] + self.args['data_hash']
124                 api_token = config.get('ARVADOS_API_TOKEN')
125                 headers = {'Authorization': "OAuth2 %s" % api_token}
126                 try:
127                     resp, content = h.request(url.encode('utf-8'), 'PUT',
128                                               headers=headers,
129                                               body=self.args['data'])
130                     if (resp['status'] == '401' and
131                         re.match(r'Timestamp verification failed', content)):
132                         body = KeepClient.sign_for_old_server(
133                             self.args['data_hash'],
134                             self.args['data'])
135                         h = httplib2.Http()
136                         resp, content = h.request(url.encode('utf-8'), 'PUT',
137                                                   headers=headers,
138                                                   body=body)
139                     if re.match(r'^2\d\d$', resp['status']):
140                         logging.debug("KeepWriterThread %s succeeded %s %s" %
141                                       (str(threading.current_thread()),
142                                        self.args['data_hash'],
143                                        self.args['service_root']))
144                         return limiter.save_response(content.strip())
145                     logging.warning("Request fail: PUT %s => %s %s" %
146                                     (url, resp['status'], content))
147                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
148                     logging.warning("Request fail: PUT %s => %s: %s" %
149                                     (url, type(e), str(e)))
150
151     def __init__(self):
152         self.lock = threading.Lock()
153         self.service_roots = None
154         self._cache_lock = threading.Lock()
155         self._cache = []
156         # default 256 megabyte cache
157         self.cache_max = 256 * 1024 * 1024
158
159     def shuffled_service_roots(self, hash):
160         if self.service_roots == None:
161             self.lock.acquire()
162             try:
163                 keep_disks = arvados.api().keep_disks().list().execute()['items']
164                 roots = (("http%s://%s:%d/" %
165                           ('s' if f['service_ssl_flag'] else '',
166                            f['service_host'],
167                            f['service_port']))
168                          for f in keep_disks)
169                 self.service_roots = sorted(set(roots))
170                 logging.debug(str(self.service_roots))
171             finally:
172                 self.lock.release()
173
174         # Build an ordering with which to query the Keep servers based on the
175         # contents of the hash.
176         # "hash" is a hex-encoded number at least 8 digits
177         # (32 bits) long
178
179         # seed used to calculate the next keep server from 'pool'
180         # to be added to 'pseq'
181         seed = hash
182
183         # Keep servers still to be added to the ordering
184         pool = self.service_roots[:]
185
186         # output probe sequence
187         pseq = []
188
189         # iterate while there are servers left to be assigned
190         while len(pool) > 0:
191             if len(seed) < 8:
192                 # ran out of digits in the seed
193                 if len(pseq) < len(hash) / 4:
194                     # the number of servers added to the probe sequence is less
195                     # than the number of 4-digit slices in 'hash' so refill the
196                     # seed with the last 4 digits and then append the contents
197                     # of 'hash'.
198                     seed = hash[-4:] + hash
199                 else:
200                     # refill the seed with the contents of 'hash'
201                     seed += hash
202
203             # Take the next 8 digits (32 bytes) and interpret as an integer,
204             # then modulus with the size of the remaining pool to get the next
205             # selected server.
206             probe = int(seed[0:8], 16) % len(pool)
207
208             print seed[0:8], int(seed[0:8], 16), len(pool), probe
209
210             # Append the selected server to the probe sequence and remove it
211             # from the pool.
212             pseq += [pool[probe]]
213             pool = pool[:probe] + pool[probe+1:]
214
215             # Remove the digits just used from the seed
216             seed = seed[8:]
217         logging.debug(str(pseq))
218         return pseq
219
220     class CacheSlot(object):
221         def __init__(self, locator):
222             self.locator = locator
223             self.ready = threading.Event()
224             self.content = None
225
226         def get(self):
227             self.ready.wait()
228             return self.content
229
230         def set(self, value):
231             self.content = value
232             self.ready.set()
233
234         def size(self):
235             if self.content == None:
236                 return 0
237             else:
238                 return len(self.content)
239
240     def cap_cache(self):
241         '''Cap the cache size to self.cache_max'''
242         self._cache_lock.acquire()
243         try:
244             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
245             sm = sum([slot.size() for slot in self._cache])
246             while sm > self.cache_max:
247                 del self._cache[-1]
248                 sm = sum([slot.size() for a in self._cache])
249         finally:
250             self._cache_lock.release()
251
252     def reserve_cache(self, locator):
253         '''Reserve a cache slot for the specified locator,
254         or return the existing slot.'''
255         self._cache_lock.acquire()
256         try:
257             # Test if the locator is already in the cache
258             for i in xrange(0, len(self._cache)):
259                 if self._cache[i].locator == locator:
260                     n = self._cache[i]
261                     if i != 0:
262                         # move it to the front
263                         del self._cache[i]
264                         self._cache.insert(0, n)
265                     return n, False
266
267             # Add a new cache slot for the locator
268             n = KeepClient.CacheSlot(locator)
269             self._cache.insert(0, n)
270             return n, True
271         finally:
272             self._cache_lock.release()
273
274     def get(self, locator):
275         #logging.debug("Keep.get %s" % (locator))
276
277         if re.search(r',', locator):
278             return ''.join(self.get(x) for x in locator.split(','))
279         if 'KEEP_LOCAL_STORE' in os.environ:
280             return KeepClient.local_store_get(locator)
281         expect_hash = re.sub(r'\+.*', '', locator)
282
283         slot, first = self.reserve_cache(expect_hash)
284         #logging.debug("%s %s %s" % (slot, first, expect_hash))
285
286         if not first:
287             v = slot.get()
288             return v
289
290         try:
291             for service_root in self.shuffled_service_roots(expect_hash):
292                 url = service_root + locator
293                 api_token = config.get('ARVADOS_API_TOKEN')
294                 headers = {'Authorization': "OAuth2 %s" % api_token,
295                            'Accept': 'application/octet-stream'}
296                 blob = self.get_url(url, headers, expect_hash)
297                 if blob:
298                     slot.set(blob)
299                     self.cap_cache()
300                     return blob
301
302             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
303                 instance = location_hint.group(1)
304                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
305                 blob = self.get_url(url, {}, expect_hash)
306                 if blob:
307                     slot.set(blob)
308                     self.cap_cache()
309                     return blob
310         except:
311             slot.set(None)
312             self.cap_cache()
313             raise
314
315         slot.set(None)
316         self.cap_cache()
317         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
318
319     def get_url(self, url, headers, expect_hash):
320         h = httplib2.Http()
321         try:
322             logging.info("Request: GET %s" % (url))
323             with timer.Timer() as t:
324                 resp, content = h.request(url.encode('utf-8'), 'GET',
325                                           headers=headers)
326             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
327                                                                         t.msecs,
328                                                                         (len(content)/(1024*1024))/t.secs))
329             if re.match(r'^2\d\d$', resp['status']):
330                 m = hashlib.new('md5')
331                 m.update(content)
332                 md5 = m.hexdigest()
333                 if md5 == expect_hash:
334                     return content
335                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
336         except Exception as e:
337             logging.info("Request fail: GET %s => %s: %s" %
338                          (url, type(e), str(e)))
339         return None
340
341     def put(self, data, **kwargs):
342         if 'KEEP_LOCAL_STORE' in os.environ:
343             return KeepClient.local_store_put(data)
344         m = hashlib.new('md5')
345         m.update(data)
346         data_hash = m.hexdigest()
347         have_copies = 0
348         want_copies = kwargs.get('copies', 2)
349         if not (want_copies > 0):
350             return data_hash
351         threads = []
352         thread_limiter = KeepClient.ThreadLimiter(want_copies)
353         for service_root in self.shuffled_service_roots(data_hash):
354             t = KeepClient.KeepWriterThread(data=data,
355                                             data_hash=data_hash,
356                                             service_root=service_root,
357                                             thread_limiter=thread_limiter)
358             t.start()
359             threads += [t]
360         for t in threads:
361             t.join()
362         have_copies = thread_limiter.done()
363         # If we're done, return the response from Keep
364         if have_copies == want_copies:
365             return thread_limiter.response()
366         raise arvados.errors.KeepWriteError(
367             "Write fail for %s: wanted %d but wrote %d" %
368             (data_hash, want_copies, have_copies))
369
370     @staticmethod
371     def sign_for_old_server(data_hash, data):
372         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
373
374
375     @staticmethod
376     def local_store_put(data):
377         m = hashlib.new('md5')
378         m.update(data)
379         md5 = m.hexdigest()
380         locator = '%s+%d' % (md5, len(data))
381         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
382             f.write(data)
383         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
384                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
385         return locator
386
387     @staticmethod
388     def local_store_get(locator):
389         r = re.search('^([0-9a-f]{32,})', locator)
390         if not r:
391             raise arvados.errors.NotFoundError(
392                 "Invalid data locator: '%s'" % locator)
393         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
394             return ''
395         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
396             return f.read()