Merge branch '1880-check-duplicate-public-key'
[arvados.git] / sdk / python / arvados / __init__.py
index 230b9c24e52b56f9ddfbfafd10566880f9dc2339..dacdba851af3a99baaa5cdda21f105cf9354d3ae 100644 (file)
@@ -18,12 +18,34 @@ import fcntl
 import time
 import threading
 
-from apiclient import errors
-from apiclient.discovery import build
+import apiclient
+import apiclient.discovery
 
-if 'ARVADOS_DEBUG' in os.environ:
+# Arvados configuration settings are taken from $HOME/.config/arvados.
+# Environment variables override settings in the config file.
+#
+class ArvadosConfig(dict):
+    def __init__(self, config_file):
+        dict.__init__(self)
+        if os.path.exists(config_file):
+            with open(config_file, "r") as f:
+                for config_line in f:
+                    var, val = config_line.rstrip().split('=', 2)
+                    self[var] = val
+        for var in os.environ:
+            if var.startswith('ARVADOS_'):
+                self[var] = os.environ[var]
+
+
+config = ArvadosConfig(os.environ['HOME'] + '/.config/arvados')
+
+if 'ARVADOS_DEBUG' in config:
     logging.basicConfig(level=logging.DEBUG)
 
+EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
+
+services = {}
+
 class errors:
     class SyntaxError(Exception):
         pass
@@ -41,10 +63,11 @@ class errors:
 class CredentialsFromEnv(object):
     @staticmethod
     def http_request(self, uri, **kwargs):
+        global config
         from httplib import BadStatusLine
         if 'headers' not in kwargs:
             kwargs['headers'] = {}
-        kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
+        kwargs['headers']['Authorization'] = 'OAuth2 %s' % config.get('ARVADOS_API_TOKEN', 'ARVADOS_API_TOKEN_not_set')
         try:
             return self.orig_http_request(uri, **kwargs)
         except BadStatusLine:
@@ -60,27 +83,9 @@ class CredentialsFromEnv(object):
         http.request = types.MethodType(self.http_request, http)
         return http
 
