Rearranging modules to eliminate recursive imports.
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 global_client_object = None
22
23 from api import *
24 import util
25
26 class Keep:
27     @staticmethod
28     def global_client_object():
29         global global_client_object
30         if global_client_object == None:
31             global_client_object = KeepClient()
32         return global_client_object
33
34     @staticmethod
35     def get(locator, **kwargs):
36         return Keep.global_client_object().get(locator, **kwargs)
37
38     @staticmethod
39     def put(data, **kwargs):
40         return Keep.global_client_object().put(data, **kwargs)
41
42 class KeepClient(object):
43
44     class ThreadLimiter(object):
45         """
46         Limit the number of threads running at a given time to
47         {desired successes} minus {successes reported}. When successes
48         reported == desired, wake up the remaining threads and tell
49         them to quit.
50
51         Should be used in a "with" block.
52         """
53         def __init__(self, todo):
54             self._todo = todo
55             self._done = 0
56             self._todo_lock = threading.Semaphore(todo)
57             self._done_lock = threading.Lock()
58
59         def __enter__(self):
60             self._todo_lock.acquire()
61             return self
62
63         def __exit__(self, type, value, traceback):
64             self._todo_lock.release()
65
66         def shall_i_proceed(self):
67             """
68             Return true if the current thread should do stuff. Return
69             false if the current thread should just stop.
70             """
71             with self._done_lock:
72                 return (self._done < self._todo)
73
74         def increment_done(self):
75             """
76             Report that the current thread was successful.
77             """
78             with self._done_lock:
79                 self._done += 1
80
81         def done(self):
82             """
83             Return how many successes were reported.
84             """
85             with self._done_lock:
86                 return self._done
87
88     class KeepWriterThread(threading.Thread):
89         """
90         Write a blob of data to the given Keep server. Call
91         increment_done() of the given ThreadLimiter if the write
92         succeeds.
93         """
94         def __init__(self, **kwargs):
95             super(KeepClient.KeepWriterThread, self).__init__()
96             self.args = kwargs
97
98         def run(self):
99             global config
100             with self.args['thread_limiter'] as limiter:
101                 if not limiter.shall_i_proceed():
102                     # My turn arrived, but the job has been done without
103                     # me.
104                     return
105                 logging.debug("KeepWriterThread %s proceeding %s %s" %
106                               (str(threading.current_thread()),
107                                self.args['data_hash'],
108                                self.args['service_root']))
109                 h = httplib2.Http()
110                 url = self.args['service_root'] + self.args['data_hash']
111                 api_token = config['ARVADOS_API_TOKEN']
112                 headers = {'Authorization': "OAuth2 %s" % api_token}
113                 try:
114                     resp, content = h.request(url.encode('utf-8'), 'PUT',
115                                               headers=headers,
116                                               body=self.args['data'])
117                     if (resp['status'] == '401' and
118                         re.match(r'Timestamp verification failed', content)):
119                         body = KeepClient.sign_for_old_server(
120                             self.args['data_hash'],
121                             self.args['data'])
122                         h = httplib2.Http()
123                         resp, content = h.request(url.encode('utf-8'), 'PUT',
124                                                   headers=headers,
125                                                   body=body)
126                     if re.match(r'^2\d\d$', resp['status']):
127                         logging.debug("KeepWriterThread %s succeeded %s %s" %
128                                       (str(threading.current_thread()),
129                                        self.args['data_hash'],
130                                        self.args['service_root']))
131                         return limiter.increment_done()
132                     logging.warning("Request fail: PUT %s => %s %s" %
133                                     (url, resp['status'], content))
134                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
135                     logging.warning("Request fail: PUT %s => %s: %s" %
136                                     (url, type(e), str(e)))
137
138     def __init__(self):
139         self.lock = threading.Lock()
140         self.service_roots = None
141
142     def shuffled_service_roots(self, hash):
143         if self.service_roots == None:
144             self.lock.acquire()
145             keep_disks = api().keep_disks().list().execute()['items']
146             roots = (("http%s://%s:%d/" %
147                       ('s' if f['service_ssl_flag'] else '',
148                        f['service_host'],
149                        f['service_port']))
150                      for f in keep_disks)
151             self.service_roots = sorted(set(roots))
152             logging.debug(str(self.service_roots))
153             self.lock.release()
154         seed = hash
155         pool = self.service_roots[:]
156         pseq = []
157         while len(pool) > 0:
158             if len(seed) < 8:
159                 if len(pseq) < len(hash) / 4: # first time around
160                     seed = hash[-4:] + hash
161                 else:
162                     seed += hash
163             probe = int(seed[0:8], 16) % len(pool)
164             pseq += [pool[probe]]
165             pool = pool[:probe] + pool[probe+1:]
166             seed = seed[8:]
167         logging.debug(str(pseq))
168         return pseq
169
170     def get(self, locator):
171         global config
172         if re.search(r',', locator):
173             return ''.join(self.get(x) for x in locator.split(','))
174         if 'KEEP_LOCAL_STORE' in os.environ:
175             return KeepClient.local_store_get(locator)
176         expect_hash = re.sub(r'\+.*', '', locator)
177         for service_root in self.shuffled_service_roots(expect_hash):
178             h = httplib2.Http()
179             url = service_root + expect_hash
180             api_token = config['ARVADOS_API_TOKEN']
181             headers = {'Authorization': "OAuth2 %s" % api_token,
182                        'Accept': 'application/octet-stream'}
183             try:
184                 resp, content = h.request(url.encode('utf-8'), 'GET',
185                                           headers=headers)
186                 if re.match(r'^2\d\d$', resp['status']):
187                     m = hashlib.new('md5')
188                     m.update(content)
189                     md5 = m.hexdigest()
190                     if md5 == expect_hash:
191                         return content
192                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
193             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
194                 logging.info("Request fail: GET %s => %s: %s" %
195                              (url, type(e), str(e)))
196         raise errors.NotFoundError("Block not found: %s" % expect_hash)
197
198     def put(self, data, **kwargs):
199         if 'KEEP_LOCAL_STORE' in os.environ:
200             return KeepClient.local_store_put(data)
201         m = hashlib.new('md5')
202         m.update(data)
203         data_hash = m.hexdigest()
204         have_copies = 0
205         want_copies = kwargs.get('copies', 2)
206         if not (want_copies > 0):
207             return data_hash
208         threads = []
209         thread_limiter = KeepClient.ThreadLimiter(want_copies)
210         for service_root in self.shuffled_service_roots(data_hash):
211             t = KeepClient.KeepWriterThread(data=data,
212                                             data_hash=data_hash,
213                                             service_root=service_root,
214                                             thread_limiter=thread_limiter)
215             t.start()
216             threads += [t]
217         for t in threads:
218             t.join()
219         have_copies = thread_limiter.done()
220         if have_copies == want_copies:
221             return (data_hash + '+' + str(len(data)))
222         raise errors.KeepWriteError(
223             "Write fail for %s: wanted %d but wrote %d" %
224             (data_hash, want_copies, have_copies))
225
226     @staticmethod
227     def sign_for_old_server(data_hash, data):
228         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
229
230
231     @staticmethod
232     def local_store_put(data):
233         m = hashlib.new('md5')
234         m.update(data)
235         md5 = m.hexdigest()
236         locator = '%s+%d' % (md5, len(data))
237         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
238             f.write(data)
239         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
240                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
241         return locator
242
243     @staticmethod
244     def local_store_get(locator):
245         r = re.search('^([0-9a-f]{32,})', locator)
246         if not r:
247             raise errors.NotFoundError(
248                 "Invalid data locator: '%s'" % locator)
249         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
250             return ''
251         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
252             return f.read()