Use collection hash instead of entire manifest as symlink target
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 class CredentialsFromEnv(object):
28     @staticmethod
29     def http_request(self, uri, **kwargs):
30         from httplib import BadStatusLine
31         if 'headers' not in kwargs:
32             kwargs['headers'] = {}
33         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
34         try:
35             return self.orig_http_request(uri, **kwargs)
36         except BadStatusLine:
37             # This is how httplib tells us that it tried to reuse an
38             # existing connection but it was already closed by the
39             # server. In that case, yes, we would like to retry.
40             # Unfortunately, we are not absolutely certain that the
41             # previous call did not succeed, so this is slightly
42             # risky.
43             return self.orig_http_request(uri, **kwargs)
44     def authorize(self, http):
45         http.orig_http_request = http.request
46         http.request = types.MethodType(self.http_request, http)
47         return http
48
49 url = ('https://%s/discovery/v1/apis/'
50        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
51 credentials = CredentialsFromEnv()
52
53 # Use system's CA certificates (if we find them) instead of httplib2's
54 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
55 if not os.path.exists(ca_certs):
56     ca_certs = None             # use httplib2 default
57
58 http = httplib2.Http(ca_certs=ca_certs)
59 http = credentials.authorize(http)
60 if re.match(r'(?i)^(true|1|yes)$',
61             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
62     http.disable_ssl_certificate_validation=True
63 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
64
65 def task_set_output(self,s):
66     service.job_tasks().update(uuid=self['uuid'],
67                                job_task=json.dumps({
68                 'output':s,
69                 'success':True,
70                 'progress':1.0
71                 })).execute()
72
73 _current_task = None
74 def current_task():
75     global _current_task
76     if _current_task:
77         return _current_task
78     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
79     t = UserDict.UserDict(t)
80     t.set_output = types.MethodType(task_set_output, t)
81     t.tmpdir = os.environ['TASK_WORK']
82     _current_task = t
83     return t
84
85 _current_job = None
86 def current_job():
87     global _current_job
88     if _current_job:
89         return _current_job
90     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
91     t = UserDict.UserDict(t)
92     t.tmpdir = os.environ['JOB_WORK']
93     _current_job = t
94     return t
95
96 def api():
97     return service
98
99 class JobTask(object):
100     def __init__(self, parameters=dict(), runtime_constraints=dict()):
101         print "init jobtask %s %s" % (parameters, runtime_constraints)
102
103 class job_setup:
104     @staticmethod
105     def one_task_per_input_file(if_sequence=0, and_end_task=True):
106         if if_sequence != current_task()['sequence']:
107             return
108         job_input = current_job()['script_parameters']['input']
109         cr = CollectionReader(job_input)
110         for s in cr.all_streams():
111             for f in s.all_files():
112                 task_input = f.as_manifest()
113                 new_task_attrs = {
114                     'job_uuid': current_job()['uuid'],
115                     'created_by_job_task_uuid': current_task()['uuid'],
116                     'sequence': if_sequence + 1,
117                     'parameters': {
118                         'input':task_input
119                         }
120                     }
121                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
122         if and_end_task:
123             service.job_tasks().update(uuid=current_task()['uuid'],
124                                        job_task=json.dumps({'success':True})
125                                        ).execute()
126             exit(0)
127
128     @staticmethod
129     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
130         if if_sequence != current_task()['sequence']:
131             return
132         job_input = current_job()['script_parameters']['input']
133         cr = CollectionReader(job_input)
134         for s in cr.all_streams():
135             task_input = s.tokens()
136             new_task_attrs = {
137                 'job_uuid': current_job()['uuid'],
138                 'created_by_job_task_uuid': current_task()['uuid'],
139                 'sequence': if_sequence + 1,
140                 'parameters': {
141                     'input':task_input
142                     }
143                 }
144             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
145         if and_end_task:
146             service.job_tasks().update(uuid=current_task()['uuid'],
147                                        job_task=json.dumps({'success':True})
148                                        ).execute()
149             exit(0)
150
151 class util:
152     @staticmethod
153     def run_command(execargs, **kwargs):
154         if 'stdin' not in kwargs:
155             kwargs['stdin'] = subprocess.PIPE
156         if 'stdout' not in kwargs:
157             kwargs['stdout'] = subprocess.PIPE
158         if 'stderr' not in kwargs:
159             kwargs['stderr'] = subprocess.PIPE
160         p = subprocess.Popen(execargs, close_fds=True, shell=False,
161                              **kwargs)
162         stdoutdata, stderrdata = p.communicate(None)
163         if p.returncode != 0:
164             raise Exception("run_command %s exit %d:\n%s" %
165                             (execargs, p.returncode, stderrdata))
166         return stdoutdata, stderrdata
167
168     @staticmethod
169     def git_checkout(url, version, path):
170         if not re.search('^/', path):
171             path = os.path.join(current_job().tmpdir, path)
172         if not os.path.exists(path):
173             util.run_command(["git", "clone", url, path],
174                              cwd=os.path.dirname(path))
175         util.run_command(["git", "checkout", version],
176                          cwd=path)
177         return path
178
179     @staticmethod
180     def tar_extractor(path, decompress_flag):
181         return subprocess.Popen(["tar",
182                                  "-C", path,
183                                  ("-x%sf" % decompress_flag),
184                                  "-"],
185                                 stdout=None,
186                                 stdin=subprocess.PIPE, stderr=sys.stderr,
187                                 shell=False, close_fds=True)
188
189     @staticmethod
190     def tarball_extract(tarball, path):
191         """Retrieve a tarball from Keep and extract it to a local
192         directory.  Return the absolute path where the tarball was
193         extracted. If the top level of the tarball contained just one
194         file or directory, return the absolute path of that single
195         item.
196
197         tarball -- collection locator
198         path -- where to extract the tarball: absolute, or relative to job tmp
199         """
200         if not re.search('^/', path):
201             path = os.path.join(current_job().tmpdir, path)
202         lockfile = open(path + '.lock', 'w')
203         fcntl.flock(lockfile, fcntl.LOCK_EX)
204         try:
205             os.stat(path)
206         except OSError:
207             os.mkdir(path)
208         already_have_it = False
209         try:
210             if os.readlink(os.path.join(path, '.locator')) == tarball:
211                 already_have_it = True
212         except OSError:
213             pass
214         if not already_have_it:
215
216             # emulate "rm -f" (i.e., if the file does not exist, we win)
217             try:
218                 os.unlink(os.path.join(path, '.locator'))
219             except OSError:
220                 if os.path.exists(os.path.join(path, '.locator')):
221                     os.unlink(os.path.join(path, '.locator'))
222
223             for f in CollectionReader(tarball).all_files():
224                 if re.search('\.(tbz|tar.bz2)$', f.name()):
225                     p = util.tar_extractor(path, 'j')
226                 elif re.search('\.(tgz|tar.gz)$', f.name()):
227                     p = util.tar_extractor(path, 'z')
228                 elif re.search('\.tar$', f.name()):
229                     p = util.tar_extractor(path, '')
230                 else:
231                     raise Exception("tarball_extract cannot handle filename %s"
232                                     % f.name())
233                 while True:
234                     buf = f.read(2**20)
235                     if len(buf) == 0:
236                         break
237                     p.stdin.write(buf)
238                 p.stdin.close()
239                 p.wait()
240                 if p.returncode != 0:
241                     lockfile.close()
242                     raise Exception("tar exited %d" % p.returncode)
243             os.symlink(tarball, os.path.join(path, '.locator'))
244         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
245         lockfile.close()
246         if len(tld_extracts) == 1:
247             return os.path.join(path, tld_extracts[0])
248         return path
249
250     @staticmethod
251     def zipball_extract(zipball, path):
252         """Retrieve a zip archive from Keep and extract it to a local
253         directory.  Return the absolute path where the archive was
254         extracted. If the top level of the archive contained just one
255         file or directory, return the absolute path of that single
256         item.
257
258         zipball -- collection locator
259         path -- where to extract the archive: absolute, or relative to job tmp
260         """
261         if not re.search('^/', path):
262             path = os.path.join(current_job().tmpdir, path)
263         lockfile = open(path + '.lock', 'w')
264         fcntl.flock(lockfile, fcntl.LOCK_EX)
265         try:
266             os.stat(path)
267         except OSError:
268             os.mkdir(path)
269         already_have_it = False
270         try:
271             if os.readlink(os.path.join(path, '.locator')) == zipball:
272                 already_have_it = True
273         except OSError:
274             pass
275         if not already_have_it:
276
277             # emulate "rm -f" (i.e., if the file does not exist, we win)
278             try:
279                 os.unlink(os.path.join(path, '.locator'))
280             except OSError:
281                 if os.path.exists(os.path.join(path, '.locator')):
282                     os.unlink(os.path.join(path, '.locator'))
283
284             for f in CollectionReader(zipball).all_files():
285                 if not re.search('\.zip$', f.name()):
286                     raise Exception("zipball_extract cannot handle filename %s"
287                                     % f.name())
288                 zip_filename = os.path.join(path, os.path.basename(f.name()))
289                 zip_file = open(zip_filename, 'wb')
290                 while True:
291                     buf = f.read(2**20)
292                     if len(buf) == 0:
293                         break
294                     zip_file.write(buf)
295                 zip_file.close()
296                 
297                 p = subprocess.Popen(["unzip",
298                                       "-q", "-o",
299                                       "-d", path,
300                                       zip_filename],
301                                      stdout=None,
302                                      stdin=None, stderr=sys.stderr,
303                                      shell=False, close_fds=True)
304                 p.wait()
305                 if p.returncode != 0:
306                     lockfile.close()
307                     raise Exception("unzip exited %d" % p.returncode)
308                 os.unlink(zip_filename)
309             os.symlink(zipball, os.path.join(path, '.locator'))
310         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
311         lockfile.close()
312         if len(tld_extracts) == 1:
313             return os.path.join(path, tld_extracts[0])
314         return path
315
316     @staticmethod
317     def collection_extract(collection, path, files=[], decompress=True):
318         """Retrieve a collection from Keep and extract it to a local
319         directory.  Return the absolute path where the collection was
320         extracted.
321
322         collection -- collection locator
323         path -- where to extract: absolute, or relative to job tmp
324         """
325         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
326         if matches:
327             collection_hash = matches.group(1)
328         else:
329             collection_hash = hashlib.md5(collection).hexdigest()
330         if not re.search('^/', path):
331             path = os.path.join(current_job().tmpdir, path)
332         lockfile = open(path + '.lock', 'w')
333         fcntl.flock(lockfile, fcntl.LOCK_EX)
334         try:
335             os.stat(path)
336         except OSError:
337             os.mkdir(path)
338         already_have_it = False
339         try:
340             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
341                 already_have_it = True
342         except OSError:
343             pass
344
345         # emulate "rm -f" (i.e., if the file does not exist, we win)
346         try:
347             os.unlink(os.path.join(path, '.locator'))
348         except OSError:
349             if os.path.exists(os.path.join(path, '.locator')):
350                 os.unlink(os.path.join(path, '.locator'))
351
352         files_got = []
353         for s in CollectionReader(collection).all_streams():
354             stream_name = s.name()
355             for f in s.all_files():
356                 if (files == [] or
357                     ((f.name() not in files_got) and
358                      (f.name() in files or
359                       (decompress and f.decompressed_name() in files)))):
360                     outname = f.decompressed_name() if decompress else f.name()
361                     files_got += [outname]
362                     if os.path.exists(os.path.join(path, stream_name, outname)):
363                         continue
364                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
365                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
366                     for buf in (f.readall_decompressed() if decompress
367                                 else f.readall()):
368                         outfile.write(buf)
369                     outfile.close()
370         if len(files_got) < len(files):
371             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
372         os.symlink(collection_hash, os.path.join(path, '.locator'))
373
374         lockfile.close()
375         return path
376
377     @staticmethod
378     def mkdir_dash_p(path):
379         if not os.path.exists(path):
380             util.mkdir_dash_p(os.path.dirname(path))
381             try:
382                 os.mkdir(path)
383             except OSError:
384                 if not os.path.exists(path):
385                     os.mkdir(path)
386
387     @staticmethod
388     def stream_extract(stream, path, files=[], decompress=True):
389         """Retrieve a stream from Keep and extract it to a local
390         directory.  Return the absolute path where the stream was
391         extracted.
392
393         stream -- StreamReader object
394         path -- where to extract: absolute, or relative to job tmp
395         """
396         if not re.search('^/', path):
397             path = os.path.join(current_job().tmpdir, path)
398         lockfile = open(path + '.lock', 'w')
399         fcntl.flock(lockfile, fcntl.LOCK_EX)
400         try:
401             os.stat(path)
402         except OSError:
403             os.mkdir(path)
404
405         files_got = []
406         for f in stream.all_files():
407             if (files == [] or
408                 ((f.name() not in files_got) and
409                  (f.name() in files or
410                   (decompress and f.decompressed_name() in files)))):
411                 outname = f.decompressed_name() if decompress else f.name()
412                 files_got += [outname]
413                 if os.path.exists(os.path.join(path, outname)):
414                     os.unlink(os.path.join(path, outname))
415                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
416                 outfile = open(os.path.join(path, outname), 'wb')
417                 for buf in (f.readall_decompressed() if decompress
418                             else f.readall()):
419                     outfile.write(buf)
420                 outfile.close()
421         if len(files_got) < len(files):
422             raise Exception("Wanted files %s but only got %s from %s" %
423                             (files, files_got, map(lambda z: z.name(),
424                                                    list(stream.all_files()))))
425         lockfile.close()
426         return path
427
428     @staticmethod
429     def listdir_recursive(dirname, base=None):
430         allfiles = []
431         for ent in sorted(os.listdir(dirname)):
432             ent_path = os.path.join(dirname, ent)
433             ent_base = os.path.join(base, ent) if base else ent
434             if os.path.isdir(ent_path):
435                 allfiles += util.listdir_recursive(ent_path, ent_base)
436             else:
437                 allfiles += [ent_base]
438         return allfiles
439
440 class StreamFileReader(object):
441     def __init__(self, stream, pos, size, name):
442         self._stream = stream
443         self._pos = pos
444         self._size = size
445         self._name = name
446         self._filepos = 0
447     def name(self):
448         return self._name
449     def decompressed_name(self):
450         return re.sub('\.(bz2|gz)$', '', self._name)
451     def size(self):
452         return self._size
453     def stream_name(self):
454         return self._stream.name()
455     def read(self, size, **kwargs):
456         self._stream.seek(self._pos + self._filepos)
457         data = self._stream.read(min(size, self._size - self._filepos))
458         self._filepos += len(data)
459         return data
460     def readall(self, size=2**20, **kwargs):
461         while True:
462             data = self.read(size, **kwargs)
463             if data == '':
464                 break
465             yield data
466     def bunzip2(self, size):
467         decompressor = bz2.BZ2Decompressor()
468         for chunk in self.readall(size):
469             data = decompressor.decompress(chunk)
470             if data and data != '':
471                 yield data
472     def gunzip(self, size):
473         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
474         for chunk in self.readall(size):
475             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
476             if data and data != '':
477                 yield data
478     def readall_decompressed(self, size=2**20):
479         self._stream.seek(self._pos + self._filepos)
480         if re.search('\.bz2$', self._name):
481             return self.bunzip2(size)
482         elif re.search('\.gz$', self._name):
483             return self.gunzip(size)
484         else:
485             return self.readall(size)
486     def readlines(self, decompress=True):
487         if decompress:
488             datasource = self.readall_decompressed()
489         else:
490             self._stream.seek(self._pos + self._filepos)
491             datasource = self.readall()
492         data = ''
493         for newdata in datasource:
494             data += newdata
495             sol = 0
496             while True:
497                 eol = string.find(data, "\n", sol)
498                 if eol < 0:
499                     break
500                 yield data[sol:eol+1]
501                 sol = eol+1
502             data = data[sol:]
503         if data != '':
504             yield data
505     def as_manifest(self):
506         if self.size() == 0:
507             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
508                     % (self._stream.name(), self.name()))
509         return string.join(self._stream.tokens_for_range(self._pos, self._size),
510                            " ") + "\n"
511
512 class StreamReader(object):
513     def __init__(self, tokens):
514         self._tokens = tokens
515         self._current_datablock_data = None
516         self._current_datablock_pos = 0
517         self._current_datablock_index = -1
518         self._pos = 0
519
520         self._stream_name = None
521         self.data_locators = []
522         self.files = []
523
524         for tok in self._tokens:
525             if self._stream_name == None:
526                 self._stream_name = tok
527             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
528                 self.data_locators += [tok]
529             elif re.search(r'^\d+:\d+:\S+', tok):
530                 pos, size, name = tok.split(':',2)
531                 self.files += [[int(pos), int(size), name]]
532             else:
533                 raise Exception("Invalid manifest format")
534
535     def tokens(self):
536         return self._tokens
537     def tokens_for_range(self, range_start, range_size):
538         resp = [self._stream_name]
539         return_all_tokens = False
540         block_start = 0
541         token_bytes_skipped = 0
542         for locator in self.data_locators:
543             sizehint = re.search(r'\+(\d+)', locator)
544             if not sizehint:
545                 return_all_tokens = True
546             if return_all_tokens:
547                 resp += [locator]
548                 next
549             blocksize = int(sizehint.group(0))
550             if range_start + range_size <= block_start:
551                 break
552             if range_start < block_start + blocksize:
553                 resp += [locator]
554             else:
555                 token_bytes_skipped += blocksize
556             block_start += blocksize
557         for f in self.files:
558             if ((f[0] < range_start + range_size)
559                 and
560                 (f[0] + f[1] > range_start)
561                 and
562                 f[1] > 0):
563                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
564         return resp
565     def name(self):
566         return self._stream_name
567     def all_files(self):
568         for f in self.files:
569             pos, size, name = f
570             yield StreamFileReader(self, pos, size, name)
571     def nextdatablock(self):
572         if self._current_datablock_index < 0:
573             self._current_datablock_pos = 0
574             self._current_datablock_index = 0
575         else:
576             self._current_datablock_pos += self.current_datablock_size()
577             self._current_datablock_index += 1
578         self._current_datablock_data = None
579     def current_datablock_data(self):
580         if self._current_datablock_data == None:
581             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
582         return self._current_datablock_data
583     def current_datablock_size(self):
584         if self._current_datablock_index < 0:
585             self.nextdatablock()
586         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
587         if sizehint:
588             return int(sizehint.group(0))
589         return len(self.current_datablock_data())
590     def seek(self, pos):
591         """Set the position of the next read operation."""
592         self._pos = pos
593     def really_seek(self):
594         """Find and load the appropriate data block, so the byte at
595         _pos is in memory.
596         """
597         if self._pos == self._current_datablock_pos:
598             return True
599         if (self._current_datablock_pos != None and
600             self._pos >= self._current_datablock_pos and
601             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
602             return True
603         if self._pos < self._current_datablock_pos:
604             self._current_datablock_index = -1
605             self.nextdatablock()
606         while (self._pos > self._current_datablock_pos and
607                self._pos > self._current_datablock_pos + self.current_datablock_size()):
608             self.nextdatablock()
609     def read(self, size):
610         """Read no more than size bytes -- but at least one byte,
611         unless _pos is already at the end of the stream.
612         """
613         if size == 0:
614             return ''
615         self.really_seek()
616         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
617             self.nextdatablock()
618             if self._current_datablock_index >= len(self.data_locators):
619                 return None
620         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
621         self._pos += len(data)
622         return data
623
624 class CollectionReader(object):
625     def __init__(self, manifest_locator_or_text):
626         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
627             self._manifest_text = manifest_locator_or_text
628             self._manifest_locator = None
629         else:
630             self._manifest_locator = manifest_locator_or_text
631             self._manifest_text = None
632         self._streams = None
633     def __enter__(self):
634         pass
635     def __exit__(self):
636         pass
637     def _populate(self):
638         if self._streams != None:
639             return
640         if not self._manifest_text:
641             self._manifest_text = Keep.get(self._manifest_locator)
642         self._streams = []
643         for stream_line in self._manifest_text.split("\n"):
644             if stream_line != '':
645                 stream_tokens = stream_line.split()
646                 self._streams += [stream_tokens]
647     def all_streams(self):
648         self._populate()
649         resp = []
650         for s in self._streams:
651             resp += [StreamReader(s)]
652         return resp
653     def all_files(self):
654         for s in self.all_streams():
655             for f in s.all_files():
656                 yield f
657     def manifest_text(self):
658         self._populate()
659         return self._manifest_text
660
661 class CollectionWriter(object):
662     KEEP_BLOCK_SIZE = 2**26
663     def __init__(self):
664         self._data_buffer = []
665         self._data_buffer_len = 0
666         self._current_stream_files = []
667         self._current_stream_length = 0
668         self._current_stream_locators = []
669         self._current_stream_name = '.'
670         self._current_file_name = None
671         self._current_file_pos = 0
672         self._finished_streams = []
673     def __enter__(self):
674         pass
675     def __exit__(self):
676         self.finish()
677     def write_directory_tree(self,
678                              path, stream_name='.', max_manifest_depth=-1):
679         self.start_new_stream(stream_name)
680         todo = []
681         if max_manifest_depth == 0:
682             dirents = util.listdir_recursive(path)
683         else:
684             dirents = sorted(os.listdir(path))
685         for dirent in dirents:
686             target = os.path.join(path, dirent)
687             if os.path.isdir(target):
688                 todo += [[target,
689                           os.path.join(stream_name, dirent),
690                           max_manifest_depth-1]]
691             else:
692                 self.start_new_file(dirent)
693                 with open(target, 'rb') as f:
694                     while True:
695                         buf = f.read(2**26)
696                         if len(buf) == 0:
697                             break
698                         self.write(buf)
699         self.finish_current_stream()
700         map(lambda x: self.write_directory_tree(*x), todo)
701
702     def write(self, newdata):
703         self._data_buffer += [newdata]
704         self._data_buffer_len += len(newdata)
705         self._current_stream_length += len(newdata)
706         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
707             self.flush_data()
708     def flush_data(self):
709         data_buffer = ''.join(self._data_buffer)
710         if data_buffer != '':
711             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
712             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
713             self._data_buffer_len = len(self._data_buffer[0])
714     def start_new_file(self, newfilename=None):
715         self.finish_current_file()
716         self.set_current_file_name(newfilename)
717     def set_current_file_name(self, newfilename):
718         newfilename = re.sub(r' ', '\\\\040', newfilename)
719         if re.search(r'[ \t\n]', newfilename):
720             raise AssertionError("Manifest filenames cannot contain whitespace")
721         self._current_file_name = newfilename
722     def current_file_name(self):
723         return self._current_file_name
724     def finish_current_file(self):
725         if self._current_file_name == None:
726             if self._current_file_pos == self._current_stream_length:
727                 return
728             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
729         self._current_stream_files += [[self._current_file_pos,
730                                        self._current_stream_length - self._current_file_pos,
731                                        self._current_file_name]]
732         self._current_file_pos = self._current_stream_length
733     def start_new_stream(self, newstreamname='.'):
734         self.finish_current_stream()
735         self.set_current_stream_name(newstreamname)
736     def set_current_stream_name(self, newstreamname):
737         if re.search(r'[ \t\n]', newstreamname):
738             raise AssertionError("Manifest stream names cannot contain whitespace")
739         self._current_stream_name = newstreamname
740     def current_stream_name(self):
741         return self._current_stream_name
742     def finish_current_stream(self):
743         self.finish_current_file()
744         self.flush_data()
745         if len(self._current_stream_files) == 0:
746             pass
747         elif self._current_stream_name == None:
748             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
749         else:
750             self._finished_streams += [[self._current_stream_name,
751                                        self._current_stream_locators,
752                                        self._current_stream_files]]
753         self._current_stream_files = []
754         self._current_stream_length = 0
755         self._current_stream_locators = []
756         self._current_stream_name = None
757         self._current_file_pos = 0
758         self._current_file_name = None
759     def finish(self):
760         return Keep.put(self.manifest_text())
761     def manifest_text(self):
762         self.finish_current_stream()
763         manifest = ''
764         for stream in self._finished_streams:
765             if not re.search(r'^\.(/.*)?$', stream[0]):
766                 manifest += './'
767             manifest += stream[0]
768             if len(stream[1]) == 0:
769                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
770             else:
771                 for locator in stream[1]:
772                     manifest += " %s" % locator
773             for sfile in stream[2]:
774                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
775             manifest += "\n"
776         return manifest
777
778 global_client_object = None
779
780 class Keep:
781     @staticmethod
782     def global_client_object():
783         global global_client_object
784         if global_client_object == None:
785             global_client_object = KeepClient()
786         return global_client_object
787
788     @staticmethod
789     def get(locator, **kwargs):
790         return Keep.global_client_object().get(locator, **kwargs)
791
792     @staticmethod
793     def put(data, **kwargs):
794         return Keep.global_client_object().put(data, **kwargs)
795
796 class KeepClient(object):
797
798     class ThreadLimiter(object):
799         """
800         Limit the number of threads running at a given time to
801         {desired successes} minus {successes reported}. When successes
802         reported == desired, wake up the remaining threads and tell
803         them to quit.
804
805         Should be used in a "with" block.
806         """
807         def __init__(self, todo):
808             self._todo = todo
809             self._done = 0
810             self._todo_lock = threading.Semaphore(todo)
811             self._done_lock = threading.Lock()
812         def __enter__(self):
813             self._todo_lock.acquire()
814             return self
815         def __exit__(self, type, value, traceback):
816             self._todo_lock.release()
817         def shall_i_proceed(self):
818             """
819             Return true if the current thread should do stuff. Return
820             false if the current thread should just stop.
821             """
822             with self._done_lock:
823                 return (self._done < self._todo)
824         def increment_done(self):
825             """
826             Report that the current thread was successful.
827             """
828             with self._done_lock:
829                 self._done += 1
830         def done(self):
831             """
832             Return how many successes were reported.
833             """
834             with self._done_lock:
835                 return self._done
836
837     class KeepWriterThread(threading.Thread):
838         """
839         Write a blob of data to the given Keep server. Call
840         increment_done() of the given ThreadLimiter if the write
841         succeeds.
842         """
843         def __init__(self, **kwargs):
844             super(KeepClient.KeepWriterThread, self).__init__()
845             self.args = kwargs
846         def run(self):
847             with self.args['thread_limiter'] as limiter:
848                 if not limiter.shall_i_proceed():
849                     # My turn arrived, but the job has been done without
850                     # me.
851                     return
852                 logging.debug("KeepWriterThread %s proceeding %s %s" %
853                               (str(threading.current_thread()),
854                                self.args['data_hash'],
855                                self.args['service_root']))
856                 h = httplib2.Http()
857                 url = self.args['service_root'] + self.args['data_hash']
858                 api_token = os.environ['ARVADOS_API_TOKEN']
859                 headers = {'Authorization': "OAuth2 %s" % api_token}
860                 try:
861                     resp, content = h.request(url.encode('utf-8'), 'PUT',
862                                               headers=headers,
863                                               body=self.args['data'])
864                     if (resp['status'] == '401' and
865                         re.match(r'Timestamp verification failed', content)):
866                         body = KeepClient.sign_for_old_server(
867                             self.args['data_hash'],
868                             self.args['data'])
869                         h = httplib2.Http()
870                         resp, content = h.request(url.encode('utf-8'), 'PUT',
871                                                   headers=headers,
872                                                   body=body)
873                     if re.match(r'^2\d\d$', resp['status']):
874                         logging.debug("KeepWriterThread %s succeeded %s %s" %
875                                       (str(threading.current_thread()),
876                                        self.args['data_hash'],
877                                        self.args['service_root']))
878                         return limiter.increment_done()
879                     logging.warning("Request fail: PUT %s => %s %s" %
880                                     (url, resp['status'], content))
881                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
882                     logging.warning("Request fail: PUT %s => %s: %s" %
883                                     (url, type(e), str(e)))
884
885     def __init__(self):
886         self.lock = threading.Lock()
887         self.service_roots = None
888
889     def shuffled_service_roots(self, hash):
890         if self.service_roots == None:
891             self.lock.acquire()
892             keep_disks = api().keep_disks().list().execute()['items']
893             roots = (("http%s://%s:%d/" %
894                       ('s' if f['service_ssl_flag'] else '',
895                        f['service_host'],
896                        f['service_port']))
897                      for f in keep_disks)
898             self.service_roots = sorted(set(roots))
899             logging.debug(str(self.service_roots))
900             self.lock.release()
901         seed = hash
902         pool = self.service_roots[:]
903         pseq = []
904         while len(pool) > 0:
905             if len(seed) < 8:
906                 if len(pseq) < len(hash) / 4: # first time around
907                     seed = hash[-4:] + hash
908                 else:
909                     seed += hash
910             probe = int(seed[0:8], 16) % len(pool)
911             pseq += [pool[probe]]
912             pool = pool[:probe] + pool[probe+1:]
913             seed = seed[8:]
914         logging.debug(str(pseq))
915         return pseq
916
917     def get(self, locator):
918         if 'KEEP_LOCAL_STORE' in os.environ:
919             return KeepClient.local_store_get(locator)
920         expect_hash = re.sub(r'\+.*', '', locator)
921         for service_root in self.shuffled_service_roots(expect_hash):
922             h = httplib2.Http()
923             url = service_root + expect_hash
924             api_token = os.environ['ARVADOS_API_TOKEN']
925             headers = {'Authorization': "OAuth2 %s" % api_token,
926                        'Accept': 'application/octet-stream'}
927             try:
928                 resp, content = h.request(url.encode('utf-8'), 'GET',
929                                           headers=headers)
930                 if re.match(r'^2\d\d$', resp['status']):
931                     m = hashlib.new('md5')
932                     m.update(content)
933                     md5 = m.hexdigest()
934                     if md5 == expect_hash:
935                         return content
936                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
937             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
938                 logging.info("Request fail: GET %s => %s: %s" %
939                              (url, type(e), str(e)))
940         raise Exception("Not found: %s" % expect_hash)
941
942     def put(self, data, **kwargs):
943         if 'KEEP_LOCAL_STORE' in os.environ:
944             return KeepClient.local_store_put(data)
945         m = hashlib.new('md5')
946         m.update(data)
947         data_hash = m.hexdigest()
948         have_copies = 0
949         want_copies = kwargs.get('copies', 2)
950         if not (want_copies > 0):
951             return data_hash
952         threads = []
953         thread_limiter = KeepClient.ThreadLimiter(want_copies)
954         for service_root in self.shuffled_service_roots(data_hash):
955             t = KeepClient.KeepWriterThread(data=data,
956                                             data_hash=data_hash,
957                                             service_root=service_root,
958                                             thread_limiter=thread_limiter)
959             t.start()
960             threads += [t]
961         for t in threads:
962             t.join()
963         have_copies = thread_limiter.done()
964         if have_copies == want_copies:
965             return (data_hash + '+' + str(len(data)))
966         raise Exception("Write fail for %s: wanted %d but wrote %d" %
967                         (data_hash, want_copies, have_copies))
968
969     @staticmethod
970     def sign_for_old_server(data_hash, data):
971         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
972
973
974     @staticmethod
975     def local_store_put(data):
976         m = hashlib.new('md5')
977         m.update(data)
978         md5 = m.hexdigest()
979         locator = '%s+%d' % (md5, len(data))
980         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
981             f.write(data)
982         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
983                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
984         return locator
985     @staticmethod
986     def local_store_get(locator):
987         r = re.search('^([0-9a-f]{32,})', locator)
988         if not r:
989             raise Exception("Keep.get: invalid data locator '%s'" % locator)
990         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
991             return ''
992         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
993             return f.read()