Add support for ARVADOS_API_PORT.
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 class CredentialsFromEnv(object):
28     @staticmethod
29     def http_request(self, uri, **kwargs):
30         from httplib import BadStatusLine
31         if 'headers' not in kwargs:
32             kwargs['headers'] = {}
33         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
34         try:
35             return self.orig_http_request(uri, **kwargs)
36         except BadStatusLine:
37             # This is how httplib tells us that it tried to reuse an
38             # existing connection but it was already closed by the
39             # server. In that case, yes, we would like to retry.
40             # Unfortunately, we are not absolutely certain that the
41             # previous call did not succeed, so this is slightly
42             # risky.
43             return self.orig_http_request(uri, **kwargs)
44     def authorize(self, http):
45         http.orig_http_request = http.request
46         http.request = types.MethodType(self.http_request, http)
47         return http
48
49 url = ('https://%s:%s/discovery/v1/apis/'
50        '{api}/{apiVersion}/rest' %
51            (os.environ['ARVADOS_API_HOST'],
52             os.environ.get('ARVADOS_API_PORT') or "443"))
53 credentials = CredentialsFromEnv()
54
55 # Use system's CA certificates (if we find them) instead of httplib2's
56 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
57 if not os.path.exists(ca_certs):
58     ca_certs = None             # use httplib2 default
59
60 http = httplib2.Http(ca_certs=ca_certs)
61 http = credentials.authorize(http)
62 if re.match(r'(?i)^(true|1|yes)$',
63             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
64     http.disable_ssl_certificate_validation=True
65 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
66
67 def task_set_output(self,s):
68     service.job_tasks().update(uuid=self['uuid'],
69                                job_task=json.dumps({
70                 'output':s,
71                 'success':True,
72                 'progress':1.0
73                 })).execute()
74
75 _current_task = None
76 def current_task():
77     global _current_task
78     if _current_task:
79         return _current_task
80     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
81     t = UserDict.UserDict(t)
82     t.set_output = types.MethodType(task_set_output, t)
83     t.tmpdir = os.environ['TASK_WORK']
84     _current_task = t
85     return t
86
87 _current_job = None
88 def current_job():
89     global _current_job
90     if _current_job:
91         return _current_job
92     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
93     t = UserDict.UserDict(t)
94     t.tmpdir = os.environ['JOB_WORK']
95     _current_job = t
96     return t
97
98 def api():
99     return service
100
101 class JobTask(object):
102     def __init__(self, parameters=dict(), runtime_constraints=dict()):
103         print "init jobtask %s %s" % (parameters, runtime_constraints)
104
105 class job_setup:
106     @staticmethod
107     def one_task_per_input_file(if_sequence=0, and_end_task=True):
108         if if_sequence != current_task()['sequence']:
109             return
110         job_input = current_job()['script_parameters']['input']
111         cr = CollectionReader(job_input)
112         for s in cr.all_streams():
113             for f in s.all_files():
114                 task_input = f.as_manifest()
115                 new_task_attrs = {
116                     'job_uuid': current_job()['uuid'],
117                     'created_by_job_task_uuid': current_task()['uuid'],
118                     'sequence': if_sequence + 1,
119                     'parameters': {
120                         'input':task_input
121                         }
122                     }
123                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
124         if and_end_task:
125             service.job_tasks().update(uuid=current_task()['uuid'],
126                                        job_task=json.dumps({'success':True})
127                                        ).execute()
128             exit(0)
129
130     @staticmethod
131     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
132         if if_sequence != current_task()['sequence']:
133             return
134         job_input = current_job()['script_parameters']['input']
135         cr = CollectionReader(job_input)
136         for s in cr.all_streams():
137             task_input = s.tokens()
138             new_task_attrs = {
139                 'job_uuid': current_job()['uuid'],
140                 'created_by_job_task_uuid': current_task()['uuid'],
141                 'sequence': if_sequence + 1,
142                 'parameters': {
143                     'input':task_input
144                     }
145                 }
146             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
147         if and_end_task:
148             service.job_tasks().update(uuid=current_task()['uuid'],
149                                        job_task=json.dumps({'success':True})
150                                        ).execute()
151             exit(0)
152
153 class util:
154     @staticmethod
155     def run_command(execargs, **kwargs):
156         if 'stdin' not in kwargs:
157             kwargs['stdin'] = subprocess.PIPE
158         if 'stdout' not in kwargs:
159             kwargs['stdout'] = subprocess.PIPE
160         if 'stderr' not in kwargs:
161             kwargs['stderr'] = subprocess.PIPE
162         p = subprocess.Popen(execargs, close_fds=True, shell=False,
163                              **kwargs)
164         stdoutdata, stderrdata = p.communicate(None)
165         if p.returncode != 0:
166             raise Exception("run_command %s exit %d:\n%s" %
167                             (execargs, p.returncode, stderrdata))
168         return stdoutdata, stderrdata
169
170     @staticmethod
171     def git_checkout(url, version, path):
172         if not re.search('^/', path):
173             path = os.path.join(current_job().tmpdir, path)
174         if not os.path.exists(path):
175             util.run_command(["git", "clone", url, path],
176                              cwd=os.path.dirname(path))
177         util.run_command(["git", "checkout", version],
178                          cwd=path)
179         return path
180
181     @staticmethod
182     def tar_extractor(path, decompress_flag):
183         return subprocess.Popen(["tar",
184                                  "-C", path,
185                                  ("-x%sf" % decompress_flag),
186                                  "-"],
187                                 stdout=None,
188                                 stdin=subprocess.PIPE, stderr=sys.stderr,
189                                 shell=False, close_fds=True)
190
191     @staticmethod
192     def tarball_extract(tarball, path):
193         """Retrieve a tarball from Keep and extract it to a local
194         directory.  Return the absolute path where the tarball was
195         extracted. If the top level of the tarball contained just one
196         file or directory, return the absolute path of that single
197         item.
198
199         tarball -- collection locator
200         path -- where to extract the tarball: absolute, or relative to job tmp
201         """
202         if not re.search('^/', path):
203             path = os.path.join(current_job().tmpdir, path)
204         lockfile = open(path + '.lock', 'w')
205         fcntl.flock(lockfile, fcntl.LOCK_EX)
206         try:
207             os.stat(path)
208         except OSError:
209             os.mkdir(path)
210         already_have_it = False
211         try:
212             if os.readlink(os.path.join(path, '.locator')) == tarball:
213                 already_have_it = True
214         except OSError:
215             pass
216         if not already_have_it:
217
218             # emulate "rm -f" (i.e., if the file does not exist, we win)
219             try:
220                 os.unlink(os.path.join(path, '.locator'))
221             except OSError:
222                 if os.path.exists(os.path.join(path, '.locator')):
223                     os.unlink(os.path.join(path, '.locator'))
224
225             for f in CollectionReader(tarball).all_files():
226                 if re.search('\.(tbz|tar.bz2)$', f.name()):
227                     p = util.tar_extractor(path, 'j')
228                 elif re.search('\.(tgz|tar.gz)$', f.name()):
229                     p = util.tar_extractor(path, 'z')
230                 elif re.search('\.tar$', f.name()):
231                     p = util.tar_extractor(path, '')
232                 else:
233                     raise Exception("tarball_extract cannot handle filename %s"
234                                     % f.name())
235                 while True:
236                     buf = f.read(2**20)
237                     if len(buf) == 0:
238                         break
239                     p.stdin.write(buf)
240                 p.stdin.close()
241                 p.wait()
242                 if p.returncode != 0:
243                     lockfile.close()
244                     raise Exception("tar exited %d" % p.returncode)
245             os.symlink(tarball, os.path.join(path, '.locator'))
246         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
247         lockfile.close()
248         if len(tld_extracts) == 1:
249             return os.path.join(path, tld_extracts[0])
250         return path
251
252     @staticmethod
253     def zipball_extract(zipball, path):
254         """Retrieve a zip archive from Keep and extract it to a local
255         directory.  Return the absolute path where the archive was
256         extracted. If the top level of the archive contained just one
257         file or directory, return the absolute path of that single
258         item.
259
260         zipball -- collection locator
261         path -- where to extract the archive: absolute, or relative to job tmp
262         """
263         if not re.search('^/', path):
264             path = os.path.join(current_job().tmpdir, path)
265         lockfile = open(path + '.lock', 'w')
266         fcntl.flock(lockfile, fcntl.LOCK_EX)
267         try:
268             os.stat(path)
269         except OSError:
270             os.mkdir(path)
271         already_have_it = False
272         try:
273             if os.readlink(os.path.join(path, '.locator')) == zipball:
274                 already_have_it = True
275         except OSError:
276             pass
277         if not already_have_it:
278
279             # emulate "rm -f" (i.e., if the file does not exist, we win)
280             try:
281                 os.unlink(os.path.join(path, '.locator'))
282             except OSError:
283                 if os.path.exists(os.path.join(path, '.locator')):
284                     os.unlink(os.path.join(path, '.locator'))
285
286             for f in CollectionReader(zipball).all_files():
287                 if not re.search('\.zip$', f.name()):
288                     raise Exception("zipball_extract cannot handle filename %s"
289                                     % f.name())
290                 zip_filename = os.path.join(path, os.path.basename(f.name()))
291                 zip_file = open(zip_filename, 'wb')
292                 while True:
293                     buf = f.read(2**20)
294                     if len(buf) == 0:
295                         break
296                     zip_file.write(buf)
297                 zip_file.close()
298                 
299                 p = subprocess.Popen(["unzip",
300                                       "-q", "-o",
301                                       "-d", path,
302                                       zip_filename],
303                                      stdout=None,
304                                      stdin=None, stderr=sys.stderr,
305                                      shell=False, close_fds=True)
306                 p.wait()
307                 if p.returncode != 0:
308                     lockfile.close()
309                     raise Exception("unzip exited %d" % p.returncode)
310                 os.unlink(zip_filename)
311             os.symlink(zipball, os.path.join(path, '.locator'))
312         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
313         lockfile.close()
314         if len(tld_extracts) == 1:
315             return os.path.join(path, tld_extracts[0])
316         return path
317
318     @staticmethod
319     def collection_extract(collection, path, files=[], decompress=True):
320         """Retrieve a collection from Keep and extract it to a local
321         directory.  Return the absolute path where the collection was
322         extracted.
323
324         collection -- collection locator
325         path -- where to extract: absolute, or relative to job tmp
326         """
327         if not re.search('^/', path):
328             path = os.path.join(current_job().tmpdir, path)
329         lockfile = open(path + '.lock', 'w')
330         fcntl.flock(lockfile, fcntl.LOCK_EX)
331         try:
332             os.stat(path)
333         except OSError:
334             os.mkdir(path)
335         already_have_it = False
336         try:
337             if os.readlink(os.path.join(path, '.locator')) == collection:
338                 already_have_it = True
339         except OSError:
340             pass
341
342         # emulate "rm -f" (i.e., if the file does not exist, we win)
343         try:
344             os.unlink(os.path.join(path, '.locator'))
345         except OSError:
346             if os.path.exists(os.path.join(path, '.locator')):
347                 os.unlink(os.path.join(path, '.locator'))
348
349         files_got = []
350         for s in CollectionReader(collection).all_streams():
351             stream_name = s.name()
352             for f in s.all_files():
353                 if (files == [] or
354                     ((f.name() not in files_got) and
355                      (f.name() in files or
356                       (decompress and f.decompressed_name() in files)))):
357                     outname = f.decompressed_name() if decompress else f.name()
358                     files_got += [outname]
359                     if os.path.exists(os.path.join(path, stream_name, outname)):
360                         continue
361                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
362                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
363                     for buf in (f.readall_decompressed() if decompress
364                                 else f.readall()):
365                         outfile.write(buf)
366                     outfile.close()
367         if len(files_got) < len(files):
368             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
369         os.symlink(collection, os.path.join(path, '.locator'))
370
371         lockfile.close()
372         return path
373
374     @staticmethod
375     def mkdir_dash_p(path):
376         if not os.path.exists(path):
377             util.mkdir_dash_p(os.path.dirname(path))
378             try:
379                 os.mkdir(path)
380             except OSError:
381                 if not os.path.exists(path):
382                     os.mkdir(path)
383
384     @staticmethod
385     def stream_extract(stream, path, files=[], decompress=True):
386         """Retrieve a stream from Keep and extract it to a local
387         directory.  Return the absolute path where the stream was
388         extracted.
389
390         stream -- StreamReader object
391         path -- where to extract: absolute, or relative to job tmp
392         """
393         if not re.search('^/', path):
394             path = os.path.join(current_job().tmpdir, path)
395         lockfile = open(path + '.lock', 'w')
396         fcntl.flock(lockfile, fcntl.LOCK_EX)
397         try:
398             os.stat(path)
399         except OSError:
400             os.mkdir(path)
401
402         files_got = []
403         for f in stream.all_files():
404             if (files == [] or
405                 ((f.name() not in files_got) and
406                  (f.name() in files or
407                   (decompress and f.decompressed_name() in files)))):
408                 outname = f.decompressed_name() if decompress else f.name()
409                 files_got += [outname]
410                 if os.path.exists(os.path.join(path, outname)):
411                     os.unlink(os.path.join(path, outname))
412                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
413                 outfile = open(os.path.join(path, outname), 'wb')
414                 for buf in (f.readall_decompressed() if decompress
415                             else f.readall()):
416                     outfile.write(buf)
417                 outfile.close()
418         if len(files_got) < len(files):
419             raise Exception("Wanted files %s but only got %s from %s" %
420                             (files, files_got, map(lambda z: z.name(),
421                                                    list(stream.all_files()))))
422         lockfile.close()
423         return path
424
425     @staticmethod
426     def listdir_recursive(dirname, base=None):
427         allfiles = []
428         for ent in sorted(os.listdir(dirname)):
429             ent_path = os.path.join(dirname, ent)
430             ent_base = os.path.join(base, ent) if base else ent
431             if os.path.isdir(ent_path):
432                 allfiles += util.listdir_recursive(ent_path, ent_base)
433             else:
434                 allfiles += [ent_base]
435         return allfiles
436
437 class StreamFileReader(object):
438     def __init__(self, stream, pos, size, name):
439         self._stream = stream
440         self._pos = pos
441         self._size = size
442         self._name = name
443         self._filepos = 0
444     def name(self):
445         return self._name
446     def decompressed_name(self):
447         return re.sub('\.(bz2|gz)$', '', self._name)
448     def size(self):
449         return self._size
450     def stream_name(self):
451         return self._stream.name()
452     def read(self, size, **kwargs):
453         self._stream.seek(self._pos + self._filepos)
454         data = self._stream.read(min(size, self._size - self._filepos))
455         self._filepos += len(data)
456         return data
457     def readall(self, size=2**20, **kwargs):
458         while True:
459             data = self.read(size, **kwargs)
460             if data == '':
461                 break
462             yield data
463     def bunzip2(self, size):
464         decompressor = bz2.BZ2Decompressor()
465         for chunk in self.readall(size):
466             data = decompressor.decompress(chunk)
467             if data and data != '':
468                 yield data
469     def gunzip(self, size):
470         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
471         for chunk in self.readall(size):
472             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
473             if data and data != '':
474                 yield data
475     def readall_decompressed(self, size=2**20):
476         self._stream.seek(self._pos + self._filepos)
477         if re.search('\.bz2$', self._name):
478             return self.bunzip2(size)
479         elif re.search('\.gz$', self._name):
480             return self.gunzip(size)
481         else:
482             return self.readall(size)
483     def readlines(self, decompress=True):
484         if decompress:
485             datasource = self.readall_decompressed()
486         else:
487             self._stream.seek(self._pos + self._filepos)
488             datasource = self.readall()
489         data = ''
490         for newdata in datasource:
491             data += newdata
492             sol = 0
493             while True:
494                 eol = string.find(data, "\n", sol)
495                 if eol < 0:
496                     break
497                 yield data[sol:eol+1]
498                 sol = eol+1
499             data = data[sol:]
500         if data != '':
501             yield data
502     def as_manifest(self):
503         if self.size() == 0:
504             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
505                     % (self._stream.name(), self.name()))
506         return string.join(self._stream.tokens_for_range(self._pos, self._size),
507                            " ") + "\n"
508
509 class StreamReader(object):
510     def __init__(self, tokens):
511         self._tokens = tokens
512         self._current_datablock_data = None
513         self._current_datablock_pos = 0
514         self._current_datablock_index = -1
515         self._pos = 0
516
517         self._stream_name = None
518         self.data_locators = []
519         self.files = []
520
521         for tok in self._tokens:
522             if self._stream_name == None:
523                 self._stream_name = tok
524             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
525                 self.data_locators += [tok]
526             elif re.search(r'^\d+:\d+:\S+', tok):
527                 pos, size, name = tok.split(':',2)
528                 self.files += [[int(pos), int(size), name]]
529             else:
530                 raise Exception("Invalid manifest format")
531
532     def tokens(self):
533         return self._tokens
534     def tokens_for_range(self, range_start, range_size):
535         resp = [self._stream_name]
536         return_all_tokens = False
537         block_start = 0
538         token_bytes_skipped = 0
539         for locator in self.data_locators:
540             sizehint = re.search(r'\+(\d+)', locator)
541             if not sizehint:
542                 return_all_tokens = True
543             if return_all_tokens:
544                 resp += [locator]
545                 next
546             blocksize = int(sizehint.group(0))
547             if range_start + range_size <= block_start:
548                 break
549             if range_start < block_start + blocksize:
550                 resp += [locator]
551             else:
552                 token_bytes_skipped += blocksize
553             block_start += blocksize
554         for f in self.files:
555             if ((f[0] < range_start + range_size)
556                 and
557                 (f[0] + f[1] > range_start)
558                 and
559                 f[1] > 0):
560                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
561         return resp
562     def name(self):
563         return self._stream_name
564     def all_files(self):
565         for f in self.files:
566             pos, size, name = f
567             yield StreamFileReader(self, pos, size, name)
568     def nextdatablock(self):
569         if self._current_datablock_index < 0:
570             self._current_datablock_pos = 0
571             self._current_datablock_index = 0
572         else:
573             self._current_datablock_pos += self.current_datablock_size()
574             self._current_datablock_index += 1
575         self._current_datablock_data = None
576     def current_datablock_data(self):
577         if self._current_datablock_data == None:
578             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
579         return self._current_datablock_data
580     def current_datablock_size(self):
581         if self._current_datablock_index < 0:
582             self.nextdatablock()
583         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
584         if sizehint:
585             return int(sizehint.group(0))
586         return len(self.current_datablock_data())
587     def seek(self, pos):
588         """Set the position of the next read operation."""
589         self._pos = pos
590     def really_seek(self):
591         """Find and load the appropriate data block, so the byte at
592         _pos is in memory.
593         """
594         if self._pos == self._current_datablock_pos:
595             return True
596         if (self._current_datablock_pos != None and
597             self._pos >= self._current_datablock_pos and
598             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
599             return True
600         if self._pos < self._current_datablock_pos:
601             self._current_datablock_index = -1
602             self.nextdatablock()
603         while (self._pos > self._current_datablock_pos and
604                self._pos > self._current_datablock_pos + self.current_datablock_size()):
605             self.nextdatablock()
606     def read(self, size):
607         """Read no more than size bytes -- but at least one byte,
608         unless _pos is already at the end of the stream.
609         """
610         if size == 0:
611             return ''
612         self.really_seek()
613         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
614             self.nextdatablock()
615             if self._current_datablock_index >= len(self.data_locators):
616                 return None
617         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
618         self._pos += len(data)
619         return data
620
621 class CollectionReader(object):
622     def __init__(self, manifest_locator_or_text):
623         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
624             self._manifest_text = manifest_locator_or_text
625             self._manifest_locator = None
626         else:
627             self._manifest_locator = manifest_locator_or_text
628             self._manifest_text = None
629         self._streams = None
630     def __enter__(self):
631         pass
632     def __exit__(self):
633         pass
634     def _populate(self):
635         if self._streams != None:
636             return
637         if not self._manifest_text:
638             self._manifest_text = Keep.get(self._manifest_locator)
639         self._streams = []
640         for stream_line in self._manifest_text.split("\n"):
641             if stream_line != '':
642                 stream_tokens = stream_line.split()
643                 self._streams += [stream_tokens]
644     def all_streams(self):
645         self._populate()
646         resp = []
647         for s in self._streams:
648             resp += [StreamReader(s)]
649         return resp
650     def all_files(self):
651         for s in self.all_streams():
652             for f in s.all_files():
653                 yield f
654     def manifest_text(self):
655         self._populate()
656         return self._manifest_text
657
658 class CollectionWriter(object):
659     KEEP_BLOCK_SIZE = 2**26
660     def __init__(self):
661         self._data_buffer = []
662         self._data_buffer_len = 0
663         self._current_stream_files = []
664         self._current_stream_length = 0
665         self._current_stream_locators = []
666         self._current_stream_name = '.'
667         self._current_file_name = None
668         self._current_file_pos = 0
669         self._finished_streams = []
670     def __enter__(self):
671         pass
672     def __exit__(self):
673         self.finish()
674     def write_directory_tree(self,
675                              path, stream_name='.', max_manifest_depth=-1):
676         self.start_new_stream(stream_name)
677         todo = []
678         if max_manifest_depth == 0:
679             dirents = util.listdir_recursive(path)
680         else:
681             dirents = sorted(os.listdir(path))
682         for dirent in dirents:
683             target = os.path.join(path, dirent)
684             if os.path.isdir(target):
685                 todo += [[target,
686                           os.path.join(stream_name, dirent),
687                           max_manifest_depth-1]]
688             else:
689                 self.start_new_file(dirent)
690                 with open(target, 'rb') as f:
691                     while True:
692                         buf = f.read(2**26)
693                         if len(buf) == 0:
694                             break
695                         self.write(buf)
696         self.finish_current_stream()
697         map(lambda x: self.write_directory_tree(*x), todo)
698
699     def write(self, newdata):
700         self._data_buffer += [newdata]
701         self._data_buffer_len += len(newdata)
702         self._current_stream_length += len(newdata)
703         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
704             self.flush_data()
705     def flush_data(self):
706         data_buffer = ''.join(self._data_buffer)
707         if data_buffer != '':
708             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
709             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
710             self._data_buffer_len = len(self._data_buffer[0])
711     def start_new_file(self, newfilename=None):
712         self.finish_current_file()
713         self.set_current_file_name(newfilename)
714     def set_current_file_name(self, newfilename):
715         newfilename = re.sub(r' ', '\\\\040', newfilename)
716         if re.search(r'[ \t\n]', newfilename):
717             raise AssertionError("Manifest filenames cannot contain whitespace")
718         self._current_file_name = newfilename
719     def current_file_name(self):
720         return self._current_file_name
721     def finish_current_file(self):
722         if self._current_file_name == None:
723             if self._current_file_pos == self._current_stream_length:
724                 return
725             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
726         self._current_stream_files += [[self._current_file_pos,
727                                        self._current_stream_length - self._current_file_pos,
728                                        self._current_file_name]]
729         self._current_file_pos = self._current_stream_length
730     def start_new_stream(self, newstreamname='.'):
731         self.finish_current_stream()
732         self.set_current_stream_name(newstreamname)
733     def set_current_stream_name(self, newstreamname):
734         if re.search(r'[ \t\n]', newstreamname):
735             raise AssertionError("Manifest stream names cannot contain whitespace")
736         self._current_stream_name = newstreamname
737     def current_stream_name(self):
738         return self._current_stream_name
739     def finish_current_stream(self):
740         self.finish_current_file()
741         self.flush_data()
742         if len(self._current_stream_files) == 0:
743             pass
744         elif self._current_stream_name == None:
745             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
746         else:
747             self._finished_streams += [[self._current_stream_name,
748                                        self._current_stream_locators,
749                                        self._current_stream_files]]
750         self._current_stream_files = []
751         self._current_stream_length = 0
752         self._current_stream_locators = []
753         self._current_stream_name = None
754         self._current_file_pos = 0
755         self._current_file_name = None
756     def finish(self):
757         return Keep.put(self.manifest_text())
758     def manifest_text(self):
759         self.finish_current_stream()
760         manifest = ''
761         for stream in self._finished_streams:
762             if not re.search(r'^\.(/.*)?$', stream[0]):
763                 manifest += './'
764             manifest += stream[0]
765             if len(stream[1]) == 0:
766                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
767             else:
768                 for locator in stream[1]:
769                     manifest += " %s" % locator
770             for sfile in stream[2]:
771                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
772             manifest += "\n"
773         return manifest
774
775 global_client_object = None
776
777 class Keep:
778     @staticmethod
779     def global_client_object():
780         global global_client_object
781         if global_client_object == None:
782             global_client_object = KeepClient()
783         return global_client_object
784
785     @staticmethod
786     def get(locator, **kwargs):
787         return Keep.global_client_object().get(locator, **kwargs)
788
789     @staticmethod
790     def put(data, **kwargs):
791         return Keep.global_client_object().put(data, **kwargs)
792
793 class KeepClient(object):
794
795     class ThreadLimiter(object):
796         """
797         Limit the number of threads running at a given time to
798         {desired successes} minus {successes reported}. When successes
799         reported == desired, wake up the remaining threads and tell
800         them to quit.
801
802         Should be used in a "with" block.
803         """
804         def __init__(self, todo):
805             self._todo = todo
806             self._done = 0
807             self._todo_lock = threading.Semaphore(todo)
808             self._done_lock = threading.Lock()
809         def __enter__(self):
810             self._todo_lock.acquire()
811             return self
812         def __exit__(self, type, value, traceback):
813             self._todo_lock.release()
814         def shall_i_proceed(self):
815             """
816             Return true if the current thread should do stuff. Return
817             false if the current thread should just stop.
818             """
819             with self._done_lock:
820                 return (self._done < self._todo)
821         def increment_done(self):
822             """
823             Report that the current thread was successful.
824             """
825             with self._done_lock:
826                 self._done += 1
827         def done(self):
828             """
829             Return how many successes were reported.
830             """
831             with self._done_lock:
832                 return self._done
833
834     class KeepWriterThread(threading.Thread):
835         """
836         Write a blob of data to the given Keep server. Call
837         increment_done() of the given ThreadLimiter if the write
838         succeeds.
839         """
840         def __init__(self, **kwargs):
841             super(KeepClient.KeepWriterThread, self).__init__()
842             self.args = kwargs
843         def run(self):
844             with self.args['thread_limiter'] as limiter:
845                 if not limiter.shall_i_proceed():
846                     # My turn arrived, but the job has been done without
847                     # me.
848                     return
849                 logging.debug("KeepWriterThread %s proceeding %s %s" %
850                               (str(threading.current_thread()),
851                                self.args['data_hash'],
852                                self.args['service_root']))
853                 h = httplib2.Http()
854                 url = self.args['service_root'] + self.args['data_hash']
855                 api_token = os.environ['ARVADOS_API_TOKEN']
856                 headers = {'Authorization': "OAuth2 %s" % api_token}
857                 try:
858                     resp, content = h.request(url.encode('utf-8'), 'PUT',
859                                               headers=headers,
860                                               body=self.args['data'])
861                     if (resp['status'] == '401' and
862                         re.match(r'Timestamp verification failed', content)):
863                         body = KeepClient.sign_for_old_server(
864                             self.args['data_hash'],
865                             self.args['data'])
866                         h = httplib2.Http()
867                         resp, content = h.request(url.encode('utf-8'), 'PUT',
868                                                   headers=headers,
869                                                   body=body)
870                     if re.match(r'^2\d\d$', resp['status']):
871                         logging.debug("KeepWriterThread %s succeeded %s %s" %
872                                       (str(threading.current_thread()),
873                                        self.args['data_hash'],
874                                        self.args['service_root']))
875                         return limiter.increment_done()
876                     logging.warning("Request fail: PUT %s => %s %s" %
877                                     (url, resp['status'], content))
878                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
879                     logging.warning("Request fail: PUT %s => %s: %s" %
880                                     (url, type(e), str(e)))
881
882     def __init__(self):
883         self.lock = threading.Lock()
884         self.service_roots = None
885
886     def shuffled_service_roots(self, hash):
887         if self.service_roots == None:
888             self.lock.acquire()
889             keep_disks = api().keep_disks().list().execute()['items']
890             roots = (("http%s://%s:%d/" %
891                       ('s' if f['service_ssl_flag'] else '',
892                        f['service_host'],
893                        f['service_port']))
894                      for f in keep_disks)
895             self.service_roots = sorted(set(roots))
896             logging.debug(str(self.service_roots))
897             self.lock.release()
898         seed = hash
899         pool = self.service_roots[:]
900         pseq = []
901         while len(pool) > 0:
902             if len(seed) < 8:
903                 if len(pseq) < len(hash) / 4: # first time around
904                     seed = hash[-4:] + hash
905                 else:
906                     seed += hash
907             probe = int(seed[0:8], 16) % len(pool)
908             pseq += [pool[probe]]
909             pool = pool[:probe] + pool[probe+1:]
910             seed = seed[8:]
911         logging.debug(str(pseq))
912         return pseq
913
914     def get(self, locator):
915         if 'KEEP_LOCAL_STORE' in os.environ:
916             return KeepClient.local_store_get(locator)
917         expect_hash = re.sub(r'\+.*', '', locator)
918         for service_root in self.shuffled_service_roots(expect_hash):
919             h = httplib2.Http()
920             url = service_root + expect_hash
921             api_token = os.environ['ARVADOS_API_TOKEN']
922             headers = {'Authorization': "OAuth2 %s" % api_token,
923                        'Accept': 'application/octet-stream'}
924             try:
925                 resp, content = h.request(url.encode('utf-8'), 'GET',
926                                           headers=headers)
927                 if re.match(r'^2\d\d$', resp['status']):
928                     m = hashlib.new('md5')
929                     m.update(content)
930                     md5 = m.hexdigest()
931                     if md5 == expect_hash:
932                         return content
933                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
934             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
935                 logging.info("Request fail: GET %s => %s: %s" %
936                              (url, type(e), str(e)))
937         raise Exception("Not found: %s" % expect_hash)
938
939     def put(self, data, **kwargs):
940         if 'KEEP_LOCAL_STORE' in os.environ:
941             return KeepClient.local_store_put(data)
942         m = hashlib.new('md5')
943         m.update(data)
944         data_hash = m.hexdigest()
945         have_copies = 0
946         want_copies = kwargs.get('copies', 2)
947         if not (want_copies > 0):
948             return data_hash
949         threads = []
950         thread_limiter = KeepClient.ThreadLimiter(want_copies)
951         for service_root in self.shuffled_service_roots(data_hash):
952             t = KeepClient.KeepWriterThread(data=data,
953                                             data_hash=data_hash,
954                                             service_root=service_root,
955                                             thread_limiter=thread_limiter)
956             t.start()
957             threads += [t]
958         for t in threads:
959             t.join()
960         have_copies = thread_limiter.done()
961         if have_copies == want_copies:
962             return (data_hash + '+' + str(len(data)))
963         raise Exception("Write fail for %s: wanted %d but wrote %d" %
964                         (data_hash, want_copies, have_copies))
965
966     @staticmethod
967     def sign_for_old_server(data_hash, data):
968         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
969
970
971     @staticmethod
972     def local_store_put(data):
973         m = hashlib.new('md5')
974         m.update(data)
975         md5 = m.hexdigest()
976         locator = '%s+%d' % (md5, len(data))
977         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
978             f.write(data)
979         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
980                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
981         return locator
982     @staticmethod
983     def local_store_get(locator):
984         r = re.search('^([0-9a-f]{32,})', locator)
985         if not r:
986             raise Exception("Keep.get: invalid data locator '%s'" % locator)
987         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
988             return ''
989         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
990             return f.read()