e4e8588d5461f817e9bb75036d58433f46053801
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "encoding/json"
9         "strings"
10
11         "git.curoverse.com/arvados.git/lib/cloud"
12         "git.curoverse.com/arvados.git/sdk/go/arvados"
13         "github.com/aws/aws-sdk-go/aws"
14         "github.com/aws/aws-sdk-go/aws/credentials"
15         "github.com/aws/aws-sdk-go/aws/session"
16         "github.com/aws/aws-sdk-go/service/ec2"
17         "github.com/sirupsen/logrus"
18         "golang.org/x/crypto/ssh"
19 )
20
21 const ARVADOS_DISPATCH_ID = "arvados-dispatch-id"
22 const TAG_PREFIX = "disispatch-"
23
24 // Driver is the ec2 implementation of the cloud.Driver interface.
25 var Driver = cloud.DriverFunc(newEC2InstanceSet)
26
27 type ec2InstanceSetConfig struct {
28         AccessKeyID     string
29         SecretAccessKey string
30         Region          string
31         SecurityGroupId string
32         SubnetId        string
33         AdminUsername   string
34         KeyPairName     string
35 }
36
37 type ec2InstanceSet struct {
38         ec2config    ec2InstanceSetConfig
39         dispatcherID cloud.InstanceSetID
40         logger       logrus.FieldLogger
41         client       *ec2.EC2
42         importedKey  bool
43 }
44
45 func newEC2InstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
46         instanceSet := &ec2InstanceSet{
47                 dispatcherID: dispatcherID,
48                 logger:       logger,
49         }
50         err = json.Unmarshal(config, &instanceSet.ec2config)
51         if err != nil {
52                 return nil, err
53         }
54         awsConfig := aws.NewConfig().
55                 WithCredentials(credentials.NewStaticCredentials(
56                         instanceSet.ec2config.AccessKeyID,
57                         instanceSet.ec2config.SecretAccessKey,
58                         "")).
59                 WithRegion(instanceSet.ec2config.Region)
60         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
61         return instanceSet, nil
62 }
63
64 func (instanceSet *ec2InstanceSet) Create(
65         instanceType arvados.InstanceType,
66         imageID cloud.ImageID,
67         newTags cloud.InstanceTags,
68         initCommand cloud.InitCommand,
69         publicKey ssh.PublicKey) (cloud.Instance, error) {
70
71         if !instanceSet.importedKey {
72                 instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
73                         KeyName:           &instanceSet.ec2config.KeyPairName,
74                         PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
75                 })
76                 instanceSet.importedKey = true
77         }
78
79         ec2tags := []*ec2.Tag{
80                 &ec2.Tag{
81                         Key:   aws.String(ARVADOS_DISPATCH_ID),
82                         Value: aws.String(string(instanceSet.dispatcherID)),
83                 },
84         }
85         for k, v := range newTags {
86                 ec2tags = append(ec2tags, &ec2.Tag{
87                         Key:   aws.String(TAG_PREFIX + k),
88                         Value: aws.String(v),
89                 })
90         }
91
92         rsv, err := instanceSet.client.RunInstances(&ec2.RunInstancesInput{
93                 ImageId:          aws.String(string(imageID)),
94                 InstanceType:     &instanceType.ProviderType,
95                 MaxCount:         aws.Int64(1),
96                 MinCount:         aws.Int64(1),
97                 KeyName:          &instanceSet.ec2config.KeyPairName,
98                 SecurityGroupIds: []*string{&instanceSet.ec2config.SecurityGroupId},
99                 SubnetId:         &instanceSet.ec2config.SubnetId,
100                 TagSpecifications: []*ec2.TagSpecification{
101                         &ec2.TagSpecification{
102                                 ResourceType: aws.String("instance"),
103                                 Tags:         ec2tags,
104                         }},
105         })
106
107         if err != nil {
108                 return nil, err
109         }
110
111         return &ec2Instance{
112                 provider: instanceSet,
113                 instance: rsv.Instances[0],
114         }, nil
115 }
116
117 func (instanceSet *ec2InstanceSet) Instances(cloud.InstanceTags) (instances []cloud.Instance, err error) {
118         dii := &ec2.DescribeInstancesInput{
119                 Filters: []*ec2.Filter{&ec2.Filter{
120                         Name:   aws.String("tag:" + ARVADOS_DISPATCH_ID),
121                         Values: []*string{aws.String(string(instanceSet.dispatcherID))},
122                 }}}
123
124         for {
125                 dio, err := instanceSet.client.DescribeInstances(dii)
126                 if err != nil {
127                         return nil, err
128                 }
129
130                 for _, rsv := range dio.Reservations {
131                         for _, inst := range rsv.Instances {
132                                 instances = append(instances, &ec2Instance{instanceSet, inst})
133                         }
134                 }
135                 if dio.NextToken == nil {
136                         return instances, err
137                 }
138                 dii.NextToken = dio.NextToken
139         }
140 }
141
142 func (az *ec2InstanceSet) Stop() {
143 }
144
145 type ec2Instance struct {
146         provider *ec2InstanceSet
147         instance *ec2.Instance
148 }
149
150 func (inst *ec2Instance) ID() cloud.InstanceID {
151         return cloud.InstanceID(*inst.instance.InstanceId)
152 }
153
154 func (inst *ec2Instance) String() string {
155         return *inst.instance.InstanceId
156 }
157
158 func (inst *ec2Instance) ProviderType() string {
159         return *inst.instance.InstanceType
160 }
161
162 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
163         ec2tags := []*ec2.Tag{
164                 &ec2.Tag{
165                         Key:   aws.String(ARVADOS_DISPATCH_ID),
166                         Value: aws.String(string(inst.provider.dispatcherID)),
167                 },
168         }
169         for k, v := range newTags {
170                 ec2tags = append(ec2tags, &ec2.Tag{
171                         Key:   aws.String(TAG_PREFIX + k),
172                         Value: aws.String(v),
173                 })
174         }
175
176         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
177                 Resources: []*string{inst.instance.InstanceId},
178                 Tags:      ec2tags,
179         })
180
181         return err
182 }
183
184 func (inst *ec2Instance) Tags() cloud.InstanceTags {
185         tags := make(map[string]string)
186
187         for _, t := range inst.instance.Tags {
188                 if strings.HasPrefix(*t.Key, TAG_PREFIX) {
189                         tags[(*t.Key)[len(TAG_PREFIX):]] = *t.Value
190                 }
191         }
192
193         return tags
194 }
195
196 func (inst *ec2Instance) Destroy() error {
197         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
198                 InstanceIds: []*string{inst.instance.InstanceId},
199         })
200         return err
201 }
202
203 func (inst *ec2Instance) Address() string {
204         if inst.instance.PrivateIpAddress != nil {
205                 return *inst.instance.PrivateIpAddress
206         } else {
207                 return ""
208         }
209 }
210
211 func (inst *ec2Instance) RemoteUser() string {
212         return inst.provider.ec2config.AdminUsername
213 }
214
215 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
216         return cloud.ErrNotImplemented
217 }