Allow caller to override close_fds in util.run_command()
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 class CredentialsFromEnv(object):
28     @staticmethod
29     def http_request(self, uri, **kwargs):
30         from httplib import BadStatusLine
31         if 'headers' not in kwargs:
32             kwargs['headers'] = {}
33         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
34         try:
35             return self.orig_http_request(uri, **kwargs)
36         except BadStatusLine:
37             # This is how httplib tells us that it tried to reuse an
38             # existing connection but it was already closed by the
39             # server. In that case, yes, we would like to retry.
40             # Unfortunately, we are not absolutely certain that the
41             # previous call did not succeed, so this is slightly
42             # risky.
43             return self.orig_http_request(uri, **kwargs)
44     def authorize(self, http):
45         http.orig_http_request = http.request
46         http.request = types.MethodType(self.http_request, http)
47         return http
48
49 url = ('https://%s/discovery/v1/apis/'
50        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
51 credentials = CredentialsFromEnv()
52
53 # Use system's CA certificates (if we find them) instead of httplib2's
54 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
55 if not os.path.exists(ca_certs):
56     ca_certs = None             # use httplib2 default
57
58 http = httplib2.Http(ca_certs=ca_certs)
59 http = credentials.authorize(http)
60 if re.match(r'(?i)^(true|1|yes)$',
61             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
62     http.disable_ssl_certificate_validation=True
63 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
64
65 def task_set_output(self,s):
66     service.job_tasks().update(uuid=self['uuid'],
67                                job_task=json.dumps({
68                 'output':s,
69                 'success':True,
70                 'progress':1.0
71                 })).execute()
72
73 _current_task = None
74 def current_task():
75     global _current_task
76     if _current_task:
77         return _current_task
78     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
79     t = UserDict.UserDict(t)
80     t.set_output = types.MethodType(task_set_output, t)
81     t.tmpdir = os.environ['TASK_WORK']
82     _current_task = t
83     return t
84
85 _current_job = None
86 def current_job():
87     global _current_job
88     if _current_job:
89         return _current_job
90     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
91     t = UserDict.UserDict(t)
92     t.tmpdir = os.environ['JOB_WORK']
93     _current_job = t
94     return t
95
96 def api():
97     return service
98
99 class JobTask(object):
100     def __init__(self, parameters=dict(), runtime_constraints=dict()):
101         print "init jobtask %s %s" % (parameters, runtime_constraints)
102
103 class job_setup:
104     @staticmethod
105     def one_task_per_input_file(if_sequence=0, and_end_task=True):
106         if if_sequence != current_task()['sequence']:
107             return
108         job_input = current_job()['script_parameters']['input']
109         cr = CollectionReader(job_input)
110         for s in cr.all_streams():
111             for f in s.all_files():
112                 task_input = f.as_manifest()
113                 new_task_attrs = {
114                     'job_uuid': current_job()['uuid'],
115                     'created_by_job_task_uuid': current_task()['uuid'],
116                     'sequence': if_sequence + 1,
117                     'parameters': {
118                         'input':task_input
119                         }
120                     }
121                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
122         if and_end_task:
123             service.job_tasks().update(uuid=current_task()['uuid'],
124                                        job_task=json.dumps({'success':True})
125                                        ).execute()
126             exit(0)
127
128     @staticmethod
129     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
130         if if_sequence != current_task()['sequence']:
131             return
132         job_input = current_job()['script_parameters']['input']
133         cr = CollectionReader(job_input)
134         for s in cr.all_streams():
135             task_input = s.tokens()
136             new_task_attrs = {
137                 'job_uuid': current_job()['uuid'],
138                 'created_by_job_task_uuid': current_task()['uuid'],
139                 'sequence': if_sequence + 1,
140                 'parameters': {
141                     'input':task_input
142                     }
143                 }
144             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
145         if and_end_task:
146             service.job_tasks().update(uuid=current_task()['uuid'],
147                                        job_task=json.dumps({'success':True})
148                                        ).execute()
149             exit(0)
150
151 class util:
152     @staticmethod
153     def run_command(execargs, **kwargs):
154         kwargs.setdefault('stdin', subprocess.PIPE)
155         kwargs.setdefault('stdout', subprocess.PIPE)
156         kwargs.setdefault('stderr', subprocess.PIPE)
157         kwargs.setdefault('close_fds', True)
158         kwargs.setdefault('shell', False)
159         p = subprocess.Popen(execargs, **kwargs)
160         stdoutdata, stderrdata = p.communicate(None)
161         if p.returncode != 0:
162             raise Exception("run_command %s exit %d:\n%s" %
163                             (execargs, p.returncode, stderrdata))
164         return stdoutdata, stderrdata
165
166     @staticmethod
167     def git_checkout(url, version, path):
168         if not re.search('^/', path):
169             path = os.path.join(current_job().tmpdir, path)
170         if not os.path.exists(path):
171             util.run_command(["git", "clone", url, path],
172                              cwd=os.path.dirname(path))
173         util.run_command(["git", "checkout", version],
174                          cwd=path)
175         return path
176
177     @staticmethod
178     def tar_extractor(path, decompress_flag):
179         return subprocess.Popen(["tar",
180                                  "-C", path,
181                                  ("-x%sf" % decompress_flag),
182                                  "-"],
183                                 stdout=None,
184                                 stdin=subprocess.PIPE, stderr=sys.stderr,
185                                 shell=False, close_fds=True)
186
187     @staticmethod
188     def tarball_extract(tarball, path):
189         """Retrieve a tarball from Keep and extract it to a local
190         directory.  Return the absolute path where the tarball was
191         extracted. If the top level of the tarball contained just one
192         file or directory, return the absolute path of that single
193         item.
194
195         tarball -- collection locator
196         path -- where to extract the tarball: absolute, or relative to job tmp
197         """
198         if not re.search('^/', path):
199             path = os.path.join(current_job().tmpdir, path)
200         lockfile = open(path + '.lock', 'w')
201         fcntl.flock(lockfile, fcntl.LOCK_EX)
202         try:
203             os.stat(path)
204         except OSError:
205             os.mkdir(path)
206         already_have_it = False
207         try:
208             if os.readlink(os.path.join(path, '.locator')) == tarball:
209                 already_have_it = True
210         except OSError:
211             pass
212         if not already_have_it:
213
214             # emulate "rm -f" (i.e., if the file does not exist, we win)
215             try:
216                 os.unlink(os.path.join(path, '.locator'))
217             except OSError:
218                 if os.path.exists(os.path.join(path, '.locator')):
219                     os.unlink(os.path.join(path, '.locator'))
220
221             for f in CollectionReader(tarball).all_files():
222                 if re.search('\.(tbz|tar.bz2)$', f.name()):
223                     p = util.tar_extractor(path, 'j')
224                 elif re.search('\.(tgz|tar.gz)$', f.name()):
225                     p = util.tar_extractor(path, 'z')
226                 elif re.search('\.tar$', f.name()):
227                     p = util.tar_extractor(path, '')
228                 else:
229                     raise Exception("tarball_extract cannot handle filename %s"
230                                     % f.name())
231                 while True:
232                     buf = f.read(2**20)
233                     if len(buf) == 0:
234                         break
235                     p.stdin.write(buf)
236                 p.stdin.close()
237                 p.wait()
238                 if p.returncode != 0:
239                     lockfile.close()
240                     raise Exception("tar exited %d" % p.returncode)
241             os.symlink(tarball, os.path.join(path, '.locator'))
242         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
243         lockfile.close()
244         if len(tld_extracts) == 1:
245             return os.path.join(path, tld_extracts[0])
246         return path
247
248     @staticmethod
249     def zipball_extract(zipball, path):
250         """Retrieve a zip archive from Keep and extract it to a local
251         directory.  Return the absolute path where the archive was
252         extracted. If the top level of the archive contained just one
253         file or directory, return the absolute path of that single
254         item.
255
256         zipball -- collection locator
257         path -- where to extract the archive: absolute, or relative to job tmp
258         """
259         if not re.search('^/', path):
260             path = os.path.join(current_job().tmpdir, path)
261         lockfile = open(path + '.lock', 'w')
262         fcntl.flock(lockfile, fcntl.LOCK_EX)
263         try:
264             os.stat(path)
265         except OSError:
266             os.mkdir(path)
267         already_have_it = False
268         try:
269             if os.readlink(os.path.join(path, '.locator')) == zipball:
270                 already_have_it = True
271         except OSError:
272             pass
273         if not already_have_it:
274
275             # emulate "rm -f" (i.e., if the file does not exist, we win)
276             try:
277                 os.unlink(os.path.join(path, '.locator'))
278             except OSError:
279                 if os.path.exists(os.path.join(path, '.locator')):
280                     os.unlink(os.path.join(path, '.locator'))
281
282             for f in CollectionReader(zipball).all_files():
283                 if not re.search('\.zip$', f.name()):
284                     raise Exception("zipball_extract cannot handle filename %s"
285                                     % f.name())
286                 zip_filename = os.path.join(path, os.path.basename(f.name()))
287                 zip_file = open(zip_filename, 'wb')
288                 while True:
289                     buf = f.read(2**20)
290                     if len(buf) == 0:
291                         break
292                     zip_file.write(buf)
293                 zip_file.close()
294                 
295                 p = subprocess.Popen(["unzip",
296                                       "-q", "-o",
297                                       "-d", path,
298                                       zip_filename],
299                                      stdout=None,
300                                      stdin=None, stderr=sys.stderr,
301                                      shell=False, close_fds=True)
302                 p.wait()
303                 if p.returncode != 0:
304                     lockfile.close()
305                     raise Exception("unzip exited %d" % p.returncode)
306                 os.unlink(zip_filename)
307             os.symlink(zipball, os.path.join(path, '.locator'))
308         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
309         lockfile.close()
310         if len(tld_extracts) == 1:
311             return os.path.join(path, tld_extracts[0])
312         return path
313
314     @staticmethod
315     def collection_extract(collection, path, files=[], decompress=True):
316         """Retrieve a collection from Keep and extract it to a local
317         directory.  Return the absolute path where the collection was
318         extracted.
319
320         collection -- collection locator
321         path -- where to extract: absolute, or relative to job tmp
322         """
323         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
324         if matches:
325             collection_hash = matches.group(1)
326         else:
327             collection_hash = hashlib.md5(collection).hexdigest()
328         if not re.search('^/', path):
329             path = os.path.join(current_job().tmpdir, path)
330         lockfile = open(path + '.lock', 'w')
331         fcntl.flock(lockfile, fcntl.LOCK_EX)
332         try:
333             os.stat(path)
334         except OSError:
335             os.mkdir(path)
336         already_have_it = False
337         try:
338             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
339                 already_have_it = True
340         except OSError:
341             pass
342
343         # emulate "rm -f" (i.e., if the file does not exist, we win)
344         try:
345             os.unlink(os.path.join(path, '.locator'))
346         except OSError:
347             if os.path.exists(os.path.join(path, '.locator')):
348                 os.unlink(os.path.join(path, '.locator'))
349
350         files_got = []
351         for s in CollectionReader(collection).all_streams():
352             stream_name = s.name()
353             for f in s.all_files():
354                 if (files == [] or
355                     ((f.name() not in files_got) and
356                      (f.name() in files or
357                       (decompress and f.decompressed_name() in files)))):
358                     outname = f.decompressed_name() if decompress else f.name()
359                     files_got += [outname]
360                     if os.path.exists(os.path.join(path, stream_name, outname)):
361                         continue
362                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
363                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
364                     for buf in (f.readall_decompressed() if decompress
365                                 else f.readall()):
366                         outfile.write(buf)
367                     outfile.close()
368         if len(files_got) < len(files):
369             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
370         os.symlink(collection_hash, os.path.join(path, '.locator'))
371
372         lockfile.close()
373         return path
374
375     @staticmethod
376     def mkdir_dash_p(path):
377         if not os.path.exists(path):
378             util.mkdir_dash_p(os.path.dirname(path))
379             try:
380                 os.mkdir(path)
381             except OSError:
382                 if not os.path.exists(path):
383                     os.mkdir(path)
384
385     @staticmethod
386     def stream_extract(stream, path, files=[], decompress=True):
387         """Retrieve a stream from Keep and extract it to a local
388         directory.  Return the absolute path where the stream was
389         extracted.
390
391         stream -- StreamReader object
392         path -- where to extract: absolute, or relative to job tmp
393         """
394         if not re.search('^/', path):
395             path = os.path.join(current_job().tmpdir, path)
396         lockfile = open(path + '.lock', 'w')
397         fcntl.flock(lockfile, fcntl.LOCK_EX)
398         try:
399             os.stat(path)
400         except OSError:
401             os.mkdir(path)
402
403         files_got = []
404         for f in stream.all_files():
405             if (files == [] or
406                 ((f.name() not in files_got) and
407                  (f.name() in files or
408                   (decompress and f.decompressed_name() in files)))):
409                 outname = f.decompressed_name() if decompress else f.name()
410                 files_got += [outname]
411                 if os.path.exists(os.path.join(path, outname)):
412                     os.unlink(os.path.join(path, outname))
413                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
414                 outfile = open(os.path.join(path, outname), 'wb')
415                 for buf in (f.readall_decompressed() if decompress
416                             else f.readall()):
417                     outfile.write(buf)
418                 outfile.close()
419         if len(files_got) < len(files):
420             raise Exception("Wanted files %s but only got %s from %s" %
421                             (files, files_got, map(lambda z: z.name(),
422                                                    list(stream.all_files()))))
423         lockfile.close()
424         return path
425
426     @staticmethod
427     def listdir_recursive(dirname, base=None):
428         allfiles = []
429         for ent in sorted(os.listdir(dirname)):
430             ent_path = os.path.join(dirname, ent)
431             ent_base = os.path.join(base, ent) if base else ent
432             if os.path.isdir(ent_path):
433                 allfiles += util.listdir_recursive(ent_path, ent_base)
434             else:
435                 allfiles += [ent_base]
436         return allfiles
437
438 class StreamFileReader(object):
439     def __init__(self, stream, pos, size, name):
440         self._stream = stream
441         self._pos = pos
442         self._size = size
443         self._name = name
444         self._filepos = 0
445     def name(self):
446         return self._name
447     def decompressed_name(self):
448         return re.sub('\.(bz2|gz)$', '', self._name)
449     def size(self):
450         return self._size
451     def stream_name(self):
452         return self._stream.name()
453     def read(self, size, **kwargs):
454         self._stream.seek(self._pos + self._filepos)
455         data = self._stream.read(min(size, self._size - self._filepos))
456         self._filepos += len(data)
457         return data
458     def readall(self, size=2**20, **kwargs):
459         while True:
460             data = self.read(size, **kwargs)
461             if data == '':
462                 break
463             yield data
464     def bunzip2(self, size):
465         decompressor = bz2.BZ2Decompressor()
466         for chunk in self.readall(size):
467             data = decompressor.decompress(chunk)
468             if data and data != '':
469                 yield data
470     def gunzip(self, size):
471         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
472         for chunk in self.readall(size):
473             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
474             if data and data != '':
475                 yield data
476     def readall_decompressed(self, size=2**20):
477         self._stream.seek(self._pos + self._filepos)
478         if re.search('\.bz2$', self._name):
479             return self.bunzip2(size)
480         elif re.search('\.gz$', self._name):
481             return self.gunzip(size)
482         else:
483             return self.readall(size)
484     def readlines(self, decompress=True):
485         if decompress:
486             datasource = self.readall_decompressed()
487         else:
488             self._stream.seek(self._pos + self._filepos)
489             datasource = self.readall()
490         data = ''
491         for newdata in datasource:
492             data += newdata
493             sol = 0
494             while True:
495                 eol = string.find(data, "\n", sol)
496                 if eol < 0:
497                     break
498                 yield data[sol:eol+1]
499                 sol = eol+1
500             data = data[sol:]
501         if data != '':
502             yield data
503     def as_manifest(self):
504         if self.size() == 0:
505             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
506                     % (self._stream.name(), self.name()))
507         return string.join(self._stream.tokens_for_range(self._pos, self._size),
508                            " ") + "\n"
509
510 class StreamReader(object):
511     def __init__(self, tokens):
512         self._tokens = tokens
513         self._current_datablock_data = None
514         self._current_datablock_pos = 0
515         self._current_datablock_index = -1
516         self._pos = 0
517
518         self._stream_name = None
519         self.data_locators = []
520         self.files = []
521
522         for tok in self._tokens:
523             if self._stream_name == None:
524                 self._stream_name = tok
525             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
526                 self.data_locators += [tok]
527             elif re.search(r'^\d+:\d+:\S+', tok):
528                 pos, size, name = tok.split(':',2)
529                 self.files += [[int(pos), int(size), name]]
530             else:
531                 raise Exception("Invalid manifest format")
532
533     def tokens(self):
534         return self._tokens
535     def tokens_for_range(self, range_start, range_size):
536         resp = [self._stream_name]
537         return_all_tokens = False
538         block_start = 0
539         token_bytes_skipped = 0
540         for locator in self.data_locators:
541             sizehint = re.search(r'\+(\d+)', locator)
542             if not sizehint:
543                 return_all_tokens = True
544             if return_all_tokens:
545                 resp += [locator]
546                 next
547             blocksize = int(sizehint.group(0))
548             if range_start + range_size <= block_start:
549                 break
550             if range_start < block_start + blocksize:
551                 resp += [locator]
552             else:
553                 token_bytes_skipped += blocksize
554             block_start += blocksize
555         for f in self.files:
556             if ((f[0] < range_start + range_size)
557                 and
558                 (f[0] + f[1] > range_start)
559                 and
560                 f[1] > 0):
561                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
562         return resp
563     def name(self):
564         return self._stream_name
565     def all_files(self):
566         for f in self.files:
567             pos, size, name = f
568             yield StreamFileReader(self, pos, size, name)
569     def nextdatablock(self):
570         if self._current_datablock_index < 0:
571             self._current_datablock_pos = 0
572             self._current_datablock_index = 0
573         else:
574             self._current_datablock_pos += self.current_datablock_size()
575             self._current_datablock_index += 1
576         self._current_datablock_data = None
577     def current_datablock_data(self):
578         if self._current_datablock_data == None:
579             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
580         return self._current_datablock_data
581     def current_datablock_size(self):
582         if self._current_datablock_index < 0:
583             self.nextdatablock()
584         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
585         if sizehint:
586             return int(sizehint.group(0))
587         return len(self.current_datablock_data())
588     def seek(self, pos):
589         """Set the position of the next read operation."""
590         self._pos = pos
591     def really_seek(self):
592         """Find and load the appropriate data block, so the byte at
593         _pos is in memory.
594         """
595         if self._pos == self._current_datablock_pos:
596             return True
597         if (self._current_datablock_pos != None and
598             self._pos >= self._current_datablock_pos and
599             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
600             return True
601         if self._pos < self._current_datablock_pos:
602             self._current_datablock_index = -1
603             self.nextdatablock()
604         while (self._pos > self._current_datablock_pos and
605                self._pos > self._current_datablock_pos + self.current_datablock_size()):
606             self.nextdatablock()
607     def read(self, size):
608         """Read no more than size bytes -- but at least one byte,
609         unless _pos is already at the end of the stream.
610         """
611         if size == 0:
612             return ''
613         self.really_seek()
614         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
615             self.nextdatablock()
616             if self._current_datablock_index >= len(self.data_locators):
617                 return None
618         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
619         self._pos += len(data)
620         return data
621
622 class CollectionReader(object):
623     def __init__(self, manifest_locator_or_text):
624         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
625             self._manifest_text = manifest_locator_or_text
626             self._manifest_locator = None
627         else:
628             self._manifest_locator = manifest_locator_or_text
629             self._manifest_text = None
630         self._streams = None
631     def __enter__(self):
632         pass
633     def __exit__(self):
634         pass
635     def _populate(self):
636         if self._streams != None:
637             return
638         if not self._manifest_text:
639             self._manifest_text = Keep.get(self._manifest_locator)
640         self._streams = []
641         for stream_line in self._manifest_text.split("\n"):
642             if stream_line != '':
643                 stream_tokens = stream_line.split()
644                 self._streams += [stream_tokens]
645     def all_streams(self):
646         self._populate()
647         resp = []
648         for s in self._streams:
649             resp += [StreamReader(s)]
650         return resp
651     def all_files(self):
652         for s in self.all_streams():
653             for f in s.all_files():
654                 yield f
655     def manifest_text(self):
656         self._populate()
657         return self._manifest_text
658
659 class CollectionWriter(object):
660     KEEP_BLOCK_SIZE = 2**26
661     def __init__(self):
662         self._data_buffer = []
663         self._data_buffer_len = 0
664         self._current_stream_files = []
665         self._current_stream_length = 0
666         self._current_stream_locators = []
667         self._current_stream_name = '.'
668         self._current_file_name = None
669         self._current_file_pos = 0
670         self._finished_streams = []
671     def __enter__(self):
672         pass
673     def __exit__(self):
674         self.finish()
675     def write_directory_tree(self,
676                              path, stream_name='.', max_manifest_depth=-1):
677         self.start_new_stream(stream_name)
678         todo = []
679         if max_manifest_depth == 0:
680             dirents = util.listdir_recursive(path)
681         else:
682             dirents = sorted(os.listdir(path))
683         for dirent in dirents:
684             target = os.path.join(path, dirent)
685             if os.path.isdir(target):
686                 todo += [[target,
687                           os.path.join(stream_name, dirent),
688                           max_manifest_depth-1]]
689             else:
690                 self.start_new_file(dirent)
691                 with open(target, 'rb') as f:
692                     while True:
693                         buf = f.read(2**26)
694                         if len(buf) == 0:
695                             break
696                         self.write(buf)
697         self.finish_current_stream()
698         map(lambda x: self.write_directory_tree(*x), todo)
699
700     def write(self, newdata):
701         self._data_buffer += [newdata]
702         self._data_buffer_len += len(newdata)
703         self._current_stream_length += len(newdata)
704         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
705             self.flush_data()
706     def flush_data(self):
707         data_buffer = ''.join(self._data_buffer)
708         if data_buffer != '':
709             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
710             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
711             self._data_buffer_len = len(self._data_buffer[0])
712     def start_new_file(self, newfilename=None):
713         self.finish_current_file()
714         self.set_current_file_name(newfilename)
715     def set_current_file_name(self, newfilename):
716         newfilename = re.sub(r' ', '\\\\040', newfilename)
717         if re.search(r'[ \t\n]', newfilename):
718             raise AssertionError("Manifest filenames cannot contain whitespace")
719         self._current_file_name = newfilename
720     def current_file_name(self):
721         return self._current_file_name
722     def finish_current_file(self):
723         if self._current_file_name == None:
724             if self._current_file_pos == self._current_stream_length:
725                 return
726             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
727         self._current_stream_files += [[self._current_file_pos,
728                                        self._current_stream_length - self._current_file_pos,
729                                        self._current_file_name]]
730         self._current_file_pos = self._current_stream_length
731     def start_new_stream(self, newstreamname='.'):
732         self.finish_current_stream()
733         self.set_current_stream_name(newstreamname)
734     def set_current_stream_name(self, newstreamname):
735         if re.search(r'[ \t\n]', newstreamname):
736             raise AssertionError("Manifest stream names cannot contain whitespace")
737         self._current_stream_name = newstreamname
738     def current_stream_name(self):
739         return self._current_stream_name
740     def finish_current_stream(self):
741         self.finish_current_file()
742         self.flush_data()
743         if len(self._current_stream_files) == 0:
744             pass
745         elif self._current_stream_name == None:
746             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
747         else:
748             self._finished_streams += [[self._current_stream_name,
749                                        self._current_stream_locators,
750                                        self._current_stream_files]]
751         self._current_stream_files = []
752         self._current_stream_length = 0
753         self._current_stream_locators = []
754         self._current_stream_name = None
755         self._current_file_pos = 0
756         self._current_file_name = None
757     def finish(self):
758         return Keep.put(self.manifest_text())
759     def manifest_text(self):
760         self.finish_current_stream()
761         manifest = ''
762         for stream in self._finished_streams:
763             if not re.search(r'^\.(/.*)?$', stream[0]):
764                 manifest += './'
765             manifest += stream[0]
766             if len(stream[1]) == 0:
767                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
768             else:
769                 for locator in stream[1]:
770                     manifest += " %s" % locator
771             for sfile in stream[2]:
772                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
773             manifest += "\n"
774         return manifest
775
776 global_client_object = None
777
778 class Keep:
779     @staticmethod
780     def global_client_object():
781         global global_client_object
782         if global_client_object == None:
783             global_client_object = KeepClient()
784         return global_client_object
785
786     @staticmethod
787     def get(locator, **kwargs):
788         return Keep.global_client_object().get(locator, **kwargs)
789
790     @staticmethod
791     def put(data, **kwargs):
792         return Keep.global_client_object().put(data, **kwargs)
793
794 class KeepClient(object):
795
796     class ThreadLimiter(object):
797         """
798         Limit the number of threads running at a given time to
799         {desired successes} minus {successes reported}. When successes
800         reported == desired, wake up the remaining threads and tell
801         them to quit.
802
803         Should be used in a "with" block.
804         """
805         def __init__(self, todo):
806             self._todo = todo
807             self._done = 0
808             self._todo_lock = threading.Semaphore(todo)
809             self._done_lock = threading.Lock()
810         def __enter__(self):
811             self._todo_lock.acquire()
812             return self
813         def __exit__(self, type, value, traceback):
814             self._todo_lock.release()
815         def shall_i_proceed(self):
816             """
817             Return true if the current thread should do stuff. Return
818             false if the current thread should just stop.
819             """
820             with self._done_lock:
821                 return (self._done < self._todo)
822         def increment_done(self):
823             """
824             Report that the current thread was successful.
825             """
826             with self._done_lock:
827                 self._done += 1
828         def done(self):
829             """
830             Return how many successes were reported.
831             """
832             with self._done_lock:
833                 return self._done
834
835     class KeepWriterThread(threading.Thread):
836         """
837         Write a blob of data to the given Keep server. Call
838         increment_done() of the given ThreadLimiter if the write
839         succeeds.
840         """
841         def __init__(self, **kwargs):
842             super(KeepClient.KeepWriterThread, self).__init__()
843             self.args = kwargs
844         def run(self):
845             with self.args['thread_limiter'] as limiter:
846                 if not limiter.shall_i_proceed():
847                     # My turn arrived, but the job has been done without
848                     # me.
849                     return
850                 logging.debug("KeepWriterThread %s proceeding %s %s" %
851                               (str(threading.current_thread()),
852                                self.args['data_hash'],
853                                self.args['service_root']))
854                 h = httplib2.Http()
855                 url = self.args['service_root'] + self.args['data_hash']
856                 api_token = os.environ['ARVADOS_API_TOKEN']
857                 headers = {'Authorization': "OAuth2 %s" % api_token}
858                 try:
859                     resp, content = h.request(url.encode('utf-8'), 'PUT',
860                                               headers=headers,
861                                               body=self.args['data'])
862                     if (resp['status'] == '401' and
863                         re.match(r'Timestamp verification failed', content)):
864                         body = KeepClient.sign_for_old_server(
865                             self.args['data_hash'],
866                             self.args['data'])
867                         h = httplib2.Http()
868                         resp, content = h.request(url.encode('utf-8'), 'PUT',
869                                                   headers=headers,
870                                                   body=body)
871                     if re.match(r'^2\d\d$', resp['status']):
872                         logging.debug("KeepWriterThread %s succeeded %s %s" %
873                                       (str(threading.current_thread()),
874                                        self.args['data_hash'],
875                                        self.args['service_root']))
876                         return limiter.increment_done()
877                     logging.warning("Request fail: PUT %s => %s %s" %
878                                     (url, resp['status'], content))
879                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
880                     logging.warning("Request fail: PUT %s => %s: %s" %
881                                     (url, type(e), str(e)))
882
883     def __init__(self):
884         self.lock = threading.Lock()
885         self.service_roots = None
886
887     def shuffled_service_roots(self, hash):
888         if self.service_roots == None:
889             self.lock.acquire()
890             keep_disks = api().keep_disks().list().execute()['items']
891             roots = (("http%s://%s:%d/" %
892                       ('s' if f['service_ssl_flag'] else '',
893                        f['service_host'],
894                        f['service_port']))
895                      for f in keep_disks)
896             self.service_roots = sorted(set(roots))
897             logging.debug(str(self.service_roots))
898             self.lock.release()
899         seed = hash
900         pool = self.service_roots[:]
901         pseq = []
902         while len(pool) > 0:
903             if len(seed) < 8:
904                 if len(pseq) < len(hash) / 4: # first time around
905                     seed = hash[-4:] + hash
906                 else:
907                     seed += hash
908             probe = int(seed[0:8], 16) % len(pool)
909             pseq += [pool[probe]]
910             pool = pool[:probe] + pool[probe+1:]
911             seed = seed[8:]
912         logging.debug(str(pseq))
913         return pseq
914
915     def get(self, locator):
916         if 'KEEP_LOCAL_STORE' in os.environ:
917             return KeepClient.local_store_get(locator)
918         expect_hash = re.sub(r'\+.*', '', locator)
919         for service_root in self.shuffled_service_roots(expect_hash):
920             h = httplib2.Http()
921             url = service_root + expect_hash
922             api_token = os.environ['ARVADOS_API_TOKEN']
923             headers = {'Authorization': "OAuth2 %s" % api_token,
924                        'Accept': 'application/octet-stream'}
925             try:
926                 resp, content = h.request(url.encode('utf-8'), 'GET',
927                                           headers=headers)
928                 if re.match(r'^2\d\d$', resp['status']):
929                     m = hashlib.new('md5')
930                     m.update(content)
931                     md5 = m.hexdigest()
932                     if md5 == expect_hash:
933                         return content
934                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
935             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
936                 logging.info("Request fail: GET %s => %s: %s" %
937                              (url, type(e), str(e)))
938         raise Exception("Not found: %s" % expect_hash)
939
940     def put(self, data, **kwargs):
941         if 'KEEP_LOCAL_STORE' in os.environ:
942             return KeepClient.local_store_put(data)
943         m = hashlib.new('md5')
944         m.update(data)
945         data_hash = m.hexdigest()
946         have_copies = 0
947         want_copies = kwargs.get('copies', 2)
948         if not (want_copies > 0):
949             return data_hash
950         threads = []
951         thread_limiter = KeepClient.ThreadLimiter(want_copies)
952         for service_root in self.shuffled_service_roots(data_hash):
953             t = KeepClient.KeepWriterThread(data=data,
954                                             data_hash=data_hash,
955                                             service_root=service_root,
956                                             thread_limiter=thread_limiter)
957             t.start()
958             threads += [t]
959         for t in threads:
960             t.join()
961         have_copies = thread_limiter.done()
962         if have_copies == want_copies:
963             return (data_hash + '+' + str(len(data)))
964         raise Exception("Write fail for %s: wanted %d but wrote %d" %
965                         (data_hash, want_copies, have_copies))
966
967     @staticmethod
968     def sign_for_old_server(data_hash, data):
969         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
970
971
972     @staticmethod
973     def local_store_put(data):
974         m = hashlib.new('md5')
975         m.update(data)
976         md5 = m.hexdigest()
977         locator = '%s+%d' % (md5, len(data))
978         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
979             f.write(data)
980         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
981                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
982         return locator
983     @staticmethod
984     def local_store_get(locator):
985         r = re.search('^([0-9a-f]{32,})', locator)
986         if not r:
987             raise Exception("Keep.get: invalid data locator '%s'" % locator)
988         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
989             return ''
990         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
991             return f.read()