Renamed Keep/Stream/Collection submodules to keep/stream/collection (lower case)
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 global_client_object = None
22
23 class Keep:
24     @staticmethod
25     def global_client_object():
26         global global_client_object
27         if global_client_object == None:
28             global_client_object = KeepClient()
29         return global_client_object
30
31     @staticmethod
32     def get(locator, **kwargs):
33         return Keep.global_client_object().get(locator, **kwargs)
34
35     @staticmethod
36     def put(data, **kwargs):
37         return Keep.global_client_object().put(data, **kwargs)
38
39 class KeepClient(object):
40
41     class ThreadLimiter(object):
42         """
43         Limit the number of threads running at a given time to
44         {desired successes} minus {successes reported}. When successes
45         reported == desired, wake up the remaining threads and tell
46         them to quit.
47
48         Should be used in a "with" block.
49         """
50         def __init__(self, todo):
51             self._todo = todo
52             self._done = 0
53             self._todo_lock = threading.Semaphore(todo)
54             self._done_lock = threading.Lock()
55
56         def __enter__(self):
57             self._todo_lock.acquire()
58             return self
59
60         def __exit__(self, type, value, traceback):
61             self._todo_lock.release()
62
63         def shall_i_proceed(self):
64             """
65             Return true if the current thread should do stuff. Return
66             false if the current thread should just stop.
67             """
68             with self._done_lock:
69                 return (self._done < self._todo)
70
71         def increment_done(self):
72             """
73             Report that the current thread was successful.
74             """
75             with self._done_lock:
76                 self._done += 1
77
78         def done(self):
79             """
80             Return how many successes were reported.
81             """
82             with self._done_lock:
83                 return self._done
84
85     class KeepWriterThread(threading.Thread):
86         """
87         Write a blob of data to the given Keep server. Call
88         increment_done() of the given ThreadLimiter if the write
89         succeeds.
90         """
91         def __init__(self, **kwargs):
92             super(KeepClient.KeepWriterThread, self).__init__()
93             self.args = kwargs
94
95         def run(self):
96             global config
97             with self.args['thread_limiter'] as limiter:
98                 if not limiter.shall_i_proceed():
99                     # My turn arrived, but the job has been done without
100                     # me.
101                     return
102                 logging.debug("KeepWriterThread %s proceeding %s %s" %
103                               (str(threading.current_thread()),
104                                self.args['data_hash'],
105                                self.args['service_root']))
106                 h = httplib2.Http()
107                 url = self.args['service_root'] + self.args['data_hash']
108                 api_token = config['ARVADOS_API_TOKEN']
109                 headers = {'Authorization': "OAuth2 %s" % api_token}
110                 try:
111                     resp, content = h.request(url.encode('utf-8'), 'PUT',
112                                               headers=headers,
113                                               body=self.args['data'])
114                     if (resp['status'] == '401' and
115                         re.match(r'Timestamp verification failed', content)):
116                         body = KeepClient.sign_for_old_server(
117                             self.args['data_hash'],
118                             self.args['data'])
119                         h = httplib2.Http()
120                         resp, content = h.request(url.encode('utf-8'), 'PUT',
121                                                   headers=headers,
122                                                   body=body)
123                     if re.match(r'^2\d\d$', resp['status']):
124                         logging.debug("KeepWriterThread %s succeeded %s %s" %
125                                       (str(threading.current_thread()),
126                                        self.args['data_hash'],
127                                        self.args['service_root']))
128                         return limiter.increment_done()
129                     logging.warning("Request fail: PUT %s => %s %s" %
130                                     (url, resp['status'], content))
131                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
132                     logging.warning("Request fail: PUT %s => %s: %s" %
133                                     (url, type(e), str(e)))
134
135     def __init__(self):
136         self.lock = threading.Lock()
137         self.service_roots = None
138
139     def shuffled_service_roots(self, hash):
140         if self.service_roots == None:
141             self.lock.acquire()
142             keep_disks = api().keep_disks().list().execute()['items']
143             roots = (("http%s://%s:%d/" %
144                       ('s' if f['service_ssl_flag'] else '',
145                        f['service_host'],
146                        f['service_port']))
147                      for f in keep_disks)
148             self.service_roots = sorted(set(roots))
149             logging.debug(str(self.service_roots))
150             self.lock.release()
151         seed = hash
152         pool = self.service_roots[:]
153         pseq = []
154         while len(pool) > 0:
155             if len(seed) < 8:
156                 if len(pseq) < len(hash) / 4: # first time around
157                     seed = hash[-4:] + hash
158                 else:
159                     seed += hash
160             probe = int(seed[0:8], 16) % len(pool)
161             pseq += [pool[probe]]
162             pool = pool[:probe] + pool[probe+1:]
163             seed = seed[8:]
164         logging.debug(str(pseq))
165         return pseq
166
167     def get(self, locator):
168         global config
169         if re.search(r',', locator):
170             return ''.join(self.get(x) for x in locator.split(','))
171         if 'KEEP_LOCAL_STORE' in os.environ:
172             return KeepClient.local_store_get(locator)
173         expect_hash = re.sub(r'\+.*', '', locator)
174         for service_root in self.shuffled_service_roots(expect_hash):
175             h = httplib2.Http()
176             url = service_root + expect_hash
177             api_token = config['ARVADOS_API_TOKEN']
178             headers = {'Authorization': "OAuth2 %s" % api_token,
179                        'Accept': 'application/octet-stream'}
180             try:
181                 resp, content = h.request(url.encode('utf-8'), 'GET',
182                                           headers=headers)
183                 if re.match(r'^2\d\d$', resp['status']):
184                     m = hashlib.new('md5')
185                     m.update(content)
186                     md5 = m.hexdigest()
187                     if md5 == expect_hash:
188                         return content
189                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
190             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
191                 logging.info("Request fail: GET %s => %s: %s" %
192                              (url, type(e), str(e)))
193         raise errors.NotFoundError("Block not found: %s" % expect_hash)
194
195     def put(self, data, **kwargs):
196         if 'KEEP_LOCAL_STORE' in os.environ:
197             return KeepClient.local_store_put(data)
198         m = hashlib.new('md5')
199         m.update(data)
200         data_hash = m.hexdigest()
201         have_copies = 0
202         want_copies = kwargs.get('copies', 2)
203         if not (want_copies > 0):
204             return data_hash
205         threads = []
206         thread_limiter = KeepClient.ThreadLimiter(want_copies)
207         for service_root in self.shuffled_service_roots(data_hash):
208             t = KeepClient.KeepWriterThread(data=data,
209                                             data_hash=data_hash,
210                                             service_root=service_root,
211                                             thread_limiter=thread_limiter)
212             t.start()
213             threads += [t]
214         for t in threads:
215             t.join()
216         have_copies = thread_limiter.done()
217         if have_copies == want_copies:
218             return (data_hash + '+' + str(len(data)))
219         raise errors.KeepWriteError(
220             "Write fail for %s: wanted %d but wrote %d" %
221             (data_hash, want_copies, have_copies))
222
223     @staticmethod
224     def sign_for_old_server(data_hash, data):
225         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
226
227
228     @staticmethod
229     def local_store_put(data):
230         m = hashlib.new('md5')
231         m.update(data)
232         md5 = m.hexdigest()
233         locator = '%s+%d' % (md5, len(data))
234         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
235             f.write(data)
236         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
237                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
238         return locator
239
240     @staticmethod
241     def local_store_get(locator):
242         r = re.search('^([0-9a-f]{32,})', locator)
243         if not r:
244             raise errors.NotFoundError(
245                 "Invalid data locator: '%s'" % locator)
246         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
247             return ''
248         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
249             return f.read()