X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/10b08e0c05d176775c4f0cf13f6e77025bb1e636..HEAD:/lib/cloud/ec2/ec2.go diff --git a/lib/cloud/ec2/ec2.go b/lib/cloud/ec2/ec2.go index 07a146d99f..a37522345d 100644 --- a/lib/cloud/ec2/ec2.go +++ b/lib/cloud/ec2/ec2.go @@ -5,14 +5,17 @@ package ec2 import ( + "context" "crypto/md5" "crypto/rsa" "crypto/sha1" "crypto/x509" "encoding/base64" "encoding/json" + "errors" "fmt" "math/big" + "regexp" "strconv" "strings" "sync" @@ -21,14 +24,13 @@ import ( "git.arvados.org/arvados.git/lib/cloud" "git.arvados.org/arvados.git/sdk/go/arvados" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + config "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/smithy-go" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" @@ -49,7 +51,7 @@ type ec2InstanceSetConfig struct { SecurityGroupIDs arvados.StringSet SubnetID sliceOrSingleString AdminUsername string - EBSVolumeType string + EBSVolumeType types.VolumeType EBSPrice float64 IAMInstanceProfile string SpotPriceUpdateInterval arvados.Duration @@ -89,14 +91,14 @@ func (ss *sliceOrSingleString) UnmarshalJSON(data []byte) error { } type ec2Interface interface { - DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error) - ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error) - RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error) - DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) - DescribeInstanceStatusPages(input *ec2.DescribeInstanceStatusInput, fn func(*ec2.DescribeInstanceStatusOutput, bool) bool) error - DescribeSpotPriceHistoryPages(input *ec2.DescribeSpotPriceHistoryInput, fn func(*ec2.DescribeSpotPriceHistoryOutput, bool) bool) error - CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) - TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error) + DescribeKeyPairs(context.Context, *ec2.DescribeKeyPairsInput, ...func(*ec2.Options)) (*ec2.DescribeKeyPairsOutput, error) + ImportKeyPair(context.Context, *ec2.ImportKeyPairInput, ...func(*ec2.Options)) (*ec2.ImportKeyPairOutput, error) + RunInstances(context.Context, *ec2.RunInstancesInput, ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) + DescribeInstances(context.Context, *ec2.DescribeInstancesInput, ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) + DescribeInstanceStatus(context.Context, *ec2.DescribeInstanceStatusInput, ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error) + DescribeSpotPriceHistory(context.Context, *ec2.DescribeSpotPriceHistoryInput, ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error) + CreateTags(context.Context, *ec2.CreateTagsInput, ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) + TerminateInstances(context.Context, *ec2.TerminateInstancesInput, ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) } type ec2InstanceSet struct { @@ -118,29 +120,40 @@ type ec2InstanceSet struct { mInstanceStarts *prometheus.CounterVec } -func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (prv cloud.InstanceSet, err error) { +func newEC2InstanceSet(confRaw json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (prv cloud.InstanceSet, err error) { instanceSet := &ec2InstanceSet{ instanceSetID: instanceSetID, logger: logger, } - err = json.Unmarshal(config, &instanceSet.ec2config) + err = json.Unmarshal(confRaw, &instanceSet.ec2config) if err != nil { return nil, err } - - sess, err := session.NewSession() + awsConfig, err := config.LoadDefaultConfig(context.Background(), + config.WithRegion(instanceSet.ec2config.Region), + config.WithCredentialsCacheOptions(func(o *aws.CredentialsCacheOptions) { + o.ExpiryWindow = 5 * time.Minute + }), + func(o *config.LoadOptions) error { + if instanceSet.ec2config.AccessKeyID == "" && instanceSet.ec2config.SecretAccessKey == "" { + // Use default SDK behavior (IAM role + // via IMDSv2) + return nil + } + o.Credentials = credentials.StaticCredentialsProvider{ + Value: aws.Credentials{ + AccessKeyID: instanceSet.ec2config.AccessKeyID, + SecretAccessKey: instanceSet.ec2config.SecretAccessKey, + Source: "Arvados configuration", + }, + } + return nil + }) if err != nil { return nil, err } - // First try any static credentials, fall back to an IAM instance profile/role - creds := credentials.NewChainCredentials( - []credentials.Provider{ - &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: instanceSet.ec2config.AccessKeyID, SecretAccessKey: instanceSet.ec2config.SecretAccessKey}}, - &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)}, - }) - awsConfig := aws.NewConfig().WithCredentials(creds).WithRegion(instanceSet.ec2config.Region) - instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig))) + instanceSet.client = ec2.NewFromConfig(awsConfig) instanceSet.keys = make(map[string]string) if instanceSet.ec2config.EBSVolumeType == "" { instanceSet.ec2config.EBSVolumeType = "gp2" @@ -191,7 +204,7 @@ func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error N *big.Int } if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil { - return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err) + return "", "", fmt.Errorf("Unmarshal failed to parse public key: %w", err) } rsaPk := rsa.PublicKey{ E: int(rsaPub.E.Int64()), @@ -218,9 +231,9 @@ func (instanceSet *ec2InstanceSet) Create( initCommand cloud.InitCommand, publicKey ssh.PublicKey) (cloud.Instance, error) { - ec2tags := []*ec2.Tag{} + ec2tags := []types.Tag{} for k, v := range newTags { - ec2tags = append(ec2tags, &ec2.Tag{ + ec2tags = append(ec2tags, types.Tag{ Key: aws.String(k), Value: aws.String(v), }) @@ -233,24 +246,29 @@ func (instanceSet *ec2InstanceSet) Create( rii := ec2.RunInstancesInput{ ImageId: aws.String(string(imageID)), - InstanceType: &instanceType.ProviderType, - MaxCount: aws.Int64(1), - MinCount: aws.Int64(1), - - NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{ - { - AssociatePublicIpAddress: aws.Bool(false), - DeleteOnTermination: aws.Bool(true), - DeviceIndex: aws.Int64(0), - Groups: aws.StringSlice(groups), - }}, + InstanceType: types.InstanceType(instanceType.ProviderType), + MaxCount: aws.Int32(1), + MinCount: aws.Int32(1), + + NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{{ + AssociatePublicIpAddress: aws.Bool(false), + DeleteOnTermination: aws.Bool(true), + DeviceIndex: aws.Int32(0), + Groups: groups, + }}, DisableApiTermination: aws.Bool(false), - InstanceInitiatedShutdownBehavior: aws.String("terminate"), - TagSpecifications: []*ec2.TagSpecification{ + InstanceInitiatedShutdownBehavior: types.ShutdownBehaviorTerminate, + TagSpecifications: []types.TagSpecification{ { - ResourceType: aws.String("instance"), + ResourceType: types.ResourceTypeInstance, Tags: ec2tags, }}, + MetadataOptions: &types.InstanceMetadataOptionsRequest{ + // Require IMDSv2, as described at + // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-IMDS-new-instances.html + HttpEndpoint: types.InstanceMetadataEndpointStateEnabled, + HttpTokens: types.HttpTokensStateRequired, + }, UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))), } @@ -263,31 +281,31 @@ func (instanceSet *ec2InstanceSet) Create( } if instanceType.AddedScratch > 0 { - rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{ + rii.BlockDeviceMappings = []types.BlockDeviceMapping{{ DeviceName: aws.String("/dev/xvdt"), - Ebs: &ec2.EbsBlockDevice{ + Ebs: &types.EbsBlockDevice{ DeleteOnTermination: aws.Bool(true), - VolumeSize: aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30), - VolumeType: &instanceSet.ec2config.EBSVolumeType, + VolumeSize: aws.Int32(int32((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30)), + VolumeType: instanceSet.ec2config.EBSVolumeType, }}} } if instanceType.Preemptible { - rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{ - MarketType: aws.String("spot"), - SpotOptions: &ec2.SpotMarketOptions{ - InstanceInterruptionBehavior: aws.String("terminate"), + rii.InstanceMarketOptions = &types.InstanceMarketOptionsRequest{ + MarketType: types.MarketTypeSpot, + SpotOptions: &types.SpotMarketOptions{ + InstanceInterruptionBehavior: types.InstanceInterruptionBehaviorTerminate, MaxPrice: aws.String(fmt.Sprintf("%v", instanceType.Price)), }} } if instanceSet.ec2config.IAMInstanceProfile != "" { - rii.IamInstanceProfile = &ec2.IamInstanceProfileSpecification{ + rii.IamInstanceProfile = &types.IamInstanceProfileSpecification{ Name: aws.String(instanceSet.ec2config.IAMInstanceProfile), } } - var rsv *ec2.Reservation + var rsv *ec2.RunInstancesOutput var errToReturn error subnets := instanceSet.ec2config.SubnetID currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex)) @@ -300,7 +318,7 @@ func (instanceSet *ec2InstanceSet) Create( rii.NetworkInterfaces[0].SubnetId = aws.String(trySubnet) } var err error - rsv, err = instanceSet.client.RunInstances(&rii) + rsv, err = instanceSet.client.RunInstances(context.Background(), &rii) instanceSet.mInstanceStarts.WithLabelValues(trySubnet, boolLabelValue[err == nil]).Add(1) if !isErrorCapacity(errToReturn) || isErrorCapacity(err) { // We want to return the last capacity error, @@ -341,47 +359,47 @@ func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, defer instanceSet.keysMtx.Unlock() md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey) if err != nil { - return "", fmt.Errorf("Could not make key fingerprint: %v", err) + return "", fmt.Errorf("Could not make key fingerprint: %w", err) } if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok { return keyname, nil } - keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{ - Filters: []*ec2.Filter{{ + keyout, err := instanceSet.client.DescribeKeyPairs(context.Background(), &ec2.DescribeKeyPairsInput{ + Filters: []types.Filter{{ Name: aws.String("fingerprint"), - Values: []*string{&md5keyFingerprint, &sha1keyFingerprint}, + Values: []string{md5keyFingerprint, sha1keyFingerprint}, }}, }) if err != nil { - return "", fmt.Errorf("Could not search for keypair: %v", err) + return "", fmt.Errorf("Could not search for keypair: %w", err) } if len(keyout.KeyPairs) > 0 { return *(keyout.KeyPairs[0].KeyName), nil } keyname := "arvados-dispatch-keypair-" + md5keyFingerprint - _, err = instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{ + _, err = instanceSet.client.ImportKeyPair(context.Background(), &ec2.ImportKeyPairInput{ KeyName: &keyname, PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey), }) if err != nil { - return "", fmt.Errorf("Could not import keypair: %v", err) + return "", fmt.Errorf("Could not import keypair: %w", err) } instanceSet.keys[md5keyFingerprint] = keyname return keyname, nil } func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) { - var filters []*ec2.Filter + var filters []types.Filter for k, v := range tags { - filters = append(filters, &ec2.Filter{ + filters = append(filters, types.Filter{ Name: aws.String("tag:" + k), - Values: []*string{aws.String(v)}, + Values: []string{v}, }) } needAZs := false dii := &ec2.DescribeInstancesInput{Filters: filters} for { - dio, err := instanceSet.client.DescribeInstances(dii) + dio, err := instanceSet.client.DescribeInstances(context.Background(), dii) err = wrapError(err, &instanceSet.throttleDelayInstances) if err != nil { return nil, err @@ -389,12 +407,15 @@ func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances for _, rsv := range dio.Reservations { for _, inst := range rsv.Instances { - if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" { + switch inst.State.Name { + case types.InstanceStateNameShuttingDown: + case types.InstanceStateNameTerminated: + default: instances = append(instances, &ec2Instance{ provider: instanceSet, instance: inst, }) - if aws.StringValue(inst.InstanceLifecycle) == "spot" { + if inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot { needAZs = true } } @@ -407,16 +428,20 @@ func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances } if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 { az := map[string]string{} - err := instanceSet.client.DescribeInstanceStatusPages(&ec2.DescribeInstanceStatusInput{ - IncludeAllInstances: aws.Bool(true), - }, func(page *ec2.DescribeInstanceStatusOutput, lastPage bool) bool { + disi := &ec2.DescribeInstanceStatusInput{IncludeAllInstances: aws.Bool(true)} + for { + page, err := instanceSet.client.DescribeInstanceStatus(context.Background(), disi) + if err != nil { + instanceSet.logger.WithError(err).Warn("error getting instance statuses") + break + } for _, ent := range page.InstanceStatuses { az[*ent.InstanceId] = *ent.AvailabilityZone } - return true - }) - if err != nil { - instanceSet.logger.Warnf("error getting instance statuses: %s", err) + if page.NextToken == nil { + break + } + disi.NextToken = page.NextToken } for _, inst := range instances { inst := inst.(*ec2Instance) @@ -468,28 +493,28 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) updateTime := time.Now() staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration()) needUpdate := false - allTypes := map[string]bool{} + allTypes := map[types.InstanceType]bool{} for _, inst := range instances { ec2inst := inst.(*ec2Instance).instance - if aws.StringValue(ec2inst.InstanceLifecycle) == "spot" { + if ec2inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot { pk := priceKey{ - instanceType: *ec2inst.InstanceType, + instanceType: string(ec2inst.InstanceType), spot: true, availabilityZone: inst.(*ec2Instance).availabilityZone, } if instanceSet.pricesUpdated[pk].Before(staleTime) { needUpdate = true } - allTypes[*ec2inst.InstanceType] = true + allTypes[ec2inst.InstanceType] = true } } if !needUpdate { return } - var typeFilterValues []*string + var typeFilterValues []string for instanceType := range allTypes { - typeFilterValues = append(typeFilterValues, aws.String(instanceType)) + typeFilterValues = append(typeFilterValues, string(instanceType)) } // Get 3x update interval worth of pricing data. (Ideally the // AWS API would tell us "we have shown you all of the price @@ -500,14 +525,19 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) // row. dsphi := &ec2.DescribeSpotPriceHistoryInput{ StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())), - Filters: []*ec2.Filter{ - &ec2.Filter{Name: aws.String("instance-type"), Values: typeFilterValues}, - &ec2.Filter{Name: aws.String("product-description"), Values: []*string{aws.String("Linux/UNIX")}}, + Filters: []types.Filter{ + types.Filter{Name: aws.String("instance-type"), Values: typeFilterValues}, + types.Filter{Name: aws.String("product-description"), Values: []string{"Linux/UNIX"}}, }, } - err := instanceSet.client.DescribeSpotPriceHistoryPages(dsphi, func(page *ec2.DescribeSpotPriceHistoryOutput, lastPage bool) bool { + for { + page, err := instanceSet.client.DescribeSpotPriceHistory(context.Background(), dsphi) + if err != nil { + instanceSet.logger.WithError(err).Warn("error retrieving spot instance prices") + break + } for _, ent := range page.SpotPriceHistory { - if ent.InstanceType == nil || ent.SpotPrice == nil || ent.Timestamp == nil { + if ent.InstanceType == "" || ent.SpotPrice == nil || ent.Timestamp == nil { // bogus record? continue } @@ -517,7 +547,7 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) continue } pk := priceKey{ - instanceType: *ent.InstanceType, + instanceType: string(ent.InstanceType), spot: true, availabilityZone: *ent.AvailabilityZone, } @@ -527,10 +557,10 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) }) instanceSet.pricesUpdated[pk] = updateTime } - return true - }) - if err != nil { - instanceSet.logger.Warnf("error retrieving spot instance prices: %s", err) + if page.NextToken == nil { + break + } + dsphi.NextToken = page.NextToken } expiredTime := updateTime.Add(-64 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration()) @@ -550,7 +580,7 @@ func (instanceSet *ec2InstanceSet) Stop() { type ec2Instance struct { provider *ec2InstanceSet - instance *ec2.Instance + instance types.Instance availabilityZone string // sometimes available for spot instances } @@ -563,20 +593,20 @@ func (inst *ec2Instance) String() string { } func (inst *ec2Instance) ProviderType() string { - return *inst.instance.InstanceType + return string(inst.instance.InstanceType) } func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error { - var ec2tags []*ec2.Tag + var ec2tags []types.Tag for k, v := range newTags { - ec2tags = append(ec2tags, &ec2.Tag{ + ec2tags = append(ec2tags, types.Tag{ Key: aws.String(k), Value: aws.String(v), }) } - _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{ - Resources: []*string{inst.instance.InstanceId}, + _, err := inst.provider.client.CreateTags(context.Background(), &ec2.CreateTagsInput{ + Resources: []string{*inst.instance.InstanceId}, Tags: ec2tags, }) @@ -594,8 +624,8 @@ func (inst *ec2Instance) Tags() cloud.InstanceTags { } func (inst *ec2Instance) Destroy() error { - _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{ - InstanceIds: []*string{inst.instance.InstanceId}, + _, err := inst.provider.client.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{ + InstanceIds: []string{*inst.instance.InstanceId}, }) return err } @@ -646,8 +676,8 @@ func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.Ins // inst.provider.prices only for spot instances, so if // spot==false here, we will return no data. pk := priceKey{ - instanceType: *inst.instance.InstanceType, - spot: aws.StringValue(inst.instance.InstanceLifecycle) == "spot", + instanceType: string(inst.instance.InstanceType), + spot: inst.instance.InstanceLifecycle == types.InstanceLifecycleTypeSpot, availabilityZone: inst.availabilityZone, } var prices []cloud.InstancePrice @@ -699,26 +729,34 @@ var isCodeQuota = map[string]bool{ // // Returns false if error is nil. func isErrorQuota(err error) bool { - if aerr, ok := err.(awserr.Error); ok && aerr != nil { - if _, ok := isCodeQuota[aerr.Code()]; ok { + var aerr smithy.APIError + if errors.As(err, &aerr) { + if _, ok := isCodeQuota[aerr.ErrorCode()]; ok { return true } } return false } +var reSubnetSpecificInvalidParameterMessage = regexp.MustCompile(`(?ms).*( subnet |sufficient free [Ii]pv[46] addresses).*`) + // isErrorSubnetSpecific returns true if the problem encountered by // RunInstances might be avoided by trying a different subnet. func isErrorSubnetSpecific(err error) bool { - aerr, ok := err.(awserr.Error) - if !ok { + var aerr smithy.APIError + if !errors.As(err, &aerr) { return false } - code := aerr.Code() + code := aerr.ErrorCode() return strings.Contains(code, "Subnet") || code == "InsufficientInstanceCapacity" || code == "InsufficientVolumeCapacity" || - code == "Unsupported" + code == "Unsupported" || + // See TestIsErrorSubnetSpecific for examples of why + // we look for substrings in code/message instead of + // only using specific codes here. + (strings.Contains(code, "InvalidParameter") && + reSubnetSpecificInvalidParameterMessage.MatchString(aerr.ErrorMessage())) } // isErrorCapacity returns true if the error indicates lack of @@ -726,13 +764,13 @@ func isErrorSubnetSpecific(err error) bool { // type -- i.e., retrying with a different instance type might // succeed. func isErrorCapacity(err error) bool { - aerr, ok := err.(awserr.Error) - if !ok { + var aerr smithy.APIError + if !errors.As(err, &aerr) { return false } - code := aerr.Code() + code := aerr.ErrorCode() return code == "InsufficientInstanceCapacity" || - (code == "Unsupported" && strings.Contains(aerr.Message(), "requested instance type")) + (code == "Unsupported" && strings.Contains(aerr.ErrorMessage(), "requested instance type")) } type ec2QuotaError struct { @@ -743,8 +781,17 @@ func (er *ec2QuotaError) IsQuotaError() bool { return true } +func isThrottleError(err error) bool { + var aerr smithy.APIError + if !errors.As(err, &aerr) { + return false + } + _, is := retry.DefaultThrottleErrorCodes[aerr.ErrorCode()] + return is +} + func wrapError(err error, throttleValue *atomic.Value) error { - if request.IsErrorThrottle(err) { + if isThrottleError(err) { // Back off exponentially until an upstream call // either succeeds or returns a non-throttle error. d, _ := throttleValue.Load().(time.Duration)