Merge branch 'master' into 13822-nm-delayed-daemon
[arvados.git] / services / nodemanager / arvnodeman / computenode / driver / ec2.py
index d89c48e270bcc119638c70fc3d5f2928fbe1f8e3..56812d258a92212b02a53d9775534d8b23b50b69 100644 (file)
@@ -1,4 +1,7 @@
 #!/usr/bin/env python
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: AGPL-3.0
 
 from __future__ import absolute_import, print_function
 
@@ -49,12 +52,15 @@ class ComputeNodeDriver(BaseComputeNodeDriver):
         self.tags = {key[4:]: value
                      for key, value in list_kwargs.iteritems()
                      if key.startswith('tag:')}
+        # Tags are assigned at instance creation time
+        create_kwargs.setdefault('ex_metadata', {})
+        create_kwargs['ex_metadata'].update(self.tags)
         super(ComputeNodeDriver, self).__init__(
             auth_kwargs, {'ex_filters': list_kwargs}, create_kwargs,
             driver_class)
 
     def _init_image_id(self, image_id):
-        return 'image', self.search_for(image_id, 'list_images')
+        return 'image', self.search_for(image_id, 'list_images', ex_owner='self')
 
     def _init_security_groups(self, group_names):
         return 'ex_security_groups', [
@@ -64,26 +70,48 @@ class ComputeNodeDriver(BaseComputeNodeDriver):
     def _init_subnet_id(self, subnet_id):
         return 'ex_subnet', self.search_for(subnet_id, 'ex_list_subnets')
 
+    create_cloud_name = staticmethod(arvados_node_fqdn)
+
     def arvados_create_kwargs(self, size, arvados_node):
-        return {'name': arvados_node_fqdn(arvados_node),
+        kw = {'name': self.create_cloud_name(arvados_node),
                 'ex_userdata': self._make_ping_url(arvados_node)}
-
-    def post_create_node(self, cloud_node):
-        self.real.ex_create_tags(cloud_node, self.tags)
+        # libcloud/ec2 disk sizes are in GB, Arvados/SLURM "scratch" value is in MB
+        scratch = int(size.scratch / 1000) + 1
+        if scratch > size.disk:
+            volsize = scratch - size.disk
+            if volsize > 16384:
+                # Must be 1-16384 for General Purpose SSD (gp2) devices
+                # https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_EbsBlockDevice.html
+                self._logger.warning("Requested EBS volume size %d is too large, capping size request to 16384 GB", volsize)
+                volsize = 16384
+            kw["ex_blockdevicemappings"] = [{
+                "DeviceName": "/dev/xvdt",
+                "Ebs": {
+                    "DeleteOnTermination": True,
+                    "VolumeSize": volsize,
+                    "VolumeType": "gp2"
+                }}]
+        if size.preemptible:
+            # Request a Spot instance for this node
+            kw['ex_spot_market'] = True
+        return kw
 
     def sync_node(self, cloud_node, arvados_node):
         self.real.ex_create_tags(cloud_node,
                                  {'Name': arvados_node_fqdn(arvados_node)})
 
-    def find_node(self, name):
-        raise NotImplementedError("ec2.ComputeNodeDriver.find_node")
+    def create_node(self, size, arvados_node):
+        # Set up tag indicating the Arvados assigned Cloud Size id.
+        self.create_kwargs['ex_metadata'].update({'arvados_node_size': size.id})
+        return super(ComputeNodeDriver, self).create_node(size, arvados_node)
 
     def list_nodes(self):
         # Need to populate Node.size
         nodes = super(ComputeNodeDriver, self).list_nodes()
         for n in nodes:
             if not n.size:
-                n.size = self.sizes[n.extra["instance_type"]]
+                n.size = self.sizes()[n.extra["instance_type"]]
+            n.extra['arvados_node_size'] = n.extra.get('tags', {}).get('arvados_node_size')
         return nodes
 
     @classmethod
@@ -95,3 +123,7 @@ class ComputeNodeDriver(BaseComputeNodeDriver):
         time_str = node.extra['launch_time'].split('.', 2)[0] + 'UTC'
         return time.mktime(time.strptime(
                 time_str,'%Y-%m-%dT%H:%M:%S%Z')) - time.timezone
+
+    @classmethod
+    def node_id(cls, node):
+        return node.id