Merge branch '1922-cache-discovery-python'
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 global_client_object = None
22
23 from arvados import *
24
25 class Keep:
26     @staticmethod
27     def global_client_object():
28         global global_client_object
29         if global_client_object == None:
30             global_client_object = KeepClient()
31         return global_client_object
32
33     @staticmethod
34     def get(locator, **kwargs):
35         return Keep.global_client_object().get(locator, **kwargs)
36
37     @staticmethod
38     def put(data, **kwargs):
39         return Keep.global_client_object().put(data, **kwargs)
40
41 class KeepClient(object):
42
43     class ThreadLimiter(object):
44         """
45         Limit the number of threads running at a given time to
46         {desired successes} minus {successes reported}. When successes
47         reported == desired, wake up the remaining threads and tell
48         them to quit.
49
50         Should be used in a "with" block.
51         """
52         def __init__(self, todo):
53             self._todo = todo
54             self._done = 0
55             self._todo_lock = threading.Semaphore(todo)
56             self._done_lock = threading.Lock()
57
58         def __enter__(self):
59             self._todo_lock.acquire()
60             return self
61
62         def __exit__(self, type, value, traceback):
63             self._todo_lock.release()
64
65         def shall_i_proceed(self):
66             """
67             Return true if the current thread should do stuff. Return
68             false if the current thread should just stop.
69             """
70             with self._done_lock:
71                 return (self._done < self._todo)
72
73         def increment_done(self):
74             """
75             Report that the current thread was successful.
76             """
77             with self._done_lock:
78                 self._done += 1
79
80         def done(self):
81             """
82             Return how many successes were reported.
83             """
84             with self._done_lock:
85                 return self._done
86
87     class KeepWriterThread(threading.Thread):
88         """
89         Write a blob of data to the given Keep server. Call
90         increment_done() of the given ThreadLimiter if the write
91         succeeds.
92         """
93         def __init__(self, **kwargs):
94             super(KeepClient.KeepWriterThread, self).__init__()
95             self.args = kwargs
96
97         def run(self):
98             global config
99             with self.args['thread_limiter'] as limiter:
100                 if not limiter.shall_i_proceed():
101                     # My turn arrived, but the job has been done without
102                     # me.
103                     return
104                 logging.debug("KeepWriterThread %s proceeding %s %s" %
105                               (str(threading.current_thread()),
106                                self.args['data_hash'],
107                                self.args['service_root']))
108                 h = httplib2.Http()
109                 url = self.args['service_root'] + self.args['data_hash']
110                 api_token = config['ARVADOS_API_TOKEN']
111                 headers = {'Authorization': "OAuth2 %s" % api_token}
112                 try:
113                     resp, content = h.request(url.encode('utf-8'), 'PUT',
114                                               headers=headers,
115                                               body=self.args['data'])
116                     if (resp['status'] == '401' and
117                         re.match(r'Timestamp verification failed', content)):
118                         body = KeepClient.sign_for_old_server(
119                             self.args['data_hash'],
120                             self.args['data'])
121                         h = httplib2.Http()
122                         resp, content = h.request(url.encode('utf-8'), 'PUT',
123                                                   headers=headers,
124                                                   body=body)
125                     if re.match(r'^2\d\d$', resp['status']):
126                         logging.debug("KeepWriterThread %s succeeded %s %s" %
127                                       (str(threading.current_thread()),
128                                        self.args['data_hash'],
129                                        self.args['service_root']))
130                         return limiter.increment_done()
131                     logging.warning("Request fail: PUT %s => %s %s" %
132                                     (url, resp['status'], content))
133                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
134                     logging.warning("Request fail: PUT %s => %s: %s" %
135                                     (url, type(e), str(e)))
136
137     def __init__(self):
138         self.lock = threading.Lock()
139         self.service_roots = None
140
141     def shuffled_service_roots(self, hash):
142         if self.service_roots == None:
143             self.lock.acquire()
144             keep_disks = api().keep_disks().list().execute()['items']
145             roots = (("http%s://%s:%d/" %
146                       ('s' if f['service_ssl_flag'] else '',
147                        f['service_host'],
148                        f['service_port']))
149                      for f in keep_disks)
150             self.service_roots = sorted(set(roots))
151             logging.debug(str(self.service_roots))
152             self.lock.release()
153         seed = hash
154         pool = self.service_roots[:]
155         pseq = []
156         while len(pool) > 0:
157             if len(seed) < 8:
158                 if len(pseq) < len(hash) / 4: # first time around
159                     seed = hash[-4:] + hash
160                 else:
161                     seed += hash
162             probe = int(seed[0:8], 16) % len(pool)
163             pseq += [pool[probe]]
164             pool = pool[:probe] + pool[probe+1:]
165             seed = seed[8:]
166         logging.debug(str(pseq))
167         return pseq
168
169     def get(self, locator):
170         global config
171         if re.search(r',', locator):
172             return ''.join(self.get(x) for x in locator.split(','))
173         if 'KEEP_LOCAL_STORE' in os.environ:
174             return KeepClient.local_store_get(locator)
175         expect_hash = re.sub(r'\+.*', '', locator)
176         for service_root in self.shuffled_service_roots(expect_hash):
177             h = httplib2.Http()
178             url = service_root + expect_hash
179             api_token = config['ARVADOS_API_TOKEN']
180             headers = {'Authorization': "OAuth2 %s" % api_token,
181                        'Accept': 'application/octet-stream'}
182             try:
183                 resp, content = h.request(url.encode('utf-8'), 'GET',
184                                           headers=headers)
185                 if re.match(r'^2\d\d$', resp['status']):
186                     m = hashlib.new('md5')
187                     m.update(content)
188                     md5 = m.hexdigest()
189                     if md5 == expect_hash:
190                         return content
191                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
192             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
193                 logging.info("Request fail: GET %s => %s: %s" %
194                              (url, type(e), str(e)))
195         raise errors.NotFoundError("Block not found: %s" % expect_hash)
196
197     def put(self, data, **kwargs):
198         if 'KEEP_LOCAL_STORE' in os.environ:
199             return KeepClient.local_store_put(data)
200         m = hashlib.new('md5')
201         m.update(data)
202         data_hash = m.hexdigest()
203         have_copies = 0
204         want_copies = kwargs.get('copies', 2)
205         if not (want_copies > 0):
206             return data_hash
207         threads = []
208         thread_limiter = KeepClient.ThreadLimiter(want_copies)
209         for service_root in self.shuffled_service_roots(data_hash):
210             t = KeepClient.KeepWriterThread(data=data,
211                                             data_hash=data_hash,
212                                             service_root=service_root,
213                                             thread_limiter=thread_limiter)
214             t.start()
215             threads += [t]
216         for t in threads:
217             t.join()
218         have_copies = thread_limiter.done()
219         if have_copies == want_copies:
220             return (data_hash + '+' + str(len(data)))
221         raise errors.KeepWriteError(
222             "Write fail for %s: wanted %d but wrote %d" %
223             (data_hash, want_copies, have_copies))
224
225     @staticmethod
226     def sign_for_old_server(data_hash, data):
227         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
228
229
230     @staticmethod
231     def local_store_put(data):
232         m = hashlib.new('md5')
233         m.update(data)
234         md5 = m.hexdigest()
235         locator = '%s+%d' % (md5, len(data))
236         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
237             f.write(data)
238         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
239                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
240         return locator
241
242     @staticmethod
243     def local_store_get(locator):
244         r = re.search('^([0-9a-f]{32,})', locator)
245         if not r:
246             raise errors.NotFoundError(
247                 "Invalid data locator: '%s'" % locator)
248         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
249             return ''
250         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
251             return f.read()