Merge branch 'master' of git.curoverse.com:arvados refs #1885
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._todo_lock = threading.Semaphore(todo)
59             self._done_lock = threading.Lock()
60
61         def __enter__(self):
62             self._todo_lock.acquire()
63             return self
64
65         def __exit__(self, type, value, traceback):
66             self._todo_lock.release()
67
68         def shall_i_proceed(self):
69             """
70             Return true if the current thread should do stuff. Return
71             false if the current thread should just stop.
72             """
73             with self._done_lock:
74                 return (self._done < self._todo)
75
76         def increment_done(self):
77             """
78             Report that the current thread was successful.
79             """
80             with self._done_lock:
81                 self._done += 1
82
83         def done(self):
84             """
85             Return how many successes were reported.
86             """
87             with self._done_lock:
88                 return self._done
89
90     class KeepWriterThread(threading.Thread):
91         """
92         Write a blob of data to the given Keep server. Call
93         increment_done() of the given ThreadLimiter if the write
94         succeeds.
95         """
96         def __init__(self, **kwargs):
97             super(KeepClient.KeepWriterThread, self).__init__()
98             self.args = kwargs
99
100         def run(self):
101             with self.args['thread_limiter'] as limiter:
102                 if not limiter.shall_i_proceed():
103                     # My turn arrived, but the job has been done without
104                     # me.
105                     return
106                 logging.debug("KeepWriterThread %s proceeding %s %s" %
107                               (str(threading.current_thread()),
108                                self.args['data_hash'],
109                                self.args['service_root']))
110                 h = httplib2.Http()
111                 url = self.args['service_root'] + self.args['data_hash']
112                 api_token = config.get('ARVADOS_API_TOKEN')
113                 headers = {'Authorization': "OAuth2 %s" % api_token}
114                 try:
115                     resp, content = h.request(url.encode('utf-8'), 'PUT',
116                                               headers=headers,
117                                               body=self.args['data'])
118                     if (resp['status'] == '401' and
119                         re.match(r'Timestamp verification failed', content)):
120                         body = KeepClient.sign_for_old_server(
121                             self.args['data_hash'],
122                             self.args['data'])
123                         h = httplib2.Http()
124                         resp, content = h.request(url.encode('utf-8'), 'PUT',
125                                                   headers=headers,
126                                                   body=body)
127                     if re.match(r'^2\d\d$', resp['status']):
128                         logging.debug("KeepWriterThread %s succeeded %s %s" %
129                                       (str(threading.current_thread()),
130                                        self.args['data_hash'],
131                                        self.args['service_root']))
132                         return limiter.increment_done()
133                     logging.warning("Request fail: PUT %s => %s %s" %
134                                     (url, resp['status'], content))
135                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
136                     logging.warning("Request fail: PUT %s => %s: %s" %
137                                     (url, type(e), str(e)))
138
139     def __init__(self):
140         self.lock = threading.Lock()
141         self.service_roots = None
142         self._cache_lock = threading.Lock()
143         self._cache = []
144         # default 256 megabyte cache
145         self.cache_max = 256 * 1024 * 1024
146
147     def shuffled_service_roots(self, hash):
148         if self.service_roots == None:
149             self.lock.acquire()
150             try:
151                 keep_disks = arvados.api().keep_disks().list().execute()['items']
152                 roots = (("http%s://%s:%d/" %
153                           ('s' if f['service_ssl_flag'] else '',
154                            f['service_host'],
155                            f['service_port']))
156                          for f in keep_disks)
157                 self.service_roots = sorted(set(roots))
158                 logging.debug(str(self.service_roots))
159             finally:
160                 self.lock.release()
161
162         # Build an ordering with which to query the Keep servers based on the
163         # contents of the hash.
164         # "hash" is a hex-encoded number at least 8 digits
165         # (32 bits) long
166
167         # seed used to calculate the next keep server from 'pool'
168         # to be added to 'pseq'
169         seed = hash
170
171         # Keep servers still to be added to the ordering
172         pool = self.service_roots[:]
173
174         # output probe sequence
175         pseq = []
176
177         # iterate while there are servers left to be assigned
178         while len(pool) > 0:
179             if len(seed) < 8:
180                 # ran out of digits in the seed
181                 if len(pseq) < len(hash) / 4:
182                     # the number of servers added to the probe sequence is less
183                     # than the number of 4-digit slices in 'hash' so refill the
184                     # seed with the last 4 digits and then append the contents
185                     # of 'hash'.
186                     seed = hash[-4:] + hash
187                 else:
188                     # refill the seed with the contents of 'hash'
189                     seed += hash
190
191             # Take the next 8 digits (32 bytes) and interpret as an integer,
192             # then modulus with the size of the remaining pool to get the next
193             # selected server.
194             probe = int(seed[0:8], 16) % len(pool)
195
196             # Append the selected server to the probe sequence and remove it
197             # from the pool.
198             pseq += [pool[probe]]
199             pool = pool[:probe] + pool[probe+1:]
200
201             # Remove the digits just used from the seed
202             seed = seed[8:]
203         logging.debug(str(pseq))
204         return pseq
205
206     class CacheSlot(object):
207         def __init__(self, locator):
208             self.locator = locator
209             self.ready = threading.Event()
210             self.content = None
211
212         def get(self):
213             self.ready.wait()
214             return self.content
215
216         def set(self, value):
217             self.content = value
218             self.ready.set()
219
220         def size(self):
221             if self.content == None:
222                 return 0
223             else:
224                 return len(self.content)
225
226     def cap_cache(self):
227         '''Cap the cache size to self.cache_max'''
228         self._cache_lock.acquire()
229         try:
230             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
231             sm = sum([slot.size() for slot in self._cache])
232             while sm > self.cache_max:
233                 del self._cache[-1]
234                 sm = sum([slot.size() for a in self._cache])
235         finally:
236             self._cache_lock.release()
237
238     def reserve_cache(self, locator):
239         '''Reserve a cache slot for the specified locator,
240         or return the existing slot.'''
241         self._cache_lock.acquire()
242         try:
243             # Test if the locator is already in the cache
244             for i in xrange(0, len(self._cache)):
245                 if self._cache[i].locator == locator:
246                     n = self._cache[i]
247                     if i != 0:
248                         # move it to the front
249                         del self._cache[i]
250                         self._cache.insert(0, n)
251                     return n, False
252
253             # Add a new cache slot for the locator
254             n = KeepClient.CacheSlot(locator)
255             self._cache.insert(0, n)
256             return n, True
257         finally:
258             self._cache_lock.release()
259
260     def get(self, locator):
261         #logging.debug("Keep.get %s" % (locator))
262
263         if re.search(r',', locator):
264             return ''.join(self.get(x) for x in locator.split(','))
265         if 'KEEP_LOCAL_STORE' in os.environ:
266             return KeepClient.local_store_get(locator)
267         expect_hash = re.sub(r'\+.*', '', locator)
268
269         slot, first = self.reserve_cache(expect_hash)
270         #logging.debug("%s %s %s" % (slot, first, expect_hash))
271
272         if not first:
273             v = slot.get()
274             return v
275
276         try:
277             for service_root in self.shuffled_service_roots(expect_hash):
278                 url = service_root + expect_hash
279                 api_token = config.get('ARVADOS_API_TOKEN')
280                 headers = {'Authorization': "OAuth2 %s" % api_token,
281                            'Accept': 'application/octet-stream'}
282                 blob = self.get_url(url, headers, expect_hash)
283                 if blob:
284                     slot.set(blob)
285                     self.cap_cache()
286                     return blob
287
288             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
289                 instance = location_hint.group(1)
290                 url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
291                 blob = self.get_url(url, {}, expect_hash)
292                 if blob:
293                     slot.set(blob)
294                     self.cap_cache()
295                     return blob
296         except:
297             slot.set(None)
298             self.cap_cache()
299             raise
300
301         slot.set(None)
302         self.cap_cache()
303         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
304
305     def get_url(self, url, headers, expect_hash):
306         h = httplib2.Http()
307         try:
308             logging.info("Request: GET %s" % (url))
309             with timer.Timer() as t:
310                 resp, content = h.request(url.encode('utf-8'), 'GET',
311                                           headers=headers)
312             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
313                                                                         t.msecs,
314                                                                         (len(content)/(1024*1024))/t.secs))
315             if re.match(r'^2\d\d$', resp['status']):
316                 m = hashlib.new('md5')
317                 m.update(content)
318                 md5 = m.hexdigest()
319                 if md5 == expect_hash:
320                     return content
321                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
322         except Exception as e:
323             logging.info("Request fail: GET %s => %s: %s" %
324                          (url, type(e), str(e)))
325         return None
326
327     def put(self, data, **kwargs):
328         if 'KEEP_LOCAL_STORE' in os.environ:
329             return KeepClient.local_store_put(data)
330         m = hashlib.new('md5')
331         m.update(data)
332         data_hash = m.hexdigest()
333         have_copies = 0
334         want_copies = kwargs.get('copies', 2)
335         if not (want_copies > 0):
336             return data_hash
337         threads = []
338         thread_limiter = KeepClient.ThreadLimiter(want_copies)
339         for service_root in self.shuffled_service_roots(data_hash):
340             t = KeepClient.KeepWriterThread(data=data,
341                                             data_hash=data_hash,
342                                             service_root=service_root,
343                                             thread_limiter=thread_limiter)
344             t.start()
345             threads += [t]
346         for t in threads:
347             t.join()
348         have_copies = thread_limiter.done()
349         if have_copies == want_copies:
350             return (data_hash + '+' + str(len(data)))
351         raise arvados.errors.KeepWriteError(
352             "Write fail for %s: wanted %d but wrote %d" %
353             (data_hash, want_copies, have_copies))
354
355     @staticmethod
356     def sign_for_old_server(data_hash, data):
357         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
358
359
360     @staticmethod
361     def local_store_put(data):
362         m = hashlib.new('md5')
363         m.update(data)
364         md5 = m.hexdigest()
365         locator = '%s+%d' % (md5, len(data))
366         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
367             f.write(data)
368         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
369                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
370         return locator
371
372     @staticmethod
373     def local_store_get(locator):
374         r = re.search('^([0-9a-f]{32,})', locator)
375         if not r:
376             raise arvados.errors.NotFoundError(
377                 "Invalid data locator: '%s'" % locator)
378         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
379             return ''
380         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
381             return f.read()