X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/00f3f818027bdc69178bb54e49802baf20d9527d..HEAD:/lib/cloud/ec2/ec2.go diff --git a/lib/cloud/ec2/ec2.go b/lib/cloud/ec2/ec2.go index a74f125610..a37522345d 100644 --- a/lib/cloud/ec2/ec2.go +++ b/lib/cloud/ec2/ec2.go @@ -5,29 +5,33 @@ package ec2 import ( + "context" "crypto/md5" "crypto/rsa" "crypto/sha1" "crypto/x509" "encoding/base64" "encoding/json" + "errors" "fmt" "math/big" + "regexp" "strconv" + "strings" "sync" "sync/atomic" "time" "git.arvados.org/arvados.git/lib/cloud" "git.arvados.org/arvados.git/sdk/go/arvados" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + config "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/smithy-go" + "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) @@ -45,27 +49,61 @@ type ec2InstanceSetConfig struct { SecretAccessKey string Region string SecurityGroupIDs arvados.StringSet - SubnetID string + SubnetID sliceOrSingleString AdminUsername string - EBSVolumeType string + EBSVolumeType types.VolumeType EBSPrice float64 IAMInstanceProfile string SpotPriceUpdateInterval arvados.Duration } +type sliceOrSingleString []string + +// UnmarshalJSON unmarshals an array of strings, and also accepts "" +// as [], and "foo" as ["foo"]. +func (ss *sliceOrSingleString) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + *ss = nil + } else if data[0] == '[' { + var slice []string + err := json.Unmarshal(data, &slice) + if err != nil { + return err + } + if len(slice) == 0 { + *ss = nil + } else { + *ss = slice + } + } else { + var str string + err := json.Unmarshal(data, &str) + if err != nil { + return err + } + if str == "" { + *ss = nil + } else { + *ss = []string{str} + } + } + return nil +} + type ec2Interface interface { - DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error) - ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error) - RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error) - DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) - DescribeInstanceStatusPages(input *ec2.DescribeInstanceStatusInput, fn func(*ec2.DescribeInstanceStatusOutput, bool) bool) error - DescribeSpotPriceHistoryPages(input *ec2.DescribeSpotPriceHistoryInput, fn func(*ec2.DescribeSpotPriceHistoryOutput, bool) bool) error - CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) - TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error) + DescribeKeyPairs(context.Context, *ec2.DescribeKeyPairsInput, ...func(*ec2.Options)) (*ec2.DescribeKeyPairsOutput, error) + ImportKeyPair(context.Context, *ec2.ImportKeyPairInput, ...func(*ec2.Options)) (*ec2.ImportKeyPairOutput, error) + RunInstances(context.Context, *ec2.RunInstancesInput, ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) + DescribeInstances(context.Context, *ec2.DescribeInstancesInput, ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) + DescribeInstanceStatus(context.Context, *ec2.DescribeInstanceStatusInput, ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error) + DescribeSpotPriceHistory(context.Context, *ec2.DescribeSpotPriceHistoryInput, ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error) + CreateTags(context.Context, *ec2.CreateTagsInput, ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) + TerminateInstances(context.Context, *ec2.TerminateInstancesInput, ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) } type ec2InstanceSet struct { ec2config ec2InstanceSetConfig + currentSubnetIDIndex int32 instanceSetID cloud.InstanceSetID logger logrus.FieldLogger client ec2Interface @@ -77,35 +115,79 @@ type ec2InstanceSet struct { prices map[priceKey][]cloud.InstancePrice pricesLock sync.Mutex pricesUpdated map[priceKey]time.Time + + mInstances *prometheus.GaugeVec + mInstanceStarts *prometheus.CounterVec } -func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) { +func newEC2InstanceSet(confRaw json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (prv cloud.InstanceSet, err error) { instanceSet := &ec2InstanceSet{ instanceSetID: instanceSetID, logger: logger, } - err = json.Unmarshal(config, &instanceSet.ec2config) + err = json.Unmarshal(confRaw, &instanceSet.ec2config) if err != nil { return nil, err } - - sess, err := session.NewSession() + awsConfig, err := config.LoadDefaultConfig(context.Background(), + config.WithRegion(instanceSet.ec2config.Region), + config.WithCredentialsCacheOptions(func(o *aws.CredentialsCacheOptions) { + o.ExpiryWindow = 5 * time.Minute + }), + func(o *config.LoadOptions) error { + if instanceSet.ec2config.AccessKeyID == "" && instanceSet.ec2config.SecretAccessKey == "" { + // Use default SDK behavior (IAM role + // via IMDSv2) + return nil + } + o.Credentials = credentials.StaticCredentialsProvider{ + Value: aws.Credentials{ + AccessKeyID: instanceSet.ec2config.AccessKeyID, + SecretAccessKey: instanceSet.ec2config.SecretAccessKey, + Source: "Arvados configuration", + }, + } + return nil + }) if err != nil { return nil, err } - // First try any static credentials, fall back to an IAM instance profile/role - creds := credentials.NewChainCredentials( - []credentials.Provider{ - &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: instanceSet.ec2config.AccessKeyID, SecretAccessKey: instanceSet.ec2config.SecretAccessKey}}, - &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)}, - }) - awsConfig := aws.NewConfig().WithCredentials(creds).WithRegion(instanceSet.ec2config.Region) - instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig))) + instanceSet.client = ec2.NewFromConfig(awsConfig) instanceSet.keys = make(map[string]string) if instanceSet.ec2config.EBSVolumeType == "" { instanceSet.ec2config.EBSVolumeType = "gp2" } + + // Set up metrics + instanceSet.mInstances = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "arvados", + Subsystem: "dispatchcloud", + Name: "ec2_instances", + Help: "Number of instances running", + }, []string{"subnet_id"}) + instanceSet.mInstanceStarts = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "arvados", + Subsystem: "dispatchcloud", + Name: "ec2_instance_starts_total", + Help: "Number of attempts to start a new instance", + }, []string{"subnet_id", "success"}) + // Initialize all of the series we'll be reporting. Otherwise + // the {subnet=A, success=0} series doesn't appear in metrics + // at all until there's a failure in subnet A. + for _, subnet := range instanceSet.ec2config.SubnetID { + instanceSet.mInstanceStarts.WithLabelValues(subnet, "0").Add(0) + instanceSet.mInstanceStarts.WithLabelValues(subnet, "1").Add(0) + } + if len(instanceSet.ec2config.SubnetID) == 0 { + instanceSet.mInstanceStarts.WithLabelValues("", "0").Add(0) + instanceSet.mInstanceStarts.WithLabelValues("", "1").Add(0) + } + if reg != nil { + reg.MustRegister(instanceSet.mInstances) + reg.MustRegister(instanceSet.mInstanceStarts) + } + return instanceSet, nil } @@ -122,7 +204,7 @@ func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error N *big.Int } if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil { - return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err) + return "", "", fmt.Errorf("Unmarshal failed to parse public key: %w", err) } rsaPk := rsa.PublicKey{ E: int(rsaPub.E.Int64()), @@ -149,43 +231,9 @@ func (instanceSet *ec2InstanceSet) Create( initCommand cloud.InitCommand, publicKey ssh.PublicKey) (cloud.Instance, error) { - md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey) - if err != nil { - return nil, fmt.Errorf("Could not make key fingerprint: %v", err) - } - instanceSet.keysMtx.Lock() - var keyname string - var ok bool - if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok { - keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{ - Filters: []*ec2.Filter{{ - Name: aws.String("fingerprint"), - Values: []*string{&md5keyFingerprint, &sha1keyFingerprint}, - }}, - }) - if err != nil { - return nil, fmt.Errorf("Could not search for keypair: %v", err) - } - - if len(keyout.KeyPairs) > 0 { - keyname = *(keyout.KeyPairs[0].KeyName) - } else { - keyname = "arvados-dispatch-keypair-" + md5keyFingerprint - _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{ - KeyName: &keyname, - PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey), - }) - if err != nil { - return nil, fmt.Errorf("Could not import keypair: %v", err) - } - } - instanceSet.keys[md5keyFingerprint] = keyname - } - instanceSet.keysMtx.Unlock() - - ec2tags := []*ec2.Tag{} + ec2tags := []types.Tag{} for k, v := range newTags { - ec2tags = append(ec2tags, &ec2.Tag{ + ec2tags = append(ec2tags, types.Tag{ Key: aws.String(k), Value: aws.String(v), }) @@ -198,58 +246,107 @@ func (instanceSet *ec2InstanceSet) Create( rii := ec2.RunInstancesInput{ ImageId: aws.String(string(imageID)), - InstanceType: &instanceType.ProviderType, - MaxCount: aws.Int64(1), - MinCount: aws.Int64(1), - KeyName: &keyname, - - NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{ - { - AssociatePublicIpAddress: aws.Bool(false), - DeleteOnTermination: aws.Bool(true), - DeviceIndex: aws.Int64(0), - Groups: aws.StringSlice(groups), - SubnetId: &instanceSet.ec2config.SubnetID, - }}, + InstanceType: types.InstanceType(instanceType.ProviderType), + MaxCount: aws.Int32(1), + MinCount: aws.Int32(1), + + NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{{ + AssociatePublicIpAddress: aws.Bool(false), + DeleteOnTermination: aws.Bool(true), + DeviceIndex: aws.Int32(0), + Groups: groups, + }}, DisableApiTermination: aws.Bool(false), - InstanceInitiatedShutdownBehavior: aws.String("terminate"), - TagSpecifications: []*ec2.TagSpecification{ + InstanceInitiatedShutdownBehavior: types.ShutdownBehaviorTerminate, + TagSpecifications: []types.TagSpecification{ { - ResourceType: aws.String("instance"), + ResourceType: types.ResourceTypeInstance, Tags: ec2tags, }}, + MetadataOptions: &types.InstanceMetadataOptionsRequest{ + // Require IMDSv2, as described at + // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-IMDS-new-instances.html + HttpEndpoint: types.InstanceMetadataEndpointStateEnabled, + HttpTokens: types.HttpTokensStateRequired, + }, UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))), } + if publicKey != nil { + keyname, err := instanceSet.getKeyName(publicKey) + if err != nil { + return nil, err + } + rii.KeyName = &keyname + } + if instanceType.AddedScratch > 0 { - rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{ + rii.BlockDeviceMappings = []types.BlockDeviceMapping{{ DeviceName: aws.String("/dev/xvdt"), - Ebs: &ec2.EbsBlockDevice{ + Ebs: &types.EbsBlockDevice{ DeleteOnTermination: aws.Bool(true), - VolumeSize: aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30), - VolumeType: &instanceSet.ec2config.EBSVolumeType, + VolumeSize: aws.Int32(int32((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30)), + VolumeType: instanceSet.ec2config.EBSVolumeType, }}} } if instanceType.Preemptible { - rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{ - MarketType: aws.String("spot"), - SpotOptions: &ec2.SpotMarketOptions{ - InstanceInterruptionBehavior: aws.String("terminate"), + rii.InstanceMarketOptions = &types.InstanceMarketOptionsRequest{ + MarketType: types.MarketTypeSpot, + SpotOptions: &types.SpotMarketOptions{ + InstanceInterruptionBehavior: types.InstanceInterruptionBehaviorTerminate, MaxPrice: aws.String(fmt.Sprintf("%v", instanceType.Price)), }} } if instanceSet.ec2config.IAMInstanceProfile != "" { - rii.IamInstanceProfile = &ec2.IamInstanceProfileSpecification{ + rii.IamInstanceProfile = &types.IamInstanceProfileSpecification{ Name: aws.String(instanceSet.ec2config.IAMInstanceProfile), } } - rsv, err := instanceSet.client.RunInstances(&rii) - err = wrapError(err, &instanceSet.throttleDelayCreate) - if err != nil { - return nil, err + var rsv *ec2.RunInstancesOutput + var errToReturn error + subnets := instanceSet.ec2config.SubnetID + currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex)) + for tryOffset := 0; ; tryOffset++ { + tryIndex := 0 + trySubnet := "" + if len(subnets) > 0 { + tryIndex = (currentSubnetIDIndex + tryOffset) % len(subnets) + trySubnet = subnets[tryIndex] + rii.NetworkInterfaces[0].SubnetId = aws.String(trySubnet) + } + var err error + rsv, err = instanceSet.client.RunInstances(context.Background(), &rii) + instanceSet.mInstanceStarts.WithLabelValues(trySubnet, boolLabelValue[err == nil]).Add(1) + if !isErrorCapacity(errToReturn) || isErrorCapacity(err) { + // We want to return the last capacity error, + // if any; otherwise the last non-capacity + // error. + errToReturn = err + } + if isErrorSubnetSpecific(err) && + tryOffset < len(subnets)-1 { + instanceSet.logger.WithError(err).WithField("SubnetID", subnets[tryIndex]). + Warn("RunInstances failed, trying next subnet") + continue + } + // Succeeded, or exhausted all subnets, or got a + // non-subnet-related error. + // + // We intentionally update currentSubnetIDIndex even + // in the non-retryable-failure case here to avoid a + // situation where successive calls to Create() keep + // returning errors for the same subnet (perhaps + // "subnet full") and never reveal the errors for the + // other configured subnets (perhaps "subnet ID + // invalid"). + atomic.StoreInt32(&instanceSet.currentSubnetIDIndex, int32(tryIndex)) + break + } + if rsv == nil || len(rsv.Instances) == 0 { + return nil, wrapError(errToReturn, &instanceSet.throttleDelayCreate) } return &ec2Instance{ provider: instanceSet, @@ -257,18 +354,52 @@ func (instanceSet *ec2InstanceSet) Create( }, nil } +func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) { + instanceSet.keysMtx.Lock() + defer instanceSet.keysMtx.Unlock() + md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey) + if err != nil { + return "", fmt.Errorf("Could not make key fingerprint: %w", err) + } + if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok { + return keyname, nil + } + keyout, err := instanceSet.client.DescribeKeyPairs(context.Background(), &ec2.DescribeKeyPairsInput{ + Filters: []types.Filter{{ + Name: aws.String("fingerprint"), + Values: []string{md5keyFingerprint, sha1keyFingerprint}, + }}, + }) + if err != nil { + return "", fmt.Errorf("Could not search for keypair: %w", err) + } + if len(keyout.KeyPairs) > 0 { + return *(keyout.KeyPairs[0].KeyName), nil + } + keyname := "arvados-dispatch-keypair-" + md5keyFingerprint + _, err = instanceSet.client.ImportKeyPair(context.Background(), &ec2.ImportKeyPairInput{ + KeyName: &keyname, + PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey), + }) + if err != nil { + return "", fmt.Errorf("Could not import keypair: %w", err) + } + instanceSet.keys[md5keyFingerprint] = keyname + return keyname, nil +} + func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) { - var filters []*ec2.Filter + var filters []types.Filter for k, v := range tags { - filters = append(filters, &ec2.Filter{ + filters = append(filters, types.Filter{ Name: aws.String("tag:" + k), - Values: []*string{aws.String(v)}, + Values: []string{v}, }) } needAZs := false dii := &ec2.DescribeInstancesInput{Filters: filters} for { - dio, err := instanceSet.client.DescribeInstances(dii) + dio, err := instanceSet.client.DescribeInstances(context.Background(), dii) err = wrapError(err, &instanceSet.throttleDelayInstances) if err != nil { return nil, err @@ -276,12 +407,15 @@ func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances for _, rsv := range dio.Reservations { for _, inst := range rsv.Instances { - if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" { + switch inst.State.Name { + case types.InstanceStateNameShuttingDown: + case types.InstanceStateNameTerminated: + default: instances = append(instances, &ec2Instance{ provider: instanceSet, instance: inst, }) - if aws.StringValue(inst.InstanceLifecycle) == "spot" { + if inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot { needAZs = true } } @@ -294,16 +428,20 @@ func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances } if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 { az := map[string]string{} - err := instanceSet.client.DescribeInstanceStatusPages(&ec2.DescribeInstanceStatusInput{ - IncludeAllInstances: aws.Bool(true), - }, func(page *ec2.DescribeInstanceStatusOutput, lastPage bool) bool { + disi := &ec2.DescribeInstanceStatusInput{IncludeAllInstances: aws.Bool(true)} + for { + page, err := instanceSet.client.DescribeInstanceStatus(context.Background(), disi) + if err != nil { + instanceSet.logger.WithError(err).Warn("error getting instance statuses") + break + } for _, ent := range page.InstanceStatuses { az[*ent.InstanceId] = *ent.AvailabilityZone } - return true - }) - if err != nil { - instanceSet.logger.Warnf("error getting instance statuses: %s", err) + if page.NextToken == nil { + break + } + disi.NextToken = page.NextToken } for _, inst := range instances { inst := inst.(*ec2Instance) @@ -311,6 +449,24 @@ func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances } instanceSet.updateSpotPrices(instances) } + + // Count instances in each subnet, and report in metrics. + subnetInstances := map[string]int{"": 0} + for _, subnet := range instanceSet.ec2config.SubnetID { + subnetInstances[subnet] = 0 + } + for _, inst := range instances { + subnet := inst.(*ec2Instance).instance.SubnetId + if subnet != nil { + subnetInstances[*subnet]++ + } else { + subnetInstances[""]++ + } + } + for subnet, count := range subnetInstances { + instanceSet.mInstances.WithLabelValues(subnet).Set(float64(count)) + } + return instances, err } @@ -337,28 +493,28 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) updateTime := time.Now() staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration()) needUpdate := false - allTypes := map[string]bool{} + allTypes := map[types.InstanceType]bool{} for _, inst := range instances { ec2inst := inst.(*ec2Instance).instance - if aws.StringValue(ec2inst.InstanceLifecycle) == "spot" { + if ec2inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot { pk := priceKey{ - instanceType: *ec2inst.InstanceType, + instanceType: string(ec2inst.InstanceType), spot: true, availabilityZone: inst.(*ec2Instance).availabilityZone, } if instanceSet.pricesUpdated[pk].Before(staleTime) { needUpdate = true } - allTypes[*ec2inst.InstanceType] = true + allTypes[ec2inst.InstanceType] = true } } if !needUpdate { return } - var typeFilterValues []*string + var typeFilterValues []string for instanceType := range allTypes { - typeFilterValues = append(typeFilterValues, aws.String(instanceType)) + typeFilterValues = append(typeFilterValues, string(instanceType)) } // Get 3x update interval worth of pricing data. (Ideally the // AWS API would tell us "we have shown you all of the price @@ -369,14 +525,19 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) // row. dsphi := &ec2.DescribeSpotPriceHistoryInput{ StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())), - Filters: []*ec2.Filter{ - &ec2.Filter{Name: aws.String("instance-type"), Values: typeFilterValues}, - &ec2.Filter{Name: aws.String("product-description"), Values: []*string{aws.String("Linux/UNIX")}}, + Filters: []types.Filter{ + types.Filter{Name: aws.String("instance-type"), Values: typeFilterValues}, + types.Filter{Name: aws.String("product-description"), Values: []string{"Linux/UNIX"}}, }, } - err := instanceSet.client.DescribeSpotPriceHistoryPages(dsphi, func(page *ec2.DescribeSpotPriceHistoryOutput, lastPage bool) bool { + for { + page, err := instanceSet.client.DescribeSpotPriceHistory(context.Background(), dsphi) + if err != nil { + instanceSet.logger.WithError(err).Warn("error retrieving spot instance prices") + break + } for _, ent := range page.SpotPriceHistory { - if ent.InstanceType == nil || ent.SpotPrice == nil || ent.Timestamp == nil { + if ent.InstanceType == "" || ent.SpotPrice == nil || ent.Timestamp == nil { // bogus record? continue } @@ -386,7 +547,7 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) continue } pk := priceKey{ - instanceType: *ent.InstanceType, + instanceType: string(ent.InstanceType), spot: true, availabilityZone: *ent.AvailabilityZone, } @@ -396,10 +557,10 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) }) instanceSet.pricesUpdated[pk] = updateTime } - return true - }) - if err != nil { - instanceSet.logger.Warnf("error retrieving spot instance prices: %s", err) + if page.NextToken == nil { + break + } + dsphi.NextToken = page.NextToken } expiredTime := updateTime.Add(-64 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration()) @@ -419,7 +580,7 @@ func (instanceSet *ec2InstanceSet) Stop() { type ec2Instance struct { provider *ec2InstanceSet - instance *ec2.Instance + instance types.Instance availabilityZone string // sometimes available for spot instances } @@ -432,20 +593,20 @@ func (inst *ec2Instance) String() string { } func (inst *ec2Instance) ProviderType() string { - return *inst.instance.InstanceType + return string(inst.instance.InstanceType) } func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error { - var ec2tags []*ec2.Tag + var ec2tags []types.Tag for k, v := range newTags { - ec2tags = append(ec2tags, &ec2.Tag{ + ec2tags = append(ec2tags, types.Tag{ Key: aws.String(k), Value: aws.String(v), }) } - _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{ - Resources: []*string{inst.instance.InstanceId}, + _, err := inst.provider.client.CreateTags(context.Background(), &ec2.CreateTagsInput{ + Resources: []string{*inst.instance.InstanceId}, Tags: ec2tags, }) @@ -463,8 +624,8 @@ func (inst *ec2Instance) Tags() cloud.InstanceTags { } func (inst *ec2Instance) Destroy() error { - _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{ - InstanceIds: []*string{inst.instance.InstanceId}, + _, err := inst.provider.client.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{ + InstanceIds: []string{*inst.instance.InstanceId}, }) return err } @@ -515,8 +676,8 @@ func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.Ins // inst.provider.prices only for spot instances, so if // spot==false here, we will return no data. pk := priceKey{ - instanceType: *inst.instance.InstanceType, - spot: aws.StringValue(inst.instance.InstanceLifecycle) == "spot", + instanceType: string(inst.instance.InstanceType), + spot: inst.instance.InstanceLifecycle == types.InstanceLifecycleTypeSpot, availabilityZone: inst.availabilityZone, } var prices []cloud.InstancePrice @@ -540,25 +701,78 @@ func (err rateLimitError) EarliestRetry() time.Time { return err.earliestRetry } -var isCodeCapacity = map[string]bool{ +type capacityError struct { + error + isInstanceTypeSpecific bool +} + +func (er *capacityError) IsCapacityError() bool { + return true +} + +func (er *capacityError) IsInstanceTypeSpecific() bool { + return er.isInstanceTypeSpecific +} + +var isCodeQuota = map[string]bool{ + "InstanceLimitExceeded": true, + "InsufficientAddressCapacity": true, "InsufficientFreeAddressesInSubnet": true, - "InsufficientInstanceCapacity": true, "InsufficientVolumeCapacity": true, "MaxSpotInstanceCountExceeded": true, "VcpuLimitExceeded": true, } -// isErrorCapacity returns whether the error is to be throttled based on its code. +// isErrorQuota returns whether the error indicates we have reached +// some usage quota/limit -- i.e., immediately retrying with an equal +// or larger instance type will probably not work. +// // Returns false if error is nil. -func isErrorCapacity(err error) bool { - if aerr, ok := err.(awserr.Error); ok && aerr != nil { - if _, ok := isCodeCapacity[aerr.Code()]; ok { +func isErrorQuota(err error) bool { + var aerr smithy.APIError + if errors.As(err, &aerr) { + if _, ok := isCodeQuota[aerr.ErrorCode()]; ok { return true } } return false } +var reSubnetSpecificInvalidParameterMessage = regexp.MustCompile(`(?ms).*( subnet |sufficient free [Ii]pv[46] addresses).*`) + +// isErrorSubnetSpecific returns true if the problem encountered by +// RunInstances might be avoided by trying a different subnet. +func isErrorSubnetSpecific(err error) bool { + var aerr smithy.APIError + if !errors.As(err, &aerr) { + return false + } + code := aerr.ErrorCode() + return strings.Contains(code, "Subnet") || + code == "InsufficientInstanceCapacity" || + code == "InsufficientVolumeCapacity" || + code == "Unsupported" || + // See TestIsErrorSubnetSpecific for examples of why + // we look for substrings in code/message instead of + // only using specific codes here. + (strings.Contains(code, "InvalidParameter") && + reSubnetSpecificInvalidParameterMessage.MatchString(aerr.ErrorMessage())) +} + +// isErrorCapacity returns true if the error indicates lack of +// capacity (either temporary or permanent) to run a specific instance +// type -- i.e., retrying with a different instance type might +// succeed. +func isErrorCapacity(err error) bool { + var aerr smithy.APIError + if !errors.As(err, &aerr) { + return false + } + code := aerr.ErrorCode() + return code == "InsufficientInstanceCapacity" || + (code == "Unsupported" && strings.Contains(aerr.ErrorMessage(), "requested instance type")) +} + type ec2QuotaError struct { error } @@ -567,8 +781,17 @@ func (er *ec2QuotaError) IsQuotaError() bool { return true } +func isThrottleError(err error) bool { + var aerr smithy.APIError + if !errors.As(err, &aerr) { + return false + } + _, is := retry.DefaultThrottleErrorCodes[aerr.ErrorCode()] + return is +} + func wrapError(err error, throttleValue *atomic.Value) error { - if request.IsErrorThrottle(err) { + if isThrottleError(err) { // Back off exponentially until an upstream call // either succeeds or returns a non-throttle error. d, _ := throttleValue.Load().(time.Duration) @@ -580,8 +803,10 @@ func wrapError(err error, throttleValue *atomic.Value) error { } throttleValue.Store(d) return rateLimitError{error: err, earliestRetry: time.Now().Add(d)} - } else if isErrorCapacity(err) { + } else if isErrorQuota(err) { return &ec2QuotaError{err} + } else if isErrorCapacity(err) { + return &capacityError{err, true} } else if err != nil { throttleValue.Store(time.Duration(0)) return err @@ -589,3 +814,5 @@ func wrapError(err error, throttleValue *atomic.Value) error { throttleValue.Store(time.Duration(0)) return nil } + +var boolLabelValue = map[bool]string{false: "0", true: "1"}