Merge branch '15028-cwl-v1.1' refs #15028
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "crypto/md5"
9         "crypto/rsa"
10         "crypto/sha1"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/json"
14         "fmt"
15         "math/big"
16         "sync"
17
18         "git.curoverse.com/arvados.git/lib/cloud"
19         "git.curoverse.com/arvados.git/sdk/go/arvados"
20         "github.com/aws/aws-sdk-go/aws"
21         "github.com/aws/aws-sdk-go/aws/credentials"
22         "github.com/aws/aws-sdk-go/aws/session"
23         "github.com/aws/aws-sdk-go/service/ec2"
24         "github.com/sirupsen/logrus"
25         "golang.org/x/crypto/ssh"
26 )
27
28 // Driver is the ec2 implementation of the cloud.Driver interface.
29 var Driver = cloud.DriverFunc(newEC2InstanceSet)
30
31 type ec2InstanceSetConfig struct {
32         AccessKeyID      string
33         SecretAccessKey  string
34         Region           string
35         SecurityGroupIDs []string
36         SubnetID         string
37         AdminUsername    string
38         EBSVolumeType    string
39 }
40
41 type ec2Interface interface {
42         DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
43         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
44         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
45         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
46         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
47         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
48 }
49
50 type ec2InstanceSet struct {
51         ec2config     ec2InstanceSetConfig
52         instanceSetID cloud.InstanceSetID
53         logger        logrus.FieldLogger
54         client        ec2Interface
55         keysMtx       sync.Mutex
56         keys          map[string]string
57 }
58
59 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
60         instanceSet := &ec2InstanceSet{
61                 instanceSetID: instanceSetID,
62                 logger:        logger,
63         }
64         err = json.Unmarshal(config, &instanceSet.ec2config)
65         if err != nil {
66                 return nil, err
67         }
68         awsConfig := aws.NewConfig().
69                 WithCredentials(credentials.NewStaticCredentials(
70                         instanceSet.ec2config.AccessKeyID,
71                         instanceSet.ec2config.SecretAccessKey,
72                         "")).
73                 WithRegion(instanceSet.ec2config.Region)
74         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
75         instanceSet.keys = make(map[string]string)
76         if instanceSet.ec2config.EBSVolumeType == "" {
77                 instanceSet.ec2config.EBSVolumeType = "gp2"
78         }
79         return instanceSet, nil
80 }
81
82 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
83         // AWS key fingerprints don't use the usual key fingerprint
84         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
85         // (you can get that from md5.Sum(pk.Marshal())
86         //
87         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
88         // public key, so calculate those fingerprints here.
89         var rsaPub struct {
90                 Name string
91                 E    *big.Int
92                 N    *big.Int
93         }
94         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
95                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
96         }
97         rsaPk := rsa.PublicKey{
98                 E: int(rsaPub.E.Int64()),
99                 N: rsaPub.N,
100         }
101         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
102         md5pkix := md5.Sum([]byte(pkix))
103         sha1pkix := sha1.Sum([]byte(pkix))
104         md5fp = ""
105         sha1fp = ""
106         for i := 0; i < len(md5pkix); i += 1 {
107                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
108         }
109         for i := 0; i < len(sha1pkix); i += 1 {
110                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
111         }
112         return md5fp[1:], sha1fp[1:], nil
113 }
114
115 func (instanceSet *ec2InstanceSet) Create(
116         instanceType arvados.InstanceType,
117         imageID cloud.ImageID,
118         newTags cloud.InstanceTags,
119         initCommand cloud.InitCommand,
120         publicKey ssh.PublicKey) (cloud.Instance, error) {
121
122         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
123         if err != nil {
124                 return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
125         }
126         instanceSet.keysMtx.Lock()
127         var keyname string
128         var ok bool
129         if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
130                 keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
131                         Filters: []*ec2.Filter{&ec2.Filter{
132                                 Name:   aws.String("fingerprint"),
133                                 Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
134                         }},
135                 })
136                 if err != nil {
137                         return nil, fmt.Errorf("Could not search for keypair: %v", err)
138                 }
139
140                 if len(keyout.KeyPairs) > 0 {
141                         keyname = *(keyout.KeyPairs[0].KeyName)
142                 } else {
143                         keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
144                         _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
145                                 KeyName:           &keyname,
146                                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
147                         })
148                         if err != nil {
149                                 return nil, fmt.Errorf("Could not import keypair: %v", err)
150                         }
151                 }
152                 instanceSet.keys[md5keyFingerprint] = keyname
153         }
154         instanceSet.keysMtx.Unlock()
155
156         ec2tags := []*ec2.Tag{}
157         for k, v := range newTags {
158                 ec2tags = append(ec2tags, &ec2.Tag{
159                         Key:   aws.String(k),
160                         Value: aws.String(v),
161                 })
162         }
163
164         rii := ec2.RunInstancesInput{
165                 ImageId:      aws.String(string(imageID)),
166                 InstanceType: &instanceType.ProviderType,
167                 MaxCount:     aws.Int64(1),
168                 MinCount:     aws.Int64(1),
169                 KeyName:      &keyname,
170
171                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
172                         &ec2.InstanceNetworkInterfaceSpecification{
173                                 AssociatePublicIpAddress: aws.Bool(false),
174                                 DeleteOnTermination:      aws.Bool(true),
175                                 DeviceIndex:              aws.Int64(0),
176                                 Groups:                   aws.StringSlice(instanceSet.ec2config.SecurityGroupIDs),
177                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
178                         }},
179                 DisableApiTermination:             aws.Bool(false),
180                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
181                 TagSpecifications: []*ec2.TagSpecification{
182                         &ec2.TagSpecification{
183                                 ResourceType: aws.String("instance"),
184                                 Tags:         ec2tags,
185                         }},
186                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
187         }
188
189         if instanceType.AddedScratch > 0 {
190                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{&ec2.BlockDeviceMapping{
191                         DeviceName: aws.String("/dev/xvdt"),
192                         Ebs: &ec2.EbsBlockDevice{
193                                 DeleteOnTermination: aws.Bool(true),
194                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
195                                 VolumeType:          &instanceSet.ec2config.EBSVolumeType,
196                         }}}
197         }
198
199         if instanceType.Preemptible {
200                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
201                         MarketType: aws.String("spot"),
202                         SpotOptions: &ec2.SpotMarketOptions{
203                                 InstanceInterruptionBehavior: aws.String("terminate"),
204                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
205                         }}
206         }
207
208         rsv, err := instanceSet.client.RunInstances(&rii)
209
210         if err != nil {
211                 return nil, err
212         }
213
214         return &ec2Instance{
215                 provider: instanceSet,
216                 instance: rsv.Instances[0],
217         }, nil
218 }
219
220 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
221         var filters []*ec2.Filter
222         for k, v := range tags {
223                 filters = append(filters, &ec2.Filter{
224                         Name:   aws.String("tag:" + k),
225                         Values: []*string{aws.String(v)},
226                 })
227         }
228         dii := &ec2.DescribeInstancesInput{Filters: filters}
229         for {
230                 dio, err := instanceSet.client.DescribeInstances(dii)
231                 if err != nil {
232                         return nil, err
233                 }
234
235                 for _, rsv := range dio.Reservations {
236                         for _, inst := range rsv.Instances {
237                                 if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
238                                         instances = append(instances, &ec2Instance{instanceSet, inst})
239                                 }
240                         }
241                 }
242                 if dio.NextToken == nil {
243                         return instances, err
244                 }
245                 dii.NextToken = dio.NextToken
246         }
247 }
248
249 func (az *ec2InstanceSet) Stop() {
250 }
251
252 type ec2Instance struct {
253         provider *ec2InstanceSet
254         instance *ec2.Instance
255 }
256
257 func (inst *ec2Instance) ID() cloud.InstanceID {
258         return cloud.InstanceID(*inst.instance.InstanceId)
259 }
260
261 func (inst *ec2Instance) String() string {
262         return *inst.instance.InstanceId
263 }
264
265 func (inst *ec2Instance) ProviderType() string {
266         return *inst.instance.InstanceType
267 }
268
269 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
270         var ec2tags []*ec2.Tag
271         for k, v := range newTags {
272                 ec2tags = append(ec2tags, &ec2.Tag{
273                         Key:   aws.String(k),
274                         Value: aws.String(v),
275                 })
276         }
277
278         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
279                 Resources: []*string{inst.instance.InstanceId},
280                 Tags:      ec2tags,
281         })
282
283         return err
284 }
285
286 func (inst *ec2Instance) Tags() cloud.InstanceTags {
287         tags := make(map[string]string)
288
289         for _, t := range inst.instance.Tags {
290                 tags[*t.Key] = *t.Value
291         }
292
293         return tags
294 }
295
296 func (inst *ec2Instance) Destroy() error {
297         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
298                 InstanceIds: []*string{inst.instance.InstanceId},
299         })
300         return err
301 }
302
303 func (inst *ec2Instance) Address() string {
304         if inst.instance.PrivateIpAddress != nil {
305                 return *inst.instance.PrivateIpAddress
306         } else {
307                 return ""
308         }
309 }
310
311 func (inst *ec2Instance) RemoteUser() string {
312         return inst.provider.ec2config.AdminUsername
313 }
314
315 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
316         return cloud.ErrNotImplemented
317 }