dacdba851af3a99baaa5cdda21f105cf9354d3ae
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 import apiclient
22 import apiclient.discovery
23
24 # Arvados configuration settings are taken from $HOME/.config/arvados.
25 # Environment variables override settings in the config file.
26 #
27 class ArvadosConfig(dict):
28     def __init__(self, config_file):
29         dict.__init__(self)
30         if os.path.exists(config_file):
31             with open(config_file, "r") as f:
32                 for config_line in f:
33                     var, val = config_line.rstrip().split('=', 2)
34                     self[var] = val
35         for var in os.environ:
36             if var.startswith('ARVADOS_'):
37                 self[var] = os.environ[var]
38
39
40 config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
41
42 if 'ARVADOS_DEBUG' in config:
43     logging.basicConfig(level=logging.DEBUG)
44
45 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
46
47 services = {}
48
49 class errors:
50     class SyntaxError(Exception):
51         pass
52     class AssertionError(Exception):
53         pass
54     class NotFoundError(Exception):
55         pass
56     class CommandFailedError(Exception):
57         pass
58     class KeepWriteError(Exception):
59         pass
60     class NotImplementedError(Exception):
61         pass
62
63 class CredentialsFromEnv(object):
64     @staticmethod
65     def http_request(self, uri, **kwargs):
66         global config
67         from httplib import BadStatusLine
68         if 'headers' not in kwargs:
69             kwargs['headers'] = {}
70         kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
71         try:
72             return self.orig_http_request(uri, **kwargs)
73         except BadStatusLine:
74             # This is how httplib tells us that it tried to reuse an
75             # existing connection but it was already closed by the
76             # server. In that case, yes, we would like to retry.
77             # Unfortunately, we are not absolutely certain that the
78             # previous call did not succeed, so this is slightly
79             # risky.
80             return self.orig_http_request(uri, **kwargs)
81     def authorize(self, http):
82         http.orig_http_request = http.request
83         http.request = types.MethodType(self.http_request, http)
84         return http
85
86 def task_set_output(self,s):
87     api('v1').job_tasks().update(uuid=self['uuid'],
88                                  body={
89             'output':s,
90             'success':True,
91             'progress':1.0
92             }).execute()
93
94 _current_task = None
95 def current_task():
96     global _current_task
97     if _current_task:
98         return _current_task
99     t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
100     t = UserDict.UserDict(t)
101     t.set_output = types.MethodType(task_set_output, t)
102     t.tmpdir = os.environ['TASK_WORK']
103     _current_task = t
104     return t
105
106 _current_job = None
107 def current_job():
108     global _current_job
109     if _current_job:
110         return _current_job
111     t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
112     t = UserDict.UserDict(t)
113     t.tmpdir = os.environ['JOB_WORK']
114     _current_job = t
115     return t
116
117 def getjobparam(*args):
118     return current_job()['script_parameters'].get(*args)
119
120 # Monkey patch discovery._cast() so objects and arrays get serialized
121 # with json.dumps() instead of str().
122 _cast_orig = apiclient.discovery._cast
123 def _cast_objects_too(value, schema_type):
124     global _cast_orig
125     if (type(value) != type('') and
126         (schema_type == 'object' or schema_type == 'array')):
127         return json.dumps(value)
128     else:
129         return _cast_orig(value, schema_type)
130 apiclient.discovery._cast = _cast_objects_too
131
132 def api(version=None):
133     global services, config
134     if not services.get(version):
135         apiVersion = version
136         if not version:
137             apiVersion = 'v1'
138             logging.info("Using default API version. " +
139                          "Call arvados.api('%s') instead." %
140                          apiVersion)
141         if 'ARVADOS_API_HOST' not in config:
142             raise Exception("ARVADOS_API_HOST is not set. Aborting.")
143         url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
144                config['ARVADOS_API_HOST'])
145         credentials = CredentialsFromEnv()
146
147         # Use system's CA certificates (if we find them) instead of httplib2's
148         ca_certs = '/etc/ssl/certs/ca-certificates.crt'
149         if not os.path.exists(ca_certs):
150             ca_certs = None             # use httplib2 default
151
152         http = httplib2.Http(ca_certs=ca_certs)
153         http = credentials.authorize(http)
154         if re.match(r'(?i)^(true|1|yes)$',
155                     config.get('ARVADOS_API_HOST_INSECURE', 'no')):
156             http.disable_ssl_certificate_validation=True
157         services[version] = apiclient.discovery.build(
158             'arvados', apiVersion, http=http, discoveryServiceUrl=url)
159     return services[version]
160
161 class JobTask(object):
162     def __init__(self, parameters=dict(), runtime_constraints=dict()):
163         print "init jobtask %s %s" % (parameters, runtime_constraints)
164
165 class job_setup:
166     @staticmethod
167     def one_task_per_input_file(if_sequence=0, and_end_task=True):
168         if if_sequence != current_task()['sequence']:
169             return
170         job_input = current_job()['script_parameters']['input']
171         cr = CollectionReader(job_input)
172         for s in cr.all_streams():
173             for f in s.all_files():
174                 task_input = f.as_manifest()
175                 new_task_attrs = {
176                     'job_uuid': current_job()['uuid'],
177                     'created_by_job_task_uuid': current_task()['uuid'],
178                     'sequence': if_sequence + 1,
179                     'parameters': {
180                         'input':task_input
181                         }
182                     }
183                 api('v1').job_tasks().create(body=new_task_attrs).execute()
184         if and_end_task:
185             api('v1').job_tasks().update(uuid=current_task()['uuid'],
186                                        body={'success':True}
187                                        ).execute()
188             exit(0)
189
190     @staticmethod
191     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
192         if if_sequence != current_task()['sequence']:
193             return
194         job_input = current_job()['script_parameters']['input']
195         cr = CollectionReader(job_input)
196         for s in cr.all_streams():
197             task_input = s.tokens()
198             new_task_attrs = {
199                 'job_uuid': current_job()['uuid'],
200                 'created_by_job_task_uuid': current_task()['uuid'],
201                 'sequence': if_sequence + 1,
202                 'parameters': {
203                     'input':task_input
204                     }
205                 }
206             api('v1').job_tasks().create(body=new_task_attrs).execute()
207         if and_end_task:
208             api('v1').job_tasks().update(uuid=current_task()['uuid'],
209                                        body={'success':True}
210                                        ).execute()
211             exit(0)
212
213 class util:
214     @staticmethod
215     def clear_tmpdir(path=None):
216         """
217         Ensure the given directory (or TASK_TMPDIR if none given)
218         exists and is empty.
219         """
220         if path == None:
221             path = current_task().tmpdir
222         if os.path.exists(path):
223             p = subprocess.Popen(['rm', '-rf', path])
224             stdout, stderr = p.communicate(None)
225             if p.returncode != 0:
226                 raise Exception('rm -rf %s: %s' % (path, stderr))
227         os.mkdir(path)
228
229     @staticmethod
230     def run_command(execargs, **kwargs):
231         kwargs.setdefault('stdin', subprocess.PIPE)
232         kwargs.setdefault('stdout', subprocess.PIPE)
233         kwargs.setdefault('stderr', sys.stderr)
234         kwargs.setdefault('close_fds', True)
235         kwargs.setdefault('shell', False)
236         p = subprocess.Popen(execargs, **kwargs)
237         stdoutdata, stderrdata = p.communicate(None)
238         if p.returncode != 0:
239             raise errors.CommandFailedError(
240                 "run_command %s exit %d:\n%s" %
241                 (execargs, p.returncode, stderrdata))
242         return stdoutdata, stderrdata
243
244     @staticmethod
245     def git_checkout(url, version, path):
246         if not re.search('^/', path):
247             path = os.path.join(current_job().tmpdir, path)
248         if not os.path.exists(path):
249             util.run_command(["git", "clone", url, path],
250                              cwd=os.path.dirname(path))
251         util.run_command(["git", "checkout", version],
252                          cwd=path)
253         return path
254
255     @staticmethod
256     def tar_extractor(path, decompress_flag):
257         return subprocess.Popen(["tar",
258                                  "-C", path,
259                                  ("-x%sf" % decompress_flag),
260                                  "-"],
261                                 stdout=None,
262                                 stdin=subprocess.PIPE, stderr=sys.stderr,
263                                 shell=False, close_fds=True)
264
265     @staticmethod
266     def tarball_extract(tarball, path):
267         """Retrieve a tarball from Keep and extract it to a local
268         directory.  Return the absolute path where the tarball was
269         extracted. If the top level of the tarball contained just one
270         file or directory, return the absolute path of that single
271         item.
272
273         tarball -- collection locator
274         path -- where to extract the tarball: absolute, or relative to job tmp
275         """
276         if not re.search('^/', path):
277             path = os.path.join(current_job().tmpdir, path)
278         lockfile = open(path + '.lock', 'w')
279         fcntl.flock(lockfile, fcntl.LOCK_EX)
280         try:
281             os.stat(path)
282         except OSError:
283             os.mkdir(path)
284         already_have_it = False
285         try:
286             if os.readlink(os.path.join(path, '.locator')) == tarball:
287                 already_have_it = True
288         except OSError:
289             pass
290         if not already_have_it:
291
292             # emulate "rm -f" (i.e., if the file does not exist, we win)
293             try:
294                 os.unlink(os.path.join(path, '.locator'))
295             except OSError:
296                 if os.path.exists(os.path.join(path, '.locator')):
297                     os.unlink(os.path.join(path, '.locator'))
298
299             for f in CollectionReader(tarball).all_files():
300                 if re.search('\.(tbz|tar.bz2)$', f.name()):
301                     p = util.tar_extractor(path, 'j')
302                 elif re.search('\.(tgz|tar.gz)$', f.name()):
303                     p = util.tar_extractor(path, 'z')
304                 elif re.search('\.tar$', f.name()):
305                     p = util.tar_extractor(path, '')
306                 else:
307                     raise errors.AssertionError(
308                         "tarball_extract cannot handle filename %s" % f.name())
309                 while True:
310                     buf = f.read(2**20)
311                     if len(buf) == 0:
312                         break
313                     p.stdin.write(buf)
314                 p.stdin.close()
315                 p.wait()
316                 if p.returncode != 0:
317                     lockfile.close()
318                     raise errors.CommandFailedError(
319                         "tar exited %d" % p.returncode)
320             os.symlink(tarball, os.path.join(path, '.locator'))
321         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
322         lockfile.close()
323         if len(tld_extracts) == 1:
324             return os.path.join(path, tld_extracts[0])
325         return path
326
327     @staticmethod
328     def zipball_extract(zipball, path):
329         """Retrieve a zip archive from Keep and extract it to a local
330         directory.  Return the absolute path where the archive was
331         extracted. If the top level of the archive contained just one
332         file or directory, return the absolute path of that single
333         item.
334
335         zipball -- collection locator
336         path -- where to extract the archive: absolute, or relative to job tmp
337         """
338         if not re.search('^/', path):
339             path = os.path.join(current_job().tmpdir, path)
340         lockfile = open(path + '.lock', 'w')
341         fcntl.flock(lockfile, fcntl.LOCK_EX)
342         try:
343             os.stat(path)
344         except OSError:
345             os.mkdir(path)
346         already_have_it = False
347         try:
348             if os.readlink(os.path.join(path, '.locator')) == zipball:
349                 already_have_it = True
350         except OSError:
351             pass
352         if not already_have_it:
353
354             # emulate "rm -f" (i.e., if the file does not exist, we win)
355             try:
356                 os.unlink(os.path.join(path, '.locator'))
357             except OSError:
358                 if os.path.exists(os.path.join(path, '.locator')):
359                     os.unlink(os.path.join(path, '.locator'))
360
361             for f in CollectionReader(zipball).all_files():
362                 if not re.search('\.zip$', f.name()):
363                     raise errors.NotImplementedError(
364                         "zipball_extract cannot handle filename %s" % f.name())
365                 zip_filename = os.path.join(path, os.path.basename(f.name()))
366                 zip_file = open(zip_filename, 'wb')
367                 while True:
368                     buf = f.read(2**20)
369                     if len(buf) == 0:
370                         break
371                     zip_file.write(buf)
372                 zip_file.close()
373                 
374                 p = subprocess.Popen(["unzip",
375                                       "-q", "-o",
376                                       "-d", path,
377                                       zip_filename],
378                                      stdout=None,
379                                      stdin=None, stderr=sys.stderr,
380                                      shell=False, close_fds=True)
381                 p.wait()
382                 if p.returncode != 0:
383                     lockfile.close()
384                     raise errors.CommandFailedError(
385                         "unzip exited %d" % p.returncode)
386                 os.unlink(zip_filename)
387             os.symlink(zipball, os.path.join(path, '.locator'))
388         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
389         lockfile.close()
390         if len(tld_extracts) == 1:
391             return os.path.join(path, tld_extracts[0])
392         return path
393
394     @staticmethod
395     def collection_extract(collection, path, files=[], decompress=True):
396         """Retrieve a collection from Keep and extract it to a local
397         directory.  Return the absolute path where the collection was
398         extracted.
399
400         collection -- collection locator
401         path -- where to extract: absolute, or relative to job tmp
402         """
403         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
404         if matches:
405             collection_hash = matches.group(1)
406         else:
407             collection_hash = hashlib.md5(collection).hexdigest()
408         if not re.search('^/', path):
409             path = os.path.join(current_job().tmpdir, path)
410         lockfile = open(path + '.lock', 'w')
411         fcntl.flock(lockfile, fcntl.LOCK_EX)
412         try:
413             os.stat(path)
414         except OSError:
415             os.mkdir(path)
416         already_have_it = False
417         try:
418             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
419                 already_have_it = True
420         except OSError:
421             pass
422
423         # emulate "rm -f" (i.e., if the file does not exist, we win)
424         try:
425             os.unlink(os.path.join(path, '.locator'))
426         except OSError:
427             if os.path.exists(os.path.join(path, '.locator')):
428                 os.unlink(os.path.join(path, '.locator'))
429
430         files_got = []
431         for s in CollectionReader(collection).all_streams():
432             stream_name = s.name()
433             for f in s.all_files():
434                 if (files == [] or
435                     ((f.name() not in files_got) and
436                      (f.name() in files or
437                       (decompress and f.decompressed_name() in files)))):
438                     outname = f.decompressed_name() if decompress else f.name()
439                     files_got += [outname]
440                     if os.path.exists(os.path.join(path, stream_name, outname)):
441                         continue
442                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
443                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
444                     for buf in (f.readall_decompressed() if decompress
445                                 else f.readall()):
446                         outfile.write(buf)
447                     outfile.close()
448         if len(files_got) < len(files):
449             raise errors.AssertionError(
450                 "Wanted files %s but only got %s from %s" %
451                 (files, files_got,
452                  [z.name() for z in CollectionReader(collection).all_files()]))
453         os.symlink(collection_hash, os.path.join(path, '.locator'))
454
455         lockfile.close()
456         return path
457
458     @staticmethod
459     def mkdir_dash_p(path):
460         if not os.path.exists(path):
461             util.mkdir_dash_p(os.path.dirname(path))
462             try:
463                 os.mkdir(path)
464             except OSError:
465                 if not os.path.exists(path):
466                     os.mkdir(path)
467
468     @staticmethod
469     def stream_extract(stream, path, files=[], decompress=True):
470         """Retrieve a stream from Keep and extract it to a local
471         directory.  Return the absolute path where the stream was
472         extracted.
473
474         stream -- StreamReader object
475         path -- where to extract: absolute, or relative to job tmp
476         """
477         if not re.search('^/', path):
478             path = os.path.join(current_job().tmpdir, path)
479         lockfile = open(path + '.lock', 'w')
480         fcntl.flock(lockfile, fcntl.LOCK_EX)
481         try:
482             os.stat(path)
483         except OSError:
484             os.mkdir(path)
485
486         files_got = []
487         for f in stream.all_files():
488             if (files == [] or
489                 ((f.name() not in files_got) and
490                  (f.name() in files or
491                   (decompress and f.decompressed_name() in files)))):
492                 outname = f.decompressed_name() if decompress else f.name()
493                 files_got += [outname]
494                 if os.path.exists(os.path.join(path, outname)):
495                     os.unlink(os.path.join(path, outname))
496                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
497                 outfile = open(os.path.join(path, outname), 'wb')
498                 for buf in (f.readall_decompressed() if decompress
499                             else f.readall()):
500                     outfile.write(buf)
501                 outfile.close()
502         if len(files_got) < len(files):
503             raise errors.AssertionError(
504                 "Wanted files %s but only got %s from %s" %
505                 (files, files_got, [z.name() for z in stream.all_files()]))
506         lockfile.close()
507         return path
508
509     @staticmethod
510     def listdir_recursive(dirname, base=None):
511         allfiles = []
512         for ent in sorted(os.listdir(dirname)):
513             ent_path = os.path.join(dirname, ent)
514             ent_base = os.path.join(base, ent) if base else ent
515             if os.path.isdir(ent_path):
516                 allfiles += util.listdir_recursive(ent_path, ent_base)
517             else:
518                 allfiles += [ent_base]
519         return allfiles
520
521 class StreamFileReader(object):
522     def __init__(self, stream, pos, size, name):
523         self._stream = stream
524         self._pos = pos
525         self._size = size
526         self._name = name
527         self._filepos = 0
528     def name(self):
529         return self._name
530     def decompressed_name(self):
531         return re.sub('\.(bz2|gz)$', '', self._name)
532     def size(self):
533         return self._size
534     def stream_name(self):
535         return self._stream.name()
536     def read(self, size, **kwargs):
537         self._stream.seek(self._pos + self._filepos)
538         data = self._stream.read(min(size, self._size - self._filepos))
539         self._filepos += len(data)
540         return data
541     def readall(self, size=2**20, **kwargs):
542         while True:
543             data = self.read(size, **kwargs)
544             if data == '':
545                 break
546             yield data
547     def bunzip2(self, size):
548         decompressor = bz2.BZ2Decompressor()
549         for chunk in self.readall(size):
550             data = decompressor.decompress(chunk)
551             if data and data != '':
552                 yield data
553     def gunzip(self, size):
554         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
555         for chunk in self.readall(size):
556             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
557             if data and data != '':
558                 yield data
559     def readall_decompressed(self, size=2**20):
560         self._stream.seek(self._pos + self._filepos)
561         if re.search('\.bz2$', self._name):
562             return self.bunzip2(size)
563         elif re.search('\.gz$', self._name):
564             return self.gunzip(size)
565         else:
566             return self.readall(size)
567     def readlines(self, decompress=True):
568         if decompress:
569             datasource = self.readall_decompressed()
570         else:
571             self._stream.seek(self._pos + self._filepos)
572             datasource = self.readall()
573         data = ''
574         for newdata in datasource:
575             data += newdata
576             sol = 0
577             while True:
578                 eol = string.find(data, "\n", sol)
579                 if eol < 0:
580                     break
581                 yield data[sol:eol+1]
582                 sol = eol+1
583             data = data[sol:]
584         if data != '':
585             yield data
586     def as_manifest(self):
587         if self.size() == 0:
588             return ("%s %s 0:0:%s\n"
589                     % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
590         return string.join(self._stream.tokens_for_range(self._pos, self._size),
591                            " ") + "\n"
592
593 class StreamReader(object):
594     def __init__(self, tokens):
595         self._tokens = tokens
596         self._current_datablock_data = None
597         self._current_datablock_pos = 0
598         self._current_datablock_index = -1
599         self._pos = 0
600
601         self._stream_name = None
602         self.data_locators = []
603         self.files = []
604
605         for tok in self._tokens:
606             if self._stream_name == None:
607                 self._stream_name = tok.replace('\\040', ' ')
608             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
609                 self.data_locators += [tok]
610             elif re.search(r'^\d+:\d+:\S+', tok):
611                 pos, size, name = tok.split(':',2)
612                 self.files += [[int(pos), int(size), name.replace('\\040', ' ')]]
613             else:
614                 raise errors.SyntaxError("Invalid manifest format")
615
616     def tokens(self):
617         return self._tokens
618     def tokens_for_range(self, range_start, range_size):
619         resp = [self._stream_name]
620         return_all_tokens = False
621         block_start = 0
622         token_bytes_skipped = 0
623         for locator in self.data_locators:
624             sizehint = re.search(r'\+(\d+)', locator)
625             if not sizehint:
626                 return_all_tokens = True
627             if return_all_tokens:
628                 resp += [locator]
629                 next
630             blocksize = int(sizehint.group(0))
631             if range_start + range_size <= block_start:
632                 break
633             if range_start < block_start + blocksize:
634                 resp += [locator]
635             else:
636                 token_bytes_skipped += blocksize
637             block_start += blocksize
638         for f in self.files:
639             if ((f[0] < range_start + range_size)
640                 and
641                 (f[0] + f[1] > range_start)
642                 and
643                 f[1] > 0):
644                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
645         return resp
646     def name(self):
647         return self._stream_name
648     def all_files(self):
649         for f in self.files:
650             pos, size, name = f
651             yield StreamFileReader(self, pos, size, name)
652     def nextdatablock(self):
653         if self._current_datablock_index < 0:
654             self._current_datablock_pos = 0
655             self._current_datablock_index = 0
656         else:
657             self._current_datablock_pos += self.current_datablock_size()
658             self._current_datablock_index += 1
659         self._current_datablock_data = None
660     def current_datablock_data(self):
661         if self._current_datablock_data == None:
662             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
663         return self._current_datablock_data
664     def current_datablock_size(self):
665         if self._current_datablock_index < 0:
666             self.nextdatablock()
667         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
668         if sizehint:
669             return int(sizehint.group(0))
670         return len(self.current_datablock_data())
671     def seek(self, pos):
672         """Set the position of the next read operation."""
673         self._pos = pos
674     def really_seek(self):
675         """Find and load the appropriate data block, so the byte at
676         _pos is in memory.
677         """
678         if self._pos == self._current_datablock_pos:
679             return True
680         if (self._current_datablock_pos != None and
681             self._pos >= self._current_datablock_pos and
682             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
683             return True
684         if self._pos < self._current_datablock_pos:
685             self._current_datablock_index = -1
686             self.nextdatablock()
687         while (self._pos > self._current_datablock_pos and
688                self._pos > self._current_datablock_pos + self.current_datablock_size()):
689             self.nextdatablock()
690     def read(self, size):
691         """Read no more than size bytes -- but at least one byte,
692         unless _pos is already at the end of the stream.
693         """
694         if size == 0:
695             return ''
696         self.really_seek()
697         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
698             self.nextdatablock()
699             if self._current_datablock_index >= len(self.data_locators):
700                 return None
701         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
702         self._pos += len(data)
703         return data
704
705 class CollectionReader(object):
706     def __init__(self, manifest_locator_or_text):
707         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
708             self._manifest_text = manifest_locator_or_text
709             self._manifest_locator = None
710         else:
711             self._manifest_locator = manifest_locator_or_text
712             self._manifest_text = None
713         self._streams = None
714     def __enter__(self):
715         pass
716     def __exit__(self):
717         pass
718     def _populate(self):
719         if self._streams != None:
720             return
721         if not self._manifest_text:
722             self._manifest_text = Keep.get(self._manifest_locator)
723         self._streams = []
724         for stream_line in self._manifest_text.split("\n"):
725             if stream_line != '':
726                 stream_tokens = stream_line.split()
727                 self._streams += [stream_tokens]
728     def all_streams(self):
729         self._populate()
730         resp = []
731         for s in self._streams:
732             resp += [StreamReader(s)]
733         return resp
734     def all_files(self):
735         for s in self.all_streams():
736             for f in s.all_files():
737                 yield f
738     def manifest_text(self):
739         self._populate()
740         return self._manifest_text
741
742 class CollectionWriter(object):
743     KEEP_BLOCK_SIZE = 2**26
744     def __init__(self):
745         self._data_buffer = []
746         self._data_buffer_len = 0
747         self._current_stream_files = []
748         self._current_stream_length = 0
749         self._current_stream_locators = []
750         self._current_stream_name = '.'
751         self._current_file_name = None
752         self._current_file_pos = 0
753         self._finished_streams = []
754     def __enter__(self):
755         pass
756     def __exit__(self):
757         self.finish()
758     def write_directory_tree(self,
759                              path, stream_name='.', max_manifest_depth=-1):
760         self.start_new_stream(stream_name)
761         todo = []
762         if max_manifest_depth == 0:
763             dirents = sorted(util.listdir_recursive(path))
764         else:
765             dirents = sorted(os.listdir(path))
766         for dirent in dirents:
767             target = os.path.join(path, dirent)
768             if os.path.isdir(target):
769                 todo += [[target,
770                           os.path.join(stream_name, dirent),
771                           max_manifest_depth-1]]
772             else:
773                 self.start_new_file(dirent)
774                 with open(target, 'rb') as f:
775                     while True:
776                         buf = f.read(2**26)
777                         if len(buf) == 0:
778                             break
779                         self.write(buf)
780         self.finish_current_stream()
781         map(lambda x: self.write_directory_tree(*x), todo)
782
783     def write(self, newdata):
784         if hasattr(newdata, '__iter__'):
785             for s in newdata:
786                 self.write(s)
787             return
788         self._data_buffer += [newdata]
789         self._data_buffer_len += len(newdata)
790         self._current_stream_length += len(newdata)
791         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
792             self.flush_data()
793     def flush_data(self):
794         data_buffer = ''.join(self._data_buffer)
795         if data_buffer != '':
796             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
797             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
798             self._data_buffer_len = len(self._data_buffer[0])
799     def start_new_file(self, newfilename=None):
800         self.finish_current_file()
801         self.set_current_file_name(newfilename)
802     def set_current_file_name(self, newfilename):
803         if re.search(r'[\t\n]', newfilename):
804             raise errors.AssertionError(
805                 "Manifest filenames cannot contain whitespace: %s" %
806                 newfilename)
807         self._current_file_name = newfilename
808     def current_file_name(self):
809         return self._current_file_name
810     def finish_current_file(self):
811         if self._current_file_name == None:
812             if self._current_file_pos == self._current_stream_length:
813                 return
814             raise errors.AssertionError(
815                 "Cannot finish an unnamed file " +
816                 "(%d bytes at offset %d in '%s' stream)" %
817                 (self._current_stream_length - self._current_file_pos,
818                  self._current_file_pos,
819                  self._current_stream_name))
820         self._current_stream_files += [[self._current_file_pos,
821                                        self._current_stream_length - self._current_file_pos,
822                                        self._current_file_name]]
823         self._current_file_pos = self._current_stream_length
824     def start_new_stream(self, newstreamname='.'):
825         self.finish_current_stream()
826         self.set_current_stream_name(newstreamname)
827     def set_current_stream_name(self, newstreamname):
828         if re.search(r'[\t\n]', newstreamname):
829             raise errors.AssertionError(
830                 "Manifest stream names cannot contain whitespace")
831         self._current_stream_name = '.' if newstreamname=='' else newstreamname
832     def current_stream_name(self):
833         return self._current_stream_name
834     def finish_current_stream(self):
835         self.finish_current_file()
836         self.flush_data()
837         if len(self._current_stream_files) == 0:
838             pass
839         elif self._current_stream_name == None:
840             raise errors.AssertionError(
841                 "Cannot finish an unnamed stream (%d bytes in %d files)" %
842                 (self._current_stream_length, len(self._current_stream_files)))
843         else:
844             if len(self._current_stream_locators) == 0:
845                 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
846             self._finished_streams += [[self._current_stream_name,
847                                        self._current_stream_locators,
848                                        self._current_stream_files]]
849         self._current_stream_files = []
850         self._current_stream_length = 0
851         self._current_stream_locators = []
852         self._current_stream_name = None
853         self._current_file_pos = 0
854         self._current_file_name = None
855     def finish(self):
856         return Keep.put(self.manifest_text())
857     def manifest_text(self):
858         self.finish_current_stream()
859         manifest = ''
860         for stream in self._finished_streams:
861             if not re.search(r'^\.(/.*)?$', stream[0]):
862                 manifest += './'
863             manifest += stream[0].replace(' ', '\\040')
864             for locator in stream[1]:
865                 manifest += " %s" % locator
866             for sfile in stream[2]:
867                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2].replace(' ', '\\040'))
868             manifest += "\n"
869         return manifest
870     def data_locators(self):
871         ret = []
872         for name, locators, files in self._finished_streams:
873             ret += locators
874         return ret
875
876 global_client_object = None
877
878 class Keep:
879     @staticmethod
880     def global_client_object():
881         global global_client_object
882         if global_client_object == None:
883             global_client_object = KeepClient()
884         return global_client_object
885
886     @staticmethod
887     def get(locator, **kwargs):
888         return Keep.global_client_object().get(locator, **kwargs)
889
890     @staticmethod
891     def put(data, **kwargs):
892         return Keep.global_client_object().put(data, **kwargs)
893
894 class KeepClient(object):
895
896     class ThreadLimiter(object):
897         """
898         Limit the number of threads running at a given time to
899         {desired successes} minus {successes reported}. When successes
900         reported == desired, wake up the remaining threads and tell
901         them to quit.
902
903         Should be used in a "with" block.
904         """
905         def __init__(self, todo):
906             self._todo = todo
907             self._done = 0
908             self._todo_lock = threading.Semaphore(todo)
909             self._done_lock = threading.Lock()
910         def __enter__(self):
911             self._todo_lock.acquire()
912             return self
913         def __exit__(self, type, value, traceback):
914             self._todo_lock.release()
915         def shall_i_proceed(self):
916             """
917             Return true if the current thread should do stuff. Return
918             false if the current thread should just stop.
919             """
920             with self._done_lock:
921                 return (self._done < self._todo)
922         def increment_done(self):
923             """
924             Report that the current thread was successful.
925             """
926             with self._done_lock:
927                 self._done += 1
928         def done(self):
929             """
930             Return how many successes were reported.
931             """
932             with self._done_lock:
933                 return self._done
934
935     class KeepWriterThread(threading.Thread):
936         """
937         Write a blob of data to the given Keep server. Call
938         increment_done() of the given ThreadLimiter if the write
939         succeeds.
940         """
941         def __init__(self, **kwargs):
942             super(KeepClient.KeepWriterThread, self).__init__()
943             self.args = kwargs
944         def run(self):
945             global config
946             with self.args['thread_limiter'] as limiter:
947                 if not limiter.shall_i_proceed():
948                     # My turn arrived, but the job has been done without
949                     # me.
950                     return
951                 logging.debug("KeepWriterThread %s proceeding %s %s" %
952                               (str(threading.current_thread()),
953                                self.args['data_hash'],
954                                self.args['service_root']))
955                 h = httplib2.Http()
956                 url = self.args['service_root'] + self.args['data_hash']
957                 api_token = config['ARVADOS_API_TOKEN']
958                 headers = {'Authorization': "OAuth2 %s" % api_token}
959                 try:
960                     resp, content = h.request(url.encode('utf-8'), 'PUT',
961                                               headers=headers,
962                                               body=self.args['data'])
963                     if (resp['status'] == '401' and
964                         re.match(r'Timestamp verification failed', content)):
965                         body = KeepClient.sign_for_old_server(
966                             self.args['data_hash'],
967                             self.args['data'])
968                         h = httplib2.Http()
969                         resp, content = h.request(url.encode('utf-8'), 'PUT',
970                                                   headers=headers,
971                                                   body=body)
972                     if re.match(r'^2\d\d$', resp['status']):
973                         logging.debug("KeepWriterThread %s succeeded %s %s" %
974                                       (str(threading.current_thread()),
975                                        self.args['data_hash'],
976                                        self.args['service_root']))
977                         return limiter.increment_done()
978                     logging.warning("Request fail: PUT %s => %s %s" %
979                                     (url, resp['status'], content))
980                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
981                     logging.warning("Request fail: PUT %s => %s: %s" %
982                                     (url, type(e), str(e)))
983
984     def __init__(self):
985         self.lock = threading.Lock()
986         self.service_roots = None
987
988     def shuffled_service_roots(self, hash):
989         if self.service_roots == None:
990             self.lock.acquire()
991             keep_disks = api().keep_disks().list().execute()['items']
992             roots = (("http%s://%s:%d/" %
993                       ('s' if f['service_ssl_flag'] else '',
994                        f['service_host'],
995                        f['service_port']))
996                      for f in keep_disks)
997             self.service_roots = sorted(set(roots))
998             logging.debug(str(self.service_roots))
999             self.lock.release()
1000         seed = hash
1001         pool = self.service_roots[:]
1002         pseq = []
1003         while len(pool) > 0:
1004             if len(seed) < 8:
1005                 if len(pseq) < len(hash) / 4: # first time around
1006                     seed = hash[-4:] + hash
1007                 else:
1008                     seed += hash
1009             probe = int(seed[0:8], 16) % len(pool)
1010             pseq += [pool[probe]]
1011             pool = pool[:probe] + pool[probe+1:]
1012             seed = seed[8:]
1013         logging.debug(str(pseq))
1014         return pseq
1015
1016     def get(self, locator):
1017         global config
1018         if re.search(r',', locator):
1019             return ''.join(self.get(x) for x in locator.split(','))
1020         if 'KEEP_LOCAL_STORE' in os.environ:
1021             return KeepClient.local_store_get(locator)
1022         expect_hash = re.sub(r'\+.*', '', locator)
1023         for service_root in self.shuffled_service_roots(expect_hash):
1024             h = httplib2.Http()
1025             url = service_root + expect_hash
1026             api_token = config['ARVADOS_API_TOKEN']
1027             headers = {'Authorization': "OAuth2 %s" % api_token,
1028                        'Accept': 'application/octet-stream'}
1029             try:
1030                 resp, content = h.request(url.encode('utf-8'), 'GET',
1031                                           headers=headers)
1032                 if re.match(r'^2\d\d$', resp['status']):
1033                     m = hashlib.new('md5')
1034                     m.update(content)
1035                     md5 = m.hexdigest()
1036                     if md5 == expect_hash:
1037                         return content
1038                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
1039             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
1040                 logging.info("Request fail: GET %s => %s: %s" %
1041                              (url, type(e), str(e)))
1042         raise errors.NotFoundError("Block not found: %s" % expect_hash)
1043
1044     def put(self, data, **kwargs):
1045         if 'KEEP_LOCAL_STORE' in os.environ:
1046             return KeepClient.local_store_put(data)
1047         m = hashlib.new('md5')
1048         m.update(data)
1049         data_hash = m.hexdigest()
1050         have_copies = 0
1051         want_copies = kwargs.get('copies', 2)
1052         if not (want_copies > 0):
1053             return data_hash
1054         threads = []
1055         thread_limiter = KeepClient.ThreadLimiter(want_copies)
1056         for service_root in self.shuffled_service_roots(data_hash):
1057             t = KeepClient.KeepWriterThread(data=data,
1058                                             data_hash=data_hash,
1059                                             service_root=service_root,
1060                                             thread_limiter=thread_limiter)
1061             t.start()
1062             threads += [t]
1063         for t in threads:
1064             t.join()
1065         have_copies = thread_limiter.done()
1066         if have_copies == want_copies:
1067             return (data_hash + '+' + str(len(data)))
1068         raise errors.KeepWriteError(
1069             "Write fail for %s: wanted %d but wrote %d" %
1070             (data_hash, want_copies, have_copies))
1071
1072     @staticmethod
1073     def sign_for_old_server(data_hash, data):
1074         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1075
1076
1077     @staticmethod
1078     def local_store_put(data):
1079         m = hashlib.new('md5')
1080         m.update(data)
1081         md5 = m.hexdigest()
1082         locator = '%s+%d' % (md5, len(data))
1083         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1084             f.write(data)
1085         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1086                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1087         return locator
1088     @staticmethod
1089     def local_store_get(locator):
1090         r = re.search('^([0-9a-f]{32,})', locator)
1091         if not r:
1092             raise errors.NotFoundError(
1093                 "Invalid data locator: '%s'" % locator)
1094         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1095             return ''
1096         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1097             return f.read()
1098
1099 # We really shouldn't do this but some clients still use
1100 # arvados.service.* directly instead of arvados.api().*
1101 service = api()