Accept comma-separated list of locators in Keep.get()
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 import apiclient
22 import apiclient.discovery
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
28
29 services = {}
30
31 class errors:
32     class SyntaxError(Exception):
33         pass
34     class AssertionError(Exception):
35         pass
36     class NotFoundError(Exception):
37         pass
38     class CommandFailedError(Exception):
39         pass
40     class KeepWriteError(Exception):
41         pass
42     class NotImplementedError(Exception):
43         pass
44
45 class CredentialsFromEnv(object):
46     @staticmethod
47     def http_request(self, uri, **kwargs):
48         from httplib import BadStatusLine
49         if 'headers' not in kwargs:
50             kwargs['headers'] = {}
51         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
52         try:
53             return self.orig_http_request(uri, **kwargs)
54         except BadStatusLine:
55             # This is how httplib tells us that it tried to reuse an
56             # existing connection but it was already closed by the
57             # server. In that case, yes, we would like to retry.
58             # Unfortunately, we are not absolutely certain that the
59             # previous call did not succeed, so this is slightly
60             # risky.
61             return self.orig_http_request(uri, **kwargs)
62     def authorize(self, http):
63         http.orig_http_request = http.request
64         http.request = types.MethodType(self.http_request, http)
65         return http
66
67 def task_set_output(self,s):
68     api('v1').job_tasks().update(uuid=self['uuid'],
69                                  body={
70             'output':s,
71             'success':True,
72             'progress':1.0
73             }).execute()
74
75 _current_task = None
76 def current_task():
77     global _current_task
78     if _current_task:
79         return _current_task
80     t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
81     t = UserDict.UserDict(t)
82     t.set_output = types.MethodType(task_set_output, t)
83     t.tmpdir = os.environ['TASK_WORK']
84     _current_task = t
85     return t
86
87 _current_job = None
88 def current_job():
89     global _current_job
90     if _current_job:
91         return _current_job
92     t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
93     t = UserDict.UserDict(t)
94     t.tmpdir = os.environ['JOB_WORK']
95     _current_job = t
96     return t
97
98 def getjobparam(*args):
99     return current_job()['script_parameters'].get(*args)
100
101 # Monkey patch discovery._cast() so objects and arrays get serialized
102 # with json.dumps() instead of str().
103 _cast_orig = apiclient.discovery._cast
104 def _cast_objects_too(value, schema_type):
105     global _cast_orig
106     if (type(value) != type('') and
107         (schema_type == 'object' or schema_type == 'array')):
108         return json.dumps(value)
109     else:
110         return _cast_orig(value, schema_type)
111 apiclient.discovery._cast = _cast_objects_too
112
113 def api(version=None):
114     global services
115     if not services.get(version):
116         apiVersion = version
117         if not version:
118             apiVersion = 'v1'
119             logging.info("Using default API version. " +
120                          "Call arvados.api('%s') instead." %
121                          apiVersion)
122         if 'ARVADOS_API_HOST' not in os.environ:
123             raise Exception("ARVADOS_API_HOST is not set. Aborting.")
124         url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
125                os.environ['ARVADOS_API_HOST'])
126         credentials = CredentialsFromEnv()
127
128         # Use system's CA certificates (if we find them) instead of httplib2's
129         ca_certs = '/etc/ssl/certs/ca-certificates.crt'
130         if not os.path.exists(ca_certs):
131             ca_certs = None             # use httplib2 default
132
133         http = httplib2.Http(ca_certs=ca_certs)
134         http = credentials.authorize(http)
135         if re.match(r'(?i)^(true|1|yes)$',
136                     os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
137             http.disable_ssl_certificate_validation=True
138         services[version] = apiclient.discovery.build(
139             'arvados', apiVersion, http=http, discoveryServiceUrl=url)
140     return services[version]
141
142 class JobTask(object):
143     def __init__(self, parameters=dict(), runtime_constraints=dict()):
144         print "init jobtask %s %s" % (parameters, runtime_constraints)
145
146 class job_setup:
147     @staticmethod
148     def one_task_per_input_file(if_sequence=0, and_end_task=True):
149         if if_sequence != current_task()['sequence']:
150             return
151         job_input = current_job()['script_parameters']['input']
152         cr = CollectionReader(job_input)
153         for s in cr.all_streams():
154             for f in s.all_files():
155                 task_input = f.as_manifest()
156                 new_task_attrs = {
157                     'job_uuid': current_job()['uuid'],
158                     'created_by_job_task_uuid': current_task()['uuid'],
159                     'sequence': if_sequence + 1,
160                     'parameters': {
161                         'input':task_input
162                         }
163                     }
164                 api('v1').job_tasks().create(body=new_task_attrs).execute()
165         if and_end_task:
166             api('v1').job_tasks().update(uuid=current_task()['uuid'],
167                                        body={'success':True}
168                                        ).execute()
169             exit(0)
170
171     @staticmethod
172     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
173         if if_sequence != current_task()['sequence']:
174             return
175         job_input = current_job()['script_parameters']['input']
176         cr = CollectionReader(job_input)
177         for s in cr.all_streams():
178             task_input = s.tokens()
179             new_task_attrs = {
180                 'job_uuid': current_job()['uuid'],
181                 'created_by_job_task_uuid': current_task()['uuid'],
182                 'sequence': if_sequence + 1,
183                 'parameters': {
184                     'input':task_input
185                     }
186                 }
187             api('v1').job_tasks().create(body=new_task_attrs).execute()
188         if and_end_task:
189             api('v1').job_tasks().update(uuid=current_task()['uuid'],
190                                        body={'success':True}
191                                        ).execute()
192             exit(0)
193
194 class util:
195     @staticmethod
196     def clear_tmpdir(path=None):
197         """
198         Ensure the given directory (or TASK_TMPDIR if none given)
199         exists and is empty.
200         """
201         if path == None:
202             path = current_task().tmpdir
203         if os.path.exists(path):
204             p = subprocess.Popen(['rm', '-rf', path])
205             stdout, stderr = p.communicate(None)
206             if p.returncode != 0:
207                 raise Exception('rm -rf %s: %s' % (path, stderr))
208         os.mkdir(path)
209
210     @staticmethod
211     def run_command(execargs, **kwargs):
212         kwargs.setdefault('stdin', subprocess.PIPE)
213         kwargs.setdefault('stdout', subprocess.PIPE)
214         kwargs.setdefault('stderr', sys.stderr)
215         kwargs.setdefault('close_fds', True)
216         kwargs.setdefault('shell', False)
217         p = subprocess.Popen(execargs, **kwargs)
218         stdoutdata, stderrdata = p.communicate(None)
219         if p.returncode != 0:
220             raise errors.CommandFailedError(
221                 "run_command %s exit %d:\n%s" %
222                 (execargs, p.returncode, stderrdata))
223         return stdoutdata, stderrdata
224
225     @staticmethod
226     def git_checkout(url, version, path):
227         if not re.search('^/', path):
228             path = os.path.join(current_job().tmpdir, path)
229         if not os.path.exists(path):
230             util.run_command(["git", "clone", url, path],
231                              cwd=os.path.dirname(path))
232         util.run_command(["git", "checkout", version],
233                          cwd=path)
234         return path
235
236     @staticmethod
237     def tar_extractor(path, decompress_flag):
238         return subprocess.Popen(["tar",
239                                  "-C", path,
240                                  ("-x%sf" % decompress_flag),
241                                  "-"],
242                                 stdout=None,
243                                 stdin=subprocess.PIPE, stderr=sys.stderr,
244                                 shell=False, close_fds=True)
245
246     @staticmethod
247     def tarball_extract(tarball, path):
248         """Retrieve a tarball from Keep and extract it to a local
249         directory.  Return the absolute path where the tarball was
250         extracted. If the top level of the tarball contained just one
251         file or directory, return the absolute path of that single
252         item.
253
254         tarball -- collection locator
255         path -- where to extract the tarball: absolute, or relative to job tmp
256         """
257         if not re.search('^/', path):
258             path = os.path.join(current_job().tmpdir, path)
259         lockfile = open(path + '.lock', 'w')
260         fcntl.flock(lockfile, fcntl.LOCK_EX)
261         try:
262             os.stat(path)
263         except OSError:
264             os.mkdir(path)
265         already_have_it = False
266         try:
267             if os.readlink(os.path.join(path, '.locator')) == tarball:
268                 already_have_it = True
269         except OSError:
270             pass
271         if not already_have_it:
272
273             # emulate "rm -f" (i.e., if the file does not exist, we win)
274             try:
275                 os.unlink(os.path.join(path, '.locator'))
276             except OSError:
277                 if os.path.exists(os.path.join(path, '.locator')):
278                     os.unlink(os.path.join(path, '.locator'))
279
280             for f in CollectionReader(tarball).all_files():
281                 if re.search('\.(tbz|tar.bz2)$', f.name()):
282                     p = util.tar_extractor(path, 'j')
283                 elif re.search('\.(tgz|tar.gz)$', f.name()):
284                     p = util.tar_extractor(path, 'z')
285                 elif re.search('\.tar$', f.name()):
286                     p = util.tar_extractor(path, '')
287                 else:
288                     raise errors.AssertionError(
289                         "tarball_extract cannot handle filename %s" % f.name())
290                 while True:
291                     buf = f.read(2**20)
292                     if len(buf) == 0:
293                         break
294                     p.stdin.write(buf)
295                 p.stdin.close()
296                 p.wait()
297                 if p.returncode != 0:
298                     lockfile.close()
299                     raise errors.CommandFailedError(
300                         "tar exited %d" % p.returncode)
301             os.symlink(tarball, os.path.join(path, '.locator'))
302         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
303         lockfile.close()
304         if len(tld_extracts) == 1:
305             return os.path.join(path, tld_extracts[0])
306         return path
307
308     @staticmethod
309     def zipball_extract(zipball, path):
310         """Retrieve a zip archive from Keep and extract it to a local
311         directory.  Return the absolute path where the archive was
312         extracted. If the top level of the archive contained just one
313         file or directory, return the absolute path of that single
314         item.
315
316         zipball -- collection locator
317         path -- where to extract the archive: absolute, or relative to job tmp
318         """
319         if not re.search('^/', path):
320             path = os.path.join(current_job().tmpdir, path)
321         lockfile = open(path + '.lock', 'w')
322         fcntl.flock(lockfile, fcntl.LOCK_EX)
323         try:
324             os.stat(path)
325         except OSError:
326             os.mkdir(path)
327         already_have_it = False
328         try:
329             if os.readlink(os.path.join(path, '.locator')) == zipball:
330                 already_have_it = True
331         except OSError:
332             pass
333         if not already_have_it:
334
335             # emulate "rm -f" (i.e., if the file does not exist, we win)
336             try:
337                 os.unlink(os.path.join(path, '.locator'))
338             except OSError:
339                 if os.path.exists(os.path.join(path, '.locator')):
340                     os.unlink(os.path.join(path, '.locator'))
341
342             for f in CollectionReader(zipball).all_files():
343                 if not re.search('\.zip$', f.name()):
344                     raise errors.NotImplementedError(
345                         "zipball_extract cannot handle filename %s" % f.name())
346                 zip_filename = os.path.join(path, os.path.basename(f.name()))
347                 zip_file = open(zip_filename, 'wb')
348                 while True:
349                     buf = f.read(2**20)
350                     if len(buf) == 0:
351                         break
352                     zip_file.write(buf)
353                 zip_file.close()
354                 
355                 p = subprocess.Popen(["unzip",
356                                       "-q", "-o",
357                                       "-d", path,
358                                       zip_filename],
359                                      stdout=None,
360                                      stdin=None, stderr=sys.stderr,
361                                      shell=False, close_fds=True)
362                 p.wait()
363                 if p.returncode != 0:
364                     lockfile.close()
365                     raise errors.CommandFailedError(
366                         "unzip exited %d" % p.returncode)
367                 os.unlink(zip_filename)
368             os.symlink(zipball, os.path.join(path, '.locator'))
369         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
370         lockfile.close()
371         if len(tld_extracts) == 1:
372             return os.path.join(path, tld_extracts[0])
373         return path
374
375     @staticmethod
376     def collection_extract(collection, path, files=[], decompress=True):
377         """Retrieve a collection from Keep and extract it to a local
378         directory.  Return the absolute path where the collection was
379         extracted.
380
381         collection -- collection locator
382         path -- where to extract: absolute, or relative to job tmp
383         """
384         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
385         if matches:
386             collection_hash = matches.group(1)
387         else:
388             collection_hash = hashlib.md5(collection).hexdigest()
389         if not re.search('^/', path):
390             path = os.path.join(current_job().tmpdir, path)
391         lockfile = open(path + '.lock', 'w')
392         fcntl.flock(lockfile, fcntl.LOCK_EX)
393         try:
394             os.stat(path)
395         except OSError:
396             os.mkdir(path)
397         already_have_it = False
398         try:
399             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
400                 already_have_it = True
401         except OSError:
402             pass
403
404         # emulate "rm -f" (i.e., if the file does not exist, we win)
405         try:
406             os.unlink(os.path.join(path, '.locator'))
407         except OSError:
408             if os.path.exists(os.path.join(path, '.locator')):
409                 os.unlink(os.path.join(path, '.locator'))
410
411         files_got = []
412         for s in CollectionReader(collection).all_streams():
413             stream_name = s.name()
414             for f in s.all_files():
415                 if (files == [] or
416                     ((f.name() not in files_got) and
417                      (f.name() in files or
418                       (decompress and f.decompressed_name() in files)))):
419                     outname = f.decompressed_name() if decompress else f.name()
420                     files_got += [outname]
421                     if os.path.exists(os.path.join(path, stream_name, outname)):
422                         continue
423                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
424                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
425                     for buf in (f.readall_decompressed() if decompress
426                                 else f.readall()):
427                         outfile.write(buf)
428                     outfile.close()
429         if len(files_got) < len(files):
430             raise errors.AssertionError(
431                 "Wanted files %s but only got %s from %s" %
432                 (files, files_got,
433                  [z.name() for z in CollectionReader(collection).all_files()]))
434         os.symlink(collection_hash, os.path.join(path, '.locator'))
435
436         lockfile.close()
437         return path
438
439     @staticmethod
440     def mkdir_dash_p(path):
441         if not os.path.exists(path):
442             util.mkdir_dash_p(os.path.dirname(path))
443             try:
444                 os.mkdir(path)
445             except OSError:
446                 if not os.path.exists(path):
447                     os.mkdir(path)
448
449     @staticmethod
450     def stream_extract(stream, path, files=[], decompress=True):
451         """Retrieve a stream from Keep and extract it to a local
452         directory.  Return the absolute path where the stream was
453         extracted.
454
455         stream -- StreamReader object
456         path -- where to extract: absolute, or relative to job tmp
457         """
458         if not re.search('^/', path):
459             path = os.path.join(current_job().tmpdir, path)
460         lockfile = open(path + '.lock', 'w')
461         fcntl.flock(lockfile, fcntl.LOCK_EX)
462         try:
463             os.stat(path)
464         except OSError:
465             os.mkdir(path)
466
467         files_got = []
468         for f in stream.all_files():
469             if (files == [] or
470                 ((f.name() not in files_got) and
471                  (f.name() in files or
472                   (decompress and f.decompressed_name() in files)))):
473                 outname = f.decompressed_name() if decompress else f.name()
474                 files_got += [outname]
475                 if os.path.exists(os.path.join(path, outname)):
476                     os.unlink(os.path.join(path, outname))
477                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
478                 outfile = open(os.path.join(path, outname), 'wb')
479                 for buf in (f.readall_decompressed() if decompress
480                             else f.readall()):
481                     outfile.write(buf)
482                 outfile.close()
483         if len(files_got) < len(files):
484             raise errors.AssertionError(
485                 "Wanted files %s but only got %s from %s" %
486                 (files, files_got, [z.name() for z in stream.all_files()]))
487         lockfile.close()
488         return path
489
490     @staticmethod
491     def listdir_recursive(dirname, base=None):
492         allfiles = []
493         for ent in sorted(os.listdir(dirname)):
494             ent_path = os.path.join(dirname, ent)
495             ent_base = os.path.join(base, ent) if base else ent
496             if os.path.isdir(ent_path):
497                 allfiles += util.listdir_recursive(ent_path, ent_base)
498             else:
499                 allfiles += [ent_base]
500         return allfiles
501
502 class StreamFileReader(object):
503     def __init__(self, stream, pos, size, name):
504         self._stream = stream
505         self._pos = pos
506         self._size = size
507         self._name = name
508         self._filepos = 0
509     def name(self):
510         return self._name
511     def decompressed_name(self):
512         return re.sub('\.(bz2|gz)$', '', self._name)
513     def size(self):
514         return self._size
515     def stream_name(self):
516         return self._stream.name()
517     def read(self, size, **kwargs):
518         self._stream.seek(self._pos + self._filepos)
519         data = self._stream.read(min(size, self._size - self._filepos))
520         self._filepos += len(data)
521         return data
522     def readall(self, size=2**20, **kwargs):
523         while True:
524             data = self.read(size, **kwargs)
525             if data == '':
526                 break
527             yield data
528     def bunzip2(self, size):
529         decompressor = bz2.BZ2Decompressor()
530         for chunk in self.readall(size):
531             data = decompressor.decompress(chunk)
532             if data and data != '':
533                 yield data
534     def gunzip(self, size):
535         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
536         for chunk in self.readall(size):
537             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
538             if data and data != '':
539                 yield data
540     def readall_decompressed(self, size=2**20):
541         self._stream.seek(self._pos + self._filepos)
542         if re.search('\.bz2$', self._name):
543             return self.bunzip2(size)
544         elif re.search('\.gz$', self._name):
545             return self.gunzip(size)
546         else:
547             return self.readall(size)
548     def readlines(self, decompress=True):
549         if decompress:
550             datasource = self.readall_decompressed()
551         else:
552             self._stream.seek(self._pos + self._filepos)
553             datasource = self.readall()
554         data = ''
555         for newdata in datasource:
556             data += newdata
557             sol = 0
558             while True:
559                 eol = string.find(data, "\n", sol)
560                 if eol < 0:
561                     break
562                 yield data[sol:eol+1]
563                 sol = eol+1
564             data = data[sol:]
565         if data != '':
566             yield data
567     def as_manifest(self):
568         if self.size() == 0:
569             return ("%s %s 0:0:%s\n"
570                     % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
571         return string.join(self._stream.tokens_for_range(self._pos, self._size),
572                            " ") + "\n"
573
574 class StreamReader(object):
575     def __init__(self, tokens):
576         self._tokens = tokens
577         self._current_datablock_data = None
578         self._current_datablock_pos = 0
579         self._current_datablock_index = -1
580         self._pos = 0
581
582         self._stream_name = None
583         self.data_locators = []
584         self.files = []
585
586         for tok in self._tokens:
587             if self._stream_name == None:
588                 self._stream_name = tok
589             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
590                 self.data_locators += [tok]
591             elif re.search(r'^\d+:\d+:\S+', tok):
592                 pos, size, name = tok.split(':',2)
593                 self.files += [[int(pos), int(size), name]]
594             else:
595                 raise errors.SyntaxError("Invalid manifest format")
596
597     def tokens(self):
598         return self._tokens
599     def tokens_for_range(self, range_start, range_size):
600         resp = [self._stream_name]
601         return_all_tokens = False
602         block_start = 0
603         token_bytes_skipped = 0
604         for locator in self.data_locators:
605             sizehint = re.search(r'\+(\d+)', locator)
606             if not sizehint:
607                 return_all_tokens = True
608             if return_all_tokens:
609                 resp += [locator]
610                 next
611             blocksize = int(sizehint.group(0))
612             if range_start + range_size <= block_start:
613                 break
614             if range_start < block_start + blocksize:
615                 resp += [locator]
616             else:
617                 token_bytes_skipped += blocksize
618             block_start += blocksize
619         for f in self.files:
620             if ((f[0] < range_start + range_size)
621                 and
622                 (f[0] + f[1] > range_start)
623                 and
624                 f[1] > 0):
625                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
626         return resp
627     def name(self):
628         return self._stream_name
629     def all_files(self):
630         for f in self.files:
631             pos, size, name = f
632             yield StreamFileReader(self, pos, size, name)
633     def nextdatablock(self):
634         if self._current_datablock_index < 0:
635             self._current_datablock_pos = 0
636             self._current_datablock_index = 0
637         else:
638             self._current_datablock_pos += self.current_datablock_size()
639             self._current_datablock_index += 1
640         self._current_datablock_data = None
641     def current_datablock_data(self):
642         if self._current_datablock_data == None:
643             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
644         return self._current_datablock_data
645     def current_datablock_size(self):
646         if self._current_datablock_index < 0:
647             self.nextdatablock()
648         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
649         if sizehint:
650             return int(sizehint.group(0))
651         return len(self.current_datablock_data())
652     def seek(self, pos):
653         """Set the position of the next read operation."""
654         self._pos = pos
655     def really_seek(self):
656         """Find and load the appropriate data block, so the byte at
657         _pos is in memory.
658         """
659         if self._pos == self._current_datablock_pos:
660             return True
661         if (self._current_datablock_pos != None and
662             self._pos >= self._current_datablock_pos and
663             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
664             return True
665         if self._pos < self._current_datablock_pos:
666             self._current_datablock_index = -1
667             self.nextdatablock()
668         while (self._pos > self._current_datablock_pos and
669                self._pos > self._current_datablock_pos + self.current_datablock_size()):
670             self.nextdatablock()
671     def read(self, size):
672         """Read no more than size bytes -- but at least one byte,
673         unless _pos is already at the end of the stream.
674         """
675         if size == 0:
676             return ''
677         self.really_seek()
678         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
679             self.nextdatablock()
680             if self._current_datablock_index >= len(self.data_locators):
681                 return None
682         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
683         self._pos += len(data)
684         return data
685
686 class CollectionReader(object):
687     def __init__(self, manifest_locator_or_text):
688         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
689             self._manifest_text = manifest_locator_or_text
690             self._manifest_locator = None
691         else:
692             self._manifest_locator = manifest_locator_or_text
693             self._manifest_text = None
694         self._streams = None
695     def __enter__(self):
696         pass
697     def __exit__(self):
698         pass
699     def _populate(self):
700         if self._streams != None:
701             return
702         if not self._manifest_text:
703             self._manifest_text = Keep.get(self._manifest_locator)
704         self._streams = []
705         for stream_line in self._manifest_text.split("\n"):
706             if stream_line != '':
707                 stream_tokens = stream_line.split()
708                 self._streams += [stream_tokens]
709     def all_streams(self):
710         self._populate()
711         resp = []
712         for s in self._streams:
713             resp += [StreamReader(s)]
714         return resp
715     def all_files(self):
716         for s in self.all_streams():
717             for f in s.all_files():
718                 yield f
719     def manifest_text(self):
720         self._populate()
721         return self._manifest_text
722
723 class CollectionWriter(object):
724     KEEP_BLOCK_SIZE = 2**26
725     def __init__(self):
726         self._data_buffer = []
727         self._data_buffer_len = 0
728         self._current_stream_files = []
729         self._current_stream_length = 0
730         self._current_stream_locators = []
731         self._current_stream_name = '.'
732         self._current_file_name = None
733         self._current_file_pos = 0
734         self._finished_streams = []
735     def __enter__(self):
736         pass
737     def __exit__(self):
738         self.finish()
739     def write_directory_tree(self,
740                              path, stream_name='.', max_manifest_depth=-1):
741         self.start_new_stream(stream_name)
742         todo = []
743         if max_manifest_depth == 0:
744             dirents = sorted(util.listdir_recursive(path))
745         else:
746             dirents = sorted(os.listdir(path))
747         for dirent in dirents:
748             target = os.path.join(path, dirent)
749             if os.path.isdir(target):
750                 todo += [[target,
751                           os.path.join(stream_name, dirent),
752                           max_manifest_depth-1]]
753             else:
754                 self.start_new_file(dirent)
755                 with open(target, 'rb') as f:
756                     while True:
757                         buf = f.read(2**26)
758                         if len(buf) == 0:
759                             break
760                         self.write(buf)
761         self.finish_current_stream()
762         map(lambda x: self.write_directory_tree(*x), todo)
763
764     def write(self, newdata):
765         if hasattr(newdata, '__iter__'):
766             for s in newdata:
767                 self.write(s)
768             return
769         self._data_buffer += [newdata]
770         self._data_buffer_len += len(newdata)
771         self._current_stream_length += len(newdata)
772         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
773             self.flush_data()
774     def flush_data(self):
775         data_buffer = ''.join(self._data_buffer)
776         if data_buffer != '':
777             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
778             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
779             self._data_buffer_len = len(self._data_buffer[0])
780     def start_new_file(self, newfilename=None):
781         self.finish_current_file()
782         self.set_current_file_name(newfilename)
783     def set_current_file_name(self, newfilename):
784         newfilename = re.sub(r' ', '\\\\040', newfilename)
785         if re.search(r'[ \t\n]', newfilename):
786             raise errors.AssertionError(
787                 "Manifest filenames cannot contain whitespace: %s" %
788                 newfilename)
789         self._current_file_name = newfilename
790     def current_file_name(self):
791         return self._current_file_name
792     def finish_current_file(self):
793         if self._current_file_name == None:
794             if self._current_file_pos == self._current_stream_length:
795                 return
796             raise errors.AssertionError(
797                 "Cannot finish an unnamed file " +
798                 "(%d bytes at offset %d in '%s' stream)" %
799                 (self._current_stream_length - self._current_file_pos,
800                  self._current_file_pos,
801                  self._current_stream_name))
802         self._current_stream_files += [[self._current_file_pos,
803                                        self._current_stream_length - self._current_file_pos,
804                                        self._current_file_name]]
805         self._current_file_pos = self._current_stream_length
806     def start_new_stream(self, newstreamname='.'):
807         self.finish_current_stream()
808         self.set_current_stream_name(newstreamname)
809     def set_current_stream_name(self, newstreamname):
810         if re.search(r'[ \t\n]', newstreamname):
811             raise errors.AssertionError(
812                 "Manifest stream names cannot contain whitespace")
813         self._current_stream_name = '.' if newstreamname=='' else newstreamname
814     def current_stream_name(self):
815         return self._current_stream_name
816     def finish_current_stream(self):
817         self.finish_current_file()
818         self.flush_data()
819         if len(self._current_stream_files) == 0:
820             pass
821         elif self._current_stream_name == None:
822             raise errors.AssertionError(
823                 "Cannot finish an unnamed stream (%d bytes in %d files)" %
824                 (self._current_stream_length, len(self._current_stream_files)))
825         else:
826             if len(self._current_stream_locators) == 0:
827                 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
828             self._finished_streams += [[self._current_stream_name,
829                                        self._current_stream_locators,
830                                        self._current_stream_files]]
831         self._current_stream_files = []
832         self._current_stream_length = 0
833         self._current_stream_locators = []
834         self._current_stream_name = None
835         self._current_file_pos = 0
836         self._current_file_name = None
837     def finish(self):
838         return Keep.put(self.manifest_text())
839     def manifest_text(self):
840         self.finish_current_stream()
841         manifest = ''
842         for stream in self._finished_streams:
843             if not re.search(r'^\.(/.*)?$', stream[0]):
844                 manifest += './'
845             manifest += stream[0]
846             for locator in stream[1]:
847                 manifest += " %s" % locator
848             for sfile in stream[2]:
849                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
850             manifest += "\n"
851         return manifest
852     def data_locators(self):
853         ret = []
854         for name, locators, files in self._finished_streams:
855             ret += locators
856         return ret
857
858 global_client_object = None
859
860 class Keep:
861     @staticmethod
862     def global_client_object():
863         global global_client_object
864         if global_client_object == None:
865             global_client_object = KeepClient()
866         return global_client_object
867
868     @staticmethod
869     def get(locator, **kwargs):
870         return Keep.global_client_object().get(locator, **kwargs)
871
872     @staticmethod
873     def put(data, **kwargs):
874         return Keep.global_client_object().put(data, **kwargs)
875
876 class KeepClient(object):
877
878     class ThreadLimiter(object):
879         """
880         Limit the number of threads running at a given time to
881         {desired successes} minus {successes reported}. When successes
882         reported == desired, wake up the remaining threads and tell
883         them to quit.
884
885         Should be used in a "with" block.
886         """
887         def __init__(self, todo):
888             self._todo = todo
889             self._done = 0
890             self._todo_lock = threading.Semaphore(todo)
891             self._done_lock = threading.Lock()
892         def __enter__(self):
893             self._todo_lock.acquire()
894             return self
895         def __exit__(self, type, value, traceback):
896             self._todo_lock.release()
897         def shall_i_proceed(self):
898             """
899             Return true if the current thread should do stuff. Return
900             false if the current thread should just stop.
901             """
902             with self._done_lock:
903                 return (self._done < self._todo)
904         def increment_done(self):
905             """
906             Report that the current thread was successful.
907             """
908             with self._done_lock:
909                 self._done += 1
910         def done(self):
911             """
912             Return how many successes were reported.
913             """
914             with self._done_lock:
915                 return self._done
916
917     class KeepWriterThread(threading.Thread):
918         """
919         Write a blob of data to the given Keep server. Call
920         increment_done() of the given ThreadLimiter if the write
921         succeeds.
922         """
923         def __init__(self, **kwargs):
924             super(KeepClient.KeepWriterThread, self).__init__()
925             self.args = kwargs
926         def run(self):
927             with self.args['thread_limiter'] as limiter:
928                 if not limiter.shall_i_proceed():
929                     # My turn arrived, but the job has been done without
930                     # me.
931                     return
932                 logging.debug("KeepWriterThread %s proceeding %s %s" %
933                               (str(threading.current_thread()),
934                                self.args['data_hash'],
935                                self.args['service_root']))
936                 h = httplib2.Http()
937                 url = self.args['service_root'] + self.args['data_hash']
938                 api_token = os.environ['ARVADOS_API_TOKEN']
939                 headers = {'Authorization': "OAuth2 %s" % api_token}
940                 try:
941                     resp, content = h.request(url.encode('utf-8'), 'PUT',
942                                               headers=headers,
943                                               body=self.args['data'])
944                     if (resp['status'] == '401' and
945                         re.match(r'Timestamp verification failed', content)):
946                         body = KeepClient.sign_for_old_server(
947                             self.args['data_hash'],
948                             self.args['data'])
949                         h = httplib2.Http()
950                         resp, content = h.request(url.encode('utf-8'), 'PUT',
951                                                   headers=headers,
952                                                   body=body)
953                     if re.match(r'^2\d\d$', resp['status']):
954                         logging.debug("KeepWriterThread %s succeeded %s %s" %
955                                       (str(threading.current_thread()),
956                                        self.args['data_hash'],
957                                        self.args['service_root']))
958                         return limiter.increment_done()
959                     logging.warning("Request fail: PUT %s => %s %s" %
960                                     (url, resp['status'], content))
961                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
962                     logging.warning("Request fail: PUT %s => %s: %s" %
963                                     (url, type(e), str(e)))
964
965     def __init__(self):
966         self.lock = threading.Lock()
967         self.service_roots = None
968
969     def shuffled_service_roots(self, hash):
970         if self.service_roots == None:
971             self.lock.acquire()
972             keep_disks = api().keep_disks().list().execute()['items']
973             roots = (("http%s://%s:%d/" %
974                       ('s' if f['service_ssl_flag'] else '',
975                        f['service_host'],
976                        f['service_port']))
977                      for f in keep_disks)
978             self.service_roots = sorted(set(roots))
979             logging.debug(str(self.service_roots))
980             self.lock.release()
981         seed = hash
982         pool = self.service_roots[:]
983         pseq = []
984         while len(pool) > 0:
985             if len(seed) < 8:
986                 if len(pseq) < len(hash) / 4: # first time around
987                     seed = hash[-4:] + hash
988                 else:
989                     seed += hash
990             probe = int(seed[0:8], 16) % len(pool)
991             pseq += [pool[probe]]
992             pool = pool[:probe] + pool[probe+1:]
993             seed = seed[8:]
994         logging.debug(str(pseq))
995         return pseq
996
997     def get(self, locator):
998         if re.search(r',', locator):
999             return ''.join(self.get(x) for x in locator.split(','))
1000         if 'KEEP_LOCAL_STORE' in os.environ:
1001             return KeepClient.local_store_get(locator)
1002         expect_hash = re.sub(r'\+.*', '', locator)
1003         for service_root in self.shuffled_service_roots(expect_hash):
1004             h = httplib2.Http()
1005             url = service_root + expect_hash
1006             api_token = os.environ['ARVADOS_API_TOKEN']
1007             headers = {'Authorization': "OAuth2 %s" % api_token,
1008                        'Accept': 'application/octet-stream'}
1009             try:
1010                 resp, content = h.request(url.encode('utf-8'), 'GET',
1011                                           headers=headers)
1012                 if re.match(r'^2\d\d$', resp['status']):
1013                     m = hashlib.new('md5')
1014                     m.update(content)
1015                     md5 = m.hexdigest()
1016                     if md5 == expect_hash:
1017                         return content
1018                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
1019             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
1020                 logging.info("Request fail: GET %s => %s: %s" %
1021                              (url, type(e), str(e)))
1022         raise errors.NotFoundError("Block not found: %s" % expect_hash)
1023
1024     def put(self, data, **kwargs):
1025         if 'KEEP_LOCAL_STORE' in os.environ:
1026             return KeepClient.local_store_put(data)
1027         m = hashlib.new('md5')
1028         m.update(data)
1029         data_hash = m.hexdigest()
1030         have_copies = 0
1031         want_copies = kwargs.get('copies', 2)
1032         if not (want_copies > 0):
1033             return data_hash
1034         threads = []
1035         thread_limiter = KeepClient.ThreadLimiter(want_copies)
1036         for service_root in self.shuffled_service_roots(data_hash):
1037             t = KeepClient.KeepWriterThread(data=data,
1038                                             data_hash=data_hash,
1039                                             service_root=service_root,
1040                                             thread_limiter=thread_limiter)
1041             t.start()
1042             threads += [t]
1043         for t in threads:
1044             t.join()
1045         have_copies = thread_limiter.done()
1046         if have_copies == want_copies:
1047             return (data_hash + '+' + str(len(data)))
1048         raise errors.KeepWriteError(
1049             "Write fail for %s: wanted %d but wrote %d" %
1050             (data_hash, want_copies, have_copies))
1051
1052     @staticmethod
1053     def sign_for_old_server(data_hash, data):
1054         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1055
1056
1057     @staticmethod
1058     def local_store_put(data):
1059         m = hashlib.new('md5')
1060         m.update(data)
1061         md5 = m.hexdigest()
1062         locator = '%s+%d' % (md5, len(data))
1063         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1064             f.write(data)
1065         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1066                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1067         return locator
1068     @staticmethod
1069     def local_store_get(locator):
1070         r = re.search('^([0-9a-f]{32,})', locator)
1071         if not r:
1072             raise errors.NotFoundError(
1073                 "Invalid data locator: '%s'" % locator)
1074         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1075             return ''
1076         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1077             return f.read()
1078
1079 # We really shouldn't do this but some clients still use
1080 # arvados.service.* directly instead of arvados.api().*
1081 service = api()