Created a seperate arvados-fuse-driver package.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._todo_lock = threading.Semaphore(todo)
59             self._done_lock = threading.Lock()
60
61         def __enter__(self):
62             self._todo_lock.acquire()
63             return self
64
65         def __exit__(self, type, value, traceback):
66             self._todo_lock.release()
67
68         def shall_i_proceed(self):
69             """
70             Return true if the current thread should do stuff. Return
71             false if the current thread should just stop.
72             """
73             with self._done_lock:
74                 return (self._done < self._todo)
75
76         def increment_done(self):
77             """
78             Report that the current thread was successful.
79             """
80             with self._done_lock:
81                 self._done += 1
82
83         def done(self):
84             """
85             Return how many successes were reported.
86             """
87             with self._done_lock:
88                 return self._done
89
90     class KeepWriterThread(threading.Thread):
91         """
92         Write a blob of data to the given Keep server. Call
93         increment_done() of the given ThreadLimiter if the write
94         succeeds.
95         """
96         def __init__(self, **kwargs):
97             super(KeepClient.KeepWriterThread, self).__init__()
98             self.args = kwargs
99
100         def run(self):
101             with self.args['thread_limiter'] as limiter:
102                 if not limiter.shall_i_proceed():
103                     # My turn arrived, but the job has been done without
104                     # me.
105                     return
106                 logging.debug("KeepWriterThread %s proceeding %s %s" %
107                               (str(threading.current_thread()),
108                                self.args['data_hash'],
109                                self.args['service_root']))
110                 h = httplib2.Http()
111                 url = self.args['service_root'] + self.args['data_hash']
112                 api_token = config.get('ARVADOS_API_TOKEN')
113                 headers = {'Authorization': "OAuth2 %s" % api_token}
114                 try:
115                     resp, content = h.request(url.encode('utf-8'), 'PUT',
116                                               headers=headers,
117                                               body=self.args['data'])
118                     if (resp['status'] == '401' and
119                         re.match(r'Timestamp verification failed', content)):
120                         body = KeepClient.sign_for_old_server(
121                             self.args['data_hash'],
122                             self.args['data'])
123                         h = httplib2.Http()
124                         resp, content = h.request(url.encode('utf-8'), 'PUT',
125                                                   headers=headers,
126                                                   body=body)
127                     if re.match(r'^2\d\d$', resp['status']):
128                         logging.debug("KeepWriterThread %s succeeded %s %s" %
129                                       (str(threading.current_thread()),
130                                        self.args['data_hash'],
131                                        self.args['service_root']))
132                         return limiter.increment_done()
133                     logging.warning("Request fail: PUT %s => %s %s" %
134                                     (url, resp['status'], content))
135                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
136                     logging.warning("Request fail: PUT %s => %s: %s" %
137                                     (url, type(e), str(e)))
138
139     def __init__(self):
140         self.lock = threading.Lock()
141         self.service_roots = None
142         self._cache_lock = threading.Lock()
143         self._cache = []
144         # default 256 megabyte cache
145         self.cache_max = 256 * 1024 * 1024
146
147     def shuffled_service_roots(self, hash):
148         if self.service_roots == None:
149             self.lock.acquire()
150             try:
151                 keep_disks = arvados.api().keep_disks().list().execute()['items']
152                 roots = (("http%s://%s:%d/" %
153                           ('s' if f['service_ssl_flag'] else '',
154                            f['service_host'],
155                            f['service_port']))
156                          for f in keep_disks)
157                 self.service_roots = sorted(set(roots))
158                 logging.debug(str(self.service_roots))
159             finally:
160                 self.lock.release()
161
162         seed = hash
163         pool = self.service_roots[:]
164         pseq = []
165         while len(pool) > 0:
166             if len(seed) < 8:
167                 if len(pseq) < len(hash) / 4: # first time around
168                     seed = hash[-4:] + hash
169                 else:
170                     seed += hash
171             probe = int(seed[0:8], 16) % len(pool)
172             pseq += [pool[probe]]
173             pool = pool[:probe] + pool[probe+1:]
174             seed = seed[8:]
175         logging.debug(str(pseq))
176         return pseq
177
178     class CacheSlot(object):
179         def __init__(self, locator):
180             self.locator = locator
181             self.ready = threading.Event()
182             self.content = None
183
184         def get(self):
185             self.ready.wait()
186             return self.content
187
188         def set(self, value):
189             self.content = value
190             self.ready.set()
191
192         def size(self):
193             if self.content == None:
194                 return 0
195             else:
196                 return len(self.content)
197
198     def cap_cache(self):
199         '''Cap the cache size to self.cache_max'''
200         self._cache_lock.acquire()
201         try:
202             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
203             sm = sum([slot.size() for slot in self._cache])
204             while sm > self.cache_max:
205                 del self._cache[-1]
206                 sm = sum([slot.size() for a in self._cache])
207         finally:
208             self._cache_lock.release()
209
210     def reserve_cache(self, locator):
211         '''Reserve a cache slot for the specified locator, 
212         or return the existing slot.'''
213         self._cache_lock.acquire()
214         try:
215             # Test if the locator is already in the cache
216             for i in xrange(0, len(self._cache)):
217                 if self._cache[i].locator == locator:
218                     n = self._cache[i]
219                     if i != 0:
220                         # move it to the front
221                         del self._cache[i]
222                         self._cache.insert(0, n)
223                     return n, False
224
225             # Add a new cache slot for the locator
226             n = KeepClient.CacheSlot(locator)
227             self._cache.insert(0, n)
228             return n, True
229         finally:
230             self._cache_lock.release()
231
232     def get(self, locator):
233         #logging.debug("Keep.get %s" % (locator))
234
235         if re.search(r',', locator):
236             return ''.join(self.get(x) for x in locator.split(','))
237         if 'KEEP_LOCAL_STORE' in os.environ:
238             return KeepClient.local_store_get(locator)
239         expect_hash = re.sub(r'\+.*', '', locator)
240
241         slot, first = self.reserve_cache(expect_hash)
242         #logging.debug("%s %s %s" % (slot, first, expect_hash))
243
244         if not first:
245             v = slot.get()
246             return v
247
248         try:
249             for service_root in self.shuffled_service_roots(expect_hash):
250                 url = service_root + expect_hash
251                 api_token = config.get('ARVADOS_API_TOKEN')
252                 headers = {'Authorization': "OAuth2 %s" % api_token,
253                            'Accept': 'application/octet-stream'}
254                 blob = self.get_url(url, headers, expect_hash)
255                 if blob:
256                     slot.set(blob)
257                     self.cap_cache()
258                     return blob
259
260             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
261                 instance = location_hint.group(1)
262                 url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
263                 blob = self.get_url(url, {}, expect_hash)
264                 if blob:
265                     slot.set(blob)
266                     self.cap_cache()
267                     return blob
268         except:
269             slot.set(None)
270             self.cap_cache()
271             raise
272
273         slot.set(None)
274         self.cap_cache()
275         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
276
277     def get_url(self, url, headers, expect_hash):
278         h = httplib2.Http()
279         try:
280             logging.info("Request: GET %s" % (url))
281             with timer.Timer() as t:
282                 resp, content = h.request(url.encode('utf-8'), 'GET',
283                                           headers=headers)
284             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content), 
285                                                                         t.msecs, 
286                                                                         (len(content)/(1024*1024))/t.secs))
287             if re.match(r'^2\d\d$', resp['status']):
288                 m = hashlib.new('md5')
289                 m.update(content)
290                 md5 = m.hexdigest()
291                 if md5 == expect_hash:
292                     return content
293                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
294         except Exception as e:
295             logging.info("Request fail: GET %s => %s: %s" %
296                          (url, type(e), str(e)))
297         return None
298
299     def put(self, data, **kwargs):
300         if 'KEEP_LOCAL_STORE' in os.environ:
301             return KeepClient.local_store_put(data)
302         m = hashlib.new('md5')
303         m.update(data)
304         data_hash = m.hexdigest()
305         have_copies = 0
306         want_copies = kwargs.get('copies', 2)
307         if not (want_copies > 0):
308             return data_hash
309         threads = []
310         thread_limiter = KeepClient.ThreadLimiter(want_copies)
311         for service_root in self.shuffled_service_roots(data_hash):
312             t = KeepClient.KeepWriterThread(data=data,
313                                             data_hash=data_hash,
314                                             service_root=service_root,
315                                             thread_limiter=thread_limiter)
316             t.start()
317             threads += [t]
318         for t in threads:
319             t.join()
320         have_copies = thread_limiter.done()
321         if have_copies == want_copies:
322             return (data_hash + '+' + str(len(data)))
323         raise arvados.errors.KeepWriteError(
324             "Write fail for %s: wanted %d but wrote %d" %
325             (data_hash, want_copies, have_copies))
326
327     @staticmethod
328     def sign_for_old_server(data_hash, data):
329         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
330
331
332     @staticmethod
333     def local_store_put(data):
334         m = hashlib.new('md5')
335         m.update(data)
336         md5 = m.hexdigest()
337         locator = '%s+%d' % (md5, len(data))
338         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
339             f.write(data)
340         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
341                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
342         return locator
343
344     @staticmethod
345     def local_store_get(locator):
346         r = re.search('^([0-9a-f]{32,})', locator)
347         if not r:
348             raise arvados.errors.NotFoundError(
349                 "Invalid data locator: '%s'" % locator)
350         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
351             return ''
352         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
353             return f.read()