1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
26 "git.arvados.org/arvados.git/lib/cloud"
27 "git.arvados.org/arvados.git/sdk/go/arvados"
28 "github.com/aws/aws-sdk-go-v2/aws"
29 "github.com/aws/aws-sdk-go-v2/aws/retry"
30 awsconfig "github.com/aws/aws-sdk-go-v2/config"
31 "github.com/aws/aws-sdk-go-v2/service/ec2"
32 "github.com/aws/aws-sdk-go-v2/service/ec2/types"
33 "github.com/aws/smithy-go"
34 "github.com/prometheus/client_golang/prometheus"
35 "github.com/sirupsen/logrus"
36 "golang.org/x/crypto/ssh"
39 // Driver is the ec2 implementation of the cloud.Driver interface.
40 var Driver = cloud.DriverFunc(newEC2InstanceSet)
43 throttleDelayMin = time.Second
44 throttleDelayMax = time.Minute
47 type ec2InstanceSetConfig struct {
49 SecretAccessKey string
51 SecurityGroupIDs arvados.StringSet
52 SubnetID sliceOrSingleString
54 EBSVolumeType types.VolumeType
56 IAMInstanceProfile string
57 SpotPriceUpdateInterval arvados.Duration
60 type sliceOrSingleString []string
62 // UnmarshalJSON unmarshals an array of strings, and also accepts ""
63 // as [], and "foo" as ["foo"].
64 func (ss *sliceOrSingleString) UnmarshalJSON(data []byte) error {
67 } else if data[0] == '[' {
69 err := json.Unmarshal(data, &slice)
80 err := json.Unmarshal(data, &str)
93 type ec2Interface interface {
94 DescribeKeyPairs(context.Context, *ec2.DescribeKeyPairsInput, ...func(*ec2.Options)) (*ec2.DescribeKeyPairsOutput, error)
95 ImportKeyPair(context.Context, *ec2.ImportKeyPairInput, ...func(*ec2.Options)) (*ec2.ImportKeyPairOutput, error)
96 RunInstances(context.Context, *ec2.RunInstancesInput, ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
97 DescribeInstances(context.Context, *ec2.DescribeInstancesInput, ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error)
98 DescribeInstanceStatus(context.Context, *ec2.DescribeInstanceStatusInput, ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error)
99 DescribeSpotPriceHistory(context.Context, *ec2.DescribeSpotPriceHistoryInput, ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error)
100 CreateTags(context.Context, *ec2.CreateTagsInput, ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
101 TerminateInstances(context.Context, *ec2.TerminateInstancesInput, ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error)
104 type ec2InstanceSet struct {
105 ec2config ec2InstanceSetConfig
106 currentSubnetIDIndex int32
107 instanceSetID cloud.InstanceSetID
108 logger logrus.FieldLogger
111 keys map[string]string
112 throttleDelayCreate atomic.Value
113 throttleDelayInstances atomic.Value
115 prices map[priceKey][]cloud.InstancePrice
116 pricesLock sync.Mutex
117 pricesUpdated map[priceKey]time.Time
119 mInstances *prometheus.GaugeVec
120 mInstanceStarts *prometheus.CounterVec
123 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (prv cloud.InstanceSet, err error) {
124 instanceSet := &ec2InstanceSet{
125 instanceSetID: instanceSetID,
128 err = json.Unmarshal(config, &instanceSet.ec2config)
133 if len(instanceSet.ec2config.AccessKeyID)+len(instanceSet.ec2config.SecretAccessKey) > 0 {
134 // AWS SDK will use credentials in environment vars if
136 os.Setenv("AWS_ACCESS_KEY_ID", instanceSet.ec2config.AccessKeyID)
137 os.Setenv("AWS_SECRET_ACCESS_KEY", instanceSet.ec2config.SecretAccessKey)
139 os.Unsetenv("AWS_ACCESS_KEY_ID")
140 os.Unsetenv("AWS_SECRET_ACCESS_KEY")
142 awsConfig, err := awsconfig.LoadDefaultConfig(context.TODO(),
143 awsconfig.WithRegion(instanceSet.ec2config.Region))
148 instanceSet.client = ec2.NewFromConfig(awsConfig)
149 instanceSet.keys = make(map[string]string)
150 if instanceSet.ec2config.EBSVolumeType == "" {
151 instanceSet.ec2config.EBSVolumeType = "gp2"
155 instanceSet.mInstances = prometheus.NewGaugeVec(prometheus.GaugeOpts{
156 Namespace: "arvados",
157 Subsystem: "dispatchcloud",
158 Name: "ec2_instances",
159 Help: "Number of instances running",
160 }, []string{"subnet_id"})
161 instanceSet.mInstanceStarts = prometheus.NewCounterVec(prometheus.CounterOpts{
162 Namespace: "arvados",
163 Subsystem: "dispatchcloud",
164 Name: "ec2_instance_starts_total",
165 Help: "Number of attempts to start a new instance",
166 }, []string{"subnet_id", "success"})
167 // Initialize all of the series we'll be reporting. Otherwise
168 // the {subnet=A, success=0} series doesn't appear in metrics
169 // at all until there's a failure in subnet A.
170 for _, subnet := range instanceSet.ec2config.SubnetID {
171 instanceSet.mInstanceStarts.WithLabelValues(subnet, "0").Add(0)
172 instanceSet.mInstanceStarts.WithLabelValues(subnet, "1").Add(0)
174 if len(instanceSet.ec2config.SubnetID) == 0 {
175 instanceSet.mInstanceStarts.WithLabelValues("", "0").Add(0)
176 instanceSet.mInstanceStarts.WithLabelValues("", "1").Add(0)
179 reg.MustRegister(instanceSet.mInstances)
180 reg.MustRegister(instanceSet.mInstanceStarts)
183 return instanceSet, nil
186 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
187 // AWS key fingerprints don't use the usual key fingerprint
188 // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
189 // (you can get that from md5.Sum(pk.Marshal())
191 // AWS uses the md5 or sha1 of the PKIX DER encoding of the
192 // public key, so calculate those fingerprints here.
198 if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
199 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
201 rsaPk := rsa.PublicKey{
202 E: int(rsaPub.E.Int64()),
205 pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
206 md5pkix := md5.Sum([]byte(pkix))
207 sha1pkix := sha1.Sum([]byte(pkix))
210 for i := 0; i < len(md5pkix); i++ {
211 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
213 for i := 0; i < len(sha1pkix); i++ {
214 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
216 return md5fp[1:], sha1fp[1:], nil
219 func (instanceSet *ec2InstanceSet) Create(
220 instanceType arvados.InstanceType,
221 imageID cloud.ImageID,
222 newTags cloud.InstanceTags,
223 initCommand cloud.InitCommand,
224 publicKey ssh.PublicKey) (cloud.Instance, error) {
226 ec2tags := []types.Tag{}
227 for k, v := range newTags {
228 ec2tags = append(ec2tags, types.Tag{
230 Value: aws.String(v),
235 for sg := range instanceSet.ec2config.SecurityGroupIDs {
236 groups = append(groups, sg)
239 rii := ec2.RunInstancesInput{
240 ImageId: aws.String(string(imageID)),
241 InstanceType: types.InstanceType(instanceType.ProviderType),
242 MaxCount: aws.Int32(1),
243 MinCount: aws.Int32(1),
245 NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{{
246 AssociatePublicIpAddress: aws.Bool(false),
247 DeleteOnTermination: aws.Bool(true),
248 DeviceIndex: aws.Int32(0),
251 DisableApiTermination: aws.Bool(false),
252 InstanceInitiatedShutdownBehavior: types.ShutdownBehaviorTerminate,
253 TagSpecifications: []types.TagSpecification{
255 ResourceType: types.ResourceTypeInstance,
258 MetadataOptions: &types.InstanceMetadataOptionsRequest{
259 // Require IMDSv2, as described at
260 // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-IMDS-new-instances.html
261 HttpEndpoint: types.InstanceMetadataEndpointStateEnabled,
262 HttpTokens: types.HttpTokensStateRequired,
264 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
267 if publicKey != nil {
268 keyname, err := instanceSet.getKeyName(publicKey)
272 rii.KeyName = &keyname
275 if instanceType.AddedScratch > 0 {
276 rii.BlockDeviceMappings = []types.BlockDeviceMapping{{
277 DeviceName: aws.String("/dev/xvdt"),
278 Ebs: &types.EbsBlockDevice{
279 DeleteOnTermination: aws.Bool(true),
280 VolumeSize: aws.Int32(int32((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30)),
281 VolumeType: instanceSet.ec2config.EBSVolumeType,
285 if instanceType.Preemptible {
286 rii.InstanceMarketOptions = &types.InstanceMarketOptionsRequest{
287 MarketType: types.MarketTypeSpot,
288 SpotOptions: &types.SpotMarketOptions{
289 InstanceInterruptionBehavior: types.InstanceInterruptionBehaviorTerminate,
290 MaxPrice: aws.String(fmt.Sprintf("%v", instanceType.Price)),
294 if instanceSet.ec2config.IAMInstanceProfile != "" {
295 rii.IamInstanceProfile = &types.IamInstanceProfileSpecification{
296 Name: aws.String(instanceSet.ec2config.IAMInstanceProfile),
300 var rsv *ec2.RunInstancesOutput
301 var errToReturn error
302 subnets := instanceSet.ec2config.SubnetID
303 currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex))
304 for tryOffset := 0; ; tryOffset++ {
307 if len(subnets) > 0 {
308 tryIndex = (currentSubnetIDIndex + tryOffset) % len(subnets)
309 trySubnet = subnets[tryIndex]
310 rii.NetworkInterfaces[0].SubnetId = aws.String(trySubnet)
313 rsv, err = instanceSet.client.RunInstances(context.TODO(), &rii)
314 instanceSet.mInstanceStarts.WithLabelValues(trySubnet, boolLabelValue[err == nil]).Add(1)
315 if !isErrorCapacity(errToReturn) || isErrorCapacity(err) {
316 // We want to return the last capacity error,
317 // if any; otherwise the last non-capacity
321 if isErrorSubnetSpecific(err) &&
322 tryOffset < len(subnets)-1 {
323 instanceSet.logger.WithError(err).WithField("SubnetID", subnets[tryIndex]).
324 Warn("RunInstances failed, trying next subnet")
327 // Succeeded, or exhausted all subnets, or got a
328 // non-subnet-related error.
330 // We intentionally update currentSubnetIDIndex even
331 // in the non-retryable-failure case here to avoid a
332 // situation where successive calls to Create() keep
333 // returning errors for the same subnet (perhaps
334 // "subnet full") and never reveal the errors for the
335 // other configured subnets (perhaps "subnet ID
337 atomic.StoreInt32(&instanceSet.currentSubnetIDIndex, int32(tryIndex))
340 if rsv == nil || len(rsv.Instances) == 0 {
341 return nil, wrapError(errToReturn, &instanceSet.throttleDelayCreate)
344 provider: instanceSet,
345 instance: rsv.Instances[0],
349 func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) {
350 instanceSet.keysMtx.Lock()
351 defer instanceSet.keysMtx.Unlock()
352 md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
354 return "", fmt.Errorf("Could not make key fingerprint: %v", err)
356 if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok {
359 keyout, err := instanceSet.client.DescribeKeyPairs(context.TODO(), &ec2.DescribeKeyPairsInput{
360 Filters: []types.Filter{{
361 Name: aws.String("fingerprint"),
362 Values: []string{md5keyFingerprint, sha1keyFingerprint},
366 return "", fmt.Errorf("Could not search for keypair: %v", err)
368 if len(keyout.KeyPairs) > 0 {
369 return *(keyout.KeyPairs[0].KeyName), nil
371 keyname := "arvados-dispatch-keypair-" + md5keyFingerprint
372 _, err = instanceSet.client.ImportKeyPair(context.TODO(), &ec2.ImportKeyPairInput{
374 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
377 return "", fmt.Errorf("Could not import keypair: %v", err)
379 instanceSet.keys[md5keyFingerprint] = keyname
383 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
384 var filters []types.Filter
385 for k, v := range tags {
386 filters = append(filters, types.Filter{
387 Name: aws.String("tag:" + k),
392 dii := &ec2.DescribeInstancesInput{Filters: filters}
394 dio, err := instanceSet.client.DescribeInstances(context.TODO(), dii)
395 err = wrapError(err, &instanceSet.throttleDelayInstances)
400 for _, rsv := range dio.Reservations {
401 for _, inst := range rsv.Instances {
402 switch inst.State.Name {
403 case types.InstanceStateNameShuttingDown:
404 case types.InstanceStateNameTerminated:
406 instances = append(instances, &ec2Instance{
407 provider: instanceSet,
410 if inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
416 if dio.NextToken == nil {
419 dii.NextToken = dio.NextToken
421 if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 {
422 az := map[string]string{}
423 disi := &ec2.DescribeInstanceStatusInput{IncludeAllInstances: aws.Bool(true)}
425 page, err := instanceSet.client.DescribeInstanceStatus(context.TODO(), disi)
427 instanceSet.logger.Warnf("error getting instance statuses: %s", err)
430 for _, ent := range page.InstanceStatuses {
431 az[*ent.InstanceId] = *ent.AvailabilityZone
433 if page.NextToken == nil {
436 disi.NextToken = page.NextToken
438 for _, inst := range instances {
439 inst := inst.(*ec2Instance)
440 inst.availabilityZone = az[*inst.instance.InstanceId]
442 instanceSet.updateSpotPrices(instances)
445 // Count instances in each subnet, and report in metrics.
446 subnetInstances := map[string]int{"": 0}
447 for _, subnet := range instanceSet.ec2config.SubnetID {
448 subnetInstances[subnet] = 0
450 for _, inst := range instances {
451 subnet := inst.(*ec2Instance).instance.SubnetId
453 subnetInstances[*subnet]++
455 subnetInstances[""]++
458 for subnet, count := range subnetInstances {
459 instanceSet.mInstances.WithLabelValues(subnet).Set(float64(count))
462 return instances, err
465 type priceKey struct {
468 availabilityZone string
471 // Refresh recent spot instance pricing data for the given instances,
472 // unless we already have recent pricing data for all relevant types.
473 func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) {
474 if len(instances) == 0 {
478 instanceSet.pricesLock.Lock()
479 defer instanceSet.pricesLock.Unlock()
480 if instanceSet.prices == nil {
481 instanceSet.prices = map[priceKey][]cloud.InstancePrice{}
482 instanceSet.pricesUpdated = map[priceKey]time.Time{}
485 updateTime := time.Now()
486 staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
488 allTypes := map[types.InstanceType]bool{}
490 for _, inst := range instances {
491 ec2inst := inst.(*ec2Instance).instance
492 if ec2inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
494 instanceType: string(ec2inst.InstanceType),
496 availabilityZone: inst.(*ec2Instance).availabilityZone,
498 if instanceSet.pricesUpdated[pk].Before(staleTime) {
501 allTypes[ec2inst.InstanceType] = true
507 var typeFilterValues []string
508 for instanceType := range allTypes {
509 typeFilterValues = append(typeFilterValues, string(instanceType))
511 // Get 3x update interval worth of pricing data. (Ideally the
512 // AWS API would tell us "we have shown you all of the price
513 // changes up to time T", but it doesn't, so we'll just ask
514 // for 3 intervals worth of data on each update, de-duplicate
515 // the data points, and not worry too much about occasionally
516 // missing some data points when our lookups fail twice in a
518 dsphi := &ec2.DescribeSpotPriceHistoryInput{
519 StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())),
520 Filters: []types.Filter{
521 types.Filter{Name: aws.String("instance-type"), Values: typeFilterValues},
522 types.Filter{Name: aws.String("product-description"), Values: []string{"Linux/UNIX"}},
526 page, err := instanceSet.client.DescribeSpotPriceHistory(context.TODO(), dsphi)
528 instanceSet.logger.Warnf("error retrieving spot instance prices: %s", err)
531 for _, ent := range page.SpotPriceHistory {
532 if ent.InstanceType == "" || ent.SpotPrice == nil || ent.Timestamp == nil {
536 price, err := strconv.ParseFloat(*ent.SpotPrice, 64)
542 instanceType: string(ent.InstanceType),
544 availabilityZone: *ent.AvailabilityZone,
546 instanceSet.prices[pk] = append(instanceSet.prices[pk], cloud.InstancePrice{
547 StartTime: *ent.Timestamp,
550 instanceSet.pricesUpdated[pk] = updateTime
552 if page.NextToken == nil {
555 dsphi.NextToken = page.NextToken
558 expiredTime := updateTime.Add(-64 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
559 for pk, last := range instanceSet.pricesUpdated {
560 if last.Before(expiredTime) {
561 delete(instanceSet.pricesUpdated, pk)
562 delete(instanceSet.prices, pk)
565 for pk, prices := range instanceSet.prices {
566 instanceSet.prices[pk] = cloud.NormalizePriceHistory(prices)
570 func (instanceSet *ec2InstanceSet) Stop() {
573 type ec2Instance struct {
574 provider *ec2InstanceSet
575 instance types.Instance
576 availabilityZone string // sometimes available for spot instances
579 func (inst *ec2Instance) ID() cloud.InstanceID {
580 return cloud.InstanceID(*inst.instance.InstanceId)
583 func (inst *ec2Instance) String() string {
584 return *inst.instance.InstanceId
587 func (inst *ec2Instance) ProviderType() string {
588 return string(inst.instance.InstanceType)
591 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
592 var ec2tags []types.Tag
593 for k, v := range newTags {
594 ec2tags = append(ec2tags, types.Tag{
596 Value: aws.String(v),
600 _, err := inst.provider.client.CreateTags(context.TODO(), &ec2.CreateTagsInput{
601 Resources: []string{*inst.instance.InstanceId},
608 func (inst *ec2Instance) Tags() cloud.InstanceTags {
609 tags := make(map[string]string)
611 for _, t := range inst.instance.Tags {
612 tags[*t.Key] = *t.Value
618 func (inst *ec2Instance) Destroy() error {
619 _, err := inst.provider.client.TerminateInstances(context.TODO(), &ec2.TerminateInstancesInput{
620 InstanceIds: []string{*inst.instance.InstanceId},
625 func (inst *ec2Instance) Address() string {
626 if inst.instance.PrivateIpAddress != nil {
627 return *inst.instance.PrivateIpAddress
632 func (inst *ec2Instance) RemoteUser() string {
633 return inst.provider.ec2config.AdminUsername
636 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
637 return cloud.ErrNotImplemented
640 // PriceHistory returns the price history for this specific instance.
642 // AWS documentation is elusive about whether the hourly cost of a
643 // given spot instance changes as the current spot price changes for
644 // the corresponding instance type and availability zone. Our
645 // implementation assumes the answer is yes, based on the following
648 // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html
649 // says: "After your Spot Instance is running, if the Spot price rises
650 // above your maximum price, Amazon EC2 interrupts your Spot
651 // Instance." (This doesn't address what happens when the spot price
652 // rises *without* exceeding your maximum price.)
654 // https://docs.aws.amazon.com/whitepapers/latest/cost-optimization-leveraging-ec2-spot-instances/how-spot-instances-work.html
655 // says: "You pay the Spot price that's in effect, billed to the
656 // nearest second." (But it's not explicitly stated whether "the price
657 // in effect" changes over time for a given instance.)
659 // The same page also says, in a discussion about the effect of
660 // specifying a maximum price: "Note that you never pay more than the
661 // Spot price that is in effect when your Spot Instance is running."
662 // (The use of the phrase "is running", as opposed to "was launched",
663 // hints that pricing is dynamic.)
664 func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.InstancePrice {
665 inst.provider.pricesLock.Lock()
666 defer inst.provider.pricesLock.Unlock()
667 // Note updateSpotPrices currently populates
668 // inst.provider.prices only for spot instances, so if
669 // spot==false here, we will return no data.
671 instanceType: string(inst.instance.InstanceType),
672 spot: inst.instance.InstanceLifecycle == types.InstanceLifecycleTypeSpot,
673 availabilityZone: inst.availabilityZone,
675 var prices []cloud.InstancePrice
676 for _, price := range inst.provider.prices[pk] {
677 // ceil(added scratch space in GiB)
678 gib := (instType.AddedScratch + 1<<30 - 1) >> 30
679 monthly := inst.provider.ec2config.EBSPrice * float64(gib)
680 hourly := monthly / 30 / 24
681 price.Price += hourly
682 prices = append(prices, price)
687 type rateLimitError struct {
689 earliestRetry time.Time
692 func (err rateLimitError) EarliestRetry() time.Time {
693 return err.earliestRetry
696 type capacityError struct {
698 isInstanceTypeSpecific bool
701 func (er *capacityError) IsCapacityError() bool {
705 func (er *capacityError) IsInstanceTypeSpecific() bool {
706 return er.isInstanceTypeSpecific
709 var isCodeQuota = map[string]bool{
710 "InstanceLimitExceeded": true,
711 "InsufficientAddressCapacity": true,
712 "InsufficientFreeAddressesInSubnet": true,
713 "InsufficientVolumeCapacity": true,
714 "MaxSpotInstanceCountExceeded": true,
715 "VcpuLimitExceeded": true,
718 // isErrorQuota returns whether the error indicates we have reached
719 // some usage quota/limit -- i.e., immediately retrying with an equal
720 // or larger instance type will probably not work.
722 // Returns false if error is nil.
723 func isErrorQuota(err error) bool {
724 var aerr smithy.APIError
725 if errors.As(err, &aerr) {
726 if _, ok := isCodeQuota[aerr.ErrorCode()]; ok {
733 var reSubnetSpecificInvalidParameterMessage = regexp.MustCompile(`(?ms).*( subnet |sufficient free [Ii]pv[46] addresses).*`)
735 // isErrorSubnetSpecific returns true if the problem encountered by
736 // RunInstances might be avoided by trying a different subnet.
737 func isErrorSubnetSpecific(err error) bool {
738 var aerr smithy.APIError
739 if !errors.As(err, &aerr) {
742 code := aerr.ErrorCode()
743 return strings.Contains(code, "Subnet") ||
744 code == "InsufficientInstanceCapacity" ||
745 code == "InsufficientVolumeCapacity" ||
746 code == "Unsupported" ||
747 // See TestIsErrorSubnetSpecific for examples of why
748 // we look for substrings in code/message instead of
749 // only using specific codes here.
750 (strings.Contains(code, "InvalidParameter") &&
751 reSubnetSpecificInvalidParameterMessage.MatchString(aerr.ErrorMessage()))
754 // isErrorCapacity returns true if the error indicates lack of
755 // capacity (either temporary or permanent) to run a specific instance
756 // type -- i.e., retrying with a different instance type might
758 func isErrorCapacity(err error) bool {
759 var aerr smithy.APIError
760 if !errors.As(err, &aerr) {
763 code := aerr.ErrorCode()
764 return code == "InsufficientInstanceCapacity" ||
765 (code == "Unsupported" && strings.Contains(aerr.ErrorMessage(), "requested instance type"))
768 type ec2QuotaError struct {
772 func (er *ec2QuotaError) IsQuotaError() bool {
776 func isThrottleError(err error) bool {
777 var aerr smithy.APIError
778 if !errors.As(err, &aerr) {
781 _, is := retry.DefaultThrottleErrorCodes[aerr.ErrorCode()]
785 func wrapError(err error, throttleValue *atomic.Value) error {
786 if isThrottleError(err) {
787 // Back off exponentially until an upstream call
788 // either succeeds or returns a non-throttle error.
789 d, _ := throttleValue.Load().(time.Duration)
790 d = d*3/2 + time.Second
791 if d < throttleDelayMin {
793 } else if d > throttleDelayMax {
796 throttleValue.Store(d)
797 return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
798 } else if isErrorQuota(err) {
799 return &ec2QuotaError{err}
800 } else if isErrorCapacity(err) {
801 return &capacityError{err, true}
802 } else if err != nil {
803 throttleValue.Store(time.Duration(0))
806 throttleValue.Store(time.Duration(0))
810 var boolLabelValue = map[bool]string{false: "0", true: "1"}