6934: Replace custom config with yaml. Simplify code path.
[arvados.git] / sdk / pam / arvados_pam / __init__.py
index 4db6e58b129db6f958ccf559f349ba3a9602cb64..b7361f56d9ddabb0330b91ee7bde81e2d98e3219 100644 (file)
@@ -4,6 +4,7 @@ sys.argv=['']
 import arvados
 import os
 import syslog
+import yaml
 
 def auth_log(msg):
     """Send errors to default auth log"""
@@ -12,108 +13,111 @@ def auth_log(msg):
     syslog.closelog()
 
 def config_file():
-    return file('/etc/default/arvados_pam')
+    return file('/etc/default/arvados_pam.conf').read()
 
 def config():
-    txt = config_file().read()
-    c = dict()
-    for x in txt.splitlines(False):
-        if not x.strip().startswith('#'):
-            kv = x.split('=', 2)
-            c[kv[0].strip()] = kv[1].strip()
-    return c
+    return yaml.load(config_file())
 
 class AuthEvent(object):
-    def __init__(self, client_host, api_host, shell_host, username, token):
+    def __init__(self, config, service, client_host, username, token):
+        self.config = config
+        self.service = service
         self.client_host = client_host
-        self.api_host = api_host
-        self.shell_hostname = shell_host
         self.username = username
         self.token = token
-        self.vm = None
+
+        self.api_host = None
+        self.vm_uuid = None
         self.user = None
 
     def can_login(self):
+        """Return truthy IFF credentials should be accepted."""
         ok = False
         try:
+            self.api_host = self.config[self.service]['ARVADOS_API_HOST']
             self.arv = arvados.api('v1', host=self.api_host, token=self.token, cache=None)
-            self._lookup_vm()
-            if self._check_login_permission():
-                self.result = 'Authenticated'
-                ok = True
-            else:
-                self.result = 'Denied'
+
+            vmname = self.config[self.service]['virtual_machine_hostname']
+            vms = self.arv.virtual_machines().list(filters=[['hostname','=',vmname]]).execute()
+            if vms['items_available'] > 1:
+                raise Exception("lookup hostname %s returned %d records" % (vmname, vms['items_available']))
+            if vms['items_available'] == 0:
+                raise Exception("lookup hostname %s not found" % vmname)
+            vm = vms['items'][0]
+            if vm['hostname'] != vmname:
+                raise Exception("lookup hostname %s returned hostname %s" % (vmname, vm['hostname']))
+            self.vm_uuid = vm['uuid']
+
+            self.user = self.arv.users().current().execute()
+
+            filters = [
+                ['link_class','=','permission'],
+                ['name','=','can_login'],
+                ['head_uuid','=',self.vm_uuid],
+                ['tail_uuid','=',self.user['uuid']]]
+            for l in self.arv.links().list(filters=filters, limit=10000).execute()['items']:
+                if (l['properties']['username'] == self.username and
+                    l['tail_uuid'] == self.user['uuid'] and
+                    l['head_uuid'] == self.vm_uuid and
+                    l['link_class'] == 'permission' and
+                    l['name'] == 'can_login'):
+                    return self._report(True)
+
+            return self._report(False)
+
         except Exception as e:
-            self.result = 'Error: ' + repr(e)
+            return self._report(e)
+
+    def _report(self, result):
+        """Log the result. Return truthy IFF result is True.
+
+        result must be True, False, or an exception.
+        """
+        self.result = result
         auth_log(self.message())
-        return ok
-
-    def _lookup_vm(self):
-        """Load the VM record for this host into self.vm. Raise if not possible."""
-
-        vms = self.arv.virtual_machines().list(filters=[['hostname','=',self.shell_hostname]]).execute()
-        if vms['items_available'] > 1:
-            raise Exception("ambiguous VM hostname matched %d records" % vms['items_available'])
-        if vms['items_available'] == 0:
-            raise Exception("VM hostname not found")
-        self.vm = vms['items'][0]
-        if self.vm['hostname'] != self.shell_hostname:
-            raise Exception("API returned record with wrong hostname")
-
-    def _check_login_permission(self):
-        """Check permission to log in. Return True if permission is granted."""
-        self._lookup_vm()
-        self.user = self.arv.users().current().execute()
-        filters = [
-            ['link_class','=','permission'],
-            ['name','=','can_login'],
-            ['head_uuid','=',self.vm['uuid']],
-            ['tail_uuid','=',self.user['uuid']]]
-        for l in self.arv.links().list(filters=filters, limit=10000).execute()['items']:
-            if (l['properties']['username'] == self.username and
-                l['tail_uuid'] == self.user['uuid'] and
-                l['head_uuid'] == self.vm['uuid'] and
-                l['link_class'] == 'permission' and
-                l['name'] == 'can_login'):
-                return True
-        return False
+        return result == True
 
     def message(self):
+        """Return a log message describing the event and its outcome."""
+        if isinstance(self.result, Exception):
+            outcome = 'Error: ' + repr(self.result)
+        elif self.result == True:
+            outcome = 'Allow'
+        else:
+            outcome = 'Deny'
+
         if len(self.token) > 40:
             log_token = self.token[0:15]
         else:
             log_token = '<invalid>'
-        log_label = [self.client_host, self.api_host, self.shell_hostname, self.username, log_token]
-        if self.vm:
-            log_label += [self.vm.get('uuid')]
+
+        log_label = [self.client_host, self.api_host, self.vm_uuid, self.username, log_token]
+        if self.vm_uuid:
+            log_label += [self.vm_uuid]
         if self.user:
             log_label += [self.user.get('uuid'), self.user.get('full_name')]
-        return str(log_label) + ': ' + self.result
+        return str(log_label) + ': ' + outcome
 
 
 def pam_sm_authenticate(pamh, flags, argv):
     try:
-        user = pamh.get_user()
+        username = pamh.get_user()
     except pamh.exception as e:
         return e.pam_result
 
-    if not user:
+    if not username:
         return pamh.PAM_USER_UNKNOWN
 
     try:
-        resp = pamh.conversation(pamh.Message(pamh.PAM_PROMPT_ECHO_OFF, ''))
+        token = pamh.conversation(pamh.Message(pamh.PAM_PROMPT_ECHO_OFF, '')).resp
     except pamh.exception as e:
         return e.pam_result
 
-    try:
-        config = config()
-        api_host = config['ARVADOS_API_HOST'].strip()
-        shell_host = config['HOSTNAME'].strip()
-    except Exception as e:
-        auth_log("loading config: " + repr(e))
-        return False
-
-    if AuthEvent(pamh.rhost, api_host, shell_host, user, resp.resp).can_login():
+    if AuthEvent(config(),
+                 service=pamh.service,
+                 client_host=pamh.rhost,
+                 username=username,
+                 token=token).can_login():
         return pamh.PAM_SUCCESS
     else:
         return pamh.PAM_AUTH_ERR