2798: Merged branch with code to read environment variables with branch working on...
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20 import timer
21
22 global_client_object = None
23
24 from api import *
25 import config
26 import arvados.errors
27
28 class Keep:
29     @staticmethod
30     def global_client_object():
31         global global_client_object
32         if global_client_object == None:
33             global_client_object = KeepClient()
34         return global_client_object
35
36     @staticmethod
37     def get(locator, **kwargs):
38         return Keep.global_client_object().get(locator, **kwargs)
39
40     @staticmethod
41     def put(data, **kwargs):
42         return Keep.global_client_object().put(data, **kwargs)
43
44 class KeepClient(object):
45
46     class ThreadLimiter(object):
47         """
48         Limit the number of threads running at a given time to
49         {desired successes} minus {successes reported}. When successes
50         reported == desired, wake up the remaining threads and tell
51         them to quit.
52
53         Should be used in a "with" block.
54         """
55         def __init__(self, todo):
56             self._todo = todo
57             self._done = 0
58             self._todo_lock = threading.Semaphore(todo)
59             self._done_lock = threading.Lock()
60
61         def __enter__(self):
62             self._todo_lock.acquire()
63             return self
64
65         def __exit__(self, type, value, traceback):
66             self._todo_lock.release()
67
68         def shall_i_proceed(self):
69             """
70             Return true if the current thread should do stuff. Return
71             false if the current thread should just stop.
72             """
73             with self._done_lock:
74                 return (self._done < self._todo)
75
76         def increment_done(self):
77             """
78             Report that the current thread was successful.
79             """
80             with self._done_lock:
81                 self._done += 1
82
83         def done(self):
84             """
85             Return how many successes were reported.
86             """
87             with self._done_lock:
88                 return self._done
89
90     class KeepWriterThread(threading.Thread):
91         """
92         Write a blob of data to the given Keep server. Call
93         increment_done() of the given ThreadLimiter if the write
94         succeeds.
95         """
96         def __init__(self, **kwargs):
97             super(KeepClient.KeepWriterThread, self).__init__()
98             self.args = kwargs
99
100         def run(self):
101             with self.args['thread_limiter'] as limiter:
102                 if not limiter.shall_i_proceed():
103                     # My turn arrived, but the job has been done without
104                     # me.
105                     return
106                 logging.debug("KeepWriterThread %s proceeding %s %s" %
107                               (str(threading.current_thread()),
108                                self.args['data_hash'],
109                                self.args['service_root']))
110                 h = httplib2.Http()
111                 url = self.args['service_root'] + self.args['data_hash']
112                 api_token = config.get('ARVADOS_API_TOKEN')
113                 headers = {'Authorization': "OAuth2 %s" % api_token}
114                 try:
115                     resp, content = h.request(url.encode('utf-8'), 'PUT',
116                                               headers=headers,
117                                               body=self.args['data'])
118                     if (resp['status'] == '401' and
119                         re.match(r'Timestamp verification failed', content)):
120                         body = KeepClient.sign_for_old_server(
121                             self.args['data_hash'],
122                             self.args['data'])
123                         h = httplib2.Http()
124                         resp, content = h.request(url.encode('utf-8'), 'PUT',
125                                                   headers=headers,
126                                                   body=body)
127                     if re.match(r'^2\d\d$', resp['status']):
128                         logging.debug("KeepWriterThread %s succeeded %s %s" %
129                                       (str(threading.current_thread()),
130                                        self.args['data_hash'],
131                                        self.args['service_root']))
132                         return limiter.increment_done()
133                     logging.warning("Request fail: PUT %s => %s %s" %
134                                     (url, resp['status'], content))
135                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
136                     logging.warning("Request fail: PUT %s => %s: %s" %
137                                     (url, type(e), str(e)))
138
139     def __init__(self):
140         self.lock = threading.Lock()
141         self.service_roots = None
142         self._cache_lock = threading.Lock()
143         self._cache = []
144         # default 256 megabyte cache
145         self.cache_max = 256 * 1024 * 1024
146
147     def shuffled_service_roots(self, hash):
148         if self.service_roots == None:
149             self.lock.acquire()
150             try:
151                 keep_disks = arvados.api().keep_disks().list().execute()['items']
152                 roots = (("http%s://%s:%d/" %
153                           ('s' if f['service_ssl_flag'] else '',
154                            f['service_host'],
155                            f['service_port']))
156                          for f in keep_disks)
157                 self.service_roots = sorted(set(roots))
158                 logging.debug(str(self.service_roots))
159             finally:
160                 self.lock.release()
161
162         # Build an ordering with which to query the Keep servers based on the
163         # contents of the hash.
164         # "hash" is a hex-encoded number at least 8 digits
165         # (32 bits) long
166
167         # seed used to calculate the next keep server from 'pool'
168         # to be added to 'pseq'
169         seed = hash
170
171         # Keep servers still to be added to the ordering
172         pool = self.service_roots[:]
173
174         # output probe sequence
175         pseq = []
176
177         # iterate while there are servers left to be assigned
178         while len(pool) > 0:
179             if len(seed) < 8:
180                 # ran out of digits in the seed
181                 if len(pseq) < len(hash) / 4:
182                     # the number of servers added to the probe sequence is less
183                     # than the number of 4-digit slices in 'hash' so refill the
184                     # seed with the last 4 digits and then append the contents
185                     # of 'hash'.
186                     seed = hash[-4:] + hash
187                 else:
188                     # refill the seed with the contents of 'hash'
189                     seed += hash
190
191             # Take the next 8 digits (32 bytes) and interpret as an integer,
192             # then modulus with the size of the remaining pool to get the next
193             # selected server.
194             probe = int(seed[0:8], 16) % len(pool)
195
196             print seed[0:8], int(seed[0:8], 16), len(pool), probe
197
198             # Append the selected server to the probe sequence and remove it
199             # from the pool.
200             pseq += [pool[probe]]
201             pool = pool[:probe] + pool[probe+1:]
202
203             # Remove the digits just used from the seed
204             seed = seed[8:]
205         logging.debug(str(pseq))
206         return pseq
207
208     class CacheSlot(object):
209         def __init__(self, locator):
210             self.locator = locator
211             self.ready = threading.Event()
212             self.content = None
213
214         def get(self):
215             self.ready.wait()
216             return self.content
217
218         def set(self, value):
219             self.content = value
220             self.ready.set()
221
222         def size(self):
223             if self.content == None:
224                 return 0
225             else:
226                 return len(self.content)
227
228     def cap_cache(self):
229         '''Cap the cache size to self.cache_max'''
230         self._cache_lock.acquire()
231         try:
232             self._cache = filter(lambda c: not (c.ready.is_set() and c.content == None), self._cache)
233             sm = sum([slot.size() for slot in self._cache])
234             while sm > self.cache_max:
235                 del self._cache[-1]
236                 sm = sum([slot.size() for a in self._cache])
237         finally:
238             self._cache_lock.release()
239
240     def reserve_cache(self, locator):
241         '''Reserve a cache slot for the specified locator,
242         or return the existing slot.'''
243         self._cache_lock.acquire()
244         try:
245             # Test if the locator is already in the cache
246             for i in xrange(0, len(self._cache)):
247                 if self._cache[i].locator == locator:
248                     n = self._cache[i]
249                     if i != 0:
250                         # move it to the front
251                         del self._cache[i]
252                         self._cache.insert(0, n)
253                     return n, False
254
255             # Add a new cache slot for the locator
256             n = KeepClient.CacheSlot(locator)
257             self._cache.insert(0, n)
258             return n, True
259         finally:
260             self._cache_lock.release()
261
262     def get(self, locator):
263         #logging.debug("Keep.get %s" % (locator))
264
265         if re.search(r',', locator):
266             return ''.join(self.get(x) for x in locator.split(','))
267         if 'KEEP_LOCAL_STORE' in os.environ:
268             return KeepClient.local_store_get(locator)
269         expect_hash = re.sub(r'\+.*', '', locator)
270
271         slot, first = self.reserve_cache(expect_hash)
272         #logging.debug("%s %s %s" % (slot, first, expect_hash))
273
274         if not first:
275             v = slot.get()
276             return v
277
278         try:
279             for service_root in self.shuffled_service_roots(expect_hash):
280                 url = service_root + expect_hash
281                 api_token = config.get('ARVADOS_API_TOKEN')
282                 headers = {'Authorization': "OAuth2 %s" % api_token,
283                            'Accept': 'application/octet-stream'}
284                 blob = self.get_url(url, headers, expect_hash)
285                 if blob:
286                     slot.set(blob)
287                     self.cap_cache()
288                     return blob
289
290             for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
291                 instance = location_hint.group(1)
292                 url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
293                 blob = self.get_url(url, {}, expect_hash)
294                 if blob:
295                     slot.set(blob)
296                     self.cap_cache()
297                     return blob
298         except:
299             slot.set(None)
300             self.cap_cache()
301             raise
302
303         slot.set(None)
304         self.cap_cache()
305         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
306
307     def get_url(self, url, headers, expect_hash):
308         h = httplib2.Http()
309         try:
310             logging.info("Request: GET %s" % (url))
311             with timer.Timer() as t:
312                 resp, content = h.request(url.encode('utf-8'), 'GET',
313                                           headers=headers)
314             logging.info("Received %s bytes in %s msec (%s MiB/sec)" % (len(content),
315                                                                         t.msecs,
316                                                                         (len(content)/(1024*1024))/t.secs))
317             if re.match(r'^2\d\d$', resp['status']):
318                 m = hashlib.new('md5')
319                 m.update(content)
320                 md5 = m.hexdigest()
321                 if md5 == expect_hash:
322                     return content
323                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
324         except Exception as e:
325             logging.info("Request fail: GET %s => %s: %s" %
326                          (url, type(e), str(e)))
327         return None
328
329     def put(self, data, **kwargs):
330         if 'KEEP_LOCAL_STORE' in os.environ:
331             return KeepClient.local_store_put(data)
332         m = hashlib.new('md5')
333         m.update(data)
334         data_hash = m.hexdigest()
335         have_copies = 0
336         want_copies = kwargs.get('copies', 2)
337         if not (want_copies > 0):
338             return data_hash
339         threads = []
340         thread_limiter = KeepClient.ThreadLimiter(want_copies)
341         for service_root in self.shuffled_service_roots(data_hash):
342             t = KeepClient.KeepWriterThread(data=data,
343                                             data_hash=data_hash,
344                                             service_root=service_root,
345                                             thread_limiter=thread_limiter)
346             t.start()
347             threads += [t]
348         for t in threads:
349             t.join()
350         have_copies = thread_limiter.done()
351         if have_copies == want_copies:
352             return (data_hash + '+' + str(len(data)))
353         raise arvados.errors.KeepWriteError(
354             "Write fail for %s: wanted %d but wrote %d" %
355             (data_hash, want_copies, have_copies))
356
357     @staticmethod
358     def sign_for_old_server(data_hash, data):
359         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
360
361
362     @staticmethod
363     def local_store_put(data):
364         m = hashlib.new('md5')
365         m.update(data)
366         md5 = m.hexdigest()
367         locator = '%s+%d' % (md5, len(data))
368         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
369             f.write(data)
370         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
371                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
372         return locator
373
374     @staticmethod
375     def local_store_get(locator):
376         r = re.search('^([0-9a-f]{32,})', locator)
377         if not r:
378             raise arvados.errors.NotFoundError(
379                 "Invalid data locator: '%s'" % locator)
380         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
381             return ''
382         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
383             return f.read()