Merge branch '2070-read-remote-arvados'
[arvados.git] / sdk / python / arvados / keep.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 global_client_object = None
22
23 from api import *
24 import config
25 import arvados.errors
26
27 class Keep:
28     @staticmethod
29     def global_client_object():
30         global global_client_object
31         if global_client_object == None:
32             global_client_object = KeepClient()
33         return global_client_object
34
35     @staticmethod
36     def get(locator, **kwargs):
37         return Keep.global_client_object().get(locator, **kwargs)
38
39     @staticmethod
40     def put(data, **kwargs):
41         return Keep.global_client_object().put(data, **kwargs)
42
43 class KeepClient(object):
44
45     class ThreadLimiter(object):
46         """
47         Limit the number of threads running at a given time to
48         {desired successes} minus {successes reported}. When successes
49         reported == desired, wake up the remaining threads and tell
50         them to quit.
51
52         Should be used in a "with" block.
53         """
54         def __init__(self, todo):
55             self._todo = todo
56             self._done = 0
57             self._todo_lock = threading.Semaphore(todo)
58             self._done_lock = threading.Lock()
59
60         def __enter__(self):
61             self._todo_lock.acquire()
62             return self
63
64         def __exit__(self, type, value, traceback):
65             self._todo_lock.release()
66
67         def shall_i_proceed(self):
68             """
69             Return true if the current thread should do stuff. Return
70             false if the current thread should just stop.
71             """
72             with self._done_lock:
73                 return (self._done < self._todo)
74
75         def increment_done(self):
76             """
77             Report that the current thread was successful.
78             """
79             with self._done_lock:
80                 self._done += 1
81
82         def done(self):
83             """
84             Return how many successes were reported.
85             """
86             with self._done_lock:
87                 return self._done
88
89     class KeepWriterThread(threading.Thread):
90         """
91         Write a blob of data to the given Keep server. Call
92         increment_done() of the given ThreadLimiter if the write
93         succeeds.
94         """
95         def __init__(self, **kwargs):
96             super(KeepClient.KeepWriterThread, self).__init__()
97             self.args = kwargs
98
99         def run(self):
100             with self.args['thread_limiter'] as limiter:
101                 if not limiter.shall_i_proceed():
102                     # My turn arrived, but the job has been done without
103                     # me.
104                     return
105                 logging.debug("KeepWriterThread %s proceeding %s %s" %
106                               (str(threading.current_thread()),
107                                self.args['data_hash'],
108                                self.args['service_root']))
109                 h = httplib2.Http()
110                 url = self.args['service_root'] + self.args['data_hash']
111                 api_token = config.get('ARVADOS_API_TOKEN')
112                 headers = {'Authorization': "OAuth2 %s" % api_token}
113                 try:
114                     resp, content = h.request(url.encode('utf-8'), 'PUT',
115                                               headers=headers,
116                                               body=self.args['data'])
117                     if (resp['status'] == '401' and
118                         re.match(r'Timestamp verification failed', content)):
119                         body = KeepClient.sign_for_old_server(
120                             self.args['data_hash'],
121                             self.args['data'])
122                         h = httplib2.Http()
123                         resp, content = h.request(url.encode('utf-8'), 'PUT',
124                                                   headers=headers,
125                                                   body=body)
126                     if re.match(r'^2\d\d$', resp['status']):
127                         logging.debug("KeepWriterThread %s succeeded %s %s" %
128                                       (str(threading.current_thread()),
129                                        self.args['data_hash'],
130                                        self.args['service_root']))
131                         return limiter.increment_done()
132                     logging.warning("Request fail: PUT %s => %s %s" %
133                                     (url, resp['status'], content))
134                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
135                     logging.warning("Request fail: PUT %s => %s: %s" %
136                                     (url, type(e), str(e)))
137
138     def __init__(self):
139         self.lock = threading.Lock()
140         self.service_roots = None
141
142     def shuffled_service_roots(self, hash):
143         if self.service_roots == None:
144             self.lock.acquire()
145             keep_disks = api().keep_disks().list().execute()['items']
146             roots = (("http%s://%s:%d/" %
147                       ('s' if f['service_ssl_flag'] else '',
148                        f['service_host'],
149                        f['service_port']))
150                      for f in keep_disks)
151             self.service_roots = sorted(set(roots))
152             logging.debug(str(self.service_roots))
153             self.lock.release()
154         seed = hash
155         pool = self.service_roots[:]
156         pseq = []
157         while len(pool) > 0:
158             if len(seed) < 8:
159                 if len(pseq) < len(hash) / 4: # first time around
160                     seed = hash[-4:] + hash
161                 else:
162                     seed += hash
163             probe = int(seed[0:8], 16) % len(pool)
164             pseq += [pool[probe]]
165             pool = pool[:probe] + pool[probe+1:]
166             seed = seed[8:]
167         logging.debug(str(pseq))
168         return pseq
169
170     def get(self, locator):
171         if re.search(r',', locator):
172             return ''.join(self.get(x) for x in locator.split(','))
173         if 'KEEP_LOCAL_STORE' in os.environ:
174             return KeepClient.local_store_get(locator)
175         expect_hash = re.sub(r'\+.*', '', locator)
176         for service_root in self.shuffled_service_roots(expect_hash):
177             url = service_root + expect_hash
178             api_token = config.get('ARVADOS_API_TOKEN')
179             headers = {'Authorization': "OAuth2 %s" % api_token,
180                        'Accept': 'application/octet-stream'}
181             blob = self.get_url(url, headers, expect_hash)
182             if blob:
183                 return blob
184         for location_hint in re.finditer(r'\+K@([a-z0-9]+)', locator):
185             instance = location_hint.group(1)
186             url = 'http://keep.' + instance + '.arvadosapi.com/' + expect_hash
187             blob = self.get_url(url, {}, expect_hash)
188             if blob:
189                 return blob
190         raise arvados.errors.NotFoundError("Block not found: %s" % expect_hash)
191
192     def get_url(self, url, headers, expect_hash):
193         h = httplib2.Http()
194         try:
195             resp, content = h.request(url.encode('utf-8'), 'GET',
196                                       headers=headers)
197             if re.match(r'^2\d\d$', resp['status']):
198                 m = hashlib.new('md5')
199                 m.update(content)
200                 md5 = m.hexdigest()
201                 if md5 == expect_hash:
202                     return content
203                 logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
204         except Exception as e:
205             logging.info("Request fail: GET %s => %s: %s" %
206                          (url, type(e), str(e)))
207         return None
208
209     def put(self, data, **kwargs):
210         if 'KEEP_LOCAL_STORE' in os.environ:
211             return KeepClient.local_store_put(data)
212         m = hashlib.new('md5')
213         m.update(data)
214         data_hash = m.hexdigest()
215         have_copies = 0
216         want_copies = kwargs.get('copies', 2)
217         if not (want_copies > 0):
218             return data_hash
219         threads = []
220         thread_limiter = KeepClient.ThreadLimiter(want_copies)
221         for service_root in self.shuffled_service_roots(data_hash):
222             t = KeepClient.KeepWriterThread(data=data,
223                                             data_hash=data_hash,
224                                             service_root=service_root,
225                                             thread_limiter=thread_limiter)
226             t.start()
227             threads += [t]
228         for t in threads:
229             t.join()
230         have_copies = thread_limiter.done()
231         if have_copies == want_copies:
232             return (data_hash + '+' + str(len(data)))
233         raise arvados.errors.KeepWriteError(
234             "Write fail for %s: wanted %d but wrote %d" %
235             (data_hash, want_copies, have_copies))
236
237     @staticmethod
238     def sign_for_old_server(data_hash, data):
239         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
240
241
242     @staticmethod
243     def local_store_put(data):
244         m = hashlib.new('md5')
245         m.update(data)
246         md5 = m.hexdigest()
247         locator = '%s+%d' % (md5, len(data))
248         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
249             f.write(data)
250         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
251                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
252         return locator
253
254     @staticmethod
255     def local_store_get(locator):
256         r = re.search('^([0-9a-f]{32,})', locator)
257         if not r:
258             raise arvados.errors.NotFoundError(
259                 "Invalid data locator: '%s'" % locator)
260         if r.group(0) == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
261             return ''
262         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
263             return f.read()