Add threading locks in Keep client
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 class CredentialsFromEnv:
28     @staticmethod
29     def http_request(self, uri, **kwargs):
30         from httplib import BadStatusLine
31         if 'headers' not in kwargs:
32             kwargs['headers'] = {}
33         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
34         try:
35             return self.orig_http_request(uri, **kwargs)
36         except BadStatusLine:
37             # This is how httplib tells us that it tried to reuse an
38             # existing connection but it was already closed by the
39             # server. In that case, yes, we would like to retry.
40             # Unfortunately, we are not absolutely certain that the
41             # previous call did not succeed, so this is slightly
42             # risky.
43             return self.orig_http_request(uri, **kwargs)
44     def authorize(self, http):
45         http.orig_http_request = http.request
46         http.request = types.MethodType(self.http_request, http)
47         return http
48
49 url = ('https://%s/discovery/v1/apis/'
50        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
51 credentials = CredentialsFromEnv()
52
53 # Use system's CA certificates (if we find them) instead of httplib2's
54 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
55 if not os.path.exists(ca_certs):
56     ca_certs = None             # use httplib2 default
57
58 http = httplib2.Http(ca_certs=ca_certs)
59 http = credentials.authorize(http)
60 if re.match(r'(?i)^(true|1|yes)$',
61             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
62     http.disable_ssl_certificate_validation=True
63 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
64
65 def task_set_output(self,s):
66     service.job_tasks().update(uuid=self['uuid'],
67                                job_task=json.dumps({
68                 'output':s,
69                 'success':True,
70                 'progress':1.0
71                 })).execute()
72
73 _current_task = None
74 def current_task():
75     global _current_task
76     if _current_task:
77         return _current_task
78     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
79     t = UserDict.UserDict(t)
80     t.set_output = types.MethodType(task_set_output, t)
81     t.tmpdir = os.environ['TASK_WORK']
82     _current_task = t
83     return t
84
85 _current_job = None
86 def current_job():
87     global _current_job
88     if _current_job:
89         return _current_job
90     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
91     t = UserDict.UserDict(t)
92     t.tmpdir = os.environ['JOB_WORK']
93     _current_job = t
94     return t
95
96 def api():
97     return service
98
99 class JobTask:
100     def __init__(self, parameters=dict(), runtime_constraints=dict()):
101         print "init jobtask %s %s" % (parameters, runtime_constraints)
102
103 class job_setup:
104     @staticmethod
105     def one_task_per_input_file(if_sequence=0, and_end_task=True):
106         if if_sequence != current_task()['sequence']:
107             return
108         job_input = current_job()['script_parameters']['input']
109         cr = CollectionReader(job_input)
110         for s in cr.all_streams():
111             for f in s.all_files():
112                 task_input = f.as_manifest()
113                 new_task_attrs = {
114                     'job_uuid': current_job()['uuid'],
115                     'created_by_job_task_uuid': current_task()['uuid'],
116                     'sequence': if_sequence + 1,
117                     'parameters': {
118                         'input':task_input
119                         }
120                     }
121                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
122         if and_end_task:
123             service.job_tasks().update(uuid=current_task()['uuid'],
124                                        job_task=json.dumps({'success':True})
125                                        ).execute()
126             exit(0)
127
128     @staticmethod
129     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
130         if if_sequence != current_task()['sequence']:
131             return
132         job_input = current_job()['script_parameters']['input']
133         cr = CollectionReader(job_input)
134         for s in cr.all_streams():
135             task_input = s.tokens()
136             new_task_attrs = {
137                 'job_uuid': current_job()['uuid'],
138                 'created_by_job_task_uuid': current_task()['uuid'],
139                 'sequence': if_sequence + 1,
140                 'parameters': {
141                     'input':task_input
142                     }
143                 }
144             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
145         if and_end_task:
146             service.job_tasks().update(uuid=current_task()['uuid'],
147                                        job_task=json.dumps({'success':True})
148                                        ).execute()
149             exit(0)
150
151 class util:
152     @staticmethod
153     def run_command(execargs, **kwargs):
154         if 'stdin' not in kwargs:
155             kwargs['stdin'] = subprocess.PIPE
156         if 'stdout' not in kwargs:
157             kwargs['stdout'] = subprocess.PIPE
158         if 'stderr' not in kwargs:
159             kwargs['stderr'] = subprocess.PIPE
160         p = subprocess.Popen(execargs, close_fds=True, shell=False,
161                              **kwargs)
162         stdoutdata, stderrdata = p.communicate(None)
163         if p.returncode != 0:
164             raise Exception("run_command %s exit %d:\n%s" %
165                             (execargs, p.returncode, stderrdata))
166         return stdoutdata, stderrdata
167
168     @staticmethod
169     def git_checkout(url, version, path):
170         if not re.search('^/', path):
171             path = os.path.join(current_job().tmpdir, path)
172         if not os.path.exists(path):
173             util.run_command(["git", "clone", url, path],
174                              cwd=os.path.dirname(path))
175         util.run_command(["git", "checkout", version],
176                          cwd=path)
177         return path
178
179     @staticmethod
180     def tar_extractor(path, decompress_flag):
181         return subprocess.Popen(["tar",
182                                  "-C", path,
183                                  ("-x%sf" % decompress_flag),
184                                  "-"],
185                                 stdout=None,
186                                 stdin=subprocess.PIPE, stderr=sys.stderr,
187                                 shell=False, close_fds=True)
188
189     @staticmethod
190     def tarball_extract(tarball, path):
191         """Retrieve a tarball from Keep and extract it to a local
192         directory.  Return the absolute path where the tarball was
193         extracted. If the top level of the tarball contained just one
194         file or directory, return the absolute path of that single
195         item.
196
197         tarball -- collection locator
198         path -- where to extract the tarball: absolute, or relative to job tmp
199         """
200         if not re.search('^/', path):
201             path = os.path.join(current_job().tmpdir, path)
202         lockfile = open(path + '.lock', 'w')
203         fcntl.flock(lockfile, fcntl.LOCK_EX)
204         try:
205             os.stat(path)
206         except OSError:
207             os.mkdir(path)
208         already_have_it = False
209         try:
210             if os.readlink(os.path.join(path, '.locator')) == tarball:
211                 already_have_it = True
212         except OSError:
213             pass
214         if not already_have_it:
215
216             # emulate "rm -f" (i.e., if the file does not exist, we win)
217             try:
218                 os.unlink(os.path.join(path, '.locator'))
219             except OSError:
220                 if os.path.exists(os.path.join(path, '.locator')):
221                     os.unlink(os.path.join(path, '.locator'))
222
223             for f in CollectionReader(tarball).all_files():
224                 if re.search('\.(tbz|tar.bz2)$', f.name()):
225                     p = util.tar_extractor(path, 'j')
226                 elif re.search('\.(tgz|tar.gz)$', f.name()):
227                     p = util.tar_extractor(path, 'z')
228                 elif re.search('\.tar$', f.name()):
229                     p = util.tar_extractor(path, '')
230                 else:
231                     raise Exception("tarball_extract cannot handle filename %s"
232                                     % f.name())
233                 while True:
234                     buf = f.read(2**20)
235                     if len(buf) == 0:
236                         break
237                     p.stdin.write(buf)
238                 p.stdin.close()
239                 p.wait()
240                 if p.returncode != 0:
241                     lockfile.close()
242                     raise Exception("tar exited %d" % p.returncode)
243             os.symlink(tarball, os.path.join(path, '.locator'))
244         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
245         lockfile.close()
246         if len(tld_extracts) == 1:
247             return os.path.join(path, tld_extracts[0])
248         return path
249
250     @staticmethod
251     def zipball_extract(zipball, path):
252         """Retrieve a zip archive from Keep and extract it to a local
253         directory.  Return the absolute path where the archive was
254         extracted. If the top level of the archive contained just one
255         file or directory, return the absolute path of that single
256         item.
257
258         zipball -- collection locator
259         path -- where to extract the archive: absolute, or relative to job tmp
260         """
261         if not re.search('^/', path):
262             path = os.path.join(current_job().tmpdir, path)
263         lockfile = open(path + '.lock', 'w')
264         fcntl.flock(lockfile, fcntl.LOCK_EX)
265         try:
266             os.stat(path)
267         except OSError:
268             os.mkdir(path)
269         already_have_it = False
270         try:
271             if os.readlink(os.path.join(path, '.locator')) == zipball:
272                 already_have_it = True
273         except OSError:
274             pass
275         if not already_have_it:
276
277             # emulate "rm -f" (i.e., if the file does not exist, we win)
278             try:
279                 os.unlink(os.path.join(path, '.locator'))
280             except OSError:
281                 if os.path.exists(os.path.join(path, '.locator')):
282                     os.unlink(os.path.join(path, '.locator'))
283
284             for f in CollectionReader(zipball).all_files():
285                 if not re.search('\.zip$', f.name()):
286                     raise Exception("zipball_extract cannot handle filename %s"
287                                     % f.name())
288                 zip_filename = os.path.join(path, os.path.basename(f.name()))
289                 zip_file = open(zip_filename, 'wb')
290                 while True:
291                     buf = f.read(2**20)
292                     if len(buf) == 0:
293                         break
294                     zip_file.write(buf)
295                 zip_file.close()
296                 
297                 p = subprocess.Popen(["unzip",
298                                       "-q", "-o",
299                                       "-d", path,
300                                       zip_filename],
301                                      stdout=None,
302                                      stdin=None, stderr=sys.stderr,
303                                      shell=False, close_fds=True)
304                 p.wait()
305                 if p.returncode != 0:
306                     lockfile.close()
307                     raise Exception("unzip exited %d" % p.returncode)
308                 os.unlink(zip_filename)
309             os.symlink(zipball, os.path.join(path, '.locator'))
310         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
311         lockfile.close()
312         if len(tld_extracts) == 1:
313             return os.path.join(path, tld_extracts[0])
314         return path
315
316     @staticmethod
317     def collection_extract(collection, path, files=[], decompress=True):
318         """Retrieve a collection from Keep and extract it to a local
319         directory.  Return the absolute path where the collection was
320         extracted.
321
322         collection -- collection locator
323         path -- where to extract: absolute, or relative to job tmp
324         """
325         if not re.search('^/', path):
326             path = os.path.join(current_job().tmpdir, path)
327         lockfile = open(path + '.lock', 'w')
328         fcntl.flock(lockfile, fcntl.LOCK_EX)
329         try:
330             os.stat(path)
331         except OSError:
332             os.mkdir(path)
333         already_have_it = False
334         try:
335             if os.readlink(os.path.join(path, '.locator')) == collection:
336                 already_have_it = True
337         except OSError:
338             pass
339
340         # emulate "rm -f" (i.e., if the file does not exist, we win)
341         try:
342             os.unlink(os.path.join(path, '.locator'))
343         except OSError:
344             if os.path.exists(os.path.join(path, '.locator')):
345                 os.unlink(os.path.join(path, '.locator'))
346
347         files_got = []
348         for s in CollectionReader(collection).all_streams():
349             stream_name = s.name()
350             for f in s.all_files():
351                 if (files == [] or
352                     ((f.name() not in files_got) and
353                      (f.name() in files or
354                       (decompress and f.decompressed_name() in files)))):
355                     outname = f.decompressed_name() if decompress else f.name()
356                     files_got += [outname]
357                     if os.path.exists(os.path.join(path, stream_name, outname)):
358                         continue
359                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
360                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
361                     for buf in (f.readall_decompressed() if decompress
362                                 else f.readall()):
363                         outfile.write(buf)
364                     outfile.close()
365         if len(files_got) < len(files):
366             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
367         os.symlink(collection, os.path.join(path, '.locator'))
368
369         lockfile.close()
370         return path
371
372     @staticmethod
373     def mkdir_dash_p(path):
374         if not os.path.exists(path):
375             util.mkdir_dash_p(os.path.dirname(path))
376             try:
377                 os.mkdir(path)
378             except OSError:
379                 if not os.path.exists(path):
380                     os.mkdir(path)
381
382     @staticmethod
383     def stream_extract(stream, path, files=[], decompress=True):
384         """Retrieve a stream from Keep and extract it to a local
385         directory.  Return the absolute path where the stream was
386         extracted.
387
388         stream -- StreamReader object
389         path -- where to extract: absolute, or relative to job tmp
390         """
391         if not re.search('^/', path):
392             path = os.path.join(current_job().tmpdir, path)
393         lockfile = open(path + '.lock', 'w')
394         fcntl.flock(lockfile, fcntl.LOCK_EX)
395         try:
396             os.stat(path)
397         except OSError:
398             os.mkdir(path)
399
400         files_got = []
401         for f in stream.all_files():
402             if (files == [] or
403                 ((f.name() not in files_got) and
404                  (f.name() in files or
405                   (decompress and f.decompressed_name() in files)))):
406                 outname = f.decompressed_name() if decompress else f.name()
407                 files_got += [outname]
408                 if os.path.exists(os.path.join(path, outname)):
409                     os.unlink(os.path.join(path, outname))
410                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
411                 outfile = open(os.path.join(path, outname), 'wb')
412                 for buf in (f.readall_decompressed() if decompress
413                             else f.readall()):
414                     outfile.write(buf)
415                 outfile.close()
416         if len(files_got) < len(files):
417             raise Exception("Wanted files %s but only got %s from %s" %
418                             (files, files_got, map(lambda z: z.name(),
419                                                    list(stream.all_files()))))
420         lockfile.close()
421         return path
422
423     @staticmethod
424     def listdir_recursive(dirname, base=None):
425         allfiles = []
426         for ent in sorted(os.listdir(dirname)):
427             ent_path = os.path.join(dirname, ent)
428             ent_base = os.path.join(base, ent) if base else ent
429             if os.path.isdir(ent_path):
430                 allfiles += util.listdir_recursive(ent_path, ent_base)
431             else:
432                 allfiles += [ent_base]
433         return allfiles
434
435 class StreamFileReader:
436     def __init__(self, stream, pos, size, name):
437         self._stream = stream
438         self._pos = pos
439         self._size = size
440         self._name = name
441         self._filepos = 0
442     def name(self):
443         return self._name
444     def decompressed_name(self):
445         return re.sub('\.(bz2|gz)$', '', self._name)
446     def size(self):
447         return self._size
448     def stream_name(self):
449         return self._stream.name()
450     def read(self, size, **kwargs):
451         self._stream.seek(self._pos + self._filepos)
452         data = self._stream.read(min(size, self._size - self._filepos))
453         self._filepos += len(data)
454         return data
455     def readall(self, size=2**20, **kwargs):
456         while True:
457             data = self.read(size, **kwargs)
458             if data == '':
459                 break
460             yield data
461     def bunzip2(self, size):
462         decompressor = bz2.BZ2Decompressor()
463         for chunk in self.readall(size):
464             data = decompressor.decompress(chunk)
465             if data and data != '':
466                 yield data
467     def gunzip(self, size):
468         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
469         for chunk in self.readall(size):
470             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
471             if data and data != '':
472                 yield data
473     def readall_decompressed(self, size=2**20):
474         self._stream.seek(self._pos + self._filepos)
475         if re.search('\.bz2$', self._name):
476             return self.bunzip2(size)
477         elif re.search('\.gz$', self._name):
478             return self.gunzip(size)
479         else:
480             return self.readall(size)
481     def readlines(self, decompress=True):
482         if decompress:
483             datasource = self.readall_decompressed()
484         else:
485             self._stream.seek(self._pos + self._filepos)
486             datasource = self.readall()
487         data = ''
488         for newdata in datasource:
489             data += newdata
490             sol = 0
491             while True:
492                 eol = string.find(data, "\n", sol)
493                 if eol < 0:
494                     break
495                 yield data[sol:eol+1]
496                 sol = eol+1
497             data = data[sol:]
498         if data != '':
499             yield data
500     def as_manifest(self):
501         if self.size() == 0:
502             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
503                     % (self._stream.name(), self.name()))
504         return string.join(self._stream.tokens_for_range(self._pos, self._size),
505                            " ") + "\n"
506
507 class StreamReader:
508     def __init__(self, tokens):
509         self._tokens = tokens
510         self._current_datablock_data = None
511         self._current_datablock_pos = 0
512         self._current_datablock_index = -1
513         self._pos = 0
514
515         self._stream_name = None
516         self.data_locators = []
517         self.files = []
518
519         for tok in self._tokens:
520             if self._stream_name == None:
521                 self._stream_name = tok
522             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
523                 self.data_locators += [tok]
524             elif re.search(r'^\d+:\d+:\S+', tok):
525                 pos, size, name = tok.split(':',2)
526                 self.files += [[int(pos), int(size), name]]
527             else:
528                 raise Exception("Invalid manifest format")
529
530     def tokens(self):
531         return self._tokens
532     def tokens_for_range(self, range_start, range_size):
533         resp = [self._stream_name]
534         return_all_tokens = False
535         block_start = 0
536         token_bytes_skipped = 0
537         for locator in self.data_locators:
538             sizehint = re.search(r'\+(\d+)', locator)
539             if not sizehint:
540                 return_all_tokens = True
541             if return_all_tokens:
542                 resp += [locator]
543                 next
544             blocksize = int(sizehint.group(0))
545             if range_start + range_size <= block_start:
546                 break
547             if range_start < block_start + blocksize:
548                 resp += [locator]
549             else:
550                 token_bytes_skipped += blocksize
551             block_start += blocksize
552         for f in self.files:
553             if ((f[0] < range_start + range_size)
554                 and
555                 (f[0] + f[1] > range_start)
556                 and
557                 f[1] > 0):
558                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
559         return resp
560     def name(self):
561         return self._stream_name
562     def all_files(self):
563         for f in self.files:
564             pos, size, name = f
565             yield StreamFileReader(self, pos, size, name)
566     def nextdatablock(self):
567         if self._current_datablock_index < 0:
568             self._current_datablock_pos = 0
569             self._current_datablock_index = 0
570         else:
571             self._current_datablock_pos += self.current_datablock_size()
572             self._current_datablock_index += 1
573         self._current_datablock_data = None
574     def current_datablock_data(self):
575         if self._current_datablock_data == None:
576             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
577         return self._current_datablock_data
578     def current_datablock_size(self):
579         if self._current_datablock_index < 0:
580             self.nextdatablock()
581         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
582         if sizehint:
583             return int(sizehint.group(0))
584         return len(self.current_datablock_data())
585     def seek(self, pos):
586         """Set the position of the next read operation."""
587         self._pos = pos
588     def really_seek(self):
589         """Find and load the appropriate data block, so the byte at
590         _pos is in memory.
591         """
592         if self._pos == self._current_datablock_pos:
593             return True
594         if (self._current_datablock_pos != None and
595             self._pos >= self._current_datablock_pos and
596             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
597             return True
598         if self._pos < self._current_datablock_pos:
599             self._current_datablock_index = -1
600             self.nextdatablock()
601         while (self._pos > self._current_datablock_pos and
602                self._pos > self._current_datablock_pos + self.current_datablock_size()):
603             self.nextdatablock()
604     def read(self, size):
605         """Read no more than size bytes -- but at least one byte,
606         unless _pos is already at the end of the stream.
607         """
608         if size == 0:
609             return ''
610         self.really_seek()
611         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
612             self.nextdatablock()
613             if self._current_datablock_index >= len(self.data_locators):
614                 return None
615         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
616         self._pos += len(data)
617         return data
618
619 class CollectionReader:
620     def __init__(self, manifest_locator_or_text):
621         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
622             self._manifest_text = manifest_locator_or_text
623             self._manifest_locator = None
624         else:
625             self._manifest_locator = manifest_locator_or_text
626             self._manifest_text = None
627         self._streams = None
628     def __enter__(self):
629         pass
630     def __exit__(self):
631         pass
632     def _populate(self):
633         if self._streams != None:
634             return
635         if not self._manifest_text:
636             self._manifest_text = Keep.get(self._manifest_locator)
637         self._streams = []
638         for stream_line in self._manifest_text.split("\n"):
639             if stream_line != '':
640                 stream_tokens = stream_line.split()
641                 self._streams += [stream_tokens]
642     def all_streams(self):
643         self._populate()
644         resp = []
645         for s in self._streams:
646             resp += [StreamReader(s)]
647         return resp
648     def all_files(self):
649         for s in self.all_streams():
650             for f in s.all_files():
651                 yield f
652     def manifest_text(self):
653         self._populate()
654         return self._manifest_text
655
656 class CollectionWriter:
657     KEEP_BLOCK_SIZE = 2**26
658     def __init__(self):
659         self._data_buffer = []
660         self._data_buffer_len = 0
661         self._current_stream_files = []
662         self._current_stream_length = 0
663         self._current_stream_locators = []
664         self._current_stream_name = '.'
665         self._current_file_name = None
666         self._current_file_pos = 0
667         self._finished_streams = []
668     def __enter__(self):
669         pass
670     def __exit__(self):
671         self.finish()
672     def write_directory_tree(self,
673                              path, stream_name='.', max_manifest_depth=-1):
674         self.start_new_stream(stream_name)
675         todo = []
676         if max_manifest_depth == 0:
677             dirents = util.listdir_recursive(path)
678         else:
679             dirents = sorted(os.listdir(path))
680         for dirent in dirents:
681             target = os.path.join(path, dirent)
682             if os.path.isdir(target):
683                 todo += [[target,
684                           os.path.join(stream_name, dirent),
685                           max_manifest_depth-1]]
686             else:
687                 self.start_new_file(dirent)
688                 with open(target, 'rb') as f:
689                     while True:
690                         buf = f.read(2**26)
691                         if len(buf) == 0:
692                             break
693                         self.write(buf)
694         self.finish_current_stream()
695         map(lambda x: self.write_directory_tree(*x), todo)
696
697     def write(self, newdata):
698         self._data_buffer += [newdata]
699         self._data_buffer_len += len(newdata)
700         self._current_stream_length += len(newdata)
701         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
702             self.flush_data()
703     def flush_data(self):
704         data_buffer = ''.join(self._data_buffer)
705         if data_buffer != '':
706             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
707             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
708             self._data_buffer_len = len(self._data_buffer[0])
709     def start_new_file(self, newfilename=None):
710         self.finish_current_file()
711         self.set_current_file_name(newfilename)
712     def set_current_file_name(self, newfilename):
713         newfilename = re.sub(r' ', '\\\\040', newfilename)
714         if re.search(r'[ \t\n]', newfilename):
715             raise AssertionError("Manifest filenames cannot contain whitespace")
716         self._current_file_name = newfilename
717     def current_file_name(self):
718         return self._current_file_name
719     def finish_current_file(self):
720         if self._current_file_name == None:
721             if self._current_file_pos == self._current_stream_length:
722                 return
723             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
724         self._current_stream_files += [[self._current_file_pos,
725                                        self._current_stream_length - self._current_file_pos,
726                                        self._current_file_name]]
727         self._current_file_pos = self._current_stream_length
728     def start_new_stream(self, newstreamname='.'):
729         self.finish_current_stream()
730         self.set_current_stream_name(newstreamname)
731     def set_current_stream_name(self, newstreamname):
732         if re.search(r'[ \t\n]', newstreamname):
733             raise AssertionError("Manifest stream names cannot contain whitespace")
734         self._current_stream_name = newstreamname
735     def current_stream_name(self):
736         return self._current_stream_name
737     def finish_current_stream(self):
738         self.finish_current_file()
739         self.flush_data()
740         if len(self._current_stream_files) == 0:
741             pass
742         elif self._current_stream_name == None:
743             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
744         else:
745             self._finished_streams += [[self._current_stream_name,
746                                        self._current_stream_locators,
747                                        self._current_stream_files]]
748         self._current_stream_files = []
749         self._current_stream_length = 0
750         self._current_stream_locators = []
751         self._current_stream_name = None
752         self._current_file_pos = 0
753         self._current_file_name = None
754     def finish(self):
755         return Keep.put(self.manifest_text())
756     def manifest_text(self):
757         self.finish_current_stream()
758         manifest = ''
759         for stream in self._finished_streams:
760             if not re.search(r'^\.(/.*)?$', stream[0]):
761                 manifest += './'
762             manifest += stream[0]
763             if len(stream[1]) == 0:
764                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
765             else:
766                 for locator in stream[1]:
767                     manifest += " %s" % locator
768             for sfile in stream[2]:
769                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
770             manifest += "\n"
771         return manifest
772
773 global_client_object = None
774
775 class Keep:
776     @staticmethod
777     def global_client_object():
778         global global_client_object
779         if global_client_object == None:
780             global_client_object = KeepClient()
781         return global_client_object
782
783     @staticmethod
784     def get(locator):
785         return Keep.global_client_object().get(locator)
786
787     @staticmethod
788     def put(data):
789         return Keep.global_client_object().put(data)
790
791 class KeepClient:
792     def __init__(self):
793         self.lock = threading.Lock()
794         self.service_roots = None
795
796     def shuffled_service_roots(self, hash):
797         if self.service_roots == None:
798             self.lock.acquire()
799             keep_disks = api().keep_disks().list().execute()['items']
800             roots = (("http%s://%s:%d/" %
801                       ('s' if f['service_ssl_flag'] else '',
802                        f['service_host'],
803                        f['service_port']))
804                      for f in keep_disks)
805             self.service_roots = sorted(set(roots))
806             logging.debug(str(self.service_roots))
807             self.lock.release()
808         seed = hash
809         pool = self.service_roots[:]
810         pseq = []
811         while len(pool) > 0:
812             if len(seed) < 8:
813                 if len(pseq) < len(hash) / 4: # first time around
814                     seed = hash[-4:] + hash
815                 else:
816                     seed += hash
817             probe = int(seed[0:8], 16) % len(pool)
818             pseq += [pool[probe]]
819             pool = pool[:probe] + pool[probe+1:]
820             seed = seed[8:]
821         logging.debug(str(pseq))
822         return pseq
823
824     def get(self, locator):
825         if 'KEEP_LOCAL_STORE' in os.environ:
826             return KeepClient.local_store_get(locator)
827         expect_hash = re.sub(r'\+.*', '', locator)
828         for service_root in self.shuffled_service_roots(expect_hash):
829             h = httplib2.Http()
830             url = service_root + expect_hash
831             api_token = os.environ['ARVADOS_API_TOKEN']
832             headers = {'Authorization': "OAuth2 %s" % api_token,
833                        'Accept': 'application/octet-stream'}
834             try:
835                 resp, content = h.request(url, 'GET', headers=headers)
836                 if re.match(r'^2\d\d$', resp['status']):
837                     m = hashlib.new('md5')
838                     m.update(content)
839                     md5 = m.hexdigest()
840                     if md5 == expect_hash:
841                         return content
842                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
843             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
844                 logging.info("Request fail: GET %s => %s: %s" %
845                              (url, type(e), str(e)))
846         raise Exception("Not found: %s" % expect_hash)
847
848     def put(self, data, **kwargs):
849         if 'KEEP_LOCAL_STORE' in os.environ:
850             return KeepClient.local_store_put(data)
851         m = hashlib.new('md5')
852         m.update(data)
853         data_hash = m.hexdigest()
854         have_copies = 0
855         want_copies = kwargs.get('copies', 2)
856         for service_root in self.shuffled_service_roots(data_hash):
857             h = httplib2.Http()
858             url = service_root + data_hash
859             api_token = os.environ['ARVADOS_API_TOKEN']
860             headers = {'Authorization': "OAuth2 %s" % api_token}
861             try:
862                 resp, content = h.request(url, 'PUT',
863                                           headers=headers,
864                                           body=data)
865                 if (resp['status'] == '401' and
866                     re.match(r'Timestamp verification failed', content)):
867                     body = self.sign_for_old_server(data_hash, data)
868                     h = httplib2.Http()
869                     resp, content = h.request(url, 'PUT',
870                                               headers=headers,
871                                               body=body)
872                 if re.match(r'^2\d\d$', resp['status']):
873                     have_copies += 1
874                     if have_copies == want_copies:
875                         return data_hash + '+' + str(len(data))
876                 else:
877                     logging.warning("Request fail: PUT %s => %s %s" %
878                                     (url, resp['status'], content))
879             except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
880                 logging.warning("Request fail: PUT %s => %s: %s" %
881                                 (url, type(e), str(e)))
882         raise Exception("Write fail for %s: wanted %d but wrote %d" %
883                         (data_hash, want_copies, have_copies))
884
885     def sign_for_old_server(self, data_hash, data):
886         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
887
888
889     @staticmethod
890     def local_store_put(data):
891         m = hashlib.new('md5')
892         m.update(data)
893         md5 = m.hexdigest()
894         locator = '%s+%d' % (md5, len(data))
895         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
896             f.write(data)
897         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
898                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
899         return locator
900     @staticmethod
901     def local_store_get(locator):
902         r = re.search('^([0-9a-f]{32,})', locator)
903         if not r:
904             raise Exception("Keep.get: invalid data locator '%s'" % locator)
905         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
906             return ''
907         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
908             return f.read()