package ec2
import (
+ "context"
"crypto/md5"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
+ "errors"
"fmt"
"math/big"
+ "os"
+ "regexp"
"strconv"
"strings"
"sync"
"git.arvados.org/arvados.git/lib/cloud"
"git.arvados.org/arvados.git/sdk/go/arvados"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/awserr"
- "github.com/aws/aws-sdk-go/aws/credentials"
- "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
- "github.com/aws/aws-sdk-go/aws/ec2metadata"
- "github.com/aws/aws-sdk-go/aws/request"
- "github.com/aws/aws-sdk-go/aws/session"
- "github.com/aws/aws-sdk-go/service/ec2"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws/retry"
+ awsconfig "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/service/ec2"
+ "github.com/aws/aws-sdk-go-v2/service/ec2/types"
+ "github.com/aws/smithy-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
SecurityGroupIDs arvados.StringSet
SubnetID sliceOrSingleString
AdminUsername string
- EBSVolumeType string
+ EBSVolumeType types.VolumeType
EBSPrice float64
IAMInstanceProfile string
SpotPriceUpdateInterval arvados.Duration
}
type ec2Interface interface {
- DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
- ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
- RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
- DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
- DescribeInstanceStatusPages(input *ec2.DescribeInstanceStatusInput, fn func(*ec2.DescribeInstanceStatusOutput, bool) bool) error
- DescribeSpotPriceHistoryPages(input *ec2.DescribeSpotPriceHistoryInput, fn func(*ec2.DescribeSpotPriceHistoryOutput, bool) bool) error
- CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
- TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
+ DescribeKeyPairs(context.Context, *ec2.DescribeKeyPairsInput, ...func(*ec2.Options)) (*ec2.DescribeKeyPairsOutput, error)
+ ImportKeyPair(context.Context, *ec2.ImportKeyPairInput, ...func(*ec2.Options)) (*ec2.ImportKeyPairOutput, error)
+ RunInstances(context.Context, *ec2.RunInstancesInput, ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
+ DescribeInstances(context.Context, *ec2.DescribeInstancesInput, ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error)
+ DescribeInstanceStatus(context.Context, *ec2.DescribeInstanceStatusInput, ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error)
+ DescribeSpotPriceHistory(context.Context, *ec2.DescribeSpotPriceHistoryInput, ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error)
+ CreateTags(context.Context, *ec2.CreateTagsInput, ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
+ TerminateInstances(context.Context, *ec2.TerminateInstancesInput, ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error)
}
type ec2InstanceSet struct {
return nil, err
}
- sess, err := session.NewSession()
+ if len(instanceSet.ec2config.AccessKeyID)+len(instanceSet.ec2config.SecretAccessKey) > 0 {
+ // AWS SDK will use credentials in environment vars if
+ // present.
+ os.Setenv("AWS_ACCESS_KEY_ID", instanceSet.ec2config.AccessKeyID)
+ os.Setenv("AWS_SECRET_ACCESS_KEY", instanceSet.ec2config.SecretAccessKey)
+ } else {
+ os.Unsetenv("AWS_ACCESS_KEY_ID")
+ os.Unsetenv("AWS_SECRET_ACCESS_KEY")
+ }
+ awsConfig, err := awsconfig.LoadDefaultConfig(context.TODO(),
+ awsconfig.WithRegion(instanceSet.ec2config.Region))
if err != nil {
return nil, err
}
- // First try any static credentials, fall back to an IAM instance profile/role
- creds := credentials.NewChainCredentials(
- []credentials.Provider{
- &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: instanceSet.ec2config.AccessKeyID, SecretAccessKey: instanceSet.ec2config.SecretAccessKey}},
- &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)},
- })
- awsConfig := aws.NewConfig().WithCredentials(creds).WithRegion(instanceSet.ec2config.Region)
- instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
+ instanceSet.client = ec2.NewFromConfig(awsConfig)
instanceSet.keys = make(map[string]string)
if instanceSet.ec2config.EBSVolumeType == "" {
instanceSet.ec2config.EBSVolumeType = "gp2"
initCommand cloud.InitCommand,
publicKey ssh.PublicKey) (cloud.Instance, error) {
- ec2tags := []*ec2.Tag{}
+ ec2tags := []types.Tag{}
for k, v := range newTags {
- ec2tags = append(ec2tags, &ec2.Tag{
+ ec2tags = append(ec2tags, types.Tag{
Key: aws.String(k),
Value: aws.String(v),
})
rii := ec2.RunInstancesInput{
ImageId: aws.String(string(imageID)),
- InstanceType: &instanceType.ProviderType,
- MaxCount: aws.Int64(1),
- MinCount: aws.Int64(1),
-
- NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
- {
- AssociatePublicIpAddress: aws.Bool(false),
- DeleteOnTermination: aws.Bool(true),
- DeviceIndex: aws.Int64(0),
- Groups: aws.StringSlice(groups),
- }},
+ InstanceType: types.InstanceType(instanceType.ProviderType),
+ MaxCount: aws.Int32(1),
+ MinCount: aws.Int32(1),
+
+ NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{{
+ AssociatePublicIpAddress: aws.Bool(false),
+ DeleteOnTermination: aws.Bool(true),
+ DeviceIndex: aws.Int32(0),
+ Groups: groups,
+ }},
DisableApiTermination: aws.Bool(false),
- InstanceInitiatedShutdownBehavior: aws.String("terminate"),
- TagSpecifications: []*ec2.TagSpecification{
+ InstanceInitiatedShutdownBehavior: types.ShutdownBehaviorTerminate,
+ TagSpecifications: []types.TagSpecification{
{
- ResourceType: aws.String("instance"),
+ ResourceType: types.ResourceTypeInstance,
Tags: ec2tags,
}},
+ MetadataOptions: &types.InstanceMetadataOptionsRequest{
+ // Require IMDSv2, as described at
+ // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-IMDS-new-instances.html
+ HttpEndpoint: types.InstanceMetadataEndpointStateEnabled,
+ HttpTokens: types.HttpTokensStateRequired,
+ },
UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
}
}
if instanceType.AddedScratch > 0 {
- rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
+ rii.BlockDeviceMappings = []types.BlockDeviceMapping{{
DeviceName: aws.String("/dev/xvdt"),
- Ebs: &ec2.EbsBlockDevice{
+ Ebs: &types.EbsBlockDevice{
DeleteOnTermination: aws.Bool(true),
- VolumeSize: aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
- VolumeType: &instanceSet.ec2config.EBSVolumeType,
+ VolumeSize: aws.Int32(int32((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30)),
+ VolumeType: instanceSet.ec2config.EBSVolumeType,
}}}
}
if instanceType.Preemptible {
- rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
- MarketType: aws.String("spot"),
- SpotOptions: &ec2.SpotMarketOptions{
- InstanceInterruptionBehavior: aws.String("terminate"),
+ rii.InstanceMarketOptions = &types.InstanceMarketOptionsRequest{
+ MarketType: types.MarketTypeSpot,
+ SpotOptions: &types.SpotMarketOptions{
+ InstanceInterruptionBehavior: types.InstanceInterruptionBehaviorTerminate,
MaxPrice: aws.String(fmt.Sprintf("%v", instanceType.Price)),
}}
}
if instanceSet.ec2config.IAMInstanceProfile != "" {
- rii.IamInstanceProfile = &ec2.IamInstanceProfileSpecification{
+ rii.IamInstanceProfile = &types.IamInstanceProfileSpecification{
Name: aws.String(instanceSet.ec2config.IAMInstanceProfile),
}
}
- var rsv *ec2.Reservation
- var err error
+ var rsv *ec2.RunInstancesOutput
+ var errToReturn error
subnets := instanceSet.ec2config.SubnetID
currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex))
for tryOffset := 0; ; tryOffset++ {
trySubnet = subnets[tryIndex]
rii.NetworkInterfaces[0].SubnetId = aws.String(trySubnet)
}
- rsv, err = instanceSet.client.RunInstances(&rii)
+ var err error
+ rsv, err = instanceSet.client.RunInstances(context.TODO(), &rii)
instanceSet.mInstanceStarts.WithLabelValues(trySubnet, boolLabelValue[err == nil]).Add(1)
+ if !isErrorCapacity(errToReturn) || isErrorCapacity(err) {
+ // We want to return the last capacity error,
+ // if any; otherwise the last non-capacity
+ // error.
+ errToReturn = err
+ }
if isErrorSubnetSpecific(err) &&
tryOffset < len(subnets)-1 {
instanceSet.logger.WithError(err).WithField("SubnetID", subnets[tryIndex]).
atomic.StoreInt32(&instanceSet.currentSubnetIDIndex, int32(tryIndex))
break
}
- err = wrapError(err, &instanceSet.throttleDelayCreate)
- if err != nil {
- return nil, err
+ if rsv == nil || len(rsv.Instances) == 0 {
+ return nil, wrapError(errToReturn, &instanceSet.throttleDelayCreate)
}
return &ec2Instance{
provider: instanceSet,
if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok {
return keyname, nil
}
- keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
- Filters: []*ec2.Filter{{
+ keyout, err := instanceSet.client.DescribeKeyPairs(context.TODO(), &ec2.DescribeKeyPairsInput{
+ Filters: []types.Filter{{
Name: aws.String("fingerprint"),
- Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
+ Values: []string{md5keyFingerprint, sha1keyFingerprint},
}},
})
if err != nil {
return *(keyout.KeyPairs[0].KeyName), nil
}
keyname := "arvados-dispatch-keypair-" + md5keyFingerprint
- _, err = instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
+ _, err = instanceSet.client.ImportKeyPair(context.TODO(), &ec2.ImportKeyPairInput{
KeyName: &keyname,
PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
})
}
func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
- var filters []*ec2.Filter
+ var filters []types.Filter
for k, v := range tags {
- filters = append(filters, &ec2.Filter{
+ filters = append(filters, types.Filter{
Name: aws.String("tag:" + k),
- Values: []*string{aws.String(v)},
+ Values: []string{v},
})
}
needAZs := false
dii := &ec2.DescribeInstancesInput{Filters: filters}
for {
- dio, err := instanceSet.client.DescribeInstances(dii)
+ dio, err := instanceSet.client.DescribeInstances(context.TODO(), dii)
err = wrapError(err, &instanceSet.throttleDelayInstances)
if err != nil {
return nil, err
for _, rsv := range dio.Reservations {
for _, inst := range rsv.Instances {
- if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
+ switch inst.State.Name {
+ case types.InstanceStateNameShuttingDown:
+ case types.InstanceStateNameTerminated:
+ default:
instances = append(instances, &ec2Instance{
provider: instanceSet,
instance: inst,
})
- if aws.StringValue(inst.InstanceLifecycle) == "spot" {
+ if inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
needAZs = true
}
}
}
if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 {
az := map[string]string{}
- err := instanceSet.client.DescribeInstanceStatusPages(&ec2.DescribeInstanceStatusInput{
- IncludeAllInstances: aws.Bool(true),
- }, func(page *ec2.DescribeInstanceStatusOutput, lastPage bool) bool {
+ disi := &ec2.DescribeInstanceStatusInput{IncludeAllInstances: aws.Bool(true)}
+ for {
+ page, err := instanceSet.client.DescribeInstanceStatus(context.TODO(), disi)
+ if err != nil {
+ instanceSet.logger.Warnf("error getting instance statuses: %s", err)
+ break
+ }
for _, ent := range page.InstanceStatuses {
az[*ent.InstanceId] = *ent.AvailabilityZone
}
- return true
- })
- if err != nil {
- instanceSet.logger.Warnf("error getting instance statuses: %s", err)
+ if page.NextToken == nil {
+ break
+ }
+ disi.NextToken = page.NextToken
}
for _, inst := range instances {
inst := inst.(*ec2Instance)
updateTime := time.Now()
staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
needUpdate := false
- allTypes := map[string]bool{}
+ allTypes := map[types.InstanceType]bool{}
for _, inst := range instances {
ec2inst := inst.(*ec2Instance).instance
- if aws.StringValue(ec2inst.InstanceLifecycle) == "spot" {
+ if ec2inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
pk := priceKey{
- instanceType: *ec2inst.InstanceType,
+ instanceType: string(ec2inst.InstanceType),
spot: true,
availabilityZone: inst.(*ec2Instance).availabilityZone,
}
if instanceSet.pricesUpdated[pk].Before(staleTime) {
needUpdate = true
}
- allTypes[*ec2inst.InstanceType] = true
+ allTypes[ec2inst.InstanceType] = true
}
}
if !needUpdate {
return
}
- var typeFilterValues []*string
+ var typeFilterValues []string
for instanceType := range allTypes {
- typeFilterValues = append(typeFilterValues, aws.String(instanceType))
+ typeFilterValues = append(typeFilterValues, string(instanceType))
}
// Get 3x update interval worth of pricing data. (Ideally the
// AWS API would tell us "we have shown you all of the price
// row.
dsphi := &ec2.DescribeSpotPriceHistoryInput{
StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())),
- Filters: []*ec2.Filter{
- &ec2.Filter{Name: aws.String("instance-type"), Values: typeFilterValues},
- &ec2.Filter{Name: aws.String("product-description"), Values: []*string{aws.String("Linux/UNIX")}},
+ Filters: []types.Filter{
+ types.Filter{Name: aws.String("instance-type"), Values: typeFilterValues},
+ types.Filter{Name: aws.String("product-description"), Values: []string{"Linux/UNIX"}},
},
}
- err := instanceSet.client.DescribeSpotPriceHistoryPages(dsphi, func(page *ec2.DescribeSpotPriceHistoryOutput, lastPage bool) bool {
+ for {
+ page, err := instanceSet.client.DescribeSpotPriceHistory(context.TODO(), dsphi)
+ if err != nil {
+ instanceSet.logger.Warnf("error retrieving spot instance prices: %s", err)
+ break
+ }
for _, ent := range page.SpotPriceHistory {
- if ent.InstanceType == nil || ent.SpotPrice == nil || ent.Timestamp == nil {
+ if ent.InstanceType == "" || ent.SpotPrice == nil || ent.Timestamp == nil {
// bogus record?
continue
}
continue
}
pk := priceKey{
- instanceType: *ent.InstanceType,
+ instanceType: string(ent.InstanceType),
spot: true,
availabilityZone: *ent.AvailabilityZone,
}
})
instanceSet.pricesUpdated[pk] = updateTime
}
- return true
- })
- if err != nil {
- instanceSet.logger.Warnf("error retrieving spot instance prices: %s", err)
+ if page.NextToken == nil {
+ break
+ }
+ dsphi.NextToken = page.NextToken
}
expiredTime := updateTime.Add(-64 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
type ec2Instance struct {
provider *ec2InstanceSet
- instance *ec2.Instance
+ instance types.Instance
availabilityZone string // sometimes available for spot instances
}
}
func (inst *ec2Instance) ProviderType() string {
- return *inst.instance.InstanceType
+ return string(inst.instance.InstanceType)
}
func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
- var ec2tags []*ec2.Tag
+ var ec2tags []types.Tag
for k, v := range newTags {
- ec2tags = append(ec2tags, &ec2.Tag{
+ ec2tags = append(ec2tags, types.Tag{
Key: aws.String(k),
Value: aws.String(v),
})
}
- _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
- Resources: []*string{inst.instance.InstanceId},
+ _, err := inst.provider.client.CreateTags(context.TODO(), &ec2.CreateTagsInput{
+ Resources: []string{*inst.instance.InstanceId},
Tags: ec2tags,
})
}
func (inst *ec2Instance) Destroy() error {
- _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
- InstanceIds: []*string{inst.instance.InstanceId},
+ _, err := inst.provider.client.TerminateInstances(context.TODO(), &ec2.TerminateInstancesInput{
+ InstanceIds: []string{*inst.instance.InstanceId},
})
return err
}
// inst.provider.prices only for spot instances, so if
// spot==false here, we will return no data.
pk := priceKey{
- instanceType: *inst.instance.InstanceType,
- spot: aws.StringValue(inst.instance.InstanceLifecycle) == "spot",
+ instanceType: string(inst.instance.InstanceType),
+ spot: inst.instance.InstanceLifecycle == types.InstanceLifecycleTypeSpot,
availabilityZone: inst.availabilityZone,
}
var prices []cloud.InstancePrice
return err.earliestRetry
}
-var isCodeCapacity = map[string]bool{
+type capacityError struct {
+ error
+ isInstanceTypeSpecific bool
+}
+
+func (er *capacityError) IsCapacityError() bool {
+ return true
+}
+
+func (er *capacityError) IsInstanceTypeSpecific() bool {
+ return er.isInstanceTypeSpecific
+}
+
+var isCodeQuota = map[string]bool{
"InstanceLimitExceeded": true,
"InsufficientAddressCapacity": true,
"InsufficientFreeAddressesInSubnet": true,
- "InsufficientInstanceCapacity": true,
"InsufficientVolumeCapacity": true,
"MaxSpotInstanceCountExceeded": true,
"VcpuLimitExceeded": true,
}
-// isErrorCapacity returns whether the error is to be throttled based on its code.
+// isErrorQuota returns whether the error indicates we have reached
+// some usage quota/limit -- i.e., immediately retrying with an equal
+// or larger instance type will probably not work.
+//
// Returns false if error is nil.
-func isErrorCapacity(err error) bool {
- if aerr, ok := err.(awserr.Error); ok && aerr != nil {
- if _, ok := isCodeCapacity[aerr.Code()]; ok {
+func isErrorQuota(err error) bool {
+ var aerr smithy.APIError
+ if errors.As(err, &aerr) {
+ if _, ok := isCodeQuota[aerr.ErrorCode()]; ok {
return true
}
}
return false
}
+var reSubnetSpecificInvalidParameterMessage = regexp.MustCompile(`(?ms).*( subnet |sufficient free [Ii]pv[46] addresses).*`)
+
// isErrorSubnetSpecific returns true if the problem encountered by
// RunInstances might be avoided by trying a different subnet.
func isErrorSubnetSpecific(err error) bool {
- aerr, ok := err.(awserr.Error)
- if !ok {
+ var aerr smithy.APIError
+ if !errors.As(err, &aerr) {
return false
}
- code := aerr.Code()
+ code := aerr.ErrorCode()
return strings.Contains(code, "Subnet") ||
code == "InsufficientInstanceCapacity" ||
- code == "InsufficientVolumeCapacity"
+ code == "InsufficientVolumeCapacity" ||
+ code == "Unsupported" ||
+ // See TestIsErrorSubnetSpecific for examples of why
+ // we look for substrings in code/message instead of
+ // only using specific codes here.
+ (strings.Contains(code, "InvalidParameter") &&
+ reSubnetSpecificInvalidParameterMessage.MatchString(aerr.ErrorMessage()))
+}
+
+// isErrorCapacity returns true if the error indicates lack of
+// capacity (either temporary or permanent) to run a specific instance
+// type -- i.e., retrying with a different instance type might
+// succeed.
+func isErrorCapacity(err error) bool {
+ var aerr smithy.APIError
+ if !errors.As(err, &aerr) {
+ return false
+ }
+ code := aerr.ErrorCode()
+ return code == "InsufficientInstanceCapacity" ||
+ (code == "Unsupported" && strings.Contains(aerr.ErrorMessage(), "requested instance type"))
}
type ec2QuotaError struct {
return true
}
+func isThrottleError(err error) bool {
+ var aerr smithy.APIError
+ if !errors.As(err, &aerr) {
+ return false
+ }
+ _, is := retry.DefaultThrottleErrorCodes[aerr.ErrorCode()]
+ return is
+}
+
func wrapError(err error, throttleValue *atomic.Value) error {
- if request.IsErrorThrottle(err) {
+ if isThrottleError(err) {
// Back off exponentially until an upstream call
// either succeeds or returns a non-throttle error.
d, _ := throttleValue.Load().(time.Duration)
}
throttleValue.Store(d)
return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
- } else if isErrorCapacity(err) {
+ } else if isErrorQuota(err) {
return &ec2QuotaError{err}
+ } else if isErrorCapacity(err) {
+ return &capacityError{err, true}
} else if err != nil {
throttleValue.Store(time.Duration(0))
return err