Renamed Keep/Stream/Collection submodules to keep/stream/collection (lower case)
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 import apiclient
22 import apiclient.discovery
23
24 from stream import *
25 from collection import *
26 from keep import *
27
28 config = None
29 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
30 services = {}
31
32 # Arvados configuration settings are taken from $HOME/.config/arvados.
33 # Environment variables override settings in the config file.
34 #
35 class ArvadosConfig(dict):
36     def __init__(self, config_file):
37         dict.__init__(self)
38         if os.path.exists(config_file):
39             with open(config_file, "r") as f:
40                 for config_line in f:
41                     var, val = config_line.rstrip().split('=', 2)
42                     self[var] = val
43         for var in os.environ:
44             if var.startswith('ARVADOS_'):
45                 self[var] = os.environ[var]
46
47 class errors:
48     class SyntaxError(Exception):
49         pass
50     class AssertionError(Exception):
51         pass
52     class NotFoundError(Exception):
53         pass
54     class CommandFailedError(Exception):
55         pass
56     class KeepWriteError(Exception):
57         pass
58     class NotImplementedError(Exception):
59         pass
60
61 class CredentialsFromEnv(object):
62     @staticmethod
63     def http_request(self, uri, **kwargs):
64         global config
65         from httplib import BadStatusLine
66         if 'headers' not in kwargs:
67             kwargs['headers'] = {}
68         kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
69         try:
70             return self.orig_http_request(uri, **kwargs)
71         except BadStatusLine:
72             # This is how httplib tells us that it tried to reuse an
73             # existing connection but it was already closed by the
74             # server. In that case, yes, we would like to retry.
75             # Unfortunately, we are not absolutely certain that the
76             # previous call did not succeed, so this is slightly
77             # risky.
78             return self.orig_http_request(uri, **kwargs)
79     def authorize(self, http):
80         http.orig_http_request = http.request
81         http.request = types.MethodType(self.http_request, http)
82         return http
83
84 def task_set_output(self,s):
85     api('v1').job_tasks().update(uuid=self['uuid'],
86                                  body={
87             'output':s,
88             'success':True,
89             'progress':1.0
90             }).execute()
91
92 _current_task = None
93 def current_task():
94     global _current_task
95     if _current_task:
96         return _current_task
97     t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
98     t = UserDict.UserDict(t)
99     t.set_output = types.MethodType(task_set_output, t)
100     t.tmpdir = os.environ['TASK_WORK']
101     _current_task = t
102     return t
103
104 _current_job = None
105 def current_job():
106     global _current_job
107     if _current_job:
108         return _current_job
109     t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
110     t = UserDict.UserDict(t)
111     t.tmpdir = os.environ['JOB_WORK']
112     _current_job = t
113     return t
114
115 def getjobparam(*args):
116     return current_job()['script_parameters'].get(*args)
117
118 # Monkey patch discovery._cast() so objects and arrays get serialized
119 # with json.dumps() instead of str().
120 _cast_orig = apiclient.discovery._cast
121 def _cast_objects_too(value, schema_type):
122     global _cast_orig
123     if (type(value) != type('') and
124         (schema_type == 'object' or schema_type == 'array')):
125         return json.dumps(value)
126     else:
127         return _cast_orig(value, schema_type)
128 apiclient.discovery._cast = _cast_objects_too
129
130 def api(version=None):
131     global services, config
132
133     if not config:
134         config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
135         if 'ARVADOS_DEBUG' in config:
136             logging.basicConfig(level=logging.DEBUG)
137
138     if not services.get(version):
139         apiVersion = version
140         if not version:
141             apiVersion = 'v1'
142             logging.info("Using default API version. " +
143                          "Call arvados.api('%s') instead." %
144                          apiVersion)
145         if 'ARVADOS_API_HOST' not in config:
146             raise Exception("ARVADOS_API_HOST is not set. Aborting.")
147         url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
148                config['ARVADOS_API_HOST'])
149         credentials = CredentialsFromEnv()
150
151         # Use system's CA certificates (if we find them) instead of httplib2's
152         ca_certs = '/etc/ssl/certs/ca-certificates.crt'
153         if not os.path.exists(ca_certs):
154             ca_certs = None             # use httplib2 default
155
156         http = httplib2.Http(ca_certs=ca_certs)
157         http = credentials.authorize(http)
158         if re.match(r'(?i)^(true|1|yes)$',
159                     config.get('ARVADOS_API_HOST_INSECURE', 'no')):
160             http.disable_ssl_certificate_validation=True
161         services[version] = apiclient.discovery.build(
162             'arvados', apiVersion, http=http, discoveryServiceUrl=url)
163     return services[version]
164
165 class JobTask(object):
166     def __init__(self, parameters=dict(), runtime_constraints=dict()):
167         print "init jobtask %s %s" % (parameters, runtime_constraints)
168
169 class job_setup:
170     @staticmethod
171     def one_task_per_input_file(if_sequence=0, and_end_task=True):
172         if if_sequence != current_task()['sequence']:
173             return
174         job_input = current_job()['script_parameters']['input']
175         cr = CollectionReader(job_input)
176         for s in cr.all_streams():
177             for f in s.all_files():
178                 task_input = f.as_manifest()
179                 new_task_attrs = {
180                     'job_uuid': current_job()['uuid'],
181                     'created_by_job_task_uuid': current_task()['uuid'],
182                     'sequence': if_sequence + 1,
183                     'parameters': {
184                         'input':task_input
185                         }
186                     }
187                 api('v1').job_tasks().create(body=new_task_attrs).execute()
188         if and_end_task:
189             api('v1').job_tasks().update(uuid=current_task()['uuid'],
190                                        body={'success':True}
191                                        ).execute()
192             exit(0)
193
194     @staticmethod
195     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
196         if if_sequence != current_task()['sequence']:
197             return
198         job_input = current_job()['script_parameters']['input']
199         cr = CollectionReader(job_input)
200         for s in cr.all_streams():
201             task_input = s.tokens()
202             new_task_attrs = {
203                 'job_uuid': current_job()['uuid'],
204                 'created_by_job_task_uuid': current_task()['uuid'],
205                 'sequence': if_sequence + 1,
206                 'parameters': {
207                     'input':task_input
208                     }
209                 }
210             api('v1').job_tasks().create(body=new_task_attrs).execute()
211         if and_end_task:
212             api('v1').job_tasks().update(uuid=current_task()['uuid'],
213                                        body={'success':True}
214                                        ).execute()
215             exit(0)
216
217 class util:
218     @staticmethod
219     def clear_tmpdir(path=None):
220         """
221         Ensure the given directory (or TASK_TMPDIR if none given)
222         exists and is empty.
223         """
224         if path == None:
225             path = current_task().tmpdir
226         if os.path.exists(path):
227             p = subprocess.Popen(['rm', '-rf', path])
228             stdout, stderr = p.communicate(None)
229             if p.returncode != 0:
230                 raise Exception('rm -rf %s: %s' % (path, stderr))
231         os.mkdir(path)
232
233     @staticmethod
234     def run_command(execargs, **kwargs):
235         kwargs.setdefault('stdin', subprocess.PIPE)
236         kwargs.setdefault('stdout', subprocess.PIPE)
237         kwargs.setdefault('stderr', sys.stderr)
238         kwargs.setdefault('close_fds', True)
239         kwargs.setdefault('shell', False)
240         p = subprocess.Popen(execargs, **kwargs)
241         stdoutdata, stderrdata = p.communicate(None)
242         if p.returncode != 0:
243             raise errors.CommandFailedError(
244                 "run_command %s exit %d:\n%s" %
245                 (execargs, p.returncode, stderrdata))
246         return stdoutdata, stderrdata
247
248     @staticmethod
249     def git_checkout(url, version, path):
250         if not re.search('^/', path):
251             path = os.path.join(current_job().tmpdir, path)
252         if not os.path.exists(path):
253             util.run_command(["git", "clone", url, path],
254                              cwd=os.path.dirname(path))
255         util.run_command(["git", "checkout", version],
256                          cwd=path)
257         return path
258
259     @staticmethod
260     def tar_extractor(path, decompress_flag):
261         return subprocess.Popen(["tar",
262                                  "-C", path,
263                                  ("-x%sf" % decompress_flag),
264                                  "-"],
265                                 stdout=None,
266                                 stdin=subprocess.PIPE, stderr=sys.stderr,
267                                 shell=False, close_fds=True)
268
269     @staticmethod
270     def tarball_extract(tarball, path):
271         """Retrieve a tarball from Keep and extract it to a local
272         directory.  Return the absolute path where the tarball was
273         extracted. If the top level of the tarball contained just one
274         file or directory, return the absolute path of that single
275         item.
276
277         tarball -- collection locator
278         path -- where to extract the tarball: absolute, or relative to job tmp
279         """
280         if not re.search('^/', path):
281             path = os.path.join(current_job().tmpdir, path)
282         lockfile = open(path + '.lock', 'w')
283         fcntl.flock(lockfile, fcntl.LOCK_EX)
284         try:
285             os.stat(path)
286         except OSError:
287             os.mkdir(path)
288         already_have_it = False
289         try:
290             if os.readlink(os.path.join(path, '.locator')) == tarball:
291                 already_have_it = True
292         except OSError:
293             pass
294         if not already_have_it:
295
296             # emulate "rm -f" (i.e., if the file does not exist, we win)
297             try:
298                 os.unlink(os.path.join(path, '.locator'))
299             except OSError:
300                 if os.path.exists(os.path.join(path, '.locator')):
301                     os.unlink(os.path.join(path, '.locator'))
302
303             for f in CollectionReader(tarball).all_files():
304                 if re.search('\.(tbz|tar.bz2)$', f.name()):
305                     p = util.tar_extractor(path, 'j')
306                 elif re.search('\.(tgz|tar.gz)$', f.name()):
307                     p = util.tar_extractor(path, 'z')
308                 elif re.search('\.tar$', f.name()):
309                     p = util.tar_extractor(path, '')
310                 else:
311                     raise errors.AssertionError(
312                         "tarball_extract cannot handle filename %s" % f.name())
313                 while True:
314                     buf = f.read(2**20)
315                     if len(buf) == 0:
316                         break
317                     p.stdin.write(buf)
318                 p.stdin.close()
319                 p.wait()
320                 if p.returncode != 0:
321                     lockfile.close()
322                     raise errors.CommandFailedError(
323                         "tar exited %d" % p.returncode)
324             os.symlink(tarball, os.path.join(path, '.locator'))
325         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
326         lockfile.close()
327         if len(tld_extracts) == 1:
328             return os.path.join(path, tld_extracts[0])
329         return path
330
331     @staticmethod
332     def zipball_extract(zipball, path):
333         """Retrieve a zip archive from Keep and extract it to a local
334         directory.  Return the absolute path where the archive was
335         extracted. If the top level of the archive contained just one
336         file or directory, return the absolute path of that single
337         item.
338
339         zipball -- collection locator
340         path -- where to extract the archive: absolute, or relative to job tmp
341         """
342         if not re.search('^/', path):
343             path = os.path.join(current_job().tmpdir, path)
344         lockfile = open(path + '.lock', 'w')
345         fcntl.flock(lockfile, fcntl.LOCK_EX)
346         try:
347             os.stat(path)
348         except OSError:
349             os.mkdir(path)
350         already_have_it = False
351         try:
352             if os.readlink(os.path.join(path, '.locator')) == zipball:
353                 already_have_it = True
354         except OSError:
355             pass
356         if not already_have_it:
357
358             # emulate "rm -f" (i.e., if the file does not exist, we win)
359             try:
360                 os.unlink(os.path.join(path, '.locator'))
361             except OSError:
362                 if os.path.exists(os.path.join(path, '.locator')):
363                     os.unlink(os.path.join(path, '.locator'))
364
365             for f in CollectionReader(zipball).all_files():
366                 if not re.search('\.zip$', f.name()):
367                     raise errors.NotImplementedError(
368                         "zipball_extract cannot handle filename %s" % f.name())
369                 zip_filename = os.path.join(path, os.path.basename(f.name()))
370                 zip_file = open(zip_filename, 'wb')
371                 while True:
372                     buf = f.read(2**20)
373                     if len(buf) == 0:
374                         break
375                     zip_file.write(buf)
376                 zip_file.close()
377                 
378                 p = subprocess.Popen(["unzip",
379                                       "-q", "-o",
380                                       "-d", path,
381                                       zip_filename],
382                                      stdout=None,
383                                      stdin=None, stderr=sys.stderr,
384                                      shell=False, close_fds=True)
385                 p.wait()
386                 if p.returncode != 0:
387                     lockfile.close()
388                     raise errors.CommandFailedError(
389                         "unzip exited %d" % p.returncode)
390                 os.unlink(zip_filename)
391             os.symlink(zipball, os.path.join(path, '.locator'))
392         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
393         lockfile.close()
394         if len(tld_extracts) == 1:
395             return os.path.join(path, tld_extracts[0])
396         return path
397
398     @staticmethod
399     def collection_extract(collection, path, files=[], decompress=True):
400         """Retrieve a collection from Keep and extract it to a local
401         directory.  Return the absolute path where the collection was
402         extracted.
403
404         collection -- collection locator
405         path -- where to extract: absolute, or relative to job tmp
406         """
407         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
408         if matches:
409             collection_hash = matches.group(1)
410         else:
411             collection_hash = hashlib.md5(collection).hexdigest()
412         if not re.search('^/', path):
413             path = os.path.join(current_job().tmpdir, path)
414         lockfile = open(path + '.lock', 'w')
415         fcntl.flock(lockfile, fcntl.LOCK_EX)
416         try:
417             os.stat(path)
418         except OSError:
419             os.mkdir(path)
420         already_have_it = False
421         try:
422             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
423                 already_have_it = True
424         except OSError:
425             pass
426
427         # emulate "rm -f" (i.e., if the file does not exist, we win)
428         try:
429             os.unlink(os.path.join(path, '.locator'))
430         except OSError:
431             if os.path.exists(os.path.join(path, '.locator')):
432                 os.unlink(os.path.join(path, '.locator'))
433
434         files_got = []
435         for s in CollectionReader(collection).all_streams():
436             stream_name = s.name()
437             for f in s.all_files():
438                 if (files == [] or
439                     ((f.name() not in files_got) and
440                      (f.name() in files or
441                       (decompress and f.decompressed_name() in files)))):
442                     outname = f.decompressed_name() if decompress else f.name()
443                     files_got += [outname]
444                     if os.path.exists(os.path.join(path, stream_name, outname)):
445                         continue
446                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
447                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
448                     for buf in (f.readall_decompressed() if decompress
449                                 else f.readall()):
450                         outfile.write(buf)
451                     outfile.close()
452         if len(files_got) < len(files):
453             raise errors.AssertionError(
454                 "Wanted files %s but only got %s from %s" %
455                 (files, files_got,
456                  [z.name() for z in CollectionReader(collection).all_files()]))
457         os.symlink(collection_hash, os.path.join(path, '.locator'))
458
459         lockfile.close()
460         return path
461
462     @staticmethod
463     def mkdir_dash_p(path):
464         if not os.path.exists(path):
465             util.mkdir_dash_p(os.path.dirname(path))
466             try:
467                 os.mkdir(path)
468             except OSError:
469                 if not os.path.exists(path):
470                     os.mkdir(path)
471
472     @staticmethod
473     def stream_extract(stream, path, files=[], decompress=True):
474         """Retrieve a stream from Keep and extract it to a local
475         directory.  Return the absolute path where the stream was
476         extracted.
477
478         stream -- StreamReader object
479         path -- where to extract: absolute, or relative to job tmp
480         """
481         if not re.search('^/', path):
482             path = os.path.join(current_job().tmpdir, path)
483         lockfile = open(path + '.lock', 'w')
484         fcntl.flock(lockfile, fcntl.LOCK_EX)
485         try:
486             os.stat(path)
487         except OSError:
488             os.mkdir(path)
489
490         files_got = []
491         for f in stream.all_files():
492             if (files == [] or
493                 ((f.name() not in files_got) and
494                  (f.name() in files or
495                   (decompress and f.decompressed_name() in files)))):
496                 outname = f.decompressed_name() if decompress else f.name()
497                 files_got += [outname]
498                 if os.path.exists(os.path.join(path, outname)):
499                     os.unlink(os.path.join(path, outname))
500                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
501                 outfile = open(os.path.join(path, outname), 'wb')
502                 for buf in (f.readall_decompressed() if decompress
503                             else f.readall()):
504                     outfile.write(buf)
505                 outfile.close()
506         if len(files_got) < len(files):
507             raise errors.AssertionError(
508                 "Wanted files %s but only got %s from %s" %
509                 (files, files_got, [z.name() for z in stream.all_files()]))
510         lockfile.close()
511         return path
512
513     @staticmethod
514     def listdir_recursive(dirname, base=None):
515         allfiles = []
516         for ent in sorted(os.listdir(dirname)):
517             ent_path = os.path.join(dirname, ent)
518             ent_base = os.path.join(base, ent) if base else ent
519             if os.path.isdir(ent_path):
520                 allfiles += util.listdir_recursive(ent_path, ent_base)
521             else:
522                 allfiles += [ent_base]
523         return allfiles
524