20755: Support multiple/alternate subnets on EC2.
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "crypto/md5"
9         "crypto/rsa"
10         "crypto/sha1"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/json"
14         "fmt"
15         "math/big"
16         "strconv"
17         "strings"
18         "sync"
19         "sync/atomic"
20         "time"
21
22         "git.arvados.org/arvados.git/lib/cloud"
23         "git.arvados.org/arvados.git/sdk/go/arvados"
24         "github.com/aws/aws-sdk-go/aws"
25         "github.com/aws/aws-sdk-go/aws/awserr"
26         "github.com/aws/aws-sdk-go/aws/credentials"
27         "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
28         "github.com/aws/aws-sdk-go/aws/ec2metadata"
29         "github.com/aws/aws-sdk-go/aws/request"
30         "github.com/aws/aws-sdk-go/aws/session"
31         "github.com/aws/aws-sdk-go/service/ec2"
32         "github.com/sirupsen/logrus"
33         "golang.org/x/crypto/ssh"
34 )
35
36 // Driver is the ec2 implementation of the cloud.Driver interface.
37 var Driver = cloud.DriverFunc(newEC2InstanceSet)
38
39 const (
40         throttleDelayMin = time.Second
41         throttleDelayMax = time.Minute
42 )
43
44 type ec2InstanceSetConfig struct {
45         AccessKeyID             string
46         SecretAccessKey         string
47         Region                  string
48         SecurityGroupIDs        arvados.StringSet
49         SubnetID                sliceOrSingleString
50         AdminUsername           string
51         EBSVolumeType           string
52         EBSPrice                float64
53         IAMInstanceProfile      string
54         SpotPriceUpdateInterval arvados.Duration
55 }
56
57 type sliceOrSingleString []string
58
59 // UnmarshalJSON unmarshals an array of strings, and also accepts ""
60 // as [], and "foo" as ["foo"].
61 func (ss *sliceOrSingleString) UnmarshalJSON(data []byte) error {
62         if len(data) == 0 {
63                 *ss = nil
64         } else if data[0] == '[' {
65                 var slice []string
66                 err := json.Unmarshal(data, &slice)
67                 if err != nil {
68                         return err
69                 }
70                 if len(slice) == 0 {
71                         *ss = nil
72                 } else {
73                         *ss = slice
74                 }
75         } else {
76                 var str string
77                 err := json.Unmarshal(data, &str)
78                 if err != nil {
79                         return err
80                 }
81                 if str == "" {
82                         *ss = nil
83                 } else {
84                         *ss = []string{str}
85                 }
86         }
87         return nil
88 }
89
90 type ec2Interface interface {
91         DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
92         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
93         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
94         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
95         DescribeInstanceStatusPages(input *ec2.DescribeInstanceStatusInput, fn func(*ec2.DescribeInstanceStatusOutput, bool) bool) error
96         DescribeSpotPriceHistoryPages(input *ec2.DescribeSpotPriceHistoryInput, fn func(*ec2.DescribeSpotPriceHistoryOutput, bool) bool) error
97         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
98         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
99 }
100
101 type ec2InstanceSet struct {
102         ec2config              ec2InstanceSetConfig
103         currentSubnetIDIndex   int32
104         instanceSetID          cloud.InstanceSetID
105         logger                 logrus.FieldLogger
106         client                 ec2Interface
107         keysMtx                sync.Mutex
108         keys                   map[string]string
109         throttleDelayCreate    atomic.Value
110         throttleDelayInstances atomic.Value
111
112         prices        map[priceKey][]cloud.InstancePrice
113         pricesLock    sync.Mutex
114         pricesUpdated map[priceKey]time.Time
115 }
116
117 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
118         instanceSet := &ec2InstanceSet{
119                 instanceSetID: instanceSetID,
120                 logger:        logger,
121         }
122         err = json.Unmarshal(config, &instanceSet.ec2config)
123         if err != nil {
124                 return nil, err
125         }
126
127         sess, err := session.NewSession()
128         if err != nil {
129                 return nil, err
130         }
131         // First try any static credentials, fall back to an IAM instance profile/role
132         creds := credentials.NewChainCredentials(
133                 []credentials.Provider{
134                         &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: instanceSet.ec2config.AccessKeyID, SecretAccessKey: instanceSet.ec2config.SecretAccessKey}},
135                         &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)},
136                 })
137
138         awsConfig := aws.NewConfig().WithCredentials(creds).WithRegion(instanceSet.ec2config.Region)
139         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
140         instanceSet.keys = make(map[string]string)
141         if instanceSet.ec2config.EBSVolumeType == "" {
142                 instanceSet.ec2config.EBSVolumeType = "gp2"
143         }
144         return instanceSet, nil
145 }
146
147 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
148         // AWS key fingerprints don't use the usual key fingerprint
149         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
150         // (you can get that from md5.Sum(pk.Marshal())
151         //
152         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
153         // public key, so calculate those fingerprints here.
154         var rsaPub struct {
155                 Name string
156                 E    *big.Int
157                 N    *big.Int
158         }
159         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
160                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
161         }
162         rsaPk := rsa.PublicKey{
163                 E: int(rsaPub.E.Int64()),
164                 N: rsaPub.N,
165         }
166         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
167         md5pkix := md5.Sum([]byte(pkix))
168         sha1pkix := sha1.Sum([]byte(pkix))
169         md5fp = ""
170         sha1fp = ""
171         for i := 0; i < len(md5pkix); i++ {
172                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
173         }
174         for i := 0; i < len(sha1pkix); i++ {
175                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
176         }
177         return md5fp[1:], sha1fp[1:], nil
178 }
179
180 func (instanceSet *ec2InstanceSet) Create(
181         instanceType arvados.InstanceType,
182         imageID cloud.ImageID,
183         newTags cloud.InstanceTags,
184         initCommand cloud.InitCommand,
185         publicKey ssh.PublicKey) (cloud.Instance, error) {
186
187         ec2tags := []*ec2.Tag{}
188         for k, v := range newTags {
189                 ec2tags = append(ec2tags, &ec2.Tag{
190                         Key:   aws.String(k),
191                         Value: aws.String(v),
192                 })
193         }
194
195         var groups []string
196         for sg := range instanceSet.ec2config.SecurityGroupIDs {
197                 groups = append(groups, sg)
198         }
199
200         rii := ec2.RunInstancesInput{
201                 ImageId:      aws.String(string(imageID)),
202                 InstanceType: &instanceType.ProviderType,
203                 MaxCount:     aws.Int64(1),
204                 MinCount:     aws.Int64(1),
205
206                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
207                         {
208                                 AssociatePublicIpAddress: aws.Bool(false),
209                                 DeleteOnTermination:      aws.Bool(true),
210                                 DeviceIndex:              aws.Int64(0),
211                                 Groups:                   aws.StringSlice(groups),
212                         }},
213                 DisableApiTermination:             aws.Bool(false),
214                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
215                 TagSpecifications: []*ec2.TagSpecification{
216                         {
217                                 ResourceType: aws.String("instance"),
218                                 Tags:         ec2tags,
219                         }},
220                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
221         }
222
223         if publicKey != nil {
224                 keyname, err := instanceSet.getKeyName(publicKey)
225                 if err != nil {
226                         return nil, err
227                 }
228                 rii.KeyName = &keyname
229         }
230
231         if instanceType.AddedScratch > 0 {
232                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
233                         DeviceName: aws.String("/dev/xvdt"),
234                         Ebs: &ec2.EbsBlockDevice{
235                                 DeleteOnTermination: aws.Bool(true),
236                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
237                                 VolumeType:          &instanceSet.ec2config.EBSVolumeType,
238                         }}}
239         }
240
241         if instanceType.Preemptible {
242                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
243                         MarketType: aws.String("spot"),
244                         SpotOptions: &ec2.SpotMarketOptions{
245                                 InstanceInterruptionBehavior: aws.String("terminate"),
246                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
247                         }}
248         }
249
250         if instanceSet.ec2config.IAMInstanceProfile != "" {
251                 rii.IamInstanceProfile = &ec2.IamInstanceProfileSpecification{
252                         Name: aws.String(instanceSet.ec2config.IAMInstanceProfile),
253                 }
254         }
255
256         var rsv *ec2.Reservation
257         var err error
258         subnets := instanceSet.ec2config.SubnetID
259         currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex))
260         for tryOffset := 0; ; tryOffset++ {
261                 tryIndex := 0
262                 if len(subnets) > 0 {
263                         tryIndex = (currentSubnetIDIndex + tryOffset) % len(subnets)
264                         rii.NetworkInterfaces[0].SubnetId = aws.String(subnets[tryIndex])
265                 }
266                 rsv, err = instanceSet.client.RunInstances(&rii)
267                 if isErrorSubnetSpecific(err) &&
268                         tryOffset < len(subnets)-1 {
269                         instanceSet.logger.WithError(err).WithField("SubnetID", subnets[tryIndex]).
270                                 Warn("RunInstances failed, trying next subnet")
271                         continue
272                 }
273                 // Succeeded, or exhausted all subnets, or got a
274                 // non-subnet-related error.
275                 //
276                 // We intentionally update currentSubnetIDIndex even
277                 // in the non-retryable-failure case here to avoid a
278                 // situation where successive calls to Create() keep
279                 // returning errors for the same subnet (perhaps
280                 // "subnet full") and never reveal the errors for the
281                 // other configured subnets (perhaps "subnet ID
282                 // invalid").
283                 atomic.StoreInt32(&instanceSet.currentSubnetIDIndex, int32(tryIndex))
284                 break
285         }
286         err = wrapError(err, &instanceSet.throttleDelayCreate)
287         if err != nil {
288                 return nil, err
289         }
290         return &ec2Instance{
291                 provider: instanceSet,
292                 instance: rsv.Instances[0],
293         }, nil
294 }
295
296 func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) {
297         instanceSet.keysMtx.Lock()
298         defer instanceSet.keysMtx.Unlock()
299         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
300         if err != nil {
301                 return "", fmt.Errorf("Could not make key fingerprint: %v", err)
302         }
303         if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok {
304                 return keyname, nil
305         }
306         keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
307                 Filters: []*ec2.Filter{{
308                         Name:   aws.String("fingerprint"),
309                         Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
310                 }},
311         })
312         if err != nil {
313                 return "", fmt.Errorf("Could not search for keypair: %v", err)
314         }
315         if len(keyout.KeyPairs) > 0 {
316                 return *(keyout.KeyPairs[0].KeyName), nil
317         }
318         keyname := "arvados-dispatch-keypair-" + md5keyFingerprint
319         _, err = instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
320                 KeyName:           &keyname,
321                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
322         })
323         if err != nil {
324                 return "", fmt.Errorf("Could not import keypair: %v", err)
325         }
326         instanceSet.keys[md5keyFingerprint] = keyname
327         return keyname, nil
328 }
329
330 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
331         var filters []*ec2.Filter
332         for k, v := range tags {
333                 filters = append(filters, &ec2.Filter{
334                         Name:   aws.String("tag:" + k),
335                         Values: []*string{aws.String(v)},
336                 })
337         }
338         needAZs := false
339         dii := &ec2.DescribeInstancesInput{Filters: filters}
340         for {
341                 dio, err := instanceSet.client.DescribeInstances(dii)
342                 err = wrapError(err, &instanceSet.throttleDelayInstances)
343                 if err != nil {
344                         return nil, err
345                 }
346
347                 for _, rsv := range dio.Reservations {
348                         for _, inst := range rsv.Instances {
349                                 if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
350                                         instances = append(instances, &ec2Instance{
351                                                 provider: instanceSet,
352                                                 instance: inst,
353                                         })
354                                         if aws.StringValue(inst.InstanceLifecycle) == "spot" {
355                                                 needAZs = true
356                                         }
357                                 }
358                         }
359                 }
360                 if dio.NextToken == nil {
361                         break
362                 }
363                 dii.NextToken = dio.NextToken
364         }
365         if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 {
366                 az := map[string]string{}
367                 err := instanceSet.client.DescribeInstanceStatusPages(&ec2.DescribeInstanceStatusInput{
368                         IncludeAllInstances: aws.Bool(true),
369                 }, func(page *ec2.DescribeInstanceStatusOutput, lastPage bool) bool {
370                         for _, ent := range page.InstanceStatuses {
371                                 az[*ent.InstanceId] = *ent.AvailabilityZone
372                         }
373                         return true
374                 })
375                 if err != nil {
376                         instanceSet.logger.Warnf("error getting instance statuses: %s", err)
377                 }
378                 for _, inst := range instances {
379                         inst := inst.(*ec2Instance)
380                         inst.availabilityZone = az[*inst.instance.InstanceId]
381                 }
382                 instanceSet.updateSpotPrices(instances)
383         }
384         return instances, err
385 }
386
387 type priceKey struct {
388         instanceType     string
389         spot             bool
390         availabilityZone string
391 }
392
393 // Refresh recent spot instance pricing data for the given instances,
394 // unless we already have recent pricing data for all relevant types.
395 func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) {
396         if len(instances) == 0 {
397                 return
398         }
399
400         instanceSet.pricesLock.Lock()
401         defer instanceSet.pricesLock.Unlock()
402         if instanceSet.prices == nil {
403                 instanceSet.prices = map[priceKey][]cloud.InstancePrice{}
404                 instanceSet.pricesUpdated = map[priceKey]time.Time{}
405         }
406
407         updateTime := time.Now()
408         staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
409         needUpdate := false
410         allTypes := map[string]bool{}
411
412         for _, inst := range instances {
413                 ec2inst := inst.(*ec2Instance).instance
414                 if aws.StringValue(ec2inst.InstanceLifecycle) == "spot" {
415                         pk := priceKey{
416                                 instanceType:     *ec2inst.InstanceType,
417                                 spot:             true,
418                                 availabilityZone: inst.(*ec2Instance).availabilityZone,
419                         }
420                         if instanceSet.pricesUpdated[pk].Before(staleTime) {
421                                 needUpdate = true
422                         }
423                         allTypes[*ec2inst.InstanceType] = true
424                 }
425         }
426         if !needUpdate {
427                 return
428         }
429         var typeFilterValues []*string
430         for instanceType := range allTypes {
431                 typeFilterValues = append(typeFilterValues, aws.String(instanceType))
432         }
433         // Get 3x update interval worth of pricing data. (Ideally the
434         // AWS API would tell us "we have shown you all of the price
435         // changes up to time T", but it doesn't, so we'll just ask
436         // for 3 intervals worth of data on each update, de-duplicate
437         // the data points, and not worry too much about occasionally
438         // missing some data points when our lookups fail twice in a
439         // row.
440         dsphi := &ec2.DescribeSpotPriceHistoryInput{
441                 StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())),
442                 Filters: []*ec2.Filter{
443                         &ec2.Filter{Name: aws.String("instance-type"), Values: typeFilterValues},
444                         &ec2.Filter{Name: aws.String("product-description"), Values: []*string{aws.String("Linux/UNIX")}},
445                 },
446         }
447         err := instanceSet.client.DescribeSpotPriceHistoryPages(dsphi, func(page *ec2.DescribeSpotPriceHistoryOutput, lastPage bool) bool {
448                 for _, ent := range page.SpotPriceHistory {
449                         if ent.InstanceType == nil || ent.SpotPrice == nil || ent.Timestamp == nil {
450                                 // bogus record?
451                                 continue
452                         }
453                         price, err := strconv.ParseFloat(*ent.SpotPrice, 64)
454                         if err != nil {
455                                 // bogus record?
456                                 continue
457                         }
458                         pk := priceKey{
459                                 instanceType:     *ent.InstanceType,
460                                 spot:             true,
461                                 availabilityZone: *ent.AvailabilityZone,
462                         }
463                         instanceSet.prices[pk] = append(instanceSet.prices[pk], cloud.InstancePrice{
464                                 StartTime: *ent.Timestamp,
465                                 Price:     price,
466                         })
467                         instanceSet.pricesUpdated[pk] = updateTime
468                 }
469                 return true
470         })
471         if err != nil {
472                 instanceSet.logger.Warnf("error retrieving spot instance prices: %s", err)
473         }
474
475         expiredTime := updateTime.Add(-64 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
476         for pk, last := range instanceSet.pricesUpdated {
477                 if last.Before(expiredTime) {
478                         delete(instanceSet.pricesUpdated, pk)
479                         delete(instanceSet.prices, pk)
480                 }
481         }
482         for pk, prices := range instanceSet.prices {
483                 instanceSet.prices[pk] = cloud.NormalizePriceHistory(prices)
484         }
485 }
486
487 func (instanceSet *ec2InstanceSet) Stop() {
488 }
489
490 type ec2Instance struct {
491         provider         *ec2InstanceSet
492         instance         *ec2.Instance
493         availabilityZone string // sometimes available for spot instances
494 }
495
496 func (inst *ec2Instance) ID() cloud.InstanceID {
497         return cloud.InstanceID(*inst.instance.InstanceId)
498 }
499
500 func (inst *ec2Instance) String() string {
501         return *inst.instance.InstanceId
502 }
503
504 func (inst *ec2Instance) ProviderType() string {
505         return *inst.instance.InstanceType
506 }
507
508 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
509         var ec2tags []*ec2.Tag
510         for k, v := range newTags {
511                 ec2tags = append(ec2tags, &ec2.Tag{
512                         Key:   aws.String(k),
513                         Value: aws.String(v),
514                 })
515         }
516
517         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
518                 Resources: []*string{inst.instance.InstanceId},
519                 Tags:      ec2tags,
520         })
521
522         return err
523 }
524
525 func (inst *ec2Instance) Tags() cloud.InstanceTags {
526         tags := make(map[string]string)
527
528         for _, t := range inst.instance.Tags {
529                 tags[*t.Key] = *t.Value
530         }
531
532         return tags
533 }
534
535 func (inst *ec2Instance) Destroy() error {
536         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
537                 InstanceIds: []*string{inst.instance.InstanceId},
538         })
539         return err
540 }
541
542 func (inst *ec2Instance) Address() string {
543         if inst.instance.PrivateIpAddress != nil {
544                 return *inst.instance.PrivateIpAddress
545         }
546         return ""
547 }
548
549 func (inst *ec2Instance) RemoteUser() string {
550         return inst.provider.ec2config.AdminUsername
551 }
552
553 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
554         return cloud.ErrNotImplemented
555 }
556
557 // PriceHistory returns the price history for this specific instance.
558 //
559 // AWS documentation is elusive about whether the hourly cost of a
560 // given spot instance changes as the current spot price changes for
561 // the corresponding instance type and availability zone. Our
562 // implementation assumes the answer is yes, based on the following
563 // hints.
564 //
565 // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html
566 // says: "After your Spot Instance is running, if the Spot price rises
567 // above your maximum price, Amazon EC2 interrupts your Spot
568 // Instance." (This doesn't address what happens when the spot price
569 // rises *without* exceeding your maximum price.)
570 //
571 // https://docs.aws.amazon.com/whitepapers/latest/cost-optimization-leveraging-ec2-spot-instances/how-spot-instances-work.html
572 // says: "You pay the Spot price that's in effect, billed to the
573 // nearest second." (But it's not explicitly stated whether "the price
574 // in effect" changes over time for a given instance.)
575 //
576 // The same page also says, in a discussion about the effect of
577 // specifying a maximum price: "Note that you never pay more than the
578 // Spot price that is in effect when your Spot Instance is running."
579 // (The use of the phrase "is running", as opposed to "was launched",
580 // hints that pricing is dynamic.)
581 func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.InstancePrice {
582         inst.provider.pricesLock.Lock()
583         defer inst.provider.pricesLock.Unlock()
584         // Note updateSpotPrices currently populates
585         // inst.provider.prices only for spot instances, so if
586         // spot==false here, we will return no data.
587         pk := priceKey{
588                 instanceType:     *inst.instance.InstanceType,
589                 spot:             aws.StringValue(inst.instance.InstanceLifecycle) == "spot",
590                 availabilityZone: inst.availabilityZone,
591         }
592         var prices []cloud.InstancePrice
593         for _, price := range inst.provider.prices[pk] {
594                 // ceil(added scratch space in GiB)
595                 gib := (instType.AddedScratch + 1<<30 - 1) >> 30
596                 monthly := inst.provider.ec2config.EBSPrice * float64(gib)
597                 hourly := monthly / 30 / 24
598                 price.Price += hourly
599                 prices = append(prices, price)
600         }
601         return prices
602 }
603
604 type rateLimitError struct {
605         error
606         earliestRetry time.Time
607 }
608
609 func (err rateLimitError) EarliestRetry() time.Time {
610         return err.earliestRetry
611 }
612
613 var isCodeCapacity = map[string]bool{
614         "InstanceLimitExceeded":             true,
615         "InsufficientAddressCapacity":       true,
616         "InsufficientFreeAddressesInSubnet": true,
617         "InsufficientInstanceCapacity":      true,
618         "InsufficientVolumeCapacity":        true,
619         "MaxSpotInstanceCountExceeded":      true,
620         "VcpuLimitExceeded":                 true,
621 }
622
623 // isErrorCapacity returns whether the error is to be throttled based on its code.
624 // Returns false if error is nil.
625 func isErrorCapacity(err error) bool {
626         if aerr, ok := err.(awserr.Error); ok && aerr != nil {
627                 if _, ok := isCodeCapacity[aerr.Code()]; ok {
628                         return true
629                 }
630         }
631         return false
632 }
633
634 // isErrorSubnetSpecific returns true if the problem encountered by
635 // RunInstances might be avoided by trying a different subnet.
636 func isErrorSubnetSpecific(err error) bool {
637         aerr, ok := err.(awserr.Error)
638         if !ok {
639                 return false
640         }
641         code := aerr.Code()
642         return strings.Contains(code, "Subnet") ||
643                 code == "InsufficientInstanceCapacity" ||
644                 code == "InsufficientVolumeCapacity"
645 }
646
647 type ec2QuotaError struct {
648         error
649 }
650
651 func (er *ec2QuotaError) IsQuotaError() bool {
652         return true
653 }
654
655 func wrapError(err error, throttleValue *atomic.Value) error {
656         if request.IsErrorThrottle(err) {
657                 // Back off exponentially until an upstream call
658                 // either succeeds or returns a non-throttle error.
659                 d, _ := throttleValue.Load().(time.Duration)
660                 d = d*3/2 + time.Second
661                 if d < throttleDelayMin {
662                         d = throttleDelayMin
663                 } else if d > throttleDelayMax {
664                         d = throttleDelayMax
665                 }
666                 throttleValue.Store(d)
667                 return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
668         } else if isErrorCapacity(err) {
669                 return &ec2QuotaError{err}
670         } else if err != nil {
671                 throttleValue.Store(time.Duration(0))
672                 return err
673         }
674         throttleValue.Store(time.Duration(0))
675         return nil
676 }