2752: Remove trailing whitespace in arv-put.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._response = None
59             self._todo_lock = threading.Semaphore(todo)
60             self._done_lock = threading.Lock()
61
62         def __enter__(self):
63             self._todo_lock.acquire()
64             return self
65
66         def __exit__(self, type, value, traceback):
67             self._todo_lock.release()
68
69         def shall_i_proceed(self):
70             """
71             Return true if the current thread should do stuff. Return
72             false if the current thread should just stop.
73             """
74             with self._done_lock:
75                 return (self._done < self._todo)
76
77         def save_response(self, response_body, replicas_stored):
78             """
79             Records a response body (a locator, possibly signed) returned by
80             the Keep server.  It is not necessary to save more than
81             one response, since we presume that any locator returned
82             in response to a successful request is valid.
83             """
84             with self._done_lock:
85                 self._done += replicas_stored
86                 self._response = response_body
87
88         def response(self):
89             """
90             Returns the body from the response to a PUT request.
91             """
92             with self._done_lock:
93                 return self._response
94
95         def done(self):
96             """
97             Return how many successes were reported.
98             """
99             with self._done_lock:
100                 return self._done
101
102     class KeepWriterThread(threading.Thread):
103         """
104         Write a blob of data to the given Keep server. On success, call
105         save_response() of the given ThreadLimiter to save the returned
106         locator.
107         """
108         def __init__(self, **kwargs):
109             super(KeepClient.KeepWriterThread, self).__init__()
110             self.args = kwargs
111
112         def run(self):
113             with self.args['thread_limiter'] as limiter:
114                 if not limiter.shall_i_proceed():
115                     # My turn arrived, but the job has been done without
116                     # me.
117                     return
118                 logging.debug("KeepWriterThread %s proceeding %s %s" %
119                               (str(threading.current_thread()),
120                                self.args['data_hash'],
121                                self.args['service_root']))
122                 h = httplib2.Http()
123                 url = self.args['service_root'] + self.args['data_hash']
124                 api_token = config.get('ARVADOS_API_TOKEN')
125                 headers = {'Authorization': "OAuth2 %s" % api_token}
126
127                 if self.args['using_proxy']:
128                     # We're using a proxy, so tell the proxy how many copies we
129                     # want it to store
130                     headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
131
132                 try:
133                     logging.debug("Uploading to {}".format(url))
134                     resp, content = h.request(url.encode('utf-8'), 'PUT',
135                                               headers=headers,
136                                               body=self.args['data'])
137                     if (resp['status'] == '401' and
138                         re.match(r'Timestamp verification failed', content)):
139                         body = KeepClient.sign_for_old_server(
140                             self.args['data_hash'],
141                             self.args['data'])
142                         h = httplib2.Http()
143                         resp, content = h.request(url.encode('utf-8'), 'PUT',
144                                                   headers=headers,
145                                                   body=body)
146                     if re.match(r'^2\d\d$', resp['status']):
147                         logging.debug("KeepWriterThread %s succeeded %s %s" %
148                                       (str(threading.current_thread()),
149                                        self.args['data_hash'],
150                                        self.args['service_root']))
151                         replicas_stored = 1
152                         if 'x-keep-replicas-stored' in resp:
153                             # Tick the 'done' counter for the number of replica
154                             # reported stored by the server, for the case that
155                             # we're talking to a proxy or other backend that
156                             # stores to multiple copies for us.
157                             try:
158                                 replicas_stored = int(resp['x-keep-replicas-stored'])
159                             except ValueError:
160                                 pass
161                         return limiter.save_response(content.strip(), replicas_stored)
162
163                     logging.warning("Request fail: PUT %s => %s %s" %
164                                     (url, resp['status'], content))
165                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
166                     logging.warning("Request fail: PUT %s => %s: %s" %
167                                     (url, type(e), str(e)))
168
169     def __init__(self):
170         self.lock = threading.Lock()
171         self.service_roots = None
172         self._cache_lock = threading.Lock()
173         self._cache = []
174         # default 256 megabyte cache
175         self.cache_max = 256 * 1024 * 1024
176         self.using_proxy = False
177
178     def shuffled_service_roots(self, hash):
179         if self.service_roots == None:
180             self.lock.acquire()
181
182             # Override normal keep disk lookup with an explict proxy
183             # configuration.
184             keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
185             if keep_proxy_env != None and len(keep_proxy_env) > 0:
186
187                 if keep_proxy_env[-1:] != '/':
188                     keep_proxy_env += "/"
189                 self.service_roots = [keep_proxy_env]
190                 self.using_proxy = True
191             else:
192                 try:
193                     try:
194                         keep_services = arvados.api().keep_services().accessible().execute()['items']
195                     except Exception:
196                         keep_services = arvados.api().keep_disks().list().execute()['items']
197
198                     if len(keep_services) == 0:
199                         raise arvados.errors.NoKeepServersError()
200
201                     if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
202                         self.using_proxy = True
203
204                     roots = (("http%s://%s:%d/" %
205                               ('s' if f['service_ssl_flag'] else '',
206                                f['service_host'],
207                                f['service_port']))
208                              for f in keep_services)
209                     self.service_roots = sorted(set(roots))
210                     logging.debug(str(self.service_roots))
211                 finally:
212                     self.lock.release()
213
214         # Build an ordering with which to query the Keep servers based on the
215         # contents of the hash.
216         # "hash" is a hex-encoded number at least 8 digits
217         # (32 bits) long
218
219         # seed used to calculate the next keep server from 'pool'
220         # to be added to 'pseq'
221         seed = hash
222
223         # Keep servers still to be added to the ordering
224         pool = self.service_roots[:]
225
226         # output probe sequence
227         pseq = []
228
229         # iterate while there are servers left to be assigned
230         while len(pool) > 0:
231             if len(seed) < 8:
232                 # ran out of digits in the seed
233                 if len(pseq) < len(hash) / 4:
234                     # the number of servers added to the probe sequence is less
235                     # than the number of 4-digit slices in 'hash' so refill the
236                     # seed with the last 4 digits and then append the contents
237                     # of 'hash'.
238                     seed = hash[-4:] + hash
239                 else:
240                     # refill the seed with the contents of 'hash'
241                     seed += hash
242
243             # Take the next 8 digits (32 bytes) and interpret as an integer,
244             # then modulus with the size of the remaining pool to get the next
245             # selected server.
246             probe = int(seed[0:8], 16) % len(pool)
247
248             # Append the selected server to the probe sequence and remove it
249             # from the pool.
250             pseq += [pool[probe]]
251             pool = pool[:probe] + pool[probe+1:]
252
253             # Remove the digits just used from the seed
254             seed = seed[8:]
255         logging.debug(str(pseq))
256         return pseq
257
258     class CacheSlot(object):
259         def __init__(self, locator):
260             self.locator = locator
261             self.ready = threading.Event()
262             self.content = None
263
264         def get(self):
265             self.ready.wait()
266             return self.content
267
268         def set(self, value):
269             self.content = value
270             self.ready.set()
271
272         def size(self):
273             if self.content == None:
274                 return 0
275             else:
276                 return len(self.content)
277
278     def cap_cache(self):
279         '''Cap the cache size to self.cache_max'''
280         self._cache_lock.acquire()
281         try:
282             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
283             sm = sum([slot.size() for slot in self._cache])
284             while sm > self.cache_max:
285                 del self._cache[-1]
286                 sm = sum([slot.size() for a in self._cache])
287         finally:
288             self._cache_lock.release()
289
290     def reserve_cache(self, locator):
291         '''Reserve a cache slot for the specified locator,
292         or return the existing slot.'''
293         self._cache_lock.acquire()
294         try:
295             # Test if the locator is already in the cache
296             for i in xrange(0, len(self._cache)):
297                 if self._cache[i].locator == locator:
298                     n = self._cache[i]
299                     if i != 0:
300                         # move it to the front
301                         del self._cache[i]
302                         self._cache.insert(0, n)
303                     return n, False
304
305             # Add a new cache slot for the locator
306             n = KeepClient.CacheSlot(locator)
307             self._cache.insert(0, n)
308             return n, True
309         finally:
310             self._cache_lock.release()
311
312     def get(self, locator):
313         #logging.debug("Keep.get %s" % (locator))
314
315         if re.search(r',', locator):
316             return ''.join(self.get(x) for x in locator.split(','))
317         if 'KEEP_LOCAL_STORE' in os.environ:
318             return KeepClient.local_store_get(locator)
319         expect_hash = re.sub(r'\+.*', '', locator)
320
321         slot, first = self.reserve_cache(expect_hash)
322         #logging.debug("%s %s %s" % (slot, first, expect_hash))
323
324         if not first:
325             v = slot.get()
326             return v
327
328         try:
329             for service_root in self.shuffled_service_roots(expect_hash):
330                 url = service_root + locator
331                 api_token = config.get('ARVADOS_API_TOKEN')
332                 headers = {'Authorization': "OAuth2 %s" % api_token,
333                            'Accept': 'application/octet-stream'}
334                 blob = self.get_url(url, headers, expect_hash)
335                 if blob:
336                     slot.set(blob)
337                     self.cap_cache()
338                     return blob
339
340             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
341                 instance = location_hint.group(1)
342                 url = 'http://keep.' + instance + '.arvadosapi.com/' + locator
343                 blob = self.get_url(url, {}, expect_hash)
344                 if blob:
345                     slot.set(blob)
346                     self.cap_cache()
347                     return blob
348         except:
349             slot.set(None)
350             self.cap_cache()
351             raise
352
353         slot.set(None)
354         self.cap_cache()
355         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
356
357     def get_url(self, url, headers, expect_hash):
358         h = httplib2.Http()
359         try:
360             logging.info("Request: GET %s" % (url))
361             with timer.Timer() as t:
362                 resp, content = h.request(url.encode('utf-8'), 'GET',
363                                           headers=headers)
364             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
365                                                                         t.msecs,
366                                                                         (len(content)/(1024*1024))/t.secs))
367             if re.match(r'^2\d\d$', resp['status']):
368                 m = hashlib.new('md5')
369                 m.update(content)
370                 md5 = m.hexdigest()
371                 if md5 == expect_hash:
372                     return content
373                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
374         except Exception as e:
375             logging.info("Request fail: GET %s => %s: %s" %
376                          (url, type(e), str(e)))
377         return None
378
379     def put(self, data, **kwargs):
380         if 'KEEP_LOCAL_STORE' in os.environ:
381             return KeepClient.local_store_put(data)
382         m = hashlib.new('md5')
383         m.update(data)
384         data_hash = m.hexdigest()
385         have_copies = 0
386         want_copies = kwargs.get('copies', 2)
387         if not (want_copies > 0):
388             return data_hash
389         threads = []
390         thread_limiter = KeepClient.ThreadLimiter(want_copies)
391         for service_root in self.shuffled_service_roots(data_hash):
392             t = KeepClient.KeepWriterThread(data=data,
393                                             data_hash=data_hash,
394                                             service_root=service_root,
395                                             thread_limiter=thread_limiter,
396                                             using_proxy=self.using_proxy,
397                                             want_copies=(want_copies if self.using_proxy else 1))
398             t.start()
399             threads += [t]
400         for t in threads:
401             t.join()
402         have_copies = thread_limiter.done()
403         # If we're done, return the response from Keep
404         if have_copies >= want_copies:
405             return thread_limiter.response()
406         raise arvados.errors.KeepWriteError(
407             "Write fail for %s: wanted %d but wrote %d" %
408             (data_hash, want_copies, have_copies))
409
410     @staticmethod
411     def sign_for_old_server(data_hash, data):
412         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
413
414
415     @staticmethod
416     def local_store_put(data):
417         m = hashlib.new('md5')
418         m.update(data)
419         md5 = m.hexdigest()
420         locator = '%s+%d' % (md5, len(data))
421         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
422             f.write(data)
423         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
424                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
425         return locator
426
427     @staticmethod
428     def local_store_get(locator):
429         r = re.search('^([0-9a-f]{32,})', locator)
430         if not r:
431             raise arvados.errors.NotFoundError(
432                 "Invalid data locator: '%s'" % locator)
433         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
434             return ''
435         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
436             return f.read()