Merge branch '1922-cache-discovery-python'
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 import apiclient
22 import apiclient.discovery
23
24 config = None
25 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
26 services = {}
27
28 from stream import *
29 from collection import *
30 from keep import *
31
32
33 # Arvados configuration settings are taken from $HOME/.config/arvados.
34 # Environment variables override settings in the config file.
35 #
36 class ArvadosConfig(dict):
37     def __init__(self, config_file):
38         dict.__init__(self)
39         if os.path.exists(config_file):
40             with open(config_file, "r") as f:
41                 for config_line in f:
42                     var, val = config_line.rstrip().split('=', 2)
43                     self[var] = val
44         for var in os.environ:
45             if var.startswith('ARVADOS_'):
46                 self[var] = os.environ[var]
47
48 class errors:
49     class SyntaxError(Exception):
50         pass
51     class AssertionError(Exception):
52         pass
53     class NotFoundError(Exception):
54         pass
55     class CommandFailedError(Exception):
56         pass
57     class KeepWriteError(Exception):
58         pass
59     class NotImplementedError(Exception):
60         pass
61
62 class CredentialsFromEnv(object):
63     @staticmethod
64     def http_request(self, uri, **kwargs):
65         global config
66         from httplib import BadStatusLine
67         if 'headers' not in kwargs:
68             kwargs['headers'] = {}
69         kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
70         try:
71             return self.orig_http_request(uri, **kwargs)
72         except BadStatusLine:
73             # This is how httplib tells us that it tried to reuse an
74             # existing connection but it was already closed by the
75             # server. In that case, yes, we would like to retry.
76             # Unfortunately, we are not absolutely certain that the
77             # previous call did not succeed, so this is slightly
78             # risky.
79             return self.orig_http_request(uri, **kwargs)
80     def authorize(self, http):
81         http.orig_http_request = http.request
82         http.request = types.MethodType(self.http_request, http)
83         return http
84
85 def task_set_output(self,s):
86     api('v1').job_tasks().update(uuid=self['uuid'],
87                                  body={
88             'output':s,
89             'success':True,
90             'progress':1.0
91             }).execute()
92
93 _current_task = None
94 def current_task():
95     global _current_task
96     if _current_task:
97         return _current_task
98     t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
99     t = UserDict.UserDict(t)
100     t.set_output = types.MethodType(task_set_output, t)
101     t.tmpdir = os.environ['TASK_WORK']
102     _current_task = t
103     return t
104
105 _current_job = None
106 def current_job():
107     global _current_job
108     if _current_job:
109         return _current_job
110     t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
111     t = UserDict.UserDict(t)
112     t.tmpdir = os.environ['JOB_WORK']
113     _current_job = t
114     return t
115
116 def getjobparam(*args):
117     return current_job()['script_parameters'].get(*args)
118
119 # Monkey patch discovery._cast() so objects and arrays get serialized
120 # with json.dumps() instead of str().
121 _cast_orig = apiclient.discovery._cast
122 def _cast_objects_too(value, schema_type):
123     global _cast_orig
124     if (type(value) != type('') and
125         (schema_type == 'object' or schema_type == 'array')):
126         return json.dumps(value)
127     else:
128         return _cast_orig(value, schema_type)
129 apiclient.discovery._cast = _cast_objects_too
130
131 def http_cache(data_type):
132     path = os.environ['HOME'] + '/.cache/arvados/' + data_type
133     try:
134         util.mkdir_dash_p(path)
135     except OSError:
136         path = None
137     return path
138
139 def api(version=None):
140     global services, config
141
142     if not config:
143         config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
144         if 'ARVADOS_DEBUG' in config:
145             logging.basicConfig(level=logging.DEBUG)
146
147     if not services.get(version):
148         apiVersion = version
149         if not version:
150             apiVersion = 'v1'
151             logging.info("Using default API version. " +
152                          "Call arvados.api('%s') instead." %
153                          apiVersion)
154         if 'ARVADOS_API_HOST' not in config:
155             raise Exception("ARVADOS_API_HOST is not set. Aborting.")
156         url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
157                config['ARVADOS_API_HOST'])
158         credentials = CredentialsFromEnv()
159
160         # Use system's CA certificates (if we find them) instead of httplib2's
161         ca_certs = '/etc/ssl/certs/ca-certificates.crt'
162         if not os.path.exists(ca_certs):
163             ca_certs = None             # use httplib2 default
164
165         http = httplib2.Http(ca_certs=ca_certs,
166                              cache=http_cache('discovery'))
167         http = credentials.authorize(http)
168         if re.match(r'(?i)^(true|1|yes)$',
169                     config.get('ARVADOS_API_HOST_INSECURE', 'no')):
170             http.disable_ssl_certificate_validation=True
171         services[version] = apiclient.discovery.build(
172             'arvados', apiVersion, http=http, discoveryServiceUrl=url)
173     return services[version]
174
175 class JobTask(object):
176     def __init__(self, parameters=dict(), runtime_constraints=dict()):
177         print "init jobtask %s %s" % (parameters, runtime_constraints)
178
179 class job_setup:
180     @staticmethod
181     def one_task_per_input_file(if_sequence=0, and_end_task=True):
182         if if_sequence != current_task()['sequence']:
183             return
184         job_input = current_job()['script_parameters']['input']
185         cr = CollectionReader(job_input)
186         for s in cr.all_streams():
187             for f in s.all_files():
188                 task_input = f.as_manifest()
189                 new_task_attrs = {
190                     'job_uuid': current_job()['uuid'],
191                     'created_by_job_task_uuid': current_task()['uuid'],
192                     'sequence': if_sequence + 1,
193                     'parameters': {
194                         'input':task_input
195                         }
196                     }
197                 api('v1').job_tasks().create(body=new_task_attrs).execute()
198         if and_end_task:
199             api('v1').job_tasks().update(uuid=current_task()['uuid'],
200                                        body={'success':True}
201                                        ).execute()
202             exit(0)
203
204     @staticmethod
205     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
206         if if_sequence != current_task()['sequence']:
207             return
208         job_input = current_job()['script_parameters']['input']
209         cr = CollectionReader(job_input)
210         for s in cr.all_streams():
211             task_input = s.tokens()
212             new_task_attrs = {
213                 'job_uuid': current_job()['uuid'],
214                 'created_by_job_task_uuid': current_task()['uuid'],
215                 'sequence': if_sequence + 1,
216                 'parameters': {
217                     'input':task_input
218                     }
219                 }
220             api('v1').job_tasks().create(body=new_task_attrs).execute()
221         if and_end_task:
222             api('v1').job_tasks().update(uuid=current_task()['uuid'],
223                                        body={'success':True}
224                                        ).execute()
225             exit(0)
226
227 class util:
228     @staticmethod
229     def clear_tmpdir(path=None):
230         """
231         Ensure the given directory (or TASK_TMPDIR if none given)
232         exists and is empty.
233         """
234         if path == None:
235             path = current_task().tmpdir
236         if os.path.exists(path):
237             p = subprocess.Popen(['rm', '-rf', path])
238             stdout, stderr = p.communicate(None)
239             if p.returncode != 0:
240                 raise Exception('rm -rf %s: %s' % (path, stderr))
241         os.mkdir(path)
242
243     @staticmethod
244     def run_command(execargs, **kwargs):
245         kwargs.setdefault('stdin', subprocess.PIPE)
246         kwargs.setdefault('stdout', subprocess.PIPE)
247         kwargs.setdefault('stderr', sys.stderr)
248         kwargs.setdefault('close_fds', True)
249         kwargs.setdefault('shell', False)
250         p = subprocess.Popen(execargs, **kwargs)
251         stdoutdata, stderrdata = p.communicate(None)
252         if p.returncode != 0:
253             raise errors.CommandFailedError(
254                 "run_command %s exit %d:\n%s" %
255                 (execargs, p.returncode, stderrdata))
256         return stdoutdata, stderrdata
257
258     @staticmethod
259     def git_checkout(url, version, path):
260         if not re.search('^/', path):
261             path = os.path.join(current_job().tmpdir, path)
262         if not os.path.exists(path):
263             util.run_command(["git", "clone", url, path],
264                              cwd=os.path.dirname(path))
265         util.run_command(["git", "checkout", version],
266                          cwd=path)
267         return path
268
269     @staticmethod
270     def tar_extractor(path, decompress_flag):
271         return subprocess.Popen(["tar",
272                                  "-C", path,
273                                  ("-x%sf" % decompress_flag),
274                                  "-"],
275                                 stdout=None,
276                                 stdin=subprocess.PIPE, stderr=sys.stderr,
277                                 shell=False, close_fds=True)
278
279     @staticmethod
280     def tarball_extract(tarball, path):
281         """Retrieve a tarball from Keep and extract it to a local
282         directory.  Return the absolute path where the tarball was
283         extracted. If the top level of the tarball contained just one
284         file or directory, return the absolute path of that single
285         item.
286
287         tarball -- collection locator
288         path -- where to extract the tarball: absolute, or relative to job tmp
289         """
290         if not re.search('^/', path):
291             path = os.path.join(current_job().tmpdir, path)
292         lockfile = open(path + '.lock', 'w')
293         fcntl.flock(lockfile, fcntl.LOCK_EX)
294         try:
295             os.stat(path)
296         except OSError:
297             os.mkdir(path)
298         already_have_it = False
299         try:
300             if os.readlink(os.path.join(path, '.locator')) == tarball:
301                 already_have_it = True
302         except OSError:
303             pass
304         if not already_have_it:
305
306             # emulate "rm -f" (i.e., if the file does not exist, we win)
307             try:
308                 os.unlink(os.path.join(path, '.locator'))
309             except OSError:
310                 if os.path.exists(os.path.join(path, '.locator')):
311                     os.unlink(os.path.join(path, '.locator'))
312
313             for f in CollectionReader(tarball).all_files():
314                 if re.search('\.(tbz|tar.bz2)$', f.name()):
315                     p = util.tar_extractor(path, 'j')
316                 elif re.search('\.(tgz|tar.gz)$', f.name()):
317                     p = util.tar_extractor(path, 'z')
318                 elif re.search('\.tar$', f.name()):
319                     p = util.tar_extractor(path, '')
320                 else:
321                     raise errors.AssertionError(
322                         "tarball_extract cannot handle filename %s" % f.name())
323                 while True:
324                     buf = f.read(2**20)
325                     if len(buf) == 0:
326                         break
327                     p.stdin.write(buf)
328                 p.stdin.close()
329                 p.wait()
330                 if p.returncode != 0:
331                     lockfile.close()
332                     raise errors.CommandFailedError(
333                         "tar exited %d" % p.returncode)
334             os.symlink(tarball, os.path.join(path, '.locator'))
335         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
336         lockfile.close()
337         if len(tld_extracts) == 1:
338             return os.path.join(path, tld_extracts[0])
339         return path
340
341     @staticmethod
342     def zipball_extract(zipball, path):
343         """Retrieve a zip archive from Keep and extract it to a local
344         directory.  Return the absolute path where the archive was
345         extracted. If the top level of the archive contained just one
346         file or directory, return the absolute path of that single
347         item.
348
349         zipball -- collection locator
350         path -- where to extract the archive: absolute, or relative to job tmp
351         """
352         if not re.search('^/', path):
353             path = os.path.join(current_job().tmpdir, path)
354         lockfile = open(path + '.lock', 'w')
355         fcntl.flock(lockfile, fcntl.LOCK_EX)
356         try:
357             os.stat(path)
358         except OSError:
359             os.mkdir(path)
360         already_have_it = False
361         try:
362             if os.readlink(os.path.join(path, '.locator')) == zipball:
363                 already_have_it = True
364         except OSError:
365             pass
366         if not already_have_it:
367
368             # emulate "rm -f" (i.e., if the file does not exist, we win)
369             try:
370                 os.unlink(os.path.join(path, '.locator'))
371             except OSError:
372                 if os.path.exists(os.path.join(path, '.locator')):
373                     os.unlink(os.path.join(path, '.locator'))
374
375             for f in CollectionReader(zipball).all_files():
376                 if not re.search('\.zip$', f.name()):
377                     raise errors.NotImplementedError(
378                         "zipball_extract cannot handle filename %s" % f.name())
379                 zip_filename = os.path.join(path, os.path.basename(f.name()))
380                 zip_file = open(zip_filename, 'wb')
381                 while True:
382                     buf = f.read(2**20)
383                     if len(buf) == 0:
384                         break
385                     zip_file.write(buf)
386                 zip_file.close()
387                 
388                 p = subprocess.Popen(["unzip",
389                                       "-q", "-o",
390                                       "-d", path,
391                                       zip_filename],
392                                      stdout=None,
393                                      stdin=None, stderr=sys.stderr,
394                                      shell=False, close_fds=True)
395                 p.wait()
396                 if p.returncode != 0:
397                     lockfile.close()
398                     raise errors.CommandFailedError(
399                         "unzip exited %d" % p.returncode)
400                 os.unlink(zip_filename)
401             os.symlink(zipball, os.path.join(path, '.locator'))
402         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
403         lockfile.close()
404         if len(tld_extracts) == 1:
405             return os.path.join(path, tld_extracts[0])
406         return path
407
408     @staticmethod
409     def collection_extract(collection, path, files=[], decompress=True):
410         """Retrieve a collection from Keep and extract it to a local
411         directory.  Return the absolute path where the collection was
412         extracted.
413
414         collection -- collection locator
415         path -- where to extract: absolute, or relative to job tmp
416         """
417         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
418         if matches:
419             collection_hash = matches.group(1)
420         else:
421             collection_hash = hashlib.md5(collection).hexdigest()
422         if not re.search('^/', path):
423             path = os.path.join(current_job().tmpdir, path)
424         lockfile = open(path + '.lock', 'w')
425         fcntl.flock(lockfile, fcntl.LOCK_EX)
426         try:
427             os.stat(path)
428         except OSError:
429             os.mkdir(path)
430         already_have_it = False
431         try:
432             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
433                 already_have_it = True
434         except OSError:
435             pass
436
437         # emulate "rm -f" (i.e., if the file does not exist, we win)
438         try:
439             os.unlink(os.path.join(path, '.locator'))
440         except OSError:
441             if os.path.exists(os.path.join(path, '.locator')):
442                 os.unlink(os.path.join(path, '.locator'))
443
444         files_got = []
445         for s in CollectionReader(collection).all_streams():
446             stream_name = s.name()
447             for f in s.all_files():
448                 if (files == [] or
449                     ((f.name() not in files_got) and
450                      (f.name() in files or
451                       (decompress and f.decompressed_name() in files)))):
452                     outname = f.decompressed_name() if decompress else f.name()
453                     files_got += [outname]
454                     if os.path.exists(os.path.join(path, stream_name, outname)):
455                         continue
456                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
457                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
458                     for buf in (f.readall_decompressed() if decompress
459                                 else f.readall()):
460                         outfile.write(buf)
461                     outfile.close()
462         if len(files_got) < len(files):
463             raise errors.AssertionError(
464                 "Wanted files %s but only got %s from %s" %
465                 (files, files_got,
466                  [z.name() for z in CollectionReader(collection).all_files()]))
467         os.symlink(collection_hash, os.path.join(path, '.locator'))
468
469         lockfile.close()
470         return path
471
472     @staticmethod
473     def mkdir_dash_p(path):
474         if not os.path.exists(path):
475             util.mkdir_dash_p(os.path.dirname(path))
476             try:
477                 os.mkdir(path)
478             except OSError:
479                 if not os.path.exists(path):
480                     os.mkdir(path)
481
482     @staticmethod
483     def stream_extract(stream, path, files=[], decompress=True):
484         """Retrieve a stream from Keep and extract it to a local
485         directory.  Return the absolute path where the stream was
486         extracted.
487
488         stream -- StreamReader object
489         path -- where to extract: absolute, or relative to job tmp
490         """
491         if not re.search('^/', path):
492             path = os.path.join(current_job().tmpdir, path)
493         lockfile = open(path + '.lock', 'w')
494         fcntl.flock(lockfile, fcntl.LOCK_EX)
495         try:
496             os.stat(path)
497         except OSError:
498             os.mkdir(path)
499
500         files_got = []
501         for f in stream.all_files():
502             if (files == [] or
503                 ((f.name() not in files_got) and
504                  (f.name() in files or
505                   (decompress and f.decompressed_name() in files)))):
506                 outname = f.decompressed_name() if decompress else f.name()
507                 files_got += [outname]
508                 if os.path.exists(os.path.join(path, outname)):
509                     os.unlink(os.path.join(path, outname))
510                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
511                 outfile = open(os.path.join(path, outname), 'wb')
512                 for buf in (f.readall_decompressed() if decompress
513                             else f.readall()):
514                     outfile.write(buf)
515                 outfile.close()
516         if len(files_got) < len(files):
517             raise errors.AssertionError(
518                 "Wanted files %s but only got %s from %s" %
519                 (files, files_got, [z.name() for z in stream.all_files()]))
520         lockfile.close()
521         return path
522
523     @staticmethod
524     def listdir_recursive(dirname, base=None):
525         allfiles = []
526         for ent in sorted(os.listdir(dirname)):
527             ent_path = os.path.join(dirname, ent)
528             ent_base = os.path.join(base, ent) if base else ent
529             if os.path.isdir(ent_path):
530                 allfiles += util.listdir_recursive(ent_path, ent_base)
531             else:
532                 allfiles += [ent_base]
533         return allfiles
534