1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
14 "git.curoverse.com/arvados.git/lib/cloud"
15 "git.curoverse.com/arvados.git/sdk/go/arvados"
16 "github.com/aws/aws-sdk-go/aws"
17 "github.com/aws/aws-sdk-go/aws/credentials"
18 "github.com/aws/aws-sdk-go/aws/session"
19 "github.com/aws/aws-sdk-go/service/ec2"
20 "github.com/sirupsen/logrus"
21 "golang.org/x/crypto/ssh"
24 const ARVADOS_DISPATCH_ID = "arvados-dispatch-id"
25 const TAG_PREFIX = "arvados-dispatch-tag-"
27 // Driver is the ec2 implementation of the cloud.Driver interface.
28 var Driver = cloud.DriverFunc(newEC2InstanceSet)
30 type ec2InstanceSetConfig struct {
32 SecretAccessKey string
34 SecurityGroupId string
40 type ec2InstanceSet struct {
41 ec2config ec2InstanceSetConfig
42 dispatcherID cloud.InstanceSetID
43 logger logrus.FieldLogger
48 func newEC2InstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
49 instanceSet := &ec2InstanceSet{
50 dispatcherID: dispatcherID,
53 err = json.Unmarshal(config, &instanceSet.ec2config)
57 awsConfig := aws.NewConfig().
58 WithCredentials(credentials.NewStaticCredentials(
59 instanceSet.ec2config.AccessKeyID,
60 instanceSet.ec2config.SecretAccessKey,
62 WithRegion(instanceSet.ec2config.Region)
63 instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
64 return instanceSet, nil
67 func (instanceSet *ec2InstanceSet) Create(
68 instanceType arvados.InstanceType,
69 imageID cloud.ImageID,
70 newTags cloud.InstanceTags,
71 initCommand cloud.InitCommand,
72 publicKey ssh.PublicKey) (cloud.Instance, error) {
74 if !instanceSet.importedKey {
75 instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
76 KeyName: &instanceSet.ec2config.KeyPairName,
77 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
79 instanceSet.importedKey = true
82 ec2tags := []*ec2.Tag{
84 Key: aws.String(ARVADOS_DISPATCH_ID),
85 Value: aws.String(string(instanceSet.dispatcherID)),
88 Key: aws.String("arvados-class"),
89 Value: aws.String("dynamic-compute"),
92 for k, v := range newTags {
93 ec2tags = append(ec2tags, &ec2.Tag{
94 Key: aws.String(TAG_PREFIX + k),
99 rii := ec2.RunInstancesInput{
100 ImageId: aws.String(string(imageID)),
101 InstanceType: &instanceType.ProviderType,
102 MaxCount: aws.Int64(1),
103 MinCount: aws.Int64(1),
104 KeyName: &instanceSet.ec2config.KeyPairName,
106 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
107 &ec2.InstanceNetworkInterfaceSpecification{
108 AssociatePublicIpAddress: aws.Bool(false),
109 DeleteOnTermination: aws.Bool(true),
110 DeviceIndex: aws.Int64(0),
111 Groups: []*string{&instanceSet.ec2config.SecurityGroupId},
112 SubnetId: &instanceSet.ec2config.SubnetId,
114 DisableApiTermination: aws.Bool(false),
115 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
116 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
117 TagSpecifications: []*ec2.TagSpecification{
118 &ec2.TagSpecification{
119 ResourceType: aws.String("instance"),
124 if instanceType.ExtraScratch > 0 {
125 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{&ec2.BlockDeviceMapping{
126 DeviceName: aws.String("/dev/xvdt"),
127 Ebs: &ec2.EbsBlockDevice{
128 DeleteOnTermination: aws.Bool(true),
129 VolumeSize: aws.Int64((int64(instanceType.ExtraScratch) / 1000000000) + 1),
130 VolumeType: aws.String("gp2"),
134 if instanceType.Preemptible {
135 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
136 MarketType: aws.String("spot"),
137 SpotOptions: &ec2.SpotMarketOptions{
138 InstanceInterruptionBehavior: aws.String("terminate"),
139 MaxPrice: aws.String(fmt.Sprintf("%v", instanceType.Price)),
143 rsv, err := instanceSet.client.RunInstances(&rii)
150 provider: instanceSet,
151 instance: rsv.Instances[0],
155 func (instanceSet *ec2InstanceSet) Instances(cloud.InstanceTags) (instances []cloud.Instance, err error) {
156 dii := &ec2.DescribeInstancesInput{
157 Filters: []*ec2.Filter{&ec2.Filter{
158 Name: aws.String("tag:" + ARVADOS_DISPATCH_ID),
159 Values: []*string{aws.String(string(instanceSet.dispatcherID))},
163 dio, err := instanceSet.client.DescribeInstances(dii)
168 for _, rsv := range dio.Reservations {
169 for _, inst := range rsv.Instances {
170 instances = append(instances, &ec2Instance{instanceSet, inst})
173 if dio.NextToken == nil {
174 return instances, err
176 dii.NextToken = dio.NextToken
180 func (az *ec2InstanceSet) Stop() {
183 type ec2Instance struct {
184 provider *ec2InstanceSet
185 instance *ec2.Instance
188 func (inst *ec2Instance) ID() cloud.InstanceID {
189 return cloud.InstanceID(*inst.instance.InstanceId)
192 func (inst *ec2Instance) String() string {
193 return *inst.instance.InstanceId
196 func (inst *ec2Instance) ProviderType() string {
197 return *inst.instance.InstanceType
200 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
201 ec2tags := []*ec2.Tag{
203 Key: aws.String(ARVADOS_DISPATCH_ID),
204 Value: aws.String(string(inst.provider.dispatcherID)),
207 for k, v := range newTags {
208 ec2tags = append(ec2tags, &ec2.Tag{
209 Key: aws.String(TAG_PREFIX + k),
210 Value: aws.String(v),
214 _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
215 Resources: []*string{inst.instance.InstanceId},
222 func (inst *ec2Instance) Tags() cloud.InstanceTags {
223 tags := make(map[string]string)
225 for _, t := range inst.instance.Tags {
226 if strings.HasPrefix(*t.Key, TAG_PREFIX) {
227 tags[(*t.Key)[len(TAG_PREFIX):]] = *t.Value
234 func (inst *ec2Instance) Destroy() error {
235 log.Printf("terminating %v", *inst.instance.InstanceId)
236 _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
237 InstanceIds: []*string{inst.instance.InstanceId},
242 func (inst *ec2Instance) Address() string {
243 if inst.instance.PrivateIpAddress != nil {
244 return *inst.instance.PrivateIpAddress
250 func (inst *ec2Instance) RemoteUser() string {
251 return inst.provider.ec2config.AdminUsername
254 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
255 return cloud.ErrNotImplemented