Merge branch '1911-python-sdk-pydoc'
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 import apiclient
22 import apiclient.discovery
23
24 config = None
25 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
26 services = {}
27
28 from stream import *
29 from collection import *
30 from keep import *
31
32
33 # Arvados configuration settings are taken from $HOME/.config/arvados.
34 # Environment variables override settings in the config file.
35 #
36 class ArvadosConfig(dict):
37     def __init__(self, config_file):
38         dict.__init__(self)
39         if os.path.exists(config_file):
40             with open(config_file, "r") as f:
41                 for config_line in f:
42                     var, val = config_line.rstrip().split('=', 2)
43                     self[var] = val
44         for var in os.environ:
45             if var.startswith('ARVADOS_'):
46                 self[var] = os.environ[var]
47
48 class errors:
49     class SyntaxError(Exception):
50         pass
51     class AssertionError(Exception):
52         pass
53     class NotFoundError(Exception):
54         pass
55     class CommandFailedError(Exception):
56         pass
57     class KeepWriteError(Exception):
58         pass
59     class NotImplementedError(Exception):
60         pass
61
62 class CredentialsFromEnv(object):
63     @staticmethod
64     def http_request(self, uri, **kwargs):
65         global config
66         from httplib import BadStatusLine
67         if 'headers' not in kwargs:
68             kwargs['headers'] = {}
69         kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
70         try:
71             return self.orig_http_request(uri, **kwargs)
72         except BadStatusLine:
73             # This is how httplib tells us that it tried to reuse an
74             # existing connection but it was already closed by the
75             # server. In that case, yes, we would like to retry.
76             # Unfortunately, we are not absolutely certain that the
77             # previous call did not succeed, so this is slightly
78             # risky.
79             return self.orig_http_request(uri, **kwargs)
80     def authorize(self, http):
81         http.orig_http_request = http.request
82         http.request = types.MethodType(self.http_request, http)
83         return http
84
85 def task_set_output(self,s):
86     api('v1').job_tasks().update(uuid=self['uuid'],
87                                  body={
88             'output':s,
89             'success':True,
90             'progress':1.0
91             }).execute()
92
93 _current_task = None
94 def current_task():
95     global _current_task
96     if _current_task:
97         return _current_task
98     t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
99     t = UserDict.UserDict(t)
100     t.set_output = types.MethodType(task_set_output, t)
101     t.tmpdir = os.environ['TASK_WORK']
102     _current_task = t
103     return t
104
105 _current_job = None
106 def current_job():
107     global _current_job
108     if _current_job:
109         return _current_job
110     t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
111     t = UserDict.UserDict(t)
112     t.tmpdir = os.environ['JOB_WORK']
113     _current_job = t
114     return t
115
116 def getjobparam(*args):
117     return current_job()['script_parameters'].get(*args)
118
119 # Monkey patch discovery._cast() so objects and arrays get serialized
120 # with json.dumps() instead of str().
121 _cast_orig = apiclient.discovery._cast
122 def _cast_objects_too(value, schema_type):
123     global _cast_orig
124     if (type(value) != type('') and
125         (schema_type == 'object' or schema_type == 'array')):
126         return json.dumps(value)
127     else:
128         return _cast_orig(value, schema_type)
129 apiclient.discovery._cast = _cast_objects_too
130
131 def api(version=None):
132     global services, config
133
134     if not config:
135         config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
136         if 'ARVADOS_DEBUG' in config:
137             logging.basicConfig(level=logging.DEBUG)
138
139     if not services.get(version):
140         apiVersion = version
141         if not version:
142             apiVersion = 'v1'
143             logging.info("Using default API version. " +
144                          "Call arvados.api('%s') instead." %
145                          apiVersion)
146         if 'ARVADOS_API_HOST' not in config:
147             raise Exception("ARVADOS_API_HOST is not set. Aborting.")
148         url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
149                config['ARVADOS_API_HOST'])
150         credentials = CredentialsFromEnv()
151
152         # Use system's CA certificates (if we find them) instead of httplib2's
153         ca_certs = '/etc/ssl/certs/ca-certificates.crt'
154         if not os.path.exists(ca_certs):
155             ca_certs = None             # use httplib2 default
156
157         http = httplib2.Http(ca_certs=ca_certs)
158         http = credentials.authorize(http)
159         if re.match(r'(?i)^(true|1|yes)$',
160                     config.get('ARVADOS_API_HOST_INSECURE', 'no')):
161             http.disable_ssl_certificate_validation=True
162         services[version] = apiclient.discovery.build(
163             'arvados', apiVersion, http=http, discoveryServiceUrl=url)
164     return services[version]
165
166 class JobTask(object):
167     def __init__(self, parameters=dict(), runtime_constraints=dict()):
168         print "init jobtask %s %s" % (parameters, runtime_constraints)
169
170 class job_setup:
171     @staticmethod
172     def one_task_per_input_file(if_sequence=0, and_end_task=True):
173         if if_sequence != current_task()['sequence']:
174             return
175         job_input = current_job()['script_parameters']['input']
176         cr = CollectionReader(job_input)
177         for s in cr.all_streams():
178             for f in s.all_files():
179                 task_input = f.as_manifest()
180                 new_task_attrs = {
181                     'job_uuid': current_job()['uuid'],
182                     'created_by_job_task_uuid': current_task()['uuid'],
183                     'sequence': if_sequence + 1,
184                     'parameters': {
185                         'input':task_input
186                         }
187                     }
188                 api('v1').job_tasks().create(body=new_task_attrs).execute()
189         if and_end_task:
190             api('v1').job_tasks().update(uuid=current_task()['uuid'],
191                                        body={'success':True}
192                                        ).execute()
193             exit(0)
194
195     @staticmethod
196     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
197         if if_sequence != current_task()['sequence']:
198             return
199         job_input = current_job()['script_parameters']['input']
200         cr = CollectionReader(job_input)
201         for s in cr.all_streams():
202             task_input = s.tokens()
203             new_task_attrs = {
204                 'job_uuid': current_job()['uuid'],
205                 'created_by_job_task_uuid': current_task()['uuid'],
206                 'sequence': if_sequence + 1,
207                 'parameters': {
208                     'input':task_input
209                     }
210                 }
211             api('v1').job_tasks().create(body=new_task_attrs).execute()
212         if and_end_task:
213             api('v1').job_tasks().update(uuid=current_task()['uuid'],
214                                        body={'success':True}
215                                        ).execute()
216             exit(0)
217
218 class util:
219     @staticmethod
220     def clear_tmpdir(path=None):
221         """
222         Ensure the given directory (or TASK_TMPDIR if none given)
223         exists and is empty.
224         """
225         if path == None:
226             path = current_task().tmpdir
227         if os.path.exists(path):
228             p = subprocess.Popen(['rm', '-rf', path])
229             stdout, stderr = p.communicate(None)
230             if p.returncode != 0:
231                 raise Exception('rm -rf %s: %s' % (path, stderr))
232         os.mkdir(path)
233
234     @staticmethod
235     def run_command(execargs, **kwargs):
236         kwargs.setdefault('stdin', subprocess.PIPE)
237         kwargs.setdefault('stdout', subprocess.PIPE)
238         kwargs.setdefault('stderr', sys.stderr)
239         kwargs.setdefault('close_fds', True)
240         kwargs.setdefault('shell', False)
241         p = subprocess.Popen(execargs, **kwargs)
242         stdoutdata, stderrdata = p.communicate(None)
243         if p.returncode != 0:
244             raise errors.CommandFailedError(
245                 "run_command %s exit %d:\n%s" %
246                 (execargs, p.returncode, stderrdata))
247         return stdoutdata, stderrdata
248
249     @staticmethod
250     def git_checkout(url, version, path):
251         if not re.search('^/', path):
252             path = os.path.join(current_job().tmpdir, path)
253         if not os.path.exists(path):
254             util.run_command(["git", "clone", url, path],
255                              cwd=os.path.dirname(path))
256         util.run_command(["git", "checkout", version],
257                          cwd=path)
258         return path
259
260     @staticmethod
261     def tar_extractor(path, decompress_flag):
262         return subprocess.Popen(["tar",
263                                  "-C", path,
264                                  ("-x%sf" % decompress_flag),
265                                  "-"],
266                                 stdout=None,
267                                 stdin=subprocess.PIPE, stderr=sys.stderr,
268                                 shell=False, close_fds=True)
269
270     @staticmethod
271     def tarball_extract(tarball, path):
272         """Retrieve a tarball from Keep and extract it to a local
273         directory.  Return the absolute path where the tarball was
274         extracted. If the top level of the tarball contained just one
275         file or directory, return the absolute path of that single
276         item.
277
278         tarball -- collection locator
279         path -- where to extract the tarball: absolute, or relative to job tmp
280         """
281         if not re.search('^/', path):
282             path = os.path.join(current_job().tmpdir, path)
283         lockfile = open(path + '.lock', 'w')
284         fcntl.flock(lockfile, fcntl.LOCK_EX)
285         try:
286             os.stat(path)
287         except OSError:
288             os.mkdir(path)
289         already_have_it = False
290         try:
291             if os.readlink(os.path.join(path, '.locator')) == tarball:
292                 already_have_it = True
293         except OSError:
294             pass
295         if not already_have_it:
296
297             # emulate "rm -f" (i.e., if the file does not exist, we win)
298             try:
299                 os.unlink(os.path.join(path, '.locator'))
300             except OSError:
301                 if os.path.exists(os.path.join(path, '.locator')):
302                     os.unlink(os.path.join(path, '.locator'))
303
304             for f in CollectionReader(tarball).all_files():
305                 if re.search('\.(tbz|tar.bz2)$', f.name()):
306                     p = util.tar_extractor(path, 'j')
307                 elif re.search('\.(tgz|tar.gz)$', f.name()):
308                     p = util.tar_extractor(path, 'z')
309                 elif re.search('\.tar$', f.name()):
310                     p = util.tar_extractor(path, '')
311                 else:
312                     raise errors.AssertionError(
313                         "tarball_extract cannot handle filename %s" % f.name())
314                 while True:
315                     buf = f.read(2**20)
316                     if len(buf) == 0:
317                         break
318                     p.stdin.write(buf)
319                 p.stdin.close()
320                 p.wait()
321                 if p.returncode != 0:
322                     lockfile.close()
323                     raise errors.CommandFailedError(
324                         "tar exited %d" % p.returncode)
325             os.symlink(tarball, os.path.join(path, '.locator'))
326         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
327         lockfile.close()
328         if len(tld_extracts) == 1:
329             return os.path.join(path, tld_extracts[0])
330         return path
331
332     @staticmethod
333     def zipball_extract(zipball, path):
334         """Retrieve a zip archive from Keep and extract it to a local
335         directory.  Return the absolute path where the archive was
336         extracted. If the top level of the archive contained just one
337         file or directory, return the absolute path of that single
338         item.
339
340         zipball -- collection locator
341         path -- where to extract the archive: absolute, or relative to job tmp
342         """
343         if not re.search('^/', path):
344             path = os.path.join(current_job().tmpdir, path)
345         lockfile = open(path + '.lock', 'w')
346         fcntl.flock(lockfile, fcntl.LOCK_EX)
347         try:
348             os.stat(path)
349         except OSError:
350             os.mkdir(path)
351         already_have_it = False
352         try:
353             if os.readlink(os.path.join(path, '.locator')) == zipball:
354                 already_have_it = True
355         except OSError:
356             pass
357         if not already_have_it:
358
359             # emulate "rm -f" (i.e., if the file does not exist, we win)
360             try:
361                 os.unlink(os.path.join(path, '.locator'))
362             except OSError:
363                 if os.path.exists(os.path.join(path, '.locator')):
364                     os.unlink(os.path.join(path, '.locator'))
365
366             for f in CollectionReader(zipball).all_files():
367                 if not re.search('\.zip$', f.name()):
368                     raise errors.NotImplementedError(
369                         "zipball_extract cannot handle filename %s" % f.name())
370                 zip_filename = os.path.join(path, os.path.basename(f.name()))
371                 zip_file = open(zip_filename, 'wb')
372                 while True:
373                     buf = f.read(2**20)
374                     if len(buf) == 0:
375                         break
376                     zip_file.write(buf)
377                 zip_file.close()
378                 
379                 p = subprocess.Popen(["unzip",
380                                       "-q", "-o",
381                                       "-d", path,
382                                       zip_filename],
383                                      stdout=None,
384                                      stdin=None, stderr=sys.stderr,
385                                      shell=False, close_fds=True)
386                 p.wait()
387                 if p.returncode != 0:
388                     lockfile.close()
389                     raise errors.CommandFailedError(
390                         "unzip exited %d" % p.returncode)
391                 os.unlink(zip_filename)
392             os.symlink(zipball, os.path.join(path, '.locator'))
393         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
394         lockfile.close()
395         if len(tld_extracts) == 1:
396             return os.path.join(path, tld_extracts[0])
397         return path
398
399     @staticmethod
400     def collection_extract(collection, path, files=[], decompress=True):
401         """Retrieve a collection from Keep and extract it to a local
402         directory.  Return the absolute path where the collection was
403         extracted.
404
405         collection -- collection locator
406         path -- where to extract: absolute, or relative to job tmp
407         """
408         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
409         if matches:
410             collection_hash = matches.group(1)
411         else:
412             collection_hash = hashlib.md5(collection).hexdigest()
413         if not re.search('^/', path):
414             path = os.path.join(current_job().tmpdir, path)
415         lockfile = open(path + '.lock', 'w')
416         fcntl.flock(lockfile, fcntl.LOCK_EX)
417         try:
418             os.stat(path)
419         except OSError:
420             os.mkdir(path)
421         already_have_it = False
422         try:
423             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
424                 already_have_it = True
425         except OSError:
426             pass
427
428         # emulate "rm -f" (i.e., if the file does not exist, we win)
429         try:
430             os.unlink(os.path.join(path, '.locator'))
431         except OSError:
432             if os.path.exists(os.path.join(path, '.locator')):
433                 os.unlink(os.path.join(path, '.locator'))
434
435         files_got = []
436         for s in CollectionReader(collection).all_streams():
437             stream_name = s.name()
438             for f in s.all_files():
439                 if (files == [] or
440                     ((f.name() not in files_got) and
441                      (f.name() in files or
442                       (decompress and f.decompressed_name() in files)))):
443                     outname = f.decompressed_name() if decompress else f.name()
444                     files_got += [outname]
445                     if os.path.exists(os.path.join(path, stream_name, outname)):
446                         continue
447                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
448                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
449                     for buf in (f.readall_decompressed() if decompress
450                                 else f.readall()):
451                         outfile.write(buf)
452                     outfile.close()
453         if len(files_got) < len(files):
454             raise errors.AssertionError(
455                 "Wanted files %s but only got %s from %s" %
456                 (files, files_got,
457                  [z.name() for z in CollectionReader(collection).all_files()]))
458         os.symlink(collection_hash, os.path.join(path, '.locator'))
459
460         lockfile.close()
461         return path
462
463     @staticmethod
464     def mkdir_dash_p(path):
465         if not os.path.exists(path):
466             util.mkdir_dash_p(os.path.dirname(path))
467             try:
468                 os.mkdir(path)
469             except OSError:
470                 if not os.path.exists(path):
471                     os.mkdir(path)
472
473     @staticmethod
474     def stream_extract(stream, path, files=[], decompress=True):
475         """Retrieve a stream from Keep and extract it to a local
476         directory.  Return the absolute path where the stream was
477         extracted.
478
479         stream -- StreamReader object
480         path -- where to extract: absolute, or relative to job tmp
481         """
482         if not re.search('^/', path):
483             path = os.path.join(current_job().tmpdir, path)
484         lockfile = open(path + '.lock', 'w')
485         fcntl.flock(lockfile, fcntl.LOCK_EX)
486         try:
487             os.stat(path)
488         except OSError:
489             os.mkdir(path)
490
491         files_got = []
492         for f in stream.all_files():
493             if (files == [] or
494                 ((f.name() not in files_got) and
495                  (f.name() in files or
496                   (decompress and f.decompressed_name() in files)))):
497                 outname = f.decompressed_name() if decompress else f.name()
498                 files_got += [outname]
499                 if os.path.exists(os.path.join(path, outname)):
500                     os.unlink(os.path.join(path, outname))
501                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
502                 outfile = open(os.path.join(path, outname), 'wb')
503                 for buf in (f.readall_decompressed() if decompress
504                             else f.readall()):
505                     outfile.write(buf)
506                 outfile.close()
507         if len(files_got) < len(files):
508             raise errors.AssertionError(
509                 "Wanted files %s but only got %s from %s" %
510                 (files, files_got, [z.name() for z in stream.all_files()]))
511         lockfile.close()
512         return path
513
514     @staticmethod
515     def listdir_recursive(dirname, base=None):
516         allfiles = []
517         for ent in sorted(os.listdir(dirname)):
518             ent_path = os.path.join(dirname, ent)
519             ent_base = os.path.join(base, ent) if base else ent
520             if os.path.isdir(ent_path):
521                 allfiles += util.listdir_recursive(ent_path, ent_base)
522             else:
523                 allfiles += [ent_base]
524         return allfiles
525