c1890c4da2058366523c66691fca7fe13b3a320f
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 class CredentialsFromEnv(object):
28     @staticmethod
29     def http_request(self, uri, **kwargs):
30         from httplib import BadStatusLine
31         if 'headers' not in kwargs:
32             kwargs['headers'] = {}
33         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
34         try:
35             return self.orig_http_request(uri, **kwargs)
36         except BadStatusLine:
37             # This is how httplib tells us that it tried to reuse an
38             # existing connection but it was already closed by the
39             # server. In that case, yes, we would like to retry.
40             # Unfortunately, we are not absolutely certain that the
41             # previous call did not succeed, so this is slightly
42             # risky.
43             return self.orig_http_request(uri, **kwargs)
44     def authorize(self, http):
45         http.orig_http_request = http.request
46         http.request = types.MethodType(self.http_request, http)
47         return http
48
49 url = ('https://%s/discovery/v1/apis/'
50        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
51 credentials = CredentialsFromEnv()
52
53 # Use system's CA certificates (if we find them) instead of httplib2's
54 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
55 if not os.path.exists(ca_certs):
56     ca_certs = None             # use httplib2 default
57
58 http = httplib2.Http(ca_certs=ca_certs)
59 http = credentials.authorize(http)
60 if re.match(r'(?i)^(true|1|yes)$',
61             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
62     http.disable_ssl_certificate_validation=True
63 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
64
65 def task_set_output(self,s):
66     service.job_tasks().update(uuid=self['uuid'],
67                                job_task=json.dumps({
68                 'output':s,
69                 'success':True,
70                 'progress':1.0
71                 })).execute()
72
73 _current_task = None
74 def current_task():
75     global _current_task
76     if _current_task:
77         return _current_task
78     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
79     t = UserDict.UserDict(t)
80     t.set_output = types.MethodType(task_set_output, t)
81     t.tmpdir = os.environ['TASK_WORK']
82     _current_task = t
83     return t
84
85 _current_job = None
86 def current_job():
87     global _current_job
88     if _current_job:
89         return _current_job
90     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
91     t = UserDict.UserDict(t)
92     t.tmpdir = os.environ['JOB_WORK']
93     _current_job = t
94     return t
95
96 def api():
97     return service
98
99 class JobTask(object):
100     def __init__(self, parameters=dict(), runtime_constraints=dict()):
101         print "init jobtask %s %s" % (parameters, runtime_constraints)
102
103 class job_setup:
104     @staticmethod
105     def one_task_per_input_file(if_sequence=0, and_end_task=True):
106         if if_sequence != current_task()['sequence']:
107             return
108         job_input = current_job()['script_parameters']['input']
109         cr = CollectionReader(job_input)
110         for s in cr.all_streams():
111             for f in s.all_files():
112                 task_input = f.as_manifest()
113                 new_task_attrs = {
114                     'job_uuid': current_job()['uuid'],
115                     'created_by_job_task_uuid': current_task()['uuid'],
116                     'sequence': if_sequence + 1,
117                     'parameters': {
118                         'input':task_input
119                         }
120                     }
121                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
122         if and_end_task:
123             service.job_tasks().update(uuid=current_task()['uuid'],
124                                        job_task=json.dumps({'success':True})
125                                        ).execute()
126             exit(0)
127
128     @staticmethod
129     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
130         if if_sequence != current_task()['sequence']:
131             return
132         job_input = current_job()['script_parameters']['input']
133         cr = CollectionReader(job_input)
134         for s in cr.all_streams():
135             task_input = s.tokens()
136             new_task_attrs = {
137                 'job_uuid': current_job()['uuid'],
138                 'created_by_job_task_uuid': current_task()['uuid'],
139                 'sequence': if_sequence + 1,
140                 'parameters': {
141                     'input':task_input
142                     }
143                 }
144             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
145         if and_end_task:
146             service.job_tasks().update(uuid=current_task()['uuid'],
147                                        job_task=json.dumps({'success':True})
148                                        ).execute()
149             exit(0)
150
151 class util:
152     @staticmethod
153     def run_command(execargs, **kwargs):
154         if 'stdin' not in kwargs:
155             kwargs['stdin'] = subprocess.PIPE
156         if 'stdout' not in kwargs:
157             kwargs['stdout'] = subprocess.PIPE
158         if 'stderr' not in kwargs:
159             kwargs['stderr'] = subprocess.PIPE
160         p = subprocess.Popen(execargs, close_fds=True, shell=False,
161                              **kwargs)
162         stdoutdata, stderrdata = p.communicate(None)
163         if p.returncode != 0:
164             raise Exception("run_command %s exit %d:\n%s" %
165                             (execargs, p.returncode, stderrdata))
166         return stdoutdata, stderrdata
167
168     @staticmethod
169     def git_checkout(url, version, path):
170         if not re.search('^/', path):
171             path = os.path.join(current_job().tmpdir, path)
172         if not os.path.exists(path):
173             util.run_command(["git", "clone", url, path],
174                              cwd=os.path.dirname(path))
175         util.run_command(["git", "checkout", version],
176                          cwd=path)
177         return path
178
179     @staticmethod
180     def tar_extractor(path, decompress_flag):
181         return subprocess.Popen(["tar",
182                                  "-C", path,
183                                  ("-x%sf" % decompress_flag),
184                                  "-"],
185                                 stdout=None,
186                                 stdin=subprocess.PIPE, stderr=sys.stderr,
187                                 shell=False, close_fds=True)
188
189     @staticmethod
190     def tarball_extract(tarball, path):
191         """Retrieve a tarball from Keep and extract it to a local
192         directory.  Return the absolute path where the tarball was
193         extracted. If the top level of the tarball contained just one
194         file or directory, return the absolute path of that single
195         item.
196
197         tarball -- collection locator
198         path -- where to extract the tarball: absolute, or relative to job tmp
199         """
200         if not re.search('^/', path):
201             path = os.path.join(current_job().tmpdir, path)
202         lockfile = open(path + '.lock', 'w')
203         fcntl.flock(lockfile, fcntl.LOCK_EX)
204         try:
205             os.stat(path)
206         except OSError:
207             os.mkdir(path)
208         already_have_it = False
209         try:
210             if os.readlink(os.path.join(path, '.locator')) == tarball:
211                 already_have_it = True
212         except OSError:
213             pass
214         if not already_have_it:
215
216             # emulate "rm -f" (i.e., if the file does not exist, we win)
217             try:
218                 os.unlink(os.path.join(path, '.locator'))
219             except OSError:
220                 if os.path.exists(os.path.join(path, '.locator')):
221                     os.unlink(os.path.join(path, '.locator'))
222
223             for f in CollectionReader(tarball).all_files():
224                 if re.search('\.(tbz|tar.bz2)$', f.name()):
225                     p = util.tar_extractor(path, 'j')
226                 elif re.search('\.(tgz|tar.gz)$', f.name()):
227                     p = util.tar_extractor(path, 'z')
228                 elif re.search('\.tar$', f.name()):
229                     p = util.tar_extractor(path, '')
230                 else:
231                     raise Exception("tarball_extract cannot handle filename %s"
232                                     % f.name())
233                 while True:
234                     buf = f.read(2**20)
235                     if len(buf) == 0:
236                         break
237                     p.stdin.write(buf)
238                 p.stdin.close()
239                 p.wait()
240                 if p.returncode != 0:
241                     lockfile.close()
242                     raise Exception("tar exited %d" % p.returncode)
243             os.symlink(tarball, os.path.join(path, '.locator'))
244         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
245         lockfile.close()
246         if len(tld_extracts) == 1:
247             return os.path.join(path, tld_extracts[0])
248         return path
249
250     @staticmethod
251     def zipball_extract(zipball, path):
252         """Retrieve a zip archive from Keep and extract it to a local
253         directory.  Return the absolute path where the archive was
254         extracted. If the top level of the archive contained just one
255         file or directory, return the absolute path of that single
256         item.
257
258         zipball -- collection locator
259         path -- where to extract the archive: absolute, or relative to job tmp
260         """
261         if not re.search('^/', path):
262             path = os.path.join(current_job().tmpdir, path)
263         lockfile = open(path + '.lock', 'w')
264         fcntl.flock(lockfile, fcntl.LOCK_EX)
265         try:
266             os.stat(path)
267         except OSError:
268             os.mkdir(path)
269         already_have_it = False
270         try:
271             if os.readlink(os.path.join(path, '.locator')) == zipball:
272                 already_have_it = True
273         except OSError:
274             pass
275         if not already_have_it:
276
277             # emulate "rm -f" (i.e., if the file does not exist, we win)
278             try:
279                 os.unlink(os.path.join(path, '.locator'))
280             except OSError:
281                 if os.path.exists(os.path.join(path, '.locator')):
282                     os.unlink(os.path.join(path, '.locator'))
283
284             for f in CollectionReader(zipball).all_files():
285                 if not re.search('\.zip$', f.name()):
286                     raise Exception("zipball_extract cannot handle filename %s"
287                                     % f.name())
288                 zip_filename = os.path.join(path, os.path.basename(f.name()))
289                 zip_file = open(zip_filename, 'wb')
290                 while True:
291                     buf = f.read(2**20)
292                     if len(buf) == 0:
293                         break
294                     zip_file.write(buf)
295                 zip_file.close()
296                 
297                 p = subprocess.Popen(["unzip",
298                                       "-q", "-o",
299                                       "-d", path,
300                                       zip_filename],
301                                      stdout=None,
302                                      stdin=None, stderr=sys.stderr,
303                                      shell=False, close_fds=True)
304                 p.wait()
305                 if p.returncode != 0:
306                     lockfile.close()
307                     raise Exception("unzip exited %d" % p.returncode)
308                 os.unlink(zip_filename)
309             os.symlink(zipball, os.path.join(path, '.locator'))
310         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
311         lockfile.close()
312         if len(tld_extracts) == 1:
313             return os.path.join(path, tld_extracts[0])
314         return path
315
316     @staticmethod
317     def collection_extract(collection, path, files=[], decompress=True):
318         """Retrieve a collection from Keep and extract it to a local
319         directory.  Return the absolute path where the collection was
320         extracted.
321
322         collection -- collection locator
323         path -- where to extract: absolute, or relative to job tmp
324         """
325         if not re.search('^/', path):
326             path = os.path.join(current_job().tmpdir, path)
327         lockfile = open(path + '.lock', 'w')
328         fcntl.flock(lockfile, fcntl.LOCK_EX)
329         try:
330             os.stat(path)
331         except OSError:
332             os.mkdir(path)
333         already_have_it = False
334         try:
335             if os.readlink(os.path.join(path, '.locator')) == collection:
336                 already_have_it = True
337         except OSError:
338             pass
339
340         # emulate "rm -f" (i.e., if the file does not exist, we win)
341         try:
342             os.unlink(os.path.join(path, '.locator'))
343         except OSError:
344             if os.path.exists(os.path.join(path, '.locator')):
345                 os.unlink(os.path.join(path, '.locator'))
346
347         files_got = []
348         for s in CollectionReader(collection).all_streams():
349             stream_name = s.name()
350             for f in s.all_files():
351                 if (files == [] or
352                     ((f.name() not in files_got) and
353                      (f.name() in files or
354                       (decompress and f.decompressed_name() in files)))):
355                     outname = f.decompressed_name() if decompress else f.name()
356                     files_got += [outname]
357                     if os.path.exists(os.path.join(path, stream_name, outname)):
358                         continue
359                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
360                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
361                     for buf in (f.readall_decompressed() if decompress
362                                 else f.readall()):
363                         outfile.write(buf)
364                     outfile.close()
365         if len(files_got) < len(files):
366             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
367         os.symlink(collection, os.path.join(path, '.locator'))
368
369         lockfile.close()
370         return path
371
372     @staticmethod
373     def mkdir_dash_p(path):
374         if not os.path.exists(path):
375             util.mkdir_dash_p(os.path.dirname(path))
376             try:
377                 os.mkdir(path)
378             except OSError:
379                 if not os.path.exists(path):
380                     os.mkdir(path)
381
382     @staticmethod
383     def stream_extract(stream, path, files=[], decompress=True):
384         """Retrieve a stream from Keep and extract it to a local
385         directory.  Return the absolute path where the stream was
386         extracted.
387
388         stream -- StreamReader object
389         path -- where to extract: absolute, or relative to job tmp
390         """
391         if not re.search('^/', path):
392             path = os.path.join(current_job().tmpdir, path)
393         lockfile = open(path + '.lock', 'w')
394         fcntl.flock(lockfile, fcntl.LOCK_EX)
395         try:
396             os.stat(path)
397         except OSError:
398             os.mkdir(path)
399
400         files_got = []
401         for f in stream.all_files():
402             if (files == [] or
403                 ((f.name() not in files_got) and
404                  (f.name() in files or
405                   (decompress and f.decompressed_name() in files)))):
406                 outname = f.decompressed_name() if decompress else f.name()
407                 files_got += [outname]
408                 if os.path.exists(os.path.join(path, outname)):
409                     os.unlink(os.path.join(path, outname))
410                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
411                 outfile = open(os.path.join(path, outname), 'wb')
412                 for buf in (f.readall_decompressed() if decompress
413                             else f.readall()):
414                     outfile.write(buf)
415                 outfile.close()
416         if len(files_got) < len(files):
417             raise Exception("Wanted files %s but only got %s from %s" %
418                             (files, files_got, map(lambda z: z.name(),
419                                                    list(stream.all_files()))))
420         lockfile.close()
421         return path
422
423     @staticmethod
424     def listdir_recursive(dirname, base=None):
425         allfiles = []
426         for ent in sorted(os.listdir(dirname)):
427             ent_path = os.path.join(dirname, ent)
428             ent_base = os.path.join(base, ent) if base else ent
429             if os.path.isdir(ent_path):
430                 allfiles += util.listdir_recursive(ent_path, ent_base)
431             else:
432                 allfiles += [ent_base]
433         return allfiles
434
435 class StreamFileReader(object):
436     def __init__(self, stream, pos, size, name):
437         self._stream = stream
438         self._pos = pos
439         self._size = size
440         self._name = name
441         self._filepos = 0
442     def name(self):
443         return self._name
444     def decompressed_name(self):
445         return re.sub('\.(bz2|gz)$', '', self._name)
446     def size(self):
447         return self._size
448     def stream_name(self):
449         return self._stream.name()
450     def read(self, size, **kwargs):
451         self._stream.seek(self._pos + self._filepos)
452         data = self._stream.read(min(size, self._size - self._filepos))
453         self._filepos += len(data)
454         return data
455     def readall(self, size=2**20, **kwargs):
456         while True:
457             data = self.read(size, **kwargs)
458             if data == '':
459                 break
460             yield data
461     def bunzip2(self, size):
462         decompressor = bz2.BZ2Decompressor()
463         for chunk in self.readall(size):
464             data = decompressor.decompress(chunk)
465             if data and data != '':
466                 yield data
467     def gunzip(self, size):
468         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
469         for chunk in self.readall(size):
470             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
471             if data and data != '':
472                 yield data
473     def readall_decompressed(self, size=2**20):
474         self._stream.seek(self._pos + self._filepos)
475         if re.search('\.bz2$', self._name):
476             return self.bunzip2(size)
477         elif re.search('\.gz$', self._name):
478             return self.gunzip(size)
479         else:
480             return self.readall(size)
481     def readlines(self, decompress=True):
482         if decompress:
483             datasource = self.readall_decompressed()
484         else:
485             self._stream.seek(self._pos + self._filepos)
486             datasource = self.readall()
487         data = ''
488         for newdata in datasource:
489             data += newdata
490             sol = 0
491             while True:
492                 eol = string.find(data, "\n", sol)
493                 if eol < 0:
494                     break
495                 yield data[sol:eol+1]
496                 sol = eol+1
497             data = data[sol:]
498         if data != '':
499             yield data
500     def as_manifest(self):
501         if self.size() == 0:
502             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
503                     % (self._stream.name(), self.name()))
504         return string.join(self._stream.tokens_for_range(self._pos, self._size),
505                            " ") + "\n"
506
507 class StreamReader(object):
508     def __init__(self, tokens):
509         self._tokens = tokens
510         self._current_datablock_data = None
511         self._current_datablock_pos = 0
512         self._current_datablock_index = -1
513         self._pos = 0
514
515         self._stream_name = None
516         self.data_locators = []
517         self.files = []
518
519         for tok in self._tokens:
520             if self._stream_name == None:
521                 self._stream_name = tok
522             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
523                 self.data_locators += [tok]
524             elif re.search(r'^\d+:\d+:\S+', tok):
525                 pos, size, name = tok.split(':',2)
526                 self.files += [[int(pos), int(size), name]]
527             else:
528                 raise Exception("Invalid manifest format")
529
530     def tokens(self):
531         return self._tokens
532     def tokens_for_range(self, range_start, range_size):
533         resp = [self._stream_name]
534         return_all_tokens = False
535         block_start = 0
536         token_bytes_skipped = 0
537         for locator in self.data_locators:
538             sizehint = re.search(r'\+(\d+)', locator)
539             if not sizehint:
540                 return_all_tokens = True
541             if return_all_tokens:
542                 resp += [locator]
543                 next
544             blocksize = int(sizehint.group(0))
545             if range_start + range_size <= block_start:
546                 break
547             if range_start < block_start + blocksize:
548                 resp += [locator]
549             else:
550                 token_bytes_skipped += blocksize
551             block_start += blocksize
552         for f in self.files:
553             if ((f[0] < range_start + range_size)
554                 and
555                 (f[0] + f[1] > range_start)
556                 and
557                 f[1] > 0):
558                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
559         return resp
560     def name(self):
561         return self._stream_name
562     def all_files(self):
563         for f in self.files:
564             pos, size, name = f
565             yield StreamFileReader(self, pos, size, name)
566     def nextdatablock(self):
567         if self._current_datablock_index < 0:
568             self._current_datablock_pos = 0
569             self._current_datablock_index = 0
570         else:
571             self._current_datablock_pos += self.current_datablock_size()
572             self._current_datablock_index += 1
573         self._current_datablock_data = None
574     def current_datablock_data(self):
575         if self._current_datablock_data == None:
576             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
577         return self._current_datablock_data
578     def current_datablock_size(self):
579         if self._current_datablock_index < 0:
580             self.nextdatablock()
581         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
582         if sizehint:
583             return int(sizehint.group(0))
584         return len(self.current_datablock_data())
585     def seek(self, pos):
586         """Set the position of the next read operation."""
587         self._pos = pos
588     def really_seek(self):
589         """Find and load the appropriate data block, so the byte at
590         _pos is in memory.
591         """
592         if self._pos == self._current_datablock_pos:
593             return True
594         if (self._current_datablock_pos != None and
595             self._pos >= self._current_datablock_pos and
596             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
597             return True
598         if self._pos < self._current_datablock_pos:
599             self._current_datablock_index = -1
600             self.nextdatablock()
601         while (self._pos > self._current_datablock_pos and
602                self._pos > self._current_datablock_pos + self.current_datablock_size()):
603             self.nextdatablock()
604     def read(self, size):
605         """Read no more than size bytes -- but at least one byte,
606         unless _pos is already at the end of the stream.
607         """
608         if size == 0:
609             return ''
610         self.really_seek()
611         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
612             self.nextdatablock()
613             if self._current_datablock_index >= len(self.data_locators):
614                 return None
615         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
616         self._pos += len(data)
617         return data
618
619 class CollectionReader(object):
620     def __init__(self, manifest_locator_or_text):
621         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
622             self._manifest_text = manifest_locator_or_text
623             self._manifest_locator = None
624         else:
625             self._manifest_locator = manifest_locator_or_text
626             self._manifest_text = None
627         self._streams = None
628     def __enter__(self):
629         pass
630     def __exit__(self):
631         pass
632     def _populate(self):
633         if self._streams != None:
634             return
635         if not self._manifest_text:
636             self._manifest_text = Keep.get(self._manifest_locator)
637         self._streams = []
638         for stream_line in self._manifest_text.split("\n"):
639             if stream_line != '':
640                 stream_tokens = stream_line.split()
641                 self._streams += [stream_tokens]
642     def all_streams(self):
643         self._populate()
644         resp = []
645         for s in self._streams:
646             resp += [StreamReader(s)]
647         return resp
648     def all_files(self):
649         for s in self.all_streams():
650             for f in s.all_files():
651                 yield f
652     def manifest_text(self):
653         self._populate()
654         return self._manifest_text
655
656 class CollectionWriter(object):
657     KEEP_BLOCK_SIZE = 2**26
658     def __init__(self):
659         self._data_buffer = []
660         self._data_buffer_len = 0
661         self._current_stream_files = []
662         self._current_stream_length = 0
663         self._current_stream_locators = []
664         self._current_stream_name = '.'
665         self._current_file_name = None
666         self._current_file_pos = 0
667         self._finished_streams = []
668     def __enter__(self):
669         pass
670     def __exit__(self):
671         self.finish()
672     def write_directory_tree(self,
673                              path, stream_name='.', max_manifest_depth=-1):
674         self.start_new_stream(stream_name)
675         todo = []
676         if max_manifest_depth == 0:
677             dirents = util.listdir_recursive(path)
678         else:
679             dirents = sorted(os.listdir(path))
680         for dirent in dirents:
681             target = os.path.join(path, dirent)
682             if os.path.isdir(target):
683                 todo += [[target,
684                           os.path.join(stream_name, dirent),
685                           max_manifest_depth-1]]
686             else:
687                 self.start_new_file(dirent)
688                 with open(target, 'rb') as f:
689                     while True:
690                         buf = f.read(2**26)
691                         if len(buf) == 0:
692                             break
693                         self.write(buf)
694         self.finish_current_stream()
695         map(lambda x: self.write_directory_tree(*x), todo)
696
697     def write(self, newdata):
698         self._data_buffer += [newdata]
699         self._data_buffer_len += len(newdata)
700         self._current_stream_length += len(newdata)
701         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
702             self.flush_data()
703     def flush_data(self):
704         data_buffer = ''.join(self._data_buffer)
705         if data_buffer != '':
706             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
707             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
708             self._data_buffer_len = len(self._data_buffer[0])
709     def start_new_file(self, newfilename=None):
710         self.finish_current_file()
711         self.set_current_file_name(newfilename)
712     def set_current_file_name(self, newfilename):
713         newfilename = re.sub(r' ', '\\\\040', newfilename)
714         if re.search(r'[ \t\n]', newfilename):
715             raise AssertionError("Manifest filenames cannot contain whitespace")
716         self._current_file_name = newfilename
717     def current_file_name(self):
718         return self._current_file_name
719     def finish_current_file(self):
720         if self._current_file_name == None:
721             if self._current_file_pos == self._current_stream_length:
722                 return
723             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
724         self._current_stream_files += [[self._current_file_pos,
725                                        self._current_stream_length - self._current_file_pos,
726                                        self._current_file_name]]
727         self._current_file_pos = self._current_stream_length
728     def start_new_stream(self, newstreamname='.'):
729         self.finish_current_stream()
730         self.set_current_stream_name(newstreamname)
731     def set_current_stream_name(self, newstreamname):
732         if re.search(r'[ \t\n]', newstreamname):
733             raise AssertionError("Manifest stream names cannot contain whitespace")
734         self._current_stream_name = newstreamname
735     def current_stream_name(self):
736         return self._current_stream_name
737     def finish_current_stream(self):
738         self.finish_current_file()
739         self.flush_data()
740         if len(self._current_stream_files) == 0:
741             pass
742         elif self._current_stream_name == None:
743             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
744         else:
745             self._finished_streams += [[self._current_stream_name,
746                                        self._current_stream_locators,
747                                        self._current_stream_files]]
748         self._current_stream_files = []
749         self._current_stream_length = 0
750         self._current_stream_locators = []
751         self._current_stream_name = None
752         self._current_file_pos = 0
753         self._current_file_name = None
754     def finish(self):
755         return Keep.put(self.manifest_text())
756     def manifest_text(self):
757         self.finish_current_stream()
758         manifest = ''
759         for stream in self._finished_streams:
760             if not re.search(r'^\.(/.*)?$', stream[0]):
761                 manifest += './'
762             manifest += stream[0]
763             if len(stream[1]) == 0:
764                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
765             else:
766                 for locator in stream[1]:
767                     manifest += " %s" % locator
768             for sfile in stream[2]:
769                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
770             manifest += "\n"
771         return manifest
772
773 global_client_object = None
774
775 class Keep:
776     @staticmethod
777     def global_client_object():
778         global global_client_object
779         if global_client_object == None:
780             global_client_object = KeepClient()
781         return global_client_object
782
783     @staticmethod
784     def get(locator, **kwargs):
785         return Keep.global_client_object().get(locator, **kwargs)
786
787     @staticmethod
788     def put(data, **kwargs):
789         return Keep.global_client_object().put(data, **kwargs)
790
791 class KeepClient(object):
792
793     class ThreadLimiter(object):
794         """
795         Limit the number of threads running at a given time to
796         {desired successes} minus {successes reported}. When successes
797         reported == desired, wake up the remaining threads and tell
798         them to quit.
799
800         Should be used in a "with" block.
801         """
802         def __init__(self, todo):
803             self._todo = todo
804             self._done = 0
805             self._todo_lock = threading.Semaphore(todo)
806             self._done_lock = threading.Lock()
807         def __enter__(self):
808             self._todo_lock.acquire()
809             return self
810         def __exit__(self, type, value, traceback):
811             self._todo_lock.release()
812         def shall_i_proceed(self):
813             """
814             Return true if the current thread should do stuff. Return
815             false if the current thread should just stop.
816             """
817             with self._done_lock:
818                 return (self._done < self._todo)
819         def increment_done(self):
820             """
821             Report that the current thread was successful.
822             """
823             with self._done_lock:
824                 self._done += 1
825         def done(self):
826             """
827             Return how many successes were reported.
828             """
829             with self._done_lock:
830                 return self._done
831
832     class KeepWriterThread(threading.Thread):
833         """
834         Write a blob of data to the given Keep server. Call
835         increment_done() of the given ThreadLimiter if the write
836         succeeds.
837         """
838         def __init__(self, **kwargs):
839             super(KeepClient.KeepWriterThread, self).__init__()
840             self.args = kwargs
841         def run(self):
842             with self.args['thread_limiter'] as limiter:
843                 if not limiter.shall_i_proceed():
844                     # My turn arrived, but the job has been done without
845                     # me.
846                     return
847                 logging.debug("KeepWriterThread %s proceeding %s %s" %
848                               (str(threading.current_thread()),
849                                self.args['data_hash'],
850                                self.args['service_root']))
851                 h = httplib2.Http()
852                 url = self.args['service_root'] + self.args['data_hash']
853                 api_token = os.environ['ARVADOS_API_TOKEN']
854                 headers = {'Authorization': "OAuth2 %s" % api_token}
855                 try:
856                     resp, content = h.request(url.encode('utf-8'), 'PUT',
857                                               headers=headers,
858                                               body=self.args['data'])
859                     if (resp['status'] == '401' and
860                         re.match(r'Timestamp verification failed', content)):
861                         body = KeepClient.sign_for_old_server(
862                             self.args['data_hash'],
863                             self.args['data'])
864                         h = httplib2.Http()
865                         resp, content = h.request(url.encode('utf-8'), 'PUT',
866                                                   headers=headers,
867                                                   body=body)
868                     if re.match(r'^2\d\d$', resp['status']):
869                         logging.debug("KeepWriterThread %s succeeded %s %s" %
870                                       (str(threading.current_thread()),
871                                        self.args['data_hash'],
872                                        self.args['service_root']))
873                         return limiter.increment_done()
874                     logging.warning("Request fail: PUT %s => %s %s" %
875                                     (url, resp['status'], content))
876                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
877                     logging.warning("Request fail: PUT %s => %s: %s" %
878                                     (url, type(e), str(e)))
879
880     def __init__(self):
881         self.lock = threading.Lock()
882         self.service_roots = None
883
884     def shuffled_service_roots(self, hash):
885         if self.service_roots == None:
886             self.lock.acquire()
887             keep_disks = api().keep_disks().list().execute()['items']
888             roots = (("http%s://%s:%d/" %
889                       ('s' if f['service_ssl_flag'] else '',
890                        f['service_host'],
891                        f['service_port']))
892                      for f in keep_disks)
893             self.service_roots = sorted(set(roots))
894             logging.debug(str(self.service_roots))
895             self.lock.release()
896         seed = hash
897         pool = self.service_roots[:]
898         pseq = []
899         while len(pool) > 0:
900             if len(seed) < 8:
901                 if len(pseq) < len(hash) / 4: # first time around
902                     seed = hash[-4:] + hash
903                 else:
904                     seed += hash
905             probe = int(seed[0:8], 16) % len(pool)
906             pseq += [pool[probe]]
907             pool = pool[:probe] + pool[probe+1:]
908             seed = seed[8:]
909         logging.debug(str(pseq))
910         return pseq
911
912     def get(self, locator):
913         if 'KEEP_LOCAL_STORE' in os.environ:
914             return KeepClient.local_store_get(locator)
915         expect_hash = re.sub(r'\+.*', '', locator)
916         for service_root in self.shuffled_service_roots(expect_hash):
917             h = httplib2.Http()
918             url = service_root + expect_hash
919             api_token = os.environ['ARVADOS_API_TOKEN']
920             headers = {'Authorization': "OAuth2 %s" % api_token,
921                        'Accept': 'application/octet-stream'}
922             try:
923                 resp, content = h.request(url.encode('utf-8'), 'GET',
924                                           headers=headers)
925                 if re.match(r'^2\d\d$', resp['status']):
926                     m = hashlib.new('md5')
927                     m.update(content)
928                     md5 = m.hexdigest()
929                     if md5 == expect_hash:
930                         return content
931                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
932             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
933                 logging.info("Request fail: GET %s => %s: %s" %
934                              (url, type(e), str(e)))
935         raise Exception("Not found: %s" % expect_hash)
936
937     def put(self, data, **kwargs):
938         if 'KEEP_LOCAL_STORE' in os.environ:
939             return KeepClient.local_store_put(data)
940         m = hashlib.new('md5')
941         m.update(data)
942         data_hash = m.hexdigest()
943         have_copies = 0
944         want_copies = kwargs.get('copies', 2)
945         if not (want_copies > 0):
946             return data_hash
947         threads = []
948         thread_limiter = KeepClient.ThreadLimiter(want_copies)
949         for service_root in self.shuffled_service_roots(data_hash):
950             t = KeepClient.KeepWriterThread(data=data,
951                                             data_hash=data_hash,
952                                             service_root=service_root,
953                                             thread_limiter=thread_limiter)
954             t.start()
955             threads += [t]
956         for t in threads:
957             t.join()
958         have_copies = thread_limiter.done()
959         if have_copies == want_copies:
960             return (data_hash + '+' + str(len(data)))
961         raise Exception("Write fail for %s: wanted %d but wrote %d" %
962                         (data_hash, want_copies, have_copies))
963
964     @staticmethod
965     def sign_for_old_server(data_hash, data):
966         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
967
968
969     @staticmethod
970     def local_store_put(data):
971         m = hashlib.new('md5')
972         m.update(data)
973         md5 = m.hexdigest()
974         locator = '%s+%d' % (md5, len(data))
975         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
976             f.write(data)
977         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
978                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
979         return locator
980     @staticmethod
981     def local_store_get(locator):
982         r = re.search('^([0-9a-f]{32,})', locator)
983         if not r:
984             raise Exception("Keep.get: invalid data locator '%s'" % locator)
985         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
986             return ''
987         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
988             return f.read()