Fuse driver works for mounting collections and reading files. Tested with jlake...
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 global_client_object = None
22
23 from api import *
24 import config
25 import arvados.errors
26
27 class Keep:
28     @staticmethod
29     def global_client_object():
30         global global_client_object
31         if global_client_object == None:
32             global_client_object = KeepClient()
33         return global_client_object
34
35     @staticmethod
36     def get(locator, **kwargs):
37         return Keep.global_client_object().get(locator, **kwargs)
38
39     @staticmethod
40     def put(data, **kwargs):
41         return Keep.global_client_object().put(data, **kwargs)
42
43 class KeepClient(object):
44
45     class ThreadLimiter(object):
46         """
47         Limit the number of threads running at a given time to
48         {desired successes} minus {successes reported}. When successes
49         reported == desired, wake up the remaining threads and tell
50         them to quit.
51
52         Should be used in a "with" block.
53         """
54         def __init__(self, todo):
55             self._todo = todo
56             self._done = 0
57             self._todo_lock = threading.Semaphore(todo)
58             self._done_lock = threading.Lock()
59
60         def __enter__(self):
61             self._todo_lock.acquire()
62             return self
63
64         def __exit__(self, type, value, traceback):
65             self._todo_lock.release()
66
67         def shall_i_proceed(self):
68             """
69             Return true if the current thread should do stuff. Return
70             false if the current thread should just stop.
71             """
72             with self._done_lock:
73                 return (self._done < self._todo)
74
75         def increment_done(self):
76             """
77             Report that the current thread was successful.
78             """
79             with self._done_lock:
80                 self._done += 1
81
82         def done(self):
83             """
84             Return how many successes were reported.
85             """
86             with self._done_lock:
87                 return self._done
88
89     class KeepWriterThread(threading.Thread):
90         """
91         Write a blob of data to the given Keep server. Call
92         increment_done() of the given ThreadLimiter if the write
93         succeeds.
94         """
95         def __init__(self, **kwargs):
96             super(KeepClient.KeepWriterThread, self).__init__()
97             self.args = kwargs
98
99         def run(self):
100             with self.args['thread_limiter'] as limiter:
101                 if not limiter.shall_i_proceed():
102                     # My turn arrived, but the job has been done without
103                     # me.
104                     return
105                 logging.debug("KeepWriterThread %s proceeding %s %s" %
106                               (str(threading.current_thread()),
107                                self.args['data_hash'],
108                                self.args['service_root']))
109                 h = httplib2.Http()
110                 url = self.args['service_root'] + self.args['data_hash']
111                 api_token = config.get('ARVADOS_API_TOKEN')
112                 headers = {'Authorization': "OAuth2 %s" % api_token}
113                 try:
114                     resp, content = h.request(url.encode('utf-8'), 'PUT',
115                                               headers=headers,
116                                               body=self.args['data'])
117                     if (resp['status'] == '401' and
118                         re.match(r'Timestamp verification failed', content)):
119                         body = KeepClient.sign_for_old_server(
120                             self.args['data_hash'],
121                             self.args['data'])
122                         h = httplib2.Http()
123                         resp, content = h.request(url.encode('utf-8'), 'PUT',
124                                                   headers=headers,
125                                                   body=body)
126                     if re.match(r'^2\d\d$', resp['status']):
127                         logging.debug("KeepWriterThread %s succeeded %s %s" %
128                                       (str(threading.current_thread()),
129                                        self.args['data_hash'],
130                                        self.args['service_root']))
131                         return limiter.increment_done()
132                     logging.warning("Request fail: PUT %s => %s %s" %
133                                     (url, resp['status'], content))
134                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
135                     logging.warning("Request fail: PUT %s => %s: %s" %
136                                     (url, type(e), str(e)))
137
138     def __init__(self):
139         self.lock = threading.Lock()
140         self.service_roots = None
141         self._cache_lock = threading.Lock()
142         self._cache = []
143         # default 256 megabyte cache
144         self._cache_max = 256 * 1024 * 1024
145
146     def shuffled_service_roots(self, hash):
147         if self.service_roots == None:
148             self.lock.acquire()
149             try:
150                 keep_disks = api().keep_disks().list().execute()['items']
151                 roots = (("http%s://%s:%d/" %
152                           ('s' if f['service_ssl_flag'] else '',
153                            f['service_host'],
154                            f['service_port']))
155                          for f in keep_disks)
156                 self.service_roots = sorted(set(roots))
157                 logging.debug(str(self.service_roots))
158             finally:
159                 self.lock.release()
160         seed = hash
161         pool = self.service_roots[:]
162         pseq = []
163         while len(pool) > 0:
164             if len(seed) < 8:
165                 if len(pseq) < len(hash) / 4: # first time around
166                     seed = hash[-4:] + hash
167                 else:
168                     seed += hash
169             probe = int(seed[0:8], 16) % len(pool)
170             pseq += [pool[probe]]
171             pool = pool[:probe] + pool[probe+1:]
172             seed = seed[8:]
173         logging.debug(str(pseq))
174         return pseq
175
176     def get(self, locator):
177         if re.search(r',', locator):
178             return ''.join(self.get(x) for x in locator.split(','))
179         if 'KEEP_LOCAL_STORE' in os.environ:
180             return KeepClient.local_store_get(locator)
181         expect_hash = re.sub(r'\+.*', '', locator)
182
183         c = self.check_cache(expect_hash)
184         if c:
185             return c
186
187         for service_root in self.shuffled_service_roots(expect_hash):
188             url = service_root + expect_hash
189             api_token = config.get('ARVADOS_API_TOKEN')
190             headers = {'Authorization': "OAuth2 %s" % api_token,
191                        'Accept': 'application/octet-stream'}
192             blob = self.get_url(url, headers, expect_hash)
193             if blob:
194                 self.put_cache(expect_hash, blob)
195                 return blob
196
197         for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
198             instance = location_hint.group(1)
199             url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
200             blob = self.get_url(url, {}, expect_hash)
201             if blob:
202                 self.put_cache(expect_hash, blob)
203                 return blob
204         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
205
206     def get_url(self, url, headers, expect_hash):
207         h = httplib2.Http()
208         try:
209             resp, content = h.request(url.encode('utf-8'), 'GET',
210                                       headers=headers)
211             if re.match(r'^2\d\d$', resp['status']):
212                 m = hashlib.new('md5')
213                 m.update(content)
214                 md5 = m.hexdigest()
215                 if md5 == expect_hash:
216                     return content
217                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
218         except Exception as e:
219             logging.info("Request fail: GET %s => %s: %s" %
220                          (url, type(e), str(e)))
221         return None
222
223     def put(self, data, **kwargs):
224         if 'KEEP_LOCAL_STORE' in os.environ:
225             return KeepClient.local_store_put(data)
226         m = hashlib.new('md5')
227         m.update(data)
228         data_hash = m.hexdigest()
229         have_copies = 0
230         want_copies = kwargs.get('copies', 2)
231         if not (want_copies > 0):
232             return data_hash
233         threads = []
234         thread_limiter = KeepClient.ThreadLimiter(want_copies)
235         for service_root in self.shuffled_service_roots(data_hash):
236             t = KeepClient.KeepWriterThread(data=data,
237                                             data_hash=data_hash,
238                                             service_root=service_root,
239                                             thread_limiter=thread_limiter)
240             t.start()
241             threads += [t]
242         for t in threads:
243             t.join()
244         have_copies = thread_limiter.done()
245         if have_copies == want_copies:
246             return (data_hash + '+' + str(len(data)))
247         raise arvados.errors.KeepWriteError(
248             "Write fail for %s: wanted %d but wrote %d" %
249             (data_hash, want_copies, have_copies))
250
251     def put_cache(self, locator, data):
252         """Put a block into the cache."""
253         if self.check_cache(locator) != None:
254             return
255         self._cache_lock.acquire()
256         try:
257             # first check cache size and delete stuff from the end if necessary
258             sm = sum([len(a[1]) for a in self._cache]) + len(data)
259             while sm > self._cache_max:
260                 print sm, self._cache_max
261                 del self._cache[-1]
262                 sm = sum([len(a[1]) for a in self._cache]) + len(data)
263
264             # now add the new block at the front of the list
265             self._cache.insert(0, [locator, data])
266         finally:
267             self._cache_lock.release()
268
269     def check_cache(self, locator):
270         """Get a block from the cache.  Also moves the block to the front of the list."""
271         self._cache_lock.acquire()
272         try:
273             for i in xrange(0, len(self._cache)):
274                 if self._cache[i][0] == locator:
275                     n = self._cache[i]
276                     del self._cache[i]
277                     self._cache.insert(0, n)
278                     return n[1]   
279         finally:
280             self._cache_lock.release()
281         return None            
282
283     @staticmethod
284     def sign_for_old_server(data_hash, data):
285         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
286
287
288     @staticmethod
289     def local_store_put(data):
290         m = hashlib.new('md5')
291         m.update(data)
292         md5 = m.hexdigest()
293         locator = '%s+%d' % (md5, len(data))
294         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
295             f.write(data)
296         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
297                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
298         return locator
299
300     @staticmethod
301     def local_store_get(locator):
302         r = re.search('^([0-9a-f]{32,})', locator)
303         if not r:
304             raise arvados.errors.NotFoundError(
305                 "Invalid data locator: '%s'" % locator)
306         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
307             return ''
308         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
309             return f.read()