2751: Test tweak to clear settings after changing environment variables.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._todo_lock = threading.Semaphore(todo)
59             self._done_lock = threading.Lock()
60
61         def __enter__(self):
62             self._todo_lock.acquire()
63             return self
64
65         def __exit__(self, type, value, traceback):
66             self._todo_lock.release()
67
68         def shall_i_proceed(self):
69             """
70             Return true if the current thread should do stuff. Return
71             false if the current thread should just stop.
72             """
73             with self._done_lock:
74                 return (self._done < self._todo)
75
76         def increment_done(self):
77             """
78             Report that the current thread was successful.
79             """
80             with self._done_lock:
81                 self._done += 1
82
83         def done(self):
84             """
85             Return how many successes were reported.
86             """
87             with self._done_lock:
88                 return self._done
89
90     class KeepWriterThread(threading.Thread):
91         """
92         Write a blob of data to the given Keep server. Call
93         increment_done() of the given ThreadLimiter if the write
94         succeeds.
95         """
96         def __init__(self, **kwargs):
97             super(KeepClient.KeepWriterThread, self).__init__()
98             self.args = kwargs
99
100         def run(self):
101             with self.args['thread_limiter'] as limiter:
102                 if not limiter.shall_i_proceed():
103                     # My turn arrived, but the job has been done without
104                     # me.
105                     return
106                 logging.debug("KeepWriterThread %s proceeding %s %s" %
107                               (str(threading.current_thread()),
108                                self.args['data_hash'],
109                                self.args['service_root']))
110                 h = httplib2.Http()
111                 url = self.args['service_root'] + self.args['data_hash']
112                 api_token = config.get('ARVADOS_API_TOKEN')
113                 headers = {'Authorization': "OAuth2 %s" % api_token}
114
115                 if self.args['using_proxy']:
116                     # We're using a proxy, so tell the proxy how many copies we
117                     # want it to store
118                     headers['X-Keep-Desired-Replication'] = str(self.args['want_copies'])
119
120                 try:
121                     logging.debug("Uploading to {}".format(url))
122                     resp, content = h.request(url.encode('utf-8'), 'PUT',
123                                               headers=headers,
124                                               body=self.args['data'])
125                     if (resp['status'] == '401' and
126                         re.match(r'Timestamp verification failed', content)):
127                         body = KeepClient.sign_for_old_server(
128                             self.args['data_hash'],
129                             self.args['data'])
130                         h = httplib2.Http()
131                         resp, content = h.request(url.encode('utf-8'), 'PUT',
132                                                   headers=headers,
133                                                   body=body)
134                     if re.match(r'^2\d\d$', resp['status']):
135                         logging.debug("KeepWriterThread %s succeeded %s %s" %
136                                       (str(threading.current_thread()),
137                                        self.args['data_hash'],
138                                        self.args['service_root']))
139
140                         if 'x-keep-replicas-stored' in resp:
141                             # Tick the 'done' counter for the number of replica
142                             # reported stored by the server, for the case that
143                             # we're talking to a proxy or other backend that
144                             # stores to multiple copies for us.
145                             replicas = int(resp['x-keep-replicas-stored'])
146                             while replicas > 0:
147                                 limiter.increment_done()
148                                 replicas -= 1
149                         else:
150                             limiter.increment_done()
151                         return
152                     logging.warning("Request fail: PUT %s => %s %s" %
153                                     (url, resp['status'], content))
154                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
155                     logging.warning("Request fail: PUT %s => %s: %s" %
156                                     (url, type(e), str(e)))
157
158     def __init__(self):
159         self.lock = threading.Lock()
160         self.service_roots = None
161         self._cache_lock = threading.Lock()
162         self._cache = []
163         # default 256 megabyte cache
164         self.cache_max = 256 * 1024 * 1024
165         self.using_proxy = False
166
167     def shuffled_service_roots(self, hash):
168         if self.service_roots == None:
169             self.lock.acquire()
170
171             # Override normal keep disk lookup with an explict proxy
172             # configuration.
173             keep_proxy_env = config.get("ARVADOS_KEEP_PROXY")
174             if keep_proxy_env != None and len(keep_proxy_env) > 0:
175                 if keep_proxy_env[-1:] != '/':
176                     keep_proxy_env += "/"
177                 self.service_roots = [keep_proxy_env]
178                 self.using_proxy = True
179             else:
180                 try:
181                     try:
182                         keep_services = arvados.api().keep_services().accessible().execute()['items']
183                     except Exception:
184                         keep_services = arvados.api().keep_disks().list().execute()['items']
185
186                     if len(keep_services) == 0:
187                         raise arvados.errors.NoKeepServersError()
188
189                     if 'service_type' in keep_services[0] and keep_services[0]['service_type'] == 'proxy':
190                         self.using_proxy = True
191
192                     roots = (("http%s://%s:%d/" %
193                               ('s' if f['service_ssl_flag'] else '',
194                                f['service_host'],
195                                f['service_port']))
196                              for f in keep_services)
197                     self.service_roots = sorted(set(roots))
198                     logging.debug(str(self.service_roots))
199                 finally:
200                     self.lock.release()
201
202         # Build an ordering with which to query the Keep servers based on the
203         # contents of the hash.
204         # "hash" is a hex-encoded number at least 8 digits
205         # (32 bits) long
206
207         # seed used to calculate the next keep server from 'pool'
208         # to be added to 'pseq'
209         seed = hash
210
211         # Keep servers still to be added to the ordering
212         pool = self.service_roots[:]
213
214         # output probe sequence
215         pseq = []
216
217         # iterate while there are servers left to be assigned
218         while len(pool) > 0:
219             if len(seed) < 8:
220                 # ran out of digits in the seed
221                 if len(pseq) < len(hash) / 4:
222                     # the number of servers added to the probe sequence is less
223                     # than the number of 4-digit slices in 'hash' so refill the
224                     # seed with the last 4 digits and then append the contents
225                     # of 'hash'.
226                     seed = hash[-4:] + hash
227                 else:
228                     # refill the seed with the contents of 'hash'
229                     seed += hash
230
231             # Take the next 8 digits (32 bytes) and interpret as an integer,
232             # then modulus with the size of the remaining pool to get the next
233             # selected server.
234             probe = int(seed[0:8], 16) % len(pool)
235
236             print seed[0:8], int(seed[0:8], 16), len(pool), probe
237
238             # Append the selected server to the probe sequence and remove it
239             # from the pool.
240             pseq += [pool[probe]]
241             pool = pool[:probe] + pool[probe+1:]
242
243             # Remove the digits just used from the seed
244             seed = seed[8:]
245         logging.debug(str(pseq))
246         return pseq
247
248     class CacheSlot(object):
249         def __init__(self, locator):
250             self.locator = locator
251             self.ready = threading.Event()
252             self.content = None
253
254         def get(self):
255             self.ready.wait()
256             return self.content
257
258         def set(self, value):
259             self.content = value
260             self.ready.set()
261
262         def size(self):
263             if self.content == None:
264                 return 0
265             else:
266                 return len(self.content)
267
268     def cap_cache(self):
269         '''Cap the cache size to self.cache_max'''
270         self._cache_lock.acquire()
271         try:
272             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
273             sm = sum([slot.size() for slot in self._cache])
274             while sm > self.cache_max:
275                 del self._cache[-1]
276                 sm = sum([slot.size() for a in self._cache])
277         finally:
278             self._cache_lock.release()
279
280     def reserve_cache(self, locator):
281         '''Reserve a cache slot for the specified locator,
282         or return the existing slot.'''
283         self._cache_lock.acquire()
284         try:
285             # Test if the locator is already in the cache
286             for i in xrange(0, len(self._cache)):
287                 if self._cache[i].locator == locator:
288                     n = self._cache[i]
289                     if i != 0:
290                         # move it to the front
291                         del self._cache[i]
292                         self._cache.insert(0, n)
293                     return n, False
294
295             # Add a new cache slot for the locator
296             n = KeepClient.CacheSlot(locator)
297             self._cache.insert(0, n)
298             return n, True
299         finally:
300             self._cache_lock.release()
301
302     def get(self, locator):
303         #logging.debug("Keep.get %s" % (locator))
304
305         if re.search(r',', locator):
306             return ''.join(self.get(x) for x in locator.split(','))
307         if 'KEEP_LOCAL_STORE' in os.environ:
308             return KeepClient.local_store_get(locator)
309         expect_hash = re.sub(r'\+.*', '', locator)
310
311         slot, first = self.reserve_cache(expect_hash)
312         #logging.debug("%s %s %s" % (slot, first, expect_hash))
313
314         if not first:
315             v = slot.get()
316             return v
317
318         try:
319             for service_root in self.shuffled_service_roots(expect_hash):
320                 url = service_root + expect_hash
321                 api_token = config.get('ARVADOS_API_TOKEN')
322                 headers = {'Authorization': "OAuth2 %s" % api_token,
323                            'Accept': 'application/octet-stream'}
324                 blob = self.get_url(url, headers, expect_hash)
325                 if blob:
326                     slot.set(blob)
327                     self.cap_cache()
328                     return blob
329
330             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
331                 instance = location_hint.group(1)
332                 url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
333                 blob = self.get_url(url, {}, expect_hash)
334                 if blob:
335                     slot.set(blob)
336                     self.cap_cache()
337                     return blob
338         except:
339             slot.set(None)
340             self.cap_cache()
341             raise
342
343         slot.set(None)
344         self.cap_cache()
345         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
346
347     def get_url(self, url, headers, expect_hash):
348         h = httplib2.Http()
349         try:
350             logging.info("Request: GET %s" % (url))
351             with timer.Timer() as t:
352                 resp, content = h.request(url.encode('utf-8'), 'GET',
353                                           headers=headers)
354             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
355                                                                         t.msecs,
356                                                                         (len(content)/(1024*1024))/t.secs))
357             if re.match(r'^2\d\d$', resp['status']):
358                 m = hashlib.new('md5')
359                 m.update(content)
360                 md5 = m.hexdigest()
361                 if md5 == expect_hash:
362                     return content
363                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
364         except Exception as e:
365             logging.info("Request fail: GET %s => %s: %s" %
366                          (url, type(e), str(e)))
367         return None
368
369     def put(self, data, **kwargs):
370         if 'KEEP_LOCAL_STORE' in os.environ:
371             return KeepClient.local_store_put(data)
372         m = hashlib.new('md5')
373         m.update(data)
374         data_hash = m.hexdigest()
375         have_copies = 0
376         want_copies = kwargs.get('copies', 2)
377         if not (want_copies > 0):
378             return data_hash
379         threads = []
380         thread_limiter = KeepClient.ThreadLimiter(want_copies)
381         for service_root in self.shuffled_service_roots(data_hash):
382             t = KeepClient.KeepWriterThread(data=data,
383                                             data_hash=data_hash,
384                                             service_root=service_root,
385                                             thread_limiter=thread_limiter,
386                                             using_proxy=self.using_proxy,
387                                             want_copies=(want_copies if self.using_proxy else 1))
388             t.start()
389             threads += [t]
390         for t in threads:
391             t.join()
392         have_copies = thread_limiter.done()
393         if have_copies >= want_copies:
394             return (data_hash + '+' + str(len(data)))
395         raise arvados.errors.KeepWriteError(
396             "Write fail for %s: wanted %d but wrote %d" %
397             (data_hash, want_copies, have_copies))
398
399     @staticmethod
400     def sign_for_old_server(data_hash, data):
401         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
402
403
404     @staticmethod
405     def local_store_put(data):
406         m = hashlib.new('md5')
407         m.update(data)
408         md5 = m.hexdigest()
409         locator = '%s+%d' % (md5, len(data))
410         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
411             f.write(data)
412         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
413                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
414         return locator
415
416     @staticmethod
417     def local_store_get(locator):
418         r = re.search('^([0-9a-f]{32,})', locator)
419         if not r:
420             raise arvados.errors.NotFoundError(
421                 "Invalid data locator: '%s'" % locator)
422         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
423             return ''
424         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
425             return f.read()