Fix 2.4.2 upgrade notes formatting refs #19330
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "crypto/md5"
9         "crypto/rsa"
10         "crypto/sha1"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/json"
14         "fmt"
15         "math/big"
16         "sync"
17         "sync/atomic"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/aws/aws-sdk-go/aws"
23         "github.com/aws/aws-sdk-go/aws/awserr"
24         "github.com/aws/aws-sdk-go/aws/credentials"
25         "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
26         "github.com/aws/aws-sdk-go/aws/ec2metadata"
27         "github.com/aws/aws-sdk-go/aws/request"
28         "github.com/aws/aws-sdk-go/aws/session"
29         "github.com/aws/aws-sdk-go/service/ec2"
30         "github.com/sirupsen/logrus"
31         "golang.org/x/crypto/ssh"
32 )
33
34 // Driver is the ec2 implementation of the cloud.Driver interface.
35 var Driver = cloud.DriverFunc(newEC2InstanceSet)
36
37 const (
38         throttleDelayMin = time.Second
39         throttleDelayMax = time.Minute
40 )
41
42 type ec2InstanceSetConfig struct {
43         AccessKeyID        string
44         SecretAccessKey    string
45         Region             string
46         SecurityGroupIDs   arvados.StringSet
47         SubnetID           string
48         AdminUsername      string
49         EBSVolumeType      string
50         IAMInstanceProfile string
51 }
52
53 type ec2Interface interface {
54         DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
55         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
56         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
57         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
58         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
59         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
60 }
61
62 type ec2InstanceSet struct {
63         ec2config              ec2InstanceSetConfig
64         instanceSetID          cloud.InstanceSetID
65         logger                 logrus.FieldLogger
66         client                 ec2Interface
67         keysMtx                sync.Mutex
68         keys                   map[string]string
69         throttleDelayCreate    atomic.Value
70         throttleDelayInstances atomic.Value
71 }
72
73 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
74         instanceSet := &ec2InstanceSet{
75                 instanceSetID: instanceSetID,
76                 logger:        logger,
77         }
78         err = json.Unmarshal(config, &instanceSet.ec2config)
79         if err != nil {
80                 return nil, err
81         }
82
83         sess, err := session.NewSession()
84         if err != nil {
85                 return nil, err
86         }
87         // First try any static credentials, fall back to an IAM instance profile/role
88         creds := credentials.NewChainCredentials(
89                 []credentials.Provider{
90                         &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: instanceSet.ec2config.AccessKeyID, SecretAccessKey: instanceSet.ec2config.SecretAccessKey}},
91                         &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)},
92                 })
93
94         awsConfig := aws.NewConfig().WithCredentials(creds).WithRegion(instanceSet.ec2config.Region)
95         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
96         instanceSet.keys = make(map[string]string)
97         if instanceSet.ec2config.EBSVolumeType == "" {
98                 instanceSet.ec2config.EBSVolumeType = "gp2"
99         }
100         return instanceSet, nil
101 }
102
103 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
104         // AWS key fingerprints don't use the usual key fingerprint
105         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
106         // (you can get that from md5.Sum(pk.Marshal())
107         //
108         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
109         // public key, so calculate those fingerprints here.
110         var rsaPub struct {
111                 Name string
112                 E    *big.Int
113                 N    *big.Int
114         }
115         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
116                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
117         }
118         rsaPk := rsa.PublicKey{
119                 E: int(rsaPub.E.Int64()),
120                 N: rsaPub.N,
121         }
122         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
123         md5pkix := md5.Sum([]byte(pkix))
124         sha1pkix := sha1.Sum([]byte(pkix))
125         md5fp = ""
126         sha1fp = ""
127         for i := 0; i < len(md5pkix); i++ {
128                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
129         }
130         for i := 0; i < len(sha1pkix); i++ {
131                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
132         }
133         return md5fp[1:], sha1fp[1:], nil
134 }
135
136 func (instanceSet *ec2InstanceSet) Create(
137         instanceType arvados.InstanceType,
138         imageID cloud.ImageID,
139         newTags cloud.InstanceTags,
140         initCommand cloud.InitCommand,
141         publicKey ssh.PublicKey) (cloud.Instance, error) {
142
143         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
144         if err != nil {
145                 return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
146         }
147         instanceSet.keysMtx.Lock()
148         var keyname string
149         var ok bool
150         if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
151                 keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
152                         Filters: []*ec2.Filter{{
153                                 Name:   aws.String("fingerprint"),
154                                 Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
155                         }},
156                 })
157                 if err != nil {
158                         return nil, fmt.Errorf("Could not search for keypair: %v", err)
159                 }
160
161                 if len(keyout.KeyPairs) > 0 {
162                         keyname = *(keyout.KeyPairs[0].KeyName)
163                 } else {
164                         keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
165                         _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
166                                 KeyName:           &keyname,
167                                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
168                         })
169                         if err != nil {
170                                 return nil, fmt.Errorf("Could not import keypair: %v", err)
171                         }
172                 }
173                 instanceSet.keys[md5keyFingerprint] = keyname
174         }
175         instanceSet.keysMtx.Unlock()
176
177         ec2tags := []*ec2.Tag{}
178         for k, v := range newTags {
179                 ec2tags = append(ec2tags, &ec2.Tag{
180                         Key:   aws.String(k),
181                         Value: aws.String(v),
182                 })
183         }
184
185         var groups []string
186         for sg := range instanceSet.ec2config.SecurityGroupIDs {
187                 groups = append(groups, sg)
188         }
189
190         rii := ec2.RunInstancesInput{
191                 ImageId:      aws.String(string(imageID)),
192                 InstanceType: &instanceType.ProviderType,
193                 MaxCount:     aws.Int64(1),
194                 MinCount:     aws.Int64(1),
195                 KeyName:      &keyname,
196
197                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
198                         {
199                                 AssociatePublicIpAddress: aws.Bool(false),
200                                 DeleteOnTermination:      aws.Bool(true),
201                                 DeviceIndex:              aws.Int64(0),
202                                 Groups:                   aws.StringSlice(groups),
203                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
204                         }},
205                 DisableApiTermination:             aws.Bool(false),
206                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
207                 TagSpecifications: []*ec2.TagSpecification{
208                         {
209                                 ResourceType: aws.String("instance"),
210                                 Tags:         ec2tags,
211                         }},
212                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
213         }
214
215         if instanceType.AddedScratch > 0 {
216                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
217                         DeviceName: aws.String("/dev/xvdt"),
218                         Ebs: &ec2.EbsBlockDevice{
219                                 DeleteOnTermination: aws.Bool(true),
220                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
221                                 VolumeType:          &instanceSet.ec2config.EBSVolumeType,
222                         }}}
223         }
224
225         if instanceType.Preemptible {
226                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
227                         MarketType: aws.String("spot"),
228                         SpotOptions: &ec2.SpotMarketOptions{
229                                 InstanceInterruptionBehavior: aws.String("terminate"),
230                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
231                         }}
232         }
233
234         if instanceSet.ec2config.IAMInstanceProfile != "" {
235                 rii.IamInstanceProfile = &ec2.IamInstanceProfileSpecification{
236                         Name: aws.String(instanceSet.ec2config.IAMInstanceProfile),
237                 }
238         }
239
240         rsv, err := instanceSet.client.RunInstances(&rii)
241         err = wrapError(err, &instanceSet.throttleDelayCreate)
242         if err != nil {
243                 return nil, err
244         }
245
246         return &ec2Instance{
247                 provider: instanceSet,
248                 instance: rsv.Instances[0],
249         }, nil
250 }
251
252 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
253         var filters []*ec2.Filter
254         for k, v := range tags {
255                 filters = append(filters, &ec2.Filter{
256                         Name:   aws.String("tag:" + k),
257                         Values: []*string{aws.String(v)},
258                 })
259         }
260         dii := &ec2.DescribeInstancesInput{Filters: filters}
261         for {
262                 dio, err := instanceSet.client.DescribeInstances(dii)
263                 err = wrapError(err, &instanceSet.throttleDelayInstances)
264                 if err != nil {
265                         return nil, err
266                 }
267
268                 for _, rsv := range dio.Reservations {
269                         for _, inst := range rsv.Instances {
270                                 if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
271                                         instances = append(instances, &ec2Instance{instanceSet, inst})
272                                 }
273                         }
274                 }
275                 if dio.NextToken == nil {
276                         return instances, err
277                 }
278                 dii.NextToken = dio.NextToken
279         }
280 }
281
282 func (instanceSet *ec2InstanceSet) Stop() {
283 }
284
285 type ec2Instance struct {
286         provider *ec2InstanceSet
287         instance *ec2.Instance
288 }
289
290 func (inst *ec2Instance) ID() cloud.InstanceID {
291         return cloud.InstanceID(*inst.instance.InstanceId)
292 }
293
294 func (inst *ec2Instance) String() string {
295         return *inst.instance.InstanceId
296 }
297
298 func (inst *ec2Instance) ProviderType() string {
299         return *inst.instance.InstanceType
300 }
301
302 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
303         var ec2tags []*ec2.Tag
304         for k, v := range newTags {
305                 ec2tags = append(ec2tags, &ec2.Tag{
306                         Key:   aws.String(k),
307                         Value: aws.String(v),
308                 })
309         }
310
311         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
312                 Resources: []*string{inst.instance.InstanceId},
313                 Tags:      ec2tags,
314         })
315
316         return err
317 }
318
319 func (inst *ec2Instance) Tags() cloud.InstanceTags {
320         tags := make(map[string]string)
321
322         for _, t := range inst.instance.Tags {
323                 tags[*t.Key] = *t.Value
324         }
325
326         return tags
327 }
328
329 func (inst *ec2Instance) Destroy() error {
330         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
331                 InstanceIds: []*string{inst.instance.InstanceId},
332         })
333         return err
334 }
335
336 func (inst *ec2Instance) Address() string {
337         if inst.instance.PrivateIpAddress != nil {
338                 return *inst.instance.PrivateIpAddress
339         }
340         return ""
341 }
342
343 func (inst *ec2Instance) RemoteUser() string {
344         return inst.provider.ec2config.AdminUsername
345 }
346
347 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
348         return cloud.ErrNotImplemented
349 }
350
351 type rateLimitError struct {
352         error
353         earliestRetry time.Time
354 }
355
356 func (err rateLimitError) EarliestRetry() time.Time {
357         return err.earliestRetry
358 }
359
360 var isCodeCapacity = map[string]bool{
361         "InsufficientInstanceCapacity": true,
362         "VcpuLimitExceeded":            true,
363         "MaxSpotInstanceCountExceeded": true,
364 }
365
366 // isErrorCapacity returns whether the error is to be throttled based on its code.
367 // Returns false if error is nil.
368 func isErrorCapacity(err error) bool {
369         if aerr, ok := err.(awserr.Error); ok && aerr != nil {
370                 if _, ok := isCodeCapacity[aerr.Code()]; ok {
371                         return true
372                 }
373         }
374         return false
375 }
376
377 type ec2QuotaError struct {
378         error
379 }
380
381 func (er *ec2QuotaError) IsQuotaError() bool {
382         return true
383 }
384
385 func wrapError(err error, throttleValue *atomic.Value) error {
386         if request.IsErrorThrottle(err) {
387                 // Back off exponentially until an upstream call
388                 // either succeeds or returns a non-throttle error.
389                 d, _ := throttleValue.Load().(time.Duration)
390                 d = d*3/2 + time.Second
391                 if d < throttleDelayMin {
392                         d = throttleDelayMin
393                 } else if d > throttleDelayMax {
394                         d = throttleDelayMax
395                 }
396                 throttleValue.Store(d)
397                 return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
398         } else if isErrorCapacity(err) {
399                 return &ec2QuotaError{err}
400         } else if err != nil {
401                 throttleValue.Store(time.Duration(0))
402                 return err
403         }
404         throttleValue.Store(time.Duration(0))
405         return nil
406 }