Hash now uses get_task_param_mount() to read from fuse mount instead of CollectionReader.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._todo_lock = threading.Semaphore(todo)
59             self._done_lock = threading.Lock()
60
61         def __enter__(self):
62             self._todo_lock.acquire()
63             return self
64
65         def __exit__(self, type, value, traceback):
66             self._todo_lock.release()
67
68         def shall_i_proceed(self):
69             """
70             Return true if the current thread should do stuff. Return
71             false if the current thread should just stop.
72             """
73             with self._done_lock:
74                 return (self._done < self._todo)
75
76         def increment_done(self):
77             """
78             Report that the current thread was successful.
79             """
80             with self._done_lock:
81                 self._done += 1
82
83         def done(self):
84             """
85             Return how many successes were reported.
86             """
87             with self._done_lock:
88                 return self._done
89
90     class KeepWriterThread(threading.Thread):
91         """
92         Write a blob of data to the given Keep server. Call
93         increment_done() of the given ThreadLimiter if the write
94         succeeds.
95         """
96         def __init__(self, **kwargs):
97             super(KeepClient.KeepWriterThread, self).__init__()
98             self.args = kwargs
99
100         def run(self):
101             with self.args['thread_limiter'] as limiter:
102                 if not limiter.shall_i_proceed():
103                     # My turn arrived, but the job has been done without
104                     # me.
105                     return
106                 logging.debug("KeepWriterThread %s proceeding %s %s" %
107                               (str(threading.current_thread()),
108                                self.args['data_hash'],
109                                self.args['service_root']))
110                 h = httplib2.Http()
111                 url = self.args['service_root'] + self.args['data_hash']
112                 api_token = config.get('ARVADOS_API_TOKEN')
113                 headers = {'Authorization': "OAuth2 %s" % api_token}
114                 try:
115                     resp, content = h.request(url.encode('utf-8'), 'PUT',
116                                               headers=headers,
117                                               body=self.args['data'])
118                     if (resp['status'] == '401' and
119                         re.match(r'Timestamp verification failed', content)):
120                         body = KeepClient.sign_for_old_server(
121                             self.args['data_hash'],
122                             self.args['data'])
123                         h = httplib2.Http()
124                         resp, content = h.request(url.encode('utf-8'), 'PUT',
125                                                   headers=headers,
126                                                   body=body)
127                     if re.match(r'^2\d\d$', resp['status']):
128                         logging.debug("KeepWriterThread %s succeeded %s %s" %
129                                       (str(threading.current_thread()),
130                                        self.args['data_hash'],
131                                        self.args['service_root']))
132                         return limiter.increment_done()
133                     logging.warning("Request fail: PUT %s => %s %s" %
134                                     (url, resp['status'], content))
135                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
136                     logging.warning("Request fail: PUT %s => %s: %s" %
137                                     (url, type(e), str(e)))
138
139     def __init__(self):
140         self.lock = threading.Lock()
141         self.service_roots = None
142         self._cache_lock = threading.Lock()
143         self._cache = []
144         # default 256 megabyte cache
145         self.cache_max = 256 * 1024 * 1024
146
147     def shuffled_service_roots(self, hash):
148         if self.service_roots == None:
149             self.lock.acquire()
150             try:
151                 keep_disks = api().keep_disks().list().execute()['items']
152                 roots = (("http%s://%s:%d/" %
153                           ('s' if f['service_ssl_flag'] else '',
154                            f['service_host'],
155                            f['service_port']))
156                          for f in keep_disks)
157                 self.service_roots = sorted(set(roots))
158                 logging.debug(str(self.service_roots))
159             finally:
160                 self.lock.release()
161         seed = hash
162         pool = self.service_roots[:]
163         pseq = []
164         while len(pool) > 0:
165             if len(seed) < 8:
166                 if len(pseq) < len(hash) / 4: # first time around
167                     seed = hash[-4:] + hash
168                 else:
169                     seed += hash
170             probe = int(seed[0:8], 16) % len(pool)
171             pseq += [pool[probe]]
172             pool = pool[:probe] + pool[probe+1:]
173             seed = seed[8:]
174         logging.debug(str(pseq))
175         return pseq
176
177     class CacheSlot(object):
178         def __init__(self, locator):
179             self.locator = locator
180             self.ready = threading.Event()
181             self.content = None
182
183         def get(self):
184             self.ready.wait()
185             return self.content
186
187         def set(self, value):
188             self.content = value
189             self.ready.set()
190
191         def size(self):
192             if self.content == None:
193                 return 0
194             else:
195                 return len(self.content)
196
197     def cap_cache(self):
198         '''Cap the cache size to self.cache_max'''
199         self._cache_lock.acquire()
200         try:
201             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
202             sm = sum([slot.size() for slot in self._cache])
203             while sm > self.cache_max:
204                 del self._cache[-1]
205                 sm = sum([slot.size() for a in self._cache])
206         finally:
207             self._cache_lock.release()
208
209     def reserve_cache(self, locator):
210         '''Reserve a cache slot for the specified locator, 
211         or return the existing slot.'''
212         self._cache_lock.acquire()
213         try:
214             # Test if the locator is already in the cache
215             for i in xrange(0, len(self._cache)):
216                 if self._cache[i].locator == locator:
217                     n = self._cache[i]
218                     if i != 0:
219                         # move it to the front
220                         del self._cache[i]
221                         self._cache.insert(0, n)
222                     return n, False
223
224             # Add a new cache slot for the locator
225             n = KeepClient.CacheSlot(locator)
226             self._cache.insert(0, n)
227             return n, True
228         finally:
229             self._cache_lock.release()
230
231     def get(self, locator):
232         #logging.debug("Keep.get %s" % (locator))
233
234         if re.search(r',', locator):
235             return ''.join(self.get(x) for x in locator.split(','))
236         if 'KEEP_LOCAL_STORE' in os.environ:
237             return KeepClient.local_store_get(locator)
238         expect_hash = re.sub(r'\+.*', '', locator)
239
240         slot, first = self.reserve_cache(expect_hash)
241         #logging.debug("%s %s %s" % (slot, first, expect_hash))
242
243         if not first:
244             v = slot.get()
245             return v
246
247         try:
248             for service_root in self.shuffled_service_roots(expect_hash):
249                 url = service_root + expect_hash
250                 api_token = config.get('ARVADOS_API_TOKEN')
251                 headers = {'Authorization': "OAuth2 %s" % api_token,
252                            'Accept': 'application/octet-stream'}
253                 blob = self.get_url(url, headers, expect_hash)
254                 if blob:
255                     slot.set(blob)
256                     self.cap_cache()
257                     return blob
258
259             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
260                 instance = location_hint.group(1)
261                 url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
262                 blob = self.get_url(url, {}, expect_hash)
263                 if blob:
264                     slot.set(blob)
265                     self.cap_cache()
266                     return blob
267         except:
268             slot.set(None)
269             self.cap_cache()
270             raise
271
272         slot.set(None)
273         self.cap_cache()
274         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
275
276     def get_url(self, url, headers, expect_hash):
277         h = httplib2.Http()
278         try:
279             logging.info("Request: GET %s" % (url))
280             with timer.Timer() as t:
281                 resp, content = h.request(url.encode('utf-8'), 'GET',
282                                           headers=headers)
283             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content), 
284                                                                         t.msecs, 
285                                                                         (len(content)/(1024*1024))/t.secs))
286             if re.match(r'^2\d\d$', resp['status']):
287                 m = hashlib.new('md5')
288                 m.update(content)
289                 md5 = m.hexdigest()
290                 if md5 == expect_hash:
291                     return content
292                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
293         except Exception as e:
294             logging.info("Request fail: GET %s => %s: %s" %
295                          (url, type(e), str(e)))
296         return None
297
298     def put(self, data, **kwargs):
299         if 'KEEP_LOCAL_STORE' in os.environ:
300             return KeepClient.local_store_put(data)
301         m = hashlib.new('md5')
302         m.update(data)
303         data_hash = m.hexdigest()
304         have_copies = 0
305         want_copies = kwargs.get('copies', 2)
306         if not (want_copies > 0):
307             return data_hash
308         threads = []
309         thread_limiter = KeepClient.ThreadLimiter(want_copies)
310         for service_root in self.shuffled_service_roots(data_hash):
311             t = KeepClient.KeepWriterThread(data=data,
312                                             data_hash=data_hash,
313                                             service_root=service_root,
314                                             thread_limiter=thread_limiter)
315             t.start()
316             threads += [t]
317         for t in threads:
318             t.join()
319         have_copies = thread_limiter.done()
320         if have_copies == want_copies:
321             return (data_hash + '+' + str(len(data)))
322         raise arvados.errors.KeepWriteError(
323             "Write fail for %s: wanted %d but wrote %d" %
324             (data_hash, want_copies, have_copies))
325
326     @staticmethod
327     def sign_for_old_server(data_hash, data):
328         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
329
330
331     @staticmethod
332     def local_store_put(data):
333         m = hashlib.new('md5')
334         m.update(data)
335         md5 = m.hexdigest()
336         locator = '%s+%d' % (md5, len(data))
337         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
338             f.write(data)
339         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
340                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
341         return locator
342
343     @staticmethod
344     def local_store_get(locator):
345         r = re.search('^([0-9a-f]{32,})', locator)
346         if not r:
347             raise arvados.errors.NotFoundError(
348                 "Invalid data locator: '%s'" % locator)
349         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
350             return ''
351         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
352             return f.read()