-url = ('https://%s:%s/discovery/v1/apis/'
-       '{api}/{apiVersion}/rest' %
-           (os.environ['ARVADOS_API_HOST'],
-            os.environ.get('ARVADOS_API_PORT') or "443"))
-credentials = CredentialsFromEnv()
-
-# Use system's CA certificates (if we find them) instead of httplib2's
-ca_certs = '/etc/ssl/certs/ca-certificates.crt'
-if not os.path.exists(ca_certs):
-    ca_certs = None             # use httplib2 default
-
-http = httplib2.Http(ca_certs=ca_certs)
-http = credentials.authorize(http)
-if re.match(r'(?i)^(true|1|yes)$',
-            os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
-    http.disable_ssl_certificate_validation=True
-service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
-
 def task_set_output(self,s):
-    service.job_tasks().update(uuid=self['uuid'],
-                               body={
+    api('v1').job_tasks().update(uuid=self['uuid'],
+                                 body={
             'output':s,
             'success':True,
             'progress':1.0
@@ -91,7 +96,7 @@ def current_task():
     global _current_task
     if _current_task:
         return _current_task
-    t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
+    t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
     t = UserDict.UserDict(t)
     t.set_output = types.MethodType(task_set_output, t)
     t.tmpdir = os.environ['TASK_WORK']
@@ -103,7 +108,7 @@ def current_job():
     global _current_job
     if _current_job:
         return _current_job
-    t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
+    t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
     t = UserDict.UserDict(t)
     t.tmpdir = os.environ['JOB_WORK']
     _current_job = t
@@ -112,8 +117,46 @@ def current_job():
 def getjobparam(*args):
     return current_job()['script_parameters'].get(*args)
 
-def api():
-    return service
+# Monkey patch discovery._cast() so objects and arrays get serialized
+# with json.dumps() instead of str().
+_cast_orig = apiclient.discovery._cast
+def _cast_objects_too(value, schema_type):
+    global _cast_orig
+    if (type(value) != type('') and
+        (schema_type == 'object' or schema_type == 'array')):
+        return json.dumps(value)
+    else:
+        return _cast_orig(value, schema_type)
+apiclient.discovery._cast = _cast_objects_too
+
+def api(version=None):
+    global services, config
+    if not services.get(version):
+        apiVersion = version
+        if not version:
+            apiVersion = 'v1'
+            logging.info("Using default API version. " +
+                         "Call arvados.api('%s') instead." %
+                         apiVersion)
+        if 'ARVADOS_API_HOST' not in config:
+            raise Exception("ARVADOS_API_HOST is not set. Aborting.")
+        url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
+               config['ARVADOS_API_HOST'])
+        credentials = CredentialsFromEnv()
+
+        # Use system's CA certificates (if we find them) instead of httplib2's
+        ca_certs = '/etc/ssl/certs/ca-certificates.crt'
+        if not os.path.exists(ca_certs):
+            ca_certs = None             # use httplib2 default
+
+        http = httplib2.Http(ca_certs=ca_certs)
+        http = credentials.authorize(http)
+        if re.match(r'(?i)^(true|1|yes)$',
+                    config.get('ARVADOS_API_HOST_INSECURE', 'no')):
+            http.disable_ssl_certificate_validation=True
+        services[version] = apiclient.discovery.build(
+            'arvados', apiVersion, http=http, discoveryServiceUrl=url)
+    return services[version]
 
 class JobTask(object):
     def __init__(self, parameters=dict(), runtime_constraints=dict()):
@@ -137,9 +180,9 @@ class job_setup:
                         'input':task_input
                         }
                     }
-                service.job_tasks().create(body=new_task_attrs).execute()
+                api('v1').job_tasks().create(body=new_task_attrs).execute()
         if and_end_task:
-            service.job_tasks().update(uuid=current_task()['uuid'],
+            api('v1').job_tasks().update(uuid=current_task()['uuid'],
                                        body={'success':True}
                                        ).execute()
             exit(0)
@@ -160,9 +203,9 @@ class job_setup:
                     'input':task_input
                     }
                 }
-            service.job_tasks().create(body=new_task_attrs).execute()
+            api('v1').job_tasks().create(body=new_task_attrs).execute()
         if and_end_task:
-            service.job_tasks().update(uuid=current_task()['uuid'],
+            api('v1').job_tasks().update(uuid=current_task()['uuid'],
                                        body={'success':True}
                                        ).execute()
             exit(0)
@@ -542,8 +585,8 @@ class StreamFileReader(object):
             yield data
     def as_manifest(self):
         if self.size() == 0:
-            return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
-                    % (self._stream.name(), self.name()))
+            return ("%s %s 0:0:%s\n"
+                    % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
         return string.join(self._stream.tokens_for_range(self._pos, self._size),
                            " ") + "\n"
 
@@ -561,12 +604,12 @@ class StreamReader(object):
 
         for tok in self._tokens:
             if self._stream_name == None:
-                self._stream_name = tok
+                self._stream_name = tok.replace('\\040', ' ')
             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
                 self.data_locators += [tok]
             elif re.search(r'^\d+:\d+:\S+', tok):
                 pos, size, name = tok.split(':',2)
-                self.files += [[int(pos), int(size), name]]
+                self.files += [[int(pos), int(size), name.replace('\\040', ' ')]]
             else:
                 raise errors.SyntaxError("Invalid manifest format")
 
@@ -757,8 +800,7 @@ class CollectionWriter(object):
         self.finish_current_file()
         self.set_current_file_name(newfilename)
     def set_current_file_name(self, newfilename):
-        newfilename = re.sub(r' ', '\\\\040', newfilename)
-        if re.search(r'[ \t\n]', newfilename):
+        if re.search(r'[\t\n]', newfilename):
             raise errors.AssertionError(
                 "Manifest filenames cannot contain whitespace: %s" %
                 newfilename)
@@ -783,7 +825,7 @@ class CollectionWriter(object):
         self.finish_current_stream()
         self.set_current_stream_name(newstreamname)
     def set_current_stream_name(self, newstreamname):
-        if re.search(r'[ \t\n]', newstreamname):
+        if re.search(r'[\t\n]', newstreamname):
             raise errors.AssertionError(
                 "Manifest stream names cannot contain whitespace")
         self._current_stream_name = '.' if newstreamname=='' else newstreamname
@@ -799,6 +841,8 @@ class CollectionWriter(object):
                 "Cannot finish an unnamed stream (%d bytes in %d files)" %
                 (self._current_stream_length, len(self._current_stream_files)))
         else:
+            if len(self._current_stream_locators) == 0:
+                self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
             self._finished_streams += [[self._current_stream_name,
                                        self._current_stream_locators,
                                        self._current_stream_files]]
@@ -816,14 +860,11 @@ class CollectionWriter(object):
         for stream in self._finished_streams:
             if not re.search(r'^\.(/.*)?$', stream[0]):
                 manifest += './'
-            manifest += stream[0]
-            if len(stream[1]) == 0:
-                manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
-            else:
-                for locator in stream[1]:
-                    manifest += " %s" % locator
+            manifest += stream[0].replace(' ', '\\040')
+            for locator in stream[1]:
+                manifest += " %s" % locator
             for sfile in stream[2]:
-                manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
+                manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2].replace(' ', '\\040'))
             manifest += "\n"
         return manifest
     def data_locators(self):
@@ -901,6 +942,7 @@ class KeepClient(object):
             super(KeepClient.KeepWriterThread, self).__init__()
             self.args = kwargs
         def run(self):
+            global config
             with self.args['thread_limiter'] as limiter:
                 if not limiter.shall_i_proceed():
                     # My turn arrived, but the job has been done without
@@ -912,7 +954,7 @@ class KeepClient(object):
                                self.args['service_root']))
                 h = httplib2.Http()
                 url = self.args['service_root'] + self.args['data_hash']
-                api_token = os.environ['ARVADOS_API_TOKEN']
+                api_token = config['ARVADOS_API_TOKEN']
                 headers = {'Authorization': "OAuth2 %s" % api_token}
                 try:
                     resp, content = h.request(url.encode('utf-8'), 'PUT',
@@ -972,13 +1014,16 @@ class KeepClient(object):
         return pseq
 
     def get(self, locator):
+        global config
+        if re.search(r',', locator):
+            return ''.join(self.get(x) for x in locator.split(','))
         if 'KEEP_LOCAL_STORE' in os.environ:
             return KeepClient.local_store_get(locator)
         expect_hash = re.sub(r'\+.*', '', locator)
         for service_root in self.shuffled_service_roots(expect_hash):
             h = httplib2.Http()
             url = service_root + expect_hash
-            api_token = os.environ['ARVADOS_API_TOKEN']
+            api_token = config['ARVADOS_API_TOKEN']
             headers = {'Authorization': "OAuth2 %s" % api_token,
                        'Accept': 'application/octet-stream'}
             try:
@@ -1046,7 +1091,11 @@ class KeepClient(object):
         if not r:
             raise errors.NotFoundError(
                 "Invalid data locator: '%s'" % locator)
-        if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
+        if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
             return ''
         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
             return f.read()
+
+# We really shouldn't do this but some clients still use
+# arvados.service.* directly instead of arvados.api().*
+service = api()