Use hash to determine probe order
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19
20 from apiclient import errors
21 from apiclient.discovery import build
22
23 class CredentialsFromEnv:
24     @staticmethod
25     def http_request(self, uri, **kwargs):
26         from httplib import BadStatusLine
27         if 'headers' not in kwargs:
28             kwargs['headers'] = {}
29         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
30         try:
31             return self.orig_http_request(uri, **kwargs)
32         except BadStatusLine:
33             # This is how httplib tells us that it tried to reuse an
34             # existing connection but it was already closed by the
35             # server. In that case, yes, we would like to retry.
36             # Unfortunately, we are not absolutely certain that the
37             # previous call did not succeed, so this is slightly
38             # risky.
39             return self.orig_http_request(uri, **kwargs)
40     def authorize(self, http):
41         http.orig_http_request = http.request
42         http.request = types.MethodType(self.http_request, http)
43         return http
44
45 url = ('https://%s/discovery/v1/apis/'
46        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
47 credentials = CredentialsFromEnv()
48
49 # Use system's CA certificates (if we find them) instead of httplib2's
50 ca_certs = '/etc/ssl/certs/ca-certificates.crt'
51 if not os.path.exists(ca_certs):
52     ca_certs = None             # use httplib2 default
53
54 http = httplib2.Http(ca_certs=ca_certs)
55 http = credentials.authorize(http)
56 if re.match(r'(?i)^(true|1|yes)$',
57             os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
58     http.disable_ssl_certificate_validation=True
59 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
60
61 def task_set_output(self,s):
62     service.job_tasks().update(uuid=self['uuid'],
63                                job_task=json.dumps({
64                 'output':s,
65                 'success':True,
66                 'progress':1.0
67                 })).execute()
68
69 _current_task = None
70 def current_task():
71     global _current_task
72     if _current_task:
73         return _current_task
74     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
75     t = UserDict.UserDict(t)
76     t.set_output = types.MethodType(task_set_output, t)
77     t.tmpdir = os.environ['TASK_WORK']
78     _current_task = t
79     return t
80
81 _current_job = None
82 def current_job():
83     global _current_job
84     if _current_job:
85         return _current_job
86     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
87     t = UserDict.UserDict(t)
88     t.tmpdir = os.environ['JOB_WORK']
89     _current_job = t
90     return t
91
92 def api():
93     return service
94
95 class JobTask:
96     def __init__(self, parameters=dict(), runtime_constraints=dict()):
97         print "init jobtask %s %s" % (parameters, runtime_constraints)
98
99 class job_setup:
100     @staticmethod
101     def one_task_per_input_file(if_sequence=0, and_end_task=True):
102         if if_sequence != current_task()['sequence']:
103             return
104         job_input = current_job()['script_parameters']['input']
105         cr = CollectionReader(job_input)
106         for s in cr.all_streams():
107             for f in s.all_files():
108                 task_input = f.as_manifest()
109                 new_task_attrs = {
110                     'job_uuid': current_job()['uuid'],
111                     'created_by_job_task_uuid': current_task()['uuid'],
112                     'sequence': if_sequence + 1,
113                     'parameters': {
114                         'input':task_input
115                         }
116                     }
117                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
118         if and_end_task:
119             service.job_tasks().update(uuid=current_task()['uuid'],
120                                        job_task=json.dumps({'success':True})
121                                        ).execute()
122             exit(0)
123
124     @staticmethod
125     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
126         if if_sequence != current_task()['sequence']:
127             return
128         job_input = current_job()['script_parameters']['input']
129         cr = CollectionReader(job_input)
130         for s in cr.all_streams():
131             task_input = s.tokens()
132             new_task_attrs = {
133                 'job_uuid': current_job()['uuid'],
134                 'created_by_job_task_uuid': current_task()['uuid'],
135                 'sequence': if_sequence + 1,
136                 'parameters': {
137                     'input':task_input
138                     }
139                 }
140             service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
141         if and_end_task:
142             service.job_tasks().update(uuid=current_task()['uuid'],
143                                        job_task=json.dumps({'success':True})
144                                        ).execute()
145             exit(0)
146
147 class util:
148     @staticmethod
149     def run_command(execargs, **kwargs):
150         if 'stdin' not in kwargs:
151             kwargs['stdin'] = subprocess.PIPE
152         if 'stdout' not in kwargs:
153             kwargs['stdout'] = subprocess.PIPE
154         if 'stderr' not in kwargs:
155             kwargs['stderr'] = subprocess.PIPE
156         p = subprocess.Popen(execargs, close_fds=True, shell=False,
157                              **kwargs)
158         stdoutdata, stderrdata = p.communicate(None)
159         if p.returncode != 0:
160             raise Exception("run_command %s exit %d:\n%s" %
161                             (execargs, p.returncode, stderrdata))
162         return stdoutdata, stderrdata
163
164     @staticmethod
165     def git_checkout(url, version, path):
166         if not re.search('^/', path):
167             path = os.path.join(current_job().tmpdir, path)
168         if not os.path.exists(path):
169             util.run_command(["git", "clone", url, path],
170                              cwd=os.path.dirname(path))
171         util.run_command(["git", "checkout", version],
172                          cwd=path)
173         return path
174
175     @staticmethod
176     def tar_extractor(path, decompress_flag):
177         return subprocess.Popen(["tar",
178                                  "-C", path,
179                                  ("-x%sf" % decompress_flag),
180                                  "-"],
181                                 stdout=None,
182                                 stdin=subprocess.PIPE, stderr=sys.stderr,
183                                 shell=False, close_fds=True)
184
185     @staticmethod
186     def tarball_extract(tarball, path):
187         """Retrieve a tarball from Keep and extract it to a local
188         directory.  Return the absolute path where the tarball was
189         extracted. If the top level of the tarball contained just one
190         file or directory, return the absolute path of that single
191         item.
192
193         tarball -- collection locator
194         path -- where to extract the tarball: absolute, or relative to job tmp
195         """
196         if not re.search('^/', path):
197             path = os.path.join(current_job().tmpdir, path)
198         lockfile = open(path + '.lock', 'w')
199         fcntl.flock(lockfile, fcntl.LOCK_EX)
200         try:
201             os.stat(path)
202         except OSError:
203             os.mkdir(path)
204         already_have_it = False
205         try:
206             if os.readlink(os.path.join(path, '.locator')) == tarball:
207                 already_have_it = True
208         except OSError:
209             pass
210         if not already_have_it:
211
212             # emulate "rm -f" (i.e., if the file does not exist, we win)
213             try:
214                 os.unlink(os.path.join(path, '.locator'))
215             except OSError:
216                 if os.path.exists(os.path.join(path, '.locator')):
217                     os.unlink(os.path.join(path, '.locator'))
218
219             for f in CollectionReader(tarball).all_files():
220                 if re.search('\.(tbz|tar.bz2)$', f.name()):
221                     p = util.tar_extractor(path, 'j')
222                 elif re.search('\.(tgz|tar.gz)$', f.name()):
223                     p = util.tar_extractor(path, 'z')
224                 elif re.search('\.tar$', f.name()):
225                     p = util.tar_extractor(path, '')
226                 else:
227                     raise Exception("tarball_extract cannot handle filename %s"
228                                     % f.name())
229                 while True:
230                     buf = f.read(2**20)
231                     if len(buf) == 0:
232                         break
233                     p.stdin.write(buf)
234                 p.stdin.close()
235                 p.wait()
236                 if p.returncode != 0:
237                     lockfile.close()
238                     raise Exception("tar exited %d" % p.returncode)
239             os.symlink(tarball, os.path.join(path, '.locator'))
240         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
241         lockfile.close()
242         if len(tld_extracts) == 1:
243             return os.path.join(path, tld_extracts[0])
244         return path
245
246     @staticmethod
247     def zipball_extract(zipball, path):
248         """Retrieve a zip archive from Keep and extract it to a local
249         directory.  Return the absolute path where the archive was
250         extracted. If the top level of the archive contained just one
251         file or directory, return the absolute path of that single
252         item.
253
254         zipball -- collection locator
255         path -- where to extract the archive: absolute, or relative to job tmp
256         """
257         if not re.search('^/', path):
258             path = os.path.join(current_job().tmpdir, path)
259         lockfile = open(path + '.lock', 'w')
260         fcntl.flock(lockfile, fcntl.LOCK_EX)
261         try:
262             os.stat(path)
263         except OSError:
264             os.mkdir(path)
265         already_have_it = False
266         try:
267             if os.readlink(os.path.join(path, '.locator')) == zipball:
268                 already_have_it = True
269         except OSError:
270             pass
271         if not already_have_it:
272
273             # emulate "rm -f" (i.e., if the file does not exist, we win)
274             try:
275                 os.unlink(os.path.join(path, '.locator'))
276             except OSError:
277                 if os.path.exists(os.path.join(path, '.locator')):
278                     os.unlink(os.path.join(path, '.locator'))
279
280             for f in CollectionReader(zipball).all_files():
281                 if not re.search('\.zip$', f.name()):
282                     raise Exception("zipball_extract cannot handle filename %s"
283                                     % f.name())
284                 zip_filename = os.path.join(path, os.path.basename(f.name()))
285                 zip_file = open(zip_filename, 'wb')
286                 while True:
287                     buf = f.read(2**20)
288                     if len(buf) == 0:
289                         break
290                     zip_file.write(buf)
291                 zip_file.close()
292                 
293                 p = subprocess.Popen(["unzip",
294                                       "-q", "-o",
295                                       "-d", path,
296                                       zip_filename],
297                                      stdout=None,
298                                      stdin=None, stderr=sys.stderr,
299                                      shell=False, close_fds=True)
300                 p.wait()
301                 if p.returncode != 0:
302                     lockfile.close()
303                     raise Exception("unzip exited %d" % p.returncode)
304                 os.unlink(zip_filename)
305             os.symlink(zipball, os.path.join(path, '.locator'))
306         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
307         lockfile.close()
308         if len(tld_extracts) == 1:
309             return os.path.join(path, tld_extracts[0])
310         return path
311
312     @staticmethod
313     def collection_extract(collection, path, files=[], decompress=True):
314         """Retrieve a collection from Keep and extract it to a local
315         directory.  Return the absolute path where the collection was
316         extracted.
317
318         collection -- collection locator
319         path -- where to extract: absolute, or relative to job tmp
320         """
321         if not re.search('^/', path):
322             path = os.path.join(current_job().tmpdir, path)
323         lockfile = open(path + '.lock', 'w')
324         fcntl.flock(lockfile, fcntl.LOCK_EX)
325         try:
326             os.stat(path)
327         except OSError:
328             os.mkdir(path)
329         already_have_it = False
330         try:
331             if os.readlink(os.path.join(path, '.locator')) == collection:
332                 already_have_it = True
333         except OSError:
334             pass
335
336         # emulate "rm -f" (i.e., if the file does not exist, we win)
337         try:
338             os.unlink(os.path.join(path, '.locator'))
339         except OSError:
340             if os.path.exists(os.path.join(path, '.locator')):
341                 os.unlink(os.path.join(path, '.locator'))
342
343         files_got = []
344         for s in CollectionReader(collection).all_streams():
345             stream_name = s.name()
346             for f in s.all_files():
347                 if (files == [] or
348                     ((f.name() not in files_got) and
349                      (f.name() in files or
350                       (decompress and f.decompressed_name() in files)))):
351                     outname = f.decompressed_name() if decompress else f.name()
352                     files_got += [outname]
353                     if os.path.exists(os.path.join(path, stream_name, outname)):
354                         continue
355                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
356                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
357                     for buf in (f.readall_decompressed() if decompress
358                                 else f.readall()):
359                         outfile.write(buf)
360                     outfile.close()
361         if len(files_got) < len(files):
362             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
363         os.symlink(collection, os.path.join(path, '.locator'))
364
365         lockfile.close()
366         return path
367
368     @staticmethod
369     def mkdir_dash_p(path):
370         if not os.path.exists(path):
371             util.mkdir_dash_p(os.path.dirname(path))
372             try:
373                 os.mkdir(path)
374             except OSError:
375                 if not os.path.exists(path):
376                     os.mkdir(path)
377
378     @staticmethod
379     def stream_extract(stream, path, files=[], decompress=True):
380         """Retrieve a stream from Keep and extract it to a local
381         directory.  Return the absolute path where the stream was
382         extracted.
383
384         stream -- StreamReader object
385         path -- where to extract: absolute, or relative to job tmp
386         """
387         if not re.search('^/', path):
388             path = os.path.join(current_job().tmpdir, path)
389         lockfile = open(path + '.lock', 'w')
390         fcntl.flock(lockfile, fcntl.LOCK_EX)
391         try:
392             os.stat(path)
393         except OSError:
394             os.mkdir(path)
395
396         files_got = []
397         for f in stream.all_files():
398             if (files == [] or
399                 ((f.name() not in files_got) and
400                  (f.name() in files or
401                   (decompress and f.decompressed_name() in files)))):
402                 outname = f.decompressed_name() if decompress else f.name()
403                 files_got += [outname]
404                 if os.path.exists(os.path.join(path, outname)):
405                     os.unlink(os.path.join(path, outname))
406                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
407                 outfile = open(os.path.join(path, outname), 'wb')
408                 for buf in (f.readall_decompressed() if decompress
409                             else f.readall()):
410                     outfile.write(buf)
411                 outfile.close()
412         if len(files_got) < len(files):
413             raise Exception("Wanted files %s but only got %s from %s" %
414                             (files, files_got, map(lambda z: z.name(),
415                                                    list(stream.all_files()))))
416         lockfile.close()
417         return path
418
419     @staticmethod
420     def listdir_recursive(dirname, base=None):
421         allfiles = []
422         for ent in sorted(os.listdir(dirname)):
423             ent_path = os.path.join(dirname, ent)
424             ent_base = os.path.join(base, ent) if base else ent
425             if os.path.isdir(ent_path):
426                 allfiles += util.listdir_recursive(ent_path, ent_base)
427             else:
428                 allfiles += [ent_base]
429         return allfiles
430
431 class DataReader:
432     def __init__(self, data_locator):
433         self.data_locator = data_locator
434         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
435                                   stdout=subprocess.PIPE,
436                                   stdin=None, stderr=subprocess.PIPE,
437                                   shell=False, close_fds=True)
438     def __enter__(self):
439         pass
440     def __exit__(self):
441         self.close()
442     def read(self, size, **kwargs):
443         return self.p.stdout.read(size, **kwargs)
444     def close(self):
445         self.p.stdout.close()
446         if not self.p.stderr.closed:
447             for err in self.p.stderr:
448                 print >> sys.stderr, err
449             self.p.stderr.close()
450         self.p.wait()
451         if self.p.returncode != 0:
452             raise Exception("whget subprocess exited %d" % self.p.returncode)
453
454 class StreamFileReader:
455     def __init__(self, stream, pos, size, name):
456         self._stream = stream
457         self._pos = pos
458         self._size = size
459         self._name = name
460         self._filepos = 0
461     def name(self):
462         return self._name
463     def decompressed_name(self):
464         return re.sub('\.(bz2|gz)$', '', self._name)
465     def size(self):
466         return self._size
467     def stream_name(self):
468         return self._stream.name()
469     def read(self, size, **kwargs):
470         self._stream.seek(self._pos + self._filepos)
471         data = self._stream.read(min(size, self._size - self._filepos))
472         self._filepos += len(data)
473         return data
474     def readall(self, size=2**20, **kwargs):
475         while True:
476             data = self.read(size, **kwargs)
477             if data == '':
478                 break
479             yield data
480     def bunzip2(self, size):
481         decompressor = bz2.BZ2Decompressor()
482         for chunk in self.readall(size):
483             data = decompressor.decompress(chunk)
484             if data and data != '':
485                 yield data
486     def gunzip(self, size):
487         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
488         for chunk in self.readall(size):
489             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
490             if data and data != '':
491                 yield data
492     def readall_decompressed(self, size=2**20):
493         self._stream.seek(self._pos + self._filepos)
494         if re.search('\.bz2$', self._name):
495             return self.bunzip2(size)
496         elif re.search('\.gz$', self._name):
497             return self.gunzip(size)
498         else:
499             return self.readall(size)
500     def readlines(self, decompress=True):
501         if decompress:
502             datasource = self.readall_decompressed()
503         else:
504             self._stream.seek(self._pos + self._filepos)
505             datasource = self.readall()
506         data = ''
507         for newdata in datasource:
508             data += newdata
509             sol = 0
510             while True:
511                 eol = string.find(data, "\n", sol)
512                 if eol < 0:
513                     break
514                 yield data[sol:eol+1]
515                 sol = eol+1
516             data = data[sol:]
517         if data != '':
518             yield data
519     def as_manifest(self):
520         if self.size() == 0:
521             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
522                     % (self._stream.name(), self.name()))
523         return string.join(self._stream.tokens_for_range(self._pos, self._size),
524                            " ") + "\n"
525
526 class StreamReader:
527     def __init__(self, tokens):
528         self._tokens = tokens
529         self._current_datablock_data = None
530         self._current_datablock_pos = 0
531         self._current_datablock_index = -1
532         self._pos = 0
533
534         self._stream_name = None
535         self.data_locators = []
536         self.files = []
537
538         for tok in self._tokens:
539             if self._stream_name == None:
540                 self._stream_name = tok
541             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
542                 self.data_locators += [tok]
543             elif re.search(r'^\d+:\d+:\S+', tok):
544                 pos, size, name = tok.split(':',2)
545                 self.files += [[int(pos), int(size), name]]
546             else:
547                 raise Exception("Invalid manifest format")
548
549     def tokens(self):
550         return self._tokens
551     def tokens_for_range(self, range_start, range_size):
552         resp = [self._stream_name]
553         return_all_tokens = False
554         block_start = 0
555         token_bytes_skipped = 0
556         for locator in self.data_locators:
557             sizehint = re.search(r'\+(\d+)', locator)
558             if not sizehint:
559                 return_all_tokens = True
560             if return_all_tokens:
561                 resp += [locator]
562                 next
563             blocksize = int(sizehint.group(0))
564             if range_start + range_size <= block_start:
565                 break
566             if range_start < block_start + blocksize:
567                 resp += [locator]
568             else:
569                 token_bytes_skipped += blocksize
570             block_start += blocksize
571         for f in self.files:
572             if ((f[0] < range_start + range_size)
573                 and
574                 (f[0] + f[1] > range_start)
575                 and
576                 f[1] > 0):
577                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
578         return resp
579     def name(self):
580         return self._stream_name
581     def all_files(self):
582         for f in self.files:
583             pos, size, name = f
584             yield StreamFileReader(self, pos, size, name)
585     def nextdatablock(self):
586         if self._current_datablock_index < 0:
587             self._current_datablock_pos = 0
588             self._current_datablock_index = 0
589         else:
590             self._current_datablock_pos += self.current_datablock_size()
591             self._current_datablock_index += 1
592         self._current_datablock_data = None
593     def current_datablock_data(self):
594         if self._current_datablock_data == None:
595             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
596         return self._current_datablock_data
597     def current_datablock_size(self):
598         if self._current_datablock_index < 0:
599             self.nextdatablock()
600         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
601         if sizehint:
602             return int(sizehint.group(0))
603         return len(self.current_datablock_data())
604     def seek(self, pos):
605         """Set the position of the next read operation."""
606         self._pos = pos
607     def really_seek(self):
608         """Find and load the appropriate data block, so the byte at
609         _pos is in memory.
610         """
611         if self._pos == self._current_datablock_pos:
612             return True
613         if (self._current_datablock_pos != None and
614             self._pos >= self._current_datablock_pos and
615             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
616             return True
617         if self._pos < self._current_datablock_pos:
618             self._current_datablock_index = -1
619             self.nextdatablock()
620         while (self._pos > self._current_datablock_pos and
621                self._pos > self._current_datablock_pos + self.current_datablock_size()):
622             self.nextdatablock()
623     def read(self, size):
624         """Read no more than size bytes -- but at least one byte,
625         unless _pos is already at the end of the stream.
626         """
627         if size == 0:
628             return ''
629         self.really_seek()
630         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
631             self.nextdatablock()
632             if self._current_datablock_index >= len(self.data_locators):
633                 return None
634         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
635         self._pos += len(data)
636         return data
637
638 class CollectionReader:
639     def __init__(self, manifest_locator_or_text):
640         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
641             self._manifest_text = manifest_locator_or_text
642             self._manifest_locator = None
643         else:
644             self._manifest_locator = manifest_locator_or_text
645             self._manifest_text = None
646         self._streams = None
647     def __enter__(self):
648         pass
649     def __exit__(self):
650         pass
651     def _populate(self):
652         if self._streams != None:
653             return
654         if not self._manifest_text:
655             self._manifest_text = Keep.get(self._manifest_locator)
656         self._streams = []
657         for stream_line in self._manifest_text.split("\n"):
658             if stream_line != '':
659                 stream_tokens = stream_line.split()
660                 self._streams += [stream_tokens]
661     def all_streams(self):
662         self._populate()
663         resp = []
664         for s in self._streams:
665             resp += [StreamReader(s)]
666         return resp
667     def all_files(self):
668         for s in self.all_streams():
669             for f in s.all_files():
670                 yield f
671     def manifest_text(self):
672         self._populate()
673         return self._manifest_text
674
675 class CollectionWriter:
676     KEEP_BLOCK_SIZE = 2**26
677     def __init__(self):
678         self._data_buffer = []
679         self._data_buffer_len = 0
680         self._current_stream_files = []
681         self._current_stream_length = 0
682         self._current_stream_locators = []
683         self._current_stream_name = '.'
684         self._current_file_name = None
685         self._current_file_pos = 0
686         self._finished_streams = []
687     def __enter__(self):
688         pass
689     def __exit__(self):
690         self.finish()
691     def write_directory_tree(self,
692                              path, stream_name='.', max_manifest_depth=-1):
693         self.start_new_stream(stream_name)
694         todo = []
695         if max_manifest_depth == 0:
696             dirents = util.listdir_recursive(path)
697         else:
698             dirents = sorted(os.listdir(path))
699         for dirent in dirents:
700             target = os.path.join(path, dirent)
701             if os.path.isdir(target):
702                 todo += [[target,
703                           os.path.join(stream_name, dirent),
704                           max_manifest_depth-1]]
705             else:
706                 self.start_new_file(dirent)
707                 with open(target, 'rb') as f:
708                     while True:
709                         buf = f.read(2**26)
710                         if len(buf) == 0:
711                             break
712                         self.write(buf)
713         self.finish_current_stream()
714         map(lambda x: self.write_directory_tree(*x), todo)
715
716     def write(self, newdata):
717         self._data_buffer += [newdata]
718         self._data_buffer_len += len(newdata)
719         self._current_stream_length += len(newdata)
720         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
721             self.flush_data()
722     def flush_data(self):
723         data_buffer = ''.join(self._data_buffer)
724         if data_buffer != '':
725             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
726             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
727             self._data_buffer_len = len(self._data_buffer[0])
728     def start_new_file(self, newfilename=None):
729         self.finish_current_file()
730         self.set_current_file_name(newfilename)
731     def set_current_file_name(self, newfilename):
732         newfilename = re.sub(r' ', '\\\\040', newfilename)
733         if re.search(r'[ \t\n]', newfilename):
734             raise AssertionError("Manifest filenames cannot contain whitespace")
735         self._current_file_name = newfilename
736     def current_file_name(self):
737         return self._current_file_name
738     def finish_current_file(self):
739         if self._current_file_name == None:
740             if self._current_file_pos == self._current_stream_length:
741                 return
742             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
743         self._current_stream_files += [[self._current_file_pos,
744                                        self._current_stream_length - self._current_file_pos,
745                                        self._current_file_name]]
746         self._current_file_pos = self._current_stream_length
747     def start_new_stream(self, newstreamname='.'):
748         self.finish_current_stream()
749         self.set_current_stream_name(newstreamname)
750     def set_current_stream_name(self, newstreamname):
751         if re.search(r'[ \t\n]', newstreamname):
752             raise AssertionError("Manifest stream names cannot contain whitespace")
753         self._current_stream_name = newstreamname
754     def current_stream_name(self):
755         return self._current_stream_name
756     def finish_current_stream(self):
757         self.finish_current_file()
758         self.flush_data()
759         if len(self._current_stream_files) == 0:
760             pass
761         elif self._current_stream_name == None:
762             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
763         else:
764             self._finished_streams += [[self._current_stream_name,
765                                        self._current_stream_locators,
766                                        self._current_stream_files]]
767         self._current_stream_files = []
768         self._current_stream_length = 0
769         self._current_stream_locators = []
770         self._current_stream_name = None
771         self._current_file_pos = 0
772         self._current_file_name = None
773     def finish(self):
774         return Keep.put(self.manifest_text())
775     def manifest_text(self):
776         self.finish_current_stream()
777         manifest = ''
778         for stream in self._finished_streams:
779             if not re.search(r'^\.(/.*)?$', stream[0]):
780                 manifest += './'
781             manifest += stream[0]
782             if len(stream[1]) == 0:
783                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
784             else:
785                 for locator in stream[1]:
786                     manifest += " %s" % locator
787             for sfile in stream[2]:
788                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
789             manifest += "\n"
790         return manifest
791
792 global_client_object = None
793
794 class Keep:
795     @staticmethod
796     def global_client_object():
797         global global_client_object
798         if global_client_object == None:
799             global_client_object = KeepClient()
800         return global_client_object
801
802     @staticmethod
803     def get(locator):
804         return Keep.global_client_object().get(locator)
805
806     @staticmethod
807     def put(data):
808         return Keep.global_client_object().put(data)
809
810 class KeepClient:
811     def __init__(self):
812         self.service_roots = None
813
814     def shuffled_service_roots(self, hash):
815         if self.service_roots == None:
816             keep_disks = api().keep_disks().list().execute()['items']
817             roots = (("http%s://%s:%d/" %
818                       ('s' if f['service_ssl_flag'] else '',
819                        f['service_host'],
820                        f['service_port']))
821                      for f in keep_disks)
822             self.service_roots = sorted(set(roots))
823             logging.debug(str(self.service_roots))
824         seed = hash
825         pool = self.service_roots[:]
826         pseq = []
827         while len(pool) > 0:
828             if len(seed) < 8:
829                 if len(pseq) < len(hash) / 4: # first time around
830                     seed = hash[-4:] + hash
831                 else:
832                     seed += hash
833             probe = int(seed[0:8], 16) % len(pool)
834             pseq += [pool[probe]]
835             pool = pool[:probe] + pool[probe+1:]
836             seed = seed[8:]
837         logging.debug(str(pseq))
838         return pseq
839
840     def get(self, locator):
841         if 'KEEP_LOCAL_STORE' in os.environ:
842             return KeepClient.local_store_get(locator)
843         expect_hash = re.sub(r'\+.*', '', locator)
844         for service_root in self.shuffled_service_roots(expect_hash):
845             h = httplib2.Http()
846             url = service_root + expect_hash
847             api_token = os.environ['ARVADOS_API_TOKEN']
848             headers = {'Authorization': "OAuth2 %s" % api_token,
849                        'Accept': 'application/octet-stream'}
850             try:
851                 resp, content = h.request(url, 'GET', headers=headers)
852                 if re.match(r'^2\d\d$', resp['status']):
853                     m = hashlib.new('md5')
854                     m.update(content)
855                     md5 = m.hexdigest()
856                     if md5 == expect_hash:
857                         return content
858                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
859             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
860                 logging.info("Request fail: GET %s => %s: %s" %
861                              (url, type(e), str(e)))
862         raise Exception("Not found: %s" % expect_hash)
863
864     def put(self, data, **kwargs):
865         if 'KEEP_LOCAL_STORE' in os.environ:
866             return KeepClient.local_store_put(data)
867         m = hashlib.new('md5')
868         m.update(data)
869         data_hash = m.hexdigest()
870         have_copies = 0
871         want_copies = kwargs.get('copies', 2)
872         for service_root in self.shuffled_service_roots(data_hash):
873             h = httplib2.Http()
874             url = service_root + data_hash
875             api_token = os.environ['ARVADOS_API_TOKEN']
876             headers = {'Authorization': "OAuth2 %s" % api_token}
877             try:
878                 resp, content = h.request(url, 'PUT',
879                                           headers=headers,
880                                           body=data)
881                 if (resp['status'] == '401' and
882                     re.match(r'Timestamp verification failed', content)):
883                     body = self.sign_for_old_server(data_hash, data)
884                     h = httplib2.Http()
885                     resp, content = h.request(url, 'PUT',
886                                               headers=headers,
887                                               body=body)
888                 if re.match(r'^2\d\d$', resp['status']):
889                     have_copies += 1
890                     if have_copies == want_copies:
891                         return data_hash + '+' + str(len(data))
892                 else:
893                     logging.warning("Request fail: PUT %s => %s %s" %
894                                     (url, resp['status'], content))
895             except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
896                 logging.warning("Request fail: PUT %s => %s: %s" %
897                                 (url, type(e), str(e)))
898         raise Exception("Write fail for %s: wanted %d but wrote %d" %
899                         (data_hash, want_copies, have_copies))
900
901     def sign_for_old_server(self, data_hash, data):
902         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
903
904
905     @staticmethod
906     def local_store_put(data):
907         m = hashlib.new('md5')
908         m.update(data)
909         md5 = m.hexdigest()
910         locator = '%s+%d' % (md5, len(data))
911         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
912             f.write(data)
913         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
914                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
915         return locator
916     @staticmethod
917     def local_store_get(locator):
918         r = re.search('^([0-9a-f]{32,})', locator)
919         if not r:
920             raise Exception("Keep.get: invalid data locator '%s'" % locator)
921         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
922             return ''
923         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
924             return f.read()