caching wip
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._todo_lock = threading.Semaphore(todo)
59             self._done_lock = threading.Lock()
60
61         def __enter__(self):
62             self._todo_lock.acquire()
63             return self
64
65         def __exit__(self, type, value, traceback):
66             self._todo_lock.release()
67
68         def shall_i_proceed(self):
69             """
70             Return true if the current thread should do stuff. Return
71             false if the current thread should just stop.
72             """
73             with self._done_lock:
74                 return (self._done < self._todo)
75
76         def increment_done(self):
77             """
78             Report that the current thread was successful.
79             """
80             with self._done_lock:
81                 self._done += 1
82
83         def done(self):
84             """
85             Return how many successes were reported.
86             """
87             with self._done_lock:
88                 return self._done
89
90     class KeepWriterThread(threading.Thread):
91         """
92         Write a blob of data to the given Keep server. Call
93         increment_done() of the given ThreadLimiter if the write
94         succeeds.
95         """
96         def __init__(self, **kwargs):
97             super(KeepClient.KeepWriterThread, self).__init__()
98             self.args = kwargs
99
100         def run(self):
101             with self.args['thread_limiter'] as limiter:
102                 if not limiter.shall_i_proceed():
103                     # My turn arrived, but the job has been done without
104                     # me.
105                     return
106                 logging.debug("KeepWriterThread %s proceeding %s %s" %
107                               (str(threading.current_thread()),
108                                self.args['data_hash'],
109                                self.args['service_root']))
110                 h = httplib2.Http()
111                 url = self.args['service_root'] + self.args['data_hash']
112                 api_token = config.get('ARVADOS_API_TOKEN')
113                 headers = {'Authorization': "OAuth2 %s" % api_token}
114                 try:
115                     resp, content = h.request(url.encode('utf-8'), 'PUT',
116                                               headers=headers,
117                                               body=self.args['data'])
118                     if (resp['status'] == '401' and
119                         re.match(r'Timestamp verification failed', content)):
120                         body = KeepClient.sign_for_old_server(
121                             self.args['data_hash'],
122                             self.args['data'])
123                         h = httplib2.Http()
124                         resp, content = h.request(url.encode('utf-8'), 'PUT',
125                                                   headers=headers,
126                                                   body=body)
127                     if re.match(r'^2\d\d$', resp['status']):
128                         logging.debug("KeepWriterThread %s succeeded %s %s" %
129                                       (str(threading.current_thread()),
130                                        self.args['data_hash'],
131                                        self.args['service_root']))
132                         return limiter.increment_done()
133                     logging.warning("Request fail: PUT %s => %s %s" %
134                                     (url, resp['status'], content))
135                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
136                     logging.warning("Request fail: PUT %s => %s: %s" %
137                                     (url, type(e), str(e)))
138
139     def __init__(self):
140         self.lock = threading.Lock()
141         self.service_roots = None
142         self._cache_lock = threading.Lock()
143         self._cache = []
144         # default 256 megabyte cache
145         self.cache_max = 256 * 1024 * 1024
146
147     def shuffled_service_roots(self, hash):
148         if self.service_roots == None:
149             self.lock.acquire()
150             try:
151                 keep_disks = api().keep_disks().list().execute()['items']
152                 roots = (("http%s://%s:%d/" %
153                           ('s' if f['service_ssl_flag'] else '',
154                            f['service_host'],
155                            f['service_port']))
156                          for f in keep_disks)
157                 self.service_roots = sorted(set(roots))
158                 logging.debug(str(self.service_roots))
159             finally:
160                 self.lock.release()
161         seed = hash
162         pool = self.service_roots[:]
163         pseq = []
164         while len(pool) > 0:
165             if len(seed) < 8:
166                 if len(pseq) < len(hash) / 4: # first time around
167                     seed = hash[-4:] + hash
168                 else:
169                     seed += hash
170             probe = int(seed[0:8], 16) % len(pool)
171             pseq += [pool[probe]]
172             pool = pool[:probe] + pool[probe+1:]
173             seed = seed[8:]
174         logging.debug(str(pseq))
175         return pseq
176
177     class CacheSlot(object):
178         def __init__(self, locator):
179             self.locator = locator
180             self.ready = threading.Event()
181             self.content = None
182
183         def get(self):
184             self.ready.wait()
185             return self.content
186
187         def set(self, value):
188             self.content = value
189             self.ready.set()
190
191         def size(self):
192             if self.content == None:
193                 return 0
194             else:
195                 return len(self.content)
196
197     def cap_cache(self):
198         '''Cap the cache size to self.cache_max'''
199         self._cache_lock.acquire()
200         try:
201             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
202             sm = sum([slot.size() for slot in self._cache])
203             while sm > self.cache_max:
204                 del self._cache[-1]
205                 sm = sum([slot.size() for a in self._cache])
206         finally:
207             self._cache_lock.release()
208
209     def reserve_cache(self, locator):
210         '''Reserve a cache slot for the specified locator, 
211         or return the existing slot.'''
212         self._cache_lock.acquire()
213         try:
214             # Test if the locator is already in the cache
215             for i in xrange(0, len(self._cache)):
216                 if self._cache[i].locator == locator:
217                     n = self._cache[i]
218                     if i != 0:
219                         del self._cache[i]
220                         self._cache.insert(0, n)
221                     return n, False
222
223             # Add a new cache slot for the locator
224             n = CacheSlot(locator)
225             self._cache.insert(0, n)
226             return n, True
227         finally:
228             self._cache_lock.release()
229
230     def get(self, locator):
231         logging.debug("Keep.get %s" % (locator))
232
233         if re.search(r',', locator):
234             return ''.join(self.get(x) for x in locator.split(','))
235         if 'KEEP_LOCAL_STORE' in os.environ:
236             return KeepClient.local_store_get(locator)
237         expect_hash = re.sub(r'\+.*', '', locator)
238
239         slot, first = self.reserve_cache(expect_hash)
240         if not first:
241             v = slot.get()
242             return v
243
244         for service_root in self.shuffled_service_roots(expect_hash):
245             url = service_root + expect_hash
246             api_token = config.get('ARVADOS_API_TOKEN')
247             headers = {'Authorization': "OAuth2 %s" % api_token,
248                        'Accept': 'application/octet-stream'}
249             blob = self.get_url(url, headers, expect_hash)
250             if blob:
251                 slot.set(blob)
252                 self.cap_cache()
253                 return blob
254
255         for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
256             instance = location_hint.group(1)
257             url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
258             blob = self.get_url(url, {}, expect_hash)
259             if blob:
260                 slot.set(blob)
261                 self.cap_cache()
262                 return blob
263
264         slot.set(None)
265         self.cap_cache()
266
267         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
268
269     def get_url(self, url, headers, expect_hash):
270         h = httplib2.Http()
271         try:
272             logging.info("Request: GET %s" % (url))
273             with timer.Timer() as t:
274                 resp, content = h.request(url.encode('utf-8'), 'GET',
275                                           headers=headers)
276             logging.info("Received %s bytes in %s msec (%s bytes/sec)" % (len(content), t.msecs, len(content)/t.secs))
277             if re.match(r'^2\d\d$', resp['status']):
278                 m = hashlib.new('md5')
279                 m.update(content)
280                 md5 = m.hexdigest()
281                 if md5 == expect_hash:
282                     return content
283                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
284         except Exception as e:
285             logging.info("Request fail: GET %s => %s: %s" %
286                          (url, type(e), str(e)))
287         return None
288
289     def put(self, data, **kwargs):
290         if 'KEEP_LOCAL_STORE' in os.environ:
291             return KeepClient.local_store_put(data)
292         m = hashlib.new('md5')
293         m.update(data)
294         data_hash = m.hexdigest()
295         have_copies = 0
296         want_copies = kwargs.get('copies', 2)
297         if not (want_copies > 0):
298             return data_hash
299         threads = []
300         thread_limiter = KeepClient.ThreadLimiter(want_copies)
301         for service_root in self.shuffled_service_roots(data_hash):
302             t = KeepClient.KeepWriterThread(data=data,
303                                             data_hash=data_hash,
304                                             service_root=service_root,
305                                             thread_limiter=thread_limiter)
306             t.start()
307             threads += [t]
308         for t in threads:
309             t.join()
310         have_copies = thread_limiter.done()
311         if have_copies == want_copies:
312             return (data_hash + '+' + str(len(data)))
313         raise arvados.errors.KeepWriteError(
314             "Write fail for %s: wanted %d but wrote %d" %
315             (data_hash, want_copies, have_copies))
316
317     @staticmethod
318     def sign_for_old_server(data_hash, data):
319         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
320
321
322     @staticmethod
323     def local_store_put(data):
324         m = hashlib.new('md5')
325         m.update(data)
326         md5 = m.hexdigest()
327         locator = '%s+%d' % (md5, len(data))
328         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
329             f.write(data)
330         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
331                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
332         return locator
333
334     @staticmethod
335     def local_store_get(locator):
336         r = re.search('^([0-9a-f]{32,})', locator)
337         if not r:
338             raise arvados.errors.NotFoundError(
339                 "Invalid data locator: '%s'" % locator)
340         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
341             return ''
342         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
343             return f.read()