Merge branch '1977-provenance-report' of git.clinicalfuture.com:arvados into 1977...
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 global_client_object = None
22
23 from api import *
24 import config
25 import arvados.errors
26
27 class Keep:
28     @staticmethod
29     def global_client_object():
30         global global_client_object
31         if global_client_object == None:
32             global_client_object = KeepClient()
33         return global_client_object
34
35     @staticmethod
36     def get(locator, **kwargs):
37         return Keep.global_client_object().get(locator, **kwargs)
38
39     @staticmethod
40     def put(data, **kwargs):
41         return Keep.global_client_object().put(data, **kwargs)
42
43 class KeepClient(object):
44
45     class ThreadLimiter(object):
46         """
47         Limit the number of threads running at a given time to
48         {desired successes} minus {successes reported}. When successes
49         reported == desired, wake up the remaining threads and tell
50         them to quit.
51
52         Should be used in a "with" block.
53         """
54         def __init__(self, todo):
55             self._todo = todo
56             self._done = 0
57             self._todo_lock = threading.Semaphore(todo)
58             self._done_lock = threading.Lock()
59
60         def __enter__(self):
61             self._todo_lock.acquire()
62             return self
63
64         def __exit__(self, type, value, traceback):
65             self._todo_lock.release()
66
67         def shall_i_proceed(self):
68             """
69             Return true if the current thread should do stuff. Return
70             false if the current thread should just stop.
71             """
72             with self._done_lock:
73                 return (self._done < self._todo)
74
75         def increment_done(self):
76             """
77             Report that the current thread was successful.
78             """
79             with self._done_lock:
80                 self._done += 1
81
82         def done(self):
83             """
84             Return how many successes were reported.
85             """
86             with self._done_lock:
87                 return self._done
88
89     class KeepWriterThread(threading.Thread):
90         """
91         Write a blob of data to the given Keep server. Call
92         increment_done() of the given ThreadLimiter if the write
93         succeeds.
94         """
95         def __init__(self, **kwargs):
96             super(KeepClient.KeepWriterThread, self).__init__()
97             self.args = kwargs
98
99         def run(self):
100             with self.args['thread_limiter'] as limiter:
101                 if not limiter.shall_i_proceed():
102                     # My turn arrived, but the job has been done without
103                     # me.
104                     return
105                 logging.debug("KeepWriterThread %s proceeding %s %s" %
106                               (str(threading.current_thread()),
107                                self.args['data_hash'],
108                                self.args['service_root']))
109                 h = httplib2.Http()
110                 url = self.args['service_root'] + self.args['data_hash']
111                 api_token = config.get('ARVADOS_API_TOKEN')
112                 headers = {'Authorization': "OAuth2 %s" % api_token}
113                 try:
114                     resp, content = h.request(url.encode('utf-8'), 'PUT',
115                                               headers=headers,
116                                               body=self.args['data'])
117                     if (resp['status'] == '401' and
118                         re.match(r'Timestamp verification failed', content)):
119                         body = KeepClient.sign_for_old_server(
120                             self.args['data_hash'],
121                             self.args['data'])
122                         h = httplib2.Http()
123                         resp, content = h.request(url.encode('utf-8'), 'PUT',
124                                                   headers=headers,
125                                                   body=body)
126                     if re.match(r'^2\d\d$', resp['status']):
127                         logging.debug("KeepWriterThread %s succeeded %s %s" %
128                                       (str(threading.current_thread()),
129                                        self.args['data_hash'],
130                                        self.args['service_root']))
131                         return limiter.increment_done()
132                     logging.warning("Request fail: PUT %s => %s %s" %
133                                     (url, resp['status'], content))
134                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
135                     logging.warning("Request fail: PUT %s => %s: %s" %
136                                     (url, type(e), str(e)))
137
138     def __init__(self):
139         self.lock = threading.Lock()
140         self.service_roots = None
141
142     def shuffled_service_roots(self, hash):
143         if self.service_roots == None:
144             self.lock.acquire()
145             keep_disks = api().keep_disks().list().execute()['items']
146             roots = (("http%s://%s:%d/" %
147                       ('s' if f['service_ssl_flag'] else '',
148                        f['service_host'],
149                        f['service_port']))
150                      for f in keep_disks)
151             self.service_roots = sorted(set(roots))
152             logging.debug(str(self.service_roots))
153             self.lock.release()
154         seed = hash
155         pool = self.service_roots[:]
156         pseq = []
157         while len(pool) > 0:
158             if len(seed) < 8:
159                 if len(pseq) < len(hash) / 4: # first time around
160                     seed = hash[-4:] + hash
161                 else:
162                     seed += hash
163             probe = int(seed[0:8], 16) % len(pool)
164             pseq += [pool[probe]]
165             pool = pool[:probe] + pool[probe+1:]
166             seed = seed[8:]
167         logging.debug(str(pseq))
168         return pseq
169
170     def get(self, locator):
171         if re.search(r',', locator):
172             return ''.join(self.get(x) for x in locator.split(','))
173         if 'KEEP_LOCAL_STORE' in os.environ:
174             return KeepClient.local_store_get(locator)
175         expect_hash = re.sub(r'\+.*', '', locator)
176         for service_root in self.shuffled_service_roots(expect_hash):
177             h = httplib2.Http()
178             url = service_root + expect_hash
179             api_token = config.get('ARVADOS_API_TOKEN')
180             headers = {'Authorization': "OAuth2 %s" % api_token,
181                        'Accept': 'application/octet-stream'}
182             try:
183                 resp, content = h.request(url.encode('utf-8'), 'GET',
184                                           headers=headers)
185                 if re.match(r'^2\d\d$', resp['status']):
186                     m = hashlib.new('md5')
187                     m.update(content)
188                     md5 = m.hexdigest()
189                     if md5 == expect_hash:
190                         return content
191                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
192             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
193                 logging.info("Request fail: GET %s => %s: %s" %
194                              (url, type(e), str(e)))
195         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
196
197     def put(self, data, **kwargs):
198         if 'KEEP_LOCAL_STORE' in os.environ:
199             return KeepClient.local_store_put(data)
200         m = hashlib.new('md5')
201         m.update(data)
202         data_hash = m.hexdigest()
203         have_copies = 0
204         want_copies = kwargs.get('copies', 2)
205         if not (want_copies > 0):
206             return data_hash
207         threads = []
208         thread_limiter = KeepClient.ThreadLimiter(want_copies)
209         for service_root in self.shuffled_service_roots(data_hash):
210             t = KeepClient.KeepWriterThread(data=data,
211                                             data_hash=data_hash,
212                                             service_root=service_root,
213                                             thread_limiter=thread_limiter)
214             t.start()
215             threads += [t]
216         for t in threads:
217             t.join()
218         have_copies = thread_limiter.done()
219         if have_copies == want_copies:
220             return (data_hash + '+' + str(len(data)))
221         raise arvados.errors.KeepWriteError(
222             "Write fail for %s: wanted %d but wrote %d" %
223             (data_hash, want_copies, have_copies))
224
225     @staticmethod
226     def sign_for_old_server(data_hash, data):
227         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
228
229
230     @staticmethod
231     def local_store_put(data):
232         m = hashlib.new('md5')
233         m.update(data)
234         md5 = m.hexdigest()
235         locator = '%s+%d' % (md5, len(data))
236         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
237             f.write(data)
238         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
239                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
240         return locator
241
242     @staticmethod
243     def local_store_get(locator):
244         r = re.search('^([0-9a-f]{32,})', locator)
245         if not r:
246             raise arvados.errors.NotFoundError(
247                 "Invalid data locator: '%s'" % locator)
248         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
249             return ''
250         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
251             return f.read()