14931: Support extra tags on resources created by dispatchcloud.
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "crypto/md5"
9         "crypto/rsa"
10         "crypto/sha1"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/json"
14         "fmt"
15         "math/big"
16         "sync"
17
18         "git.curoverse.com/arvados.git/lib/cloud"
19         "git.curoverse.com/arvados.git/sdk/go/arvados"
20         "github.com/aws/aws-sdk-go/aws"
21         "github.com/aws/aws-sdk-go/aws/credentials"
22         "github.com/aws/aws-sdk-go/aws/session"
23         "github.com/aws/aws-sdk-go/service/ec2"
24         "github.com/sirupsen/logrus"
25         "golang.org/x/crypto/ssh"
26 )
27
28 const tagKeyInstanceSetID = "arvados-dispatch-id"
29
30 // Driver is the ec2 implementation of the cloud.Driver interface.
31 var Driver = cloud.DriverFunc(newEC2InstanceSet)
32
33 type ec2InstanceSetConfig struct {
34         AccessKeyID      string
35         SecretAccessKey  string
36         Region           string
37         SecurityGroupIDs []string
38         SubnetID         string
39         AdminUsername    string
40         EBSVolumeType    string
41 }
42
43 type ec2Interface interface {
44         DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
45         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
46         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
47         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
48         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
49         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
50 }
51
52 type ec2InstanceSet struct {
53         ec2config     ec2InstanceSetConfig
54         instanceSetID cloud.InstanceSetID
55         logger        logrus.FieldLogger
56         client        ec2Interface
57         keysMtx       sync.Mutex
58         keys          map[string]string
59 }
60
61 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
62         instanceSet := &ec2InstanceSet{
63                 instanceSetID: instanceSetID,
64                 logger:        logger,
65         }
66         err = json.Unmarshal(config, &instanceSet.ec2config)
67         if err != nil {
68                 return nil, err
69         }
70         awsConfig := aws.NewConfig().
71                 WithCredentials(credentials.NewStaticCredentials(
72                         instanceSet.ec2config.AccessKeyID,
73                         instanceSet.ec2config.SecretAccessKey,
74                         "")).
75                 WithRegion(instanceSet.ec2config.Region)
76         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
77         instanceSet.keys = make(map[string]string)
78         if instanceSet.ec2config.EBSVolumeType == "" {
79                 instanceSet.ec2config.EBSVolumeType = "gp2"
80         }
81         return instanceSet, nil
82 }
83
84 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
85         // AWS key fingerprints don't use the usual key fingerprint
86         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
87         // (you can get that from md5.Sum(pk.Marshal())
88         //
89         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
90         // public key, so calculate those fingerprints here.
91         var rsaPub struct {
92                 Name string
93                 E    *big.Int
94                 N    *big.Int
95         }
96         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
97                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
98         }
99         rsaPk := rsa.PublicKey{
100                 E: int(rsaPub.E.Int64()),
101                 N: rsaPub.N,
102         }
103         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
104         md5pkix := md5.Sum([]byte(pkix))
105         sha1pkix := sha1.Sum([]byte(pkix))
106         md5fp = ""
107         sha1fp = ""
108         for i := 0; i < len(md5pkix); i += 1 {
109                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
110         }
111         for i := 0; i < len(sha1pkix); i += 1 {
112                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
113         }
114         return md5fp[1:], sha1fp[1:], nil
115 }
116
117 func (instanceSet *ec2InstanceSet) Create(
118         instanceType arvados.InstanceType,
119         imageID cloud.ImageID,
120         newTags cloud.InstanceTags,
121         initCommand cloud.InitCommand,
122         publicKey ssh.PublicKey) (cloud.Instance, error) {
123
124         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
125         if err != nil {
126                 return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
127         }
128         instanceSet.keysMtx.Lock()
129         var keyname string
130         var ok bool
131         if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
132                 keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
133                         Filters: []*ec2.Filter{&ec2.Filter{
134                                 Name:   aws.String("fingerprint"),
135                                 Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
136                         }},
137                 })
138                 if err != nil {
139                         return nil, fmt.Errorf("Could not search for keypair: %v", err)
140                 }
141
142                 if len(keyout.KeyPairs) > 0 {
143                         keyname = *(keyout.KeyPairs[0].KeyName)
144                 } else {
145                         keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
146                         _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
147                                 KeyName:           &keyname,
148                                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
149                         })
150                         if err != nil {
151                                 return nil, fmt.Errorf("Could not import keypair: %v", err)
152                         }
153                 }
154                 instanceSet.keys[md5keyFingerprint] = keyname
155         }
156         instanceSet.keysMtx.Unlock()
157
158         ec2tags := []*ec2.Tag{
159                 &ec2.Tag{
160                         Key:   aws.String(tagKeyInstanceSetID),
161                         Value: aws.String(string(instanceSet.instanceSetID)),
162                 },
163         }
164         for k, v := range newTags {
165                 ec2tags = append(ec2tags, &ec2.Tag{
166                         Key:   aws.String(k),
167                         Value: aws.String(v),
168                 })
169         }
170
171         rii := ec2.RunInstancesInput{
172                 ImageId:      aws.String(string(imageID)),
173                 InstanceType: &instanceType.ProviderType,
174                 MaxCount:     aws.Int64(1),
175                 MinCount:     aws.Int64(1),
176                 KeyName:      &keyname,
177
178                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
179                         &ec2.InstanceNetworkInterfaceSpecification{
180                                 AssociatePublicIpAddress: aws.Bool(false),
181                                 DeleteOnTermination:      aws.Bool(true),
182                                 DeviceIndex:              aws.Int64(0),
183                                 Groups:                   aws.StringSlice(instanceSet.ec2config.SecurityGroupIDs),
184                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
185                         }},
186                 DisableApiTermination:             aws.Bool(false),
187                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
188                 TagSpecifications: []*ec2.TagSpecification{
189                         &ec2.TagSpecification{
190                                 ResourceType: aws.String("instance"),
191                                 Tags:         ec2tags,
192                         }},
193                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
194         }
195
196         if instanceType.AddedScratch > 0 {
197                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{&ec2.BlockDeviceMapping{
198                         DeviceName: aws.String("/dev/xvdt"),
199                         Ebs: &ec2.EbsBlockDevice{
200                                 DeleteOnTermination: aws.Bool(true),
201                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
202                                 VolumeType:          &instanceSet.ec2config.EBSVolumeType,
203                         }}}
204         }
205
206         if instanceType.Preemptible {
207                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
208                         MarketType: aws.String("spot"),
209                         SpotOptions: &ec2.SpotMarketOptions{
210                                 InstanceInterruptionBehavior: aws.String("terminate"),
211                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
212                         }}
213         }
214
215         rsv, err := instanceSet.client.RunInstances(&rii)
216
217         if err != nil {
218                 return nil, err
219         }
220
221         return &ec2Instance{
222                 provider: instanceSet,
223                 instance: rsv.Instances[0],
224         }, nil
225 }
226
227 func (instanceSet *ec2InstanceSet) Instances(cloud.InstanceTags) (instances []cloud.Instance, err error) {
228         dii := &ec2.DescribeInstancesInput{
229                 Filters: []*ec2.Filter{&ec2.Filter{
230                         Name:   aws.String("tag:" + tagKeyInstanceSetID),
231                         Values: []*string{aws.String(string(instanceSet.instanceSetID))},
232                 }}}
233
234         for {
235                 dio, err := instanceSet.client.DescribeInstances(dii)
236                 if err != nil {
237                         return nil, err
238                 }
239
240                 for _, rsv := range dio.Reservations {
241                         for _, inst := range rsv.Instances {
242                                 if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
243                                         instances = append(instances, &ec2Instance{instanceSet, inst})
244                                 }
245                         }
246                 }
247                 if dio.NextToken == nil {
248                         return instances, err
249                 }
250                 dii.NextToken = dio.NextToken
251         }
252 }
253
254 func (az *ec2InstanceSet) Stop() {
255 }
256
257 type ec2Instance struct {
258         provider *ec2InstanceSet
259         instance *ec2.Instance
260 }
261
262 func (inst *ec2Instance) ID() cloud.InstanceID {
263         return cloud.InstanceID(*inst.instance.InstanceId)
264 }
265
266 func (inst *ec2Instance) String() string {
267         return *inst.instance.InstanceId
268 }
269
270 func (inst *ec2Instance) ProviderType() string {
271         return *inst.instance.InstanceType
272 }
273
274 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
275         ec2tags := []*ec2.Tag{
276                 &ec2.Tag{
277                         Key:   aws.String(tagKeyInstanceSetID),
278                         Value: aws.String(string(inst.provider.instanceSetID)),
279                 },
280         }
281         for k, v := range newTags {
282                 ec2tags = append(ec2tags, &ec2.Tag{
283                         Key:   aws.String(k),
284                         Value: aws.String(v),
285                 })
286         }
287
288         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
289                 Resources: []*string{inst.instance.InstanceId},
290                 Tags:      ec2tags,
291         })
292
293         return err
294 }
295
296 func (inst *ec2Instance) Tags() cloud.InstanceTags {
297         tags := make(map[string]string)
298
299         for _, t := range inst.instance.Tags {
300                 tags[*t.Key] = *t.Value
301         }
302
303         return tags
304 }
305
306 func (inst *ec2Instance) Destroy() error {
307         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
308                 InstanceIds: []*string{inst.instance.InstanceId},
309         })
310         return err
311 }
312
313 func (inst *ec2Instance) Address() string {
314         if inst.instance.PrivateIpAddress != nil {
315                 return *inst.instance.PrivateIpAddress
316         } else {
317                 return ""
318         }
319 }
320
321 func (inst *ec2Instance) RemoteUser() string {
322         return inst.provider.ec2config.AdminUsername
323 }
324
325 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
326         return cloud.ErrNotImplemented
327 }