]> git.arvados.org - arvados.git/blob - sdk/python/arvados/keep.py
Added readfrom()
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 global_client_object = None
22
23 from api import *
24 import config
25 import arvados.errors
26
27 class Keep:
28     @staticmethod
29     def global_client_object():
30         global global_client_object
31         if global_client_object == None:
32             global_client_object = KeepClient()
33         return global_client_object
34
35     @staticmethod
36     def get(locator, **kwargs):
37         return Keep.global_client_object().get(locator, **kwargs)
38
39     @staticmethod
40     def put(data, **kwargs):
41         return Keep.global_client_object().put(data, **kwargs)
42
43 class KeepClient(object):
44
45     class ThreadLimiter(object):
46         """
47         Limit the number of threads running at a given time to
48         {desired successes} minus {successes reported}. When successes
49         reported == desired, wake up the remaining threads and tell
50         them to quit.
51
52         Should be used in a "with" block.
53         """
54         def __init__(self, todo):
55             self._todo = todo
56             self._done = 0
57             self._todo_lock = threading.Semaphore(todo)
58             self._done_lock = threading.Lock()
59
60         def __enter__(self):
61             self._todo_lock.acquire()
62             return self
63
64         def __exit__(self, type, value, traceback):
65             self._todo_lock.release()
66
67         def shall_i_proceed(self):
68             """
69             Return true if the current thread should do stuff. Return
70             false if the current thread should just stop.
71             """
72             with self._done_lock:
73                 return (self._done < self._todo)
74
75         def increment_done(self):
76             """
77             Report that the current thread was successful.
78             """
79             with self._done_lock:
80                 self._done += 1
81
82         def done(self):
83             """
84             Return how many successes were reported.
85             """
86             with self._done_lock:
87                 return self._done
88
89     class KeepWriterThread(threading.Thread):
90         """
91         Write a blob of data to the given Keep server. Call
92         increment_done() of the given ThreadLimiter if the write
93         succeeds.
94         """
95         def __init__(self, **kwargs):
96             super(KeepClient.KeepWriterThread, self).__init__()
97             self.args = kwargs
98
99         def run(self):
100             with self.args['thread_limiter'] as limiter:
101                 if not limiter.shall_i_proceed():
102                     # My turn arrived, but the job has been done without
103                     # me.
104                     return
105                 logging.debug("KeepWriterThread %s proceeding %s %s" %
106                               (str(threading.current_thread()),
107                                self.args['data_hash'],
108                                self.args['service_root']))
109                 h = httplib2.Http()
110                 url = self.args['service_root'] + self.args['data_hash']
111                 api_token = config.get('ARVADOS_API_TOKEN')
112                 headers = {'Authorization': "OAuth2 %s" % api_token}
113                 try:
114                     resp, content = h.request(url.encode('utf-8'), 'PUT',
115                                               headers=headers,
116                                               body=self.args['data'])
117                     if (resp['status'] == '401' and
118                         re.match(r'Timestamp verification failed', content)):
119                         body = KeepClient.sign_for_old_server(
120                             self.args['data_hash'],
121                             self.args['data'])
122                         h = httplib2.Http()
123                         resp, content = h.request(url.encode('utf-8'), 'PUT',
124                                                   headers=headers,
125                                                   body=body)
126                     if re.match(r'^2\d\d$', resp['status']):
127                         logging.debug("KeepWriterThread %s succeeded %s %s" %
128                                       (str(threading.current_thread()),
129                                        self.args['data_hash'],
130                                        self.args['service_root']))
131                         return limiter.increment_done()
132                     logging.warning("Request fail: PUT %s => %s %s" %
133                                     (url, resp['status'], content))
134                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
135                     logging.warning("Request fail: PUT %s => %s: %s" %
136                                     (url, type(e), str(e)))
137
138     def __init__(self):
139         self.lock = threading.Lock()
140         self.service_roots = None
141         self._cache_lock = threading.Lock()
142         self._cache = []
143         # default 256 megabyte cache
144         self._cache_max = 256 * 1024 * 1024
145
146     def shuffled_service_roots(self, hash):
147         if self.service_roots == None:
148             self.lock.acquire()
149             keep_disks = api().keep_disks().list().execute()['items']
150             roots = (("http%s://%s:%d/" %
151                       ('s' if f['service_ssl_flag'] else '',
152                        f['service_host'],
153                        f['service_port']))
154                      for f in keep_disks)
155             self.service_roots = sorted(set(roots))
156             logging.debug(str(self.service_roots))
157             self.lock.release()
158         seed = hash
159         pool = self.service_roots[:]
160         pseq = []
161         while len(pool) > 0:
162             if len(seed) < 8:
163                 if len(pseq) < len(hash) / 4: # first time around
164                     seed = hash[-4:] + hash
165                 else:
166                     seed += hash
167             probe = int(seed[0:8], 16) % len(pool)
168             pseq += [pool[probe]]
169             pool = pool[:probe] + pool[probe+1:]
170             seed = seed[8:]
171         logging.debug(str(pseq))
172         return pseq
173
174     def get(self, locator):
175         if re.search(r',', locator):
176             return ''.join(self.get(x) for x in locator.split(','))
177         if 'KEEP_LOCAL_STORE' in os.environ:
178             return KeepClient.local_store_get(locator)
179         expect_hash = re.sub(r'\+.*', '', locator)
180
181         c = self.check_cache(expect_hash)
182         if c:
183             return c
184
185         for service_root in self.shuffled_service_roots(expect_hash):
186             url = service_root + expect_hash
187             api_token = config.get('ARVADOS_API_TOKEN')
188             headers = {'Authorization': "OAuth2 %s" % api_token,
189                        'Accept': 'application/octet-stream'}
190             blob = self.get_url(url, headers, expect_hash)
191             if blob:
192                 self.put_cache(expect_hash, blob)
193                 return blob
194
195         for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
196             instance = location_hint.group(1)
197             url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
198             blob = self.get_url(url, {}, expect_hash)
199             if blob:
200                 self.put_cache(expect_hash, blob)
201                 return blob
202         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
203
204     def get_url(self, url, headers, expect_hash):
205         h = httplib2.Http()
206         try:
207             resp, content = h.request(url.encode('utf-8'), 'GET',
208                                       headers=headers)
209             if re.match(r'^2\d\d$', resp['status']):
210                 m = hashlib.new('md5')
211                 m.update(content)
212                 md5 = m.hexdigest()
213                 if md5 == expect_hash:
214                     return content
215                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
216         except Exception as e:
217             logging.info("Request fail: GET %s => %s: %s" %
218                          (url, type(e), str(e)))
219         return None
220
221     def put(self, data, **kwargs):
222         if 'KEEP_LOCAL_STORE' in os.environ:
223             return KeepClient.local_store_put(data)
224         m = hashlib.new('md5')
225         m.update(data)
226         data_hash = m.hexdigest()
227         have_copies = 0
228         want_copies = kwargs.get('copies', 2)
229         if not (want_copies > 0):
230             return data_hash
231         threads = []
232         thread_limiter = KeepClient.ThreadLimiter(want_copies)
233         for service_root in self.shuffled_service_roots(data_hash):
234             t = KeepClient.KeepWriterThread(data=data,
235                                             data_hash=data_hash,
236                                             service_root=service_root,
237                                             thread_limiter=thread_limiter)
238             t.start()
239             threads += [t]
240         for t in threads:
241             t.join()
242         have_copies = thread_limiter.done()
243         if have_copies == want_copies:
244             return (data_hash + '+' + str(len(data)))
245         raise arvados.errors.KeepWriteError(
246             "Write fail for %s: wanted %d but wrote %d" %
247             (data_hash, want_copies, have_copies))
248
249     def put_cache(self, locator, data):
250         """Put a block into the cache."""
251         if self.check_cache(locator) != None:
252             return
253         self.cache_lock.acquire()
254         try:
255             # first check cache size and delete stuff from the end if necessary
256             sm = sum([len(a[1]) for a in self._cache]) + len(data)
257             while sum > self._cache_max:
258                 del self._cache[-1]
259                 sm = sum([len(a[1]) for a in self._cache]) + len(data)
260
261             # now add the new block at the front of the list
262             self._cache.insert(0, [locator, data])
263         finally:
264             self.cache_lock.release()
265
266     def check_cache(self, locator):
267         """Get a block from the cache.  Also moves the block to the front of the list."""
268         self._cache_lock.acquire()
269         try:
270             for i in xrange(0, len(self._cache)):
271                 if self._cache[i][0] == locator:
272                     n = self._cache[i]
273                     del self._cache[i]
274                     self._cache.insert(0, n)
275                     return n[1]   
276         finally:
277             self.cache_lock.release()
278         return None            
279
280     @staticmethod
281     def sign_for_old_server(data_hash, data):
282         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
283
284
285     @staticmethod
286     def local_store_put(data):
287         m = hashlib.new('md5')
288         m.update(data)
289         md5 = m.hexdigest()
290         locator = '%s+%d' % (md5, len(data))
291         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
292             f.write(data)
293         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
294                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
295         return locator
296
297     @staticmethod
298     def local_store_get(locator):
299         r = re.search('^([0-9a-f]{32,})', locator)
300         if not r:
301             raise arvados.errors.NotFoundError(
302                 "Invalid data locator: '%s'" % locator)
303         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
304             return ''
305         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
306             return f.read()