20755: Support multiple/alternate subnets on EC2.
authorTom Clegg <tom@curii.com>
Thu, 3 Aug 2023 19:03:51 +0000 (15:03 -0400)
committerTom Clegg <tom@curii.com>
Thu, 3 Aug 2023 19:03:51 +0000 (15:03 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/cloud/ec2/ec2.go
lib/cloud/ec2/ec2_test.go
lib/config/config.default.yml

index e2cf5e0f1c3f35e881c882e0f005a241bd75ad8c..526fc1307dc29e8d77d588cf6dae4c9489c0464d 100644 (file)
@@ -14,6 +14,7 @@ import (
        "fmt"
        "math/big"
        "strconv"
+       "strings"
        "sync"
        "sync/atomic"
        "time"
@@ -45,7 +46,7 @@ type ec2InstanceSetConfig struct {
        SecretAccessKey         string
        Region                  string
        SecurityGroupIDs        arvados.StringSet
-       SubnetID                string
+       SubnetID                sliceOrSingleString
        AdminUsername           string
        EBSVolumeType           string
        EBSPrice                float64
@@ -53,6 +54,39 @@ type ec2InstanceSetConfig struct {
        SpotPriceUpdateInterval arvados.Duration
 }
 
+type sliceOrSingleString []string
+
+// UnmarshalJSON unmarshals an array of strings, and also accepts ""
+// as [], and "foo" as ["foo"].
+func (ss *sliceOrSingleString) UnmarshalJSON(data []byte) error {
+       if len(data) == 0 {
+               *ss = nil
+       } else if data[0] == '[' {
+               var slice []string
+               err := json.Unmarshal(data, &slice)
+               if err != nil {
+                       return err
+               }
+               if len(slice) == 0 {
+                       *ss = nil
+               } else {
+                       *ss = slice
+               }
+       } else {
+               var str string
+               err := json.Unmarshal(data, &str)
+               if err != nil {
+                       return err
+               }
+               if str == "" {
+                       *ss = nil
+               } else {
+                       *ss = []string{str}
+               }
+       }
+       return nil
+}
+
 type ec2Interface interface {
        DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
        ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
@@ -66,6 +100,7 @@ type ec2Interface interface {
 
 type ec2InstanceSet struct {
        ec2config              ec2InstanceSetConfig
+       currentSubnetIDIndex   int32
        instanceSetID          cloud.InstanceSetID
        logger                 logrus.FieldLogger
        client                 ec2Interface
@@ -174,7 +209,6 @@ func (instanceSet *ec2InstanceSet) Create(
                                DeleteOnTermination:      aws.Bool(true),
                                DeviceIndex:              aws.Int64(0),
                                Groups:                   aws.StringSlice(groups),
-                               SubnetId:                 &instanceSet.ec2config.SubnetID,
                        }},
                DisableApiTermination:             aws.Bool(false),
                InstanceInitiatedShutdownBehavior: aws.String("terminate"),
@@ -219,7 +253,36 @@ func (instanceSet *ec2InstanceSet) Create(
                }
        }
 
-       rsv, err := instanceSet.client.RunInstances(&rii)
+       var rsv *ec2.Reservation
+       var err error
+       subnets := instanceSet.ec2config.SubnetID
+       currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex))
+       for tryOffset := 0; ; tryOffset++ {
+               tryIndex := 0
+               if len(subnets) > 0 {
+                       tryIndex = (currentSubnetIDIndex + tryOffset) % len(subnets)
+                       rii.NetworkInterfaces[0].SubnetId = aws.String(subnets[tryIndex])
+               }
+               rsv, err = instanceSet.client.RunInstances(&rii)
+               if isErrorSubnetSpecific(err) &&
+                       tryOffset < len(subnets)-1 {
+                       instanceSet.logger.WithError(err).WithField("SubnetID", subnets[tryIndex]).
+                               Warn("RunInstances failed, trying next subnet")
+                       continue
+               }
+               // Succeeded, or exhausted all subnets, or got a
+               // non-subnet-related error.
+               //
+               // We intentionally update currentSubnetIDIndex even
+               // in the non-retryable-failure case here to avoid a
+               // situation where successive calls to Create() keep
+               // returning errors for the same subnet (perhaps
+               // "subnet full") and never reveal the errors for the
+               // other configured subnets (perhaps "subnet ID
+               // invalid").
+               atomic.StoreInt32(&instanceSet.currentSubnetIDIndex, int32(tryIndex))
+               break
+       }
        err = wrapError(err, &instanceSet.throttleDelayCreate)
        if err != nil {
                return nil, err
@@ -548,6 +611,8 @@ func (err rateLimitError) EarliestRetry() time.Time {
 }
 
 var isCodeCapacity = map[string]bool{
+       "InstanceLimitExceeded":             true,
+       "InsufficientAddressCapacity":       true,
        "InsufficientFreeAddressesInSubnet": true,
        "InsufficientInstanceCapacity":      true,
        "InsufficientVolumeCapacity":        true,
@@ -566,6 +631,19 @@ func isErrorCapacity(err error) bool {
        return false
 }
 
+// isErrorSubnetSpecific returns true if the problem encountered by
+// RunInstances might be avoided by trying a different subnet.
+func isErrorSubnetSpecific(err error) bool {
+       aerr, ok := err.(awserr.Error)
+       if !ok {
+               return false
+       }
+       code := aerr.Code()
+       return strings.Contains(code, "Subnet") ||
+               code == "InsufficientInstanceCapacity" ||
+               code == "InsufficientVolumeCapacity"
+}
+
 type ec2QuotaError struct {
        error
 }
index ede7f9de5d2cc8b3dd1ddb17416782c3ec21aee7..2f3d319e0be3b706e4fbdaef0000bdeba8972fc4 100644 (file)
@@ -24,7 +24,9 @@ package ec2
 
 import (
        "encoding/json"
+       "errors"
        "flag"
+       "fmt"
        "sync/atomic"
        "testing"
        "time"
@@ -33,9 +35,11 @@ import (
        "git.arvados.org/arvados.git/lib/dispatchcloud/test"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/config"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "github.com/aws/aws-sdk-go/aws"
        "github.com/aws/aws-sdk-go/aws/awserr"
        "github.com/aws/aws-sdk-go/service/ec2"
+       "github.com/ghodss/yaml"
        "github.com/sirupsen/logrus"
        check "gopkg.in/check.v1"
 )
@@ -47,6 +51,34 @@ func Test(t *testing.T) {
        check.TestingT(t)
 }
 
+type sliceOrStringSuite struct{}
+
+var _ = check.Suite(&sliceOrStringSuite{})
+
+func (s *sliceOrStringSuite) TestUnmarshal(c *check.C) {
+       var conf ec2InstanceSetConfig
+       for _, trial := range []struct {
+               input  string
+               output sliceOrSingleString
+       }{
+               {``, nil},
+               {`""`, nil},
+               {`[]`, nil},
+               {`"foo"`, sliceOrSingleString{"foo"}},
+               {`["foo"]`, sliceOrSingleString{"foo"}},
+               {`[foo]`, sliceOrSingleString{"foo"}},
+               {`["foo", "bar"]`, sliceOrSingleString{"foo", "bar"}},
+               {`[foo-bar, baz]`, sliceOrSingleString{"foo-bar", "baz"}},
+       } {
+               c.Logf("trial: %+v", trial)
+               err := yaml.Unmarshal([]byte("SubnetID: "+trial.input+"\n"), &conf)
+               if !c.Check(err, check.IsNil) {
+                       continue
+               }
+               c.Check(conf.SubnetID, check.DeepEquals, trial.output)
+       }
+}
+
 type EC2InstanceSetSuite struct{}
 
 var _ = check.Suite(&EC2InstanceSetSuite{})
@@ -61,6 +93,10 @@ type ec2stub struct {
        reftime               time.Time
        importKeyPairCalls    []*ec2.ImportKeyPairInput
        describeKeyPairsCalls []*ec2.DescribeKeyPairsInput
+       runInstancesCalls     []*ec2.RunInstancesInput
+       // {subnetID => error}: RunInstances returns error if subnetID
+       // matches.
+       subnetErrorOnRunInstances map[string]error
 }
 
 func (e *ec2stub) ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error) {
@@ -74,6 +110,13 @@ func (e *ec2stub) DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.Descr
 }
 
 func (e *ec2stub) RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error) {
+       e.runInstancesCalls = append(e.runInstancesCalls, input)
+       if len(input.NetworkInterfaces) > 0 && input.NetworkInterfaces[0].SubnetId != nil {
+               err := e.subnetErrorOnRunInstances[*input.NetworkInterfaces[0].SubnetId]
+               if err != nil {
+                       return nil, err
+               }
+       }
        return &ec2.Reservation{Instances: []*ec2.Instance{{
                InstanceId:   aws.String("i-123"),
                InstanceType: aws.String("t2.micro"),
@@ -154,6 +197,19 @@ func (e *ec2stub) TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.T
        return nil, nil
 }
 
+type ec2stubError struct {
+       code    string
+       message string
+}
+
+func (err *ec2stubError) Code() string    { return err.code }
+func (err *ec2stubError) Message() string { return err.message }
+func (err *ec2stubError) Error() string   { return fmt.Sprintf("%s: %s", err.code, err.message) }
+func (err *ec2stubError) OrigErr() error  { return errors.New("stub OrigErr") }
+
+// Ensure ec2stubError satisfies the aws.Error interface
+var _ = awserr.Error(&ec2stubError{})
+
 func GetInstanceSet(c *check.C) (*ec2InstanceSet, cloud.ImageID, arvados.Cluster) {
        cluster := arvados.Cluster{
                InstanceTypes: arvados.InstanceTypeMap(map[string]arvados.InstanceType{
@@ -196,7 +252,7 @@ func GetInstanceSet(c *check.C) (*ec2InstanceSet, cloud.ImageID, arvados.Cluster
        }
        ap := ec2InstanceSet{
                instanceSetID: "test123",
-               logger:        logrus.StandardLogger(),
+               logger:        ctxlog.TestLogger(c),
                client:        &ec2stub{c: c, reftime: time.Now().UTC()},
                keys:          make(map[string]string),
        }
@@ -261,6 +317,61 @@ func (*EC2InstanceSetSuite) TestCreatePreemptible(c *check.C) {
 
 }
 
+func (*EC2InstanceSetSuite) TestCreateFailoverSecondSubnet(c *check.C) {
+       if *live != "" {
+               c.Skip("not applicable in live mode")
+               return
+       }
+
+       ap, img, cluster := GetInstanceSet(c)
+       ap.ec2config.SubnetID = sliceOrSingleString{"subnet-full", "subnet-good"}
+       ap.client.(*ec2stub).subnetErrorOnRunInstances = map[string]error{
+               "subnet-full": &ec2stubError{
+                       code:    "InsufficientFreeAddressesInSubnet",
+                       message: "subnet is full",
+               },
+       }
+       inst, err := ap.Create(cluster.InstanceTypes["tiny"], img, nil, "", nil)
+       c.Check(err, check.IsNil)
+       c.Check(inst, check.NotNil)
+       c.Check(ap.client.(*ec2stub).runInstancesCalls, check.HasLen, 2)
+
+       // Next RunInstances call should try the working subnet first
+       inst, err = ap.Create(cluster.InstanceTypes["tiny"], img, nil, "", nil)
+       c.Check(err, check.IsNil)
+       c.Check(inst, check.NotNil)
+       c.Check(ap.client.(*ec2stub).runInstancesCalls, check.HasLen, 3)
+}
+
+func (*EC2InstanceSetSuite) TestCreateAllSubnetsFailing(c *check.C) {
+       if *live != "" {
+               c.Skip("not applicable in live mode")
+               return
+       }
+
+       ap, img, cluster := GetInstanceSet(c)
+       ap.ec2config.SubnetID = sliceOrSingleString{"subnet-full", "subnet-broken"}
+       ap.client.(*ec2stub).subnetErrorOnRunInstances = map[string]error{
+               "subnet-full": &ec2stubError{
+                       code:    "InsufficientFreeAddressesInSubnet",
+                       message: "subnet is full",
+               },
+               "subnet-broken": &ec2stubError{
+                       code:    "InvalidSubnetId.NotFound",
+                       message: "bogus subnet id",
+               },
+       }
+       _, err := ap.Create(cluster.InstanceTypes["tiny"], img, nil, "", nil)
+       c.Check(err, check.NotNil)
+       c.Check(err, check.ErrorMatches, `.*InvalidSubnetId\.NotFound.*`)
+       c.Check(ap.client.(*ec2stub).runInstancesCalls, check.HasLen, 2)
+
+       _, err = ap.Create(cluster.InstanceTypes["tiny"], img, nil, "", nil)
+       c.Check(err, check.NotNil)
+       c.Check(err, check.ErrorMatches, `.*InsufficientFreeAddressesInSubnet.*`)
+       c.Check(ap.client.(*ec2stub).runInstancesCalls, check.HasLen, 4)
+}
+
 func (*EC2InstanceSetSuite) TestTagInstances(c *check.C) {
        ap, _, _ := GetInstanceSet(c)
        l, err := ap.Instances(nil)
index 723e64ceabf6147a69833d5d53a68511aa1358eb..d14ab46619eac3dadd70c80e2a96caf549e520e5 100644 (file)
@@ -1531,10 +1531,23 @@ Clusters:
           SecretAccessKey: ""
 
           # (ec2) Instance configuration.
+
+          # (ec2) Region, like "us-east-1".
+          Region: ""
+
+          # (ec2) Security group IDs. Omit or use {} to use the
+          # default security group.
           SecurityGroupIDs:
             "SAMPLE": {}
+
+          # (ec2) One or more subnet IDs. Omit or leave empty to let
+          # AWS choose a default subnet from your default VPC. If
+          # multiple subnets are configured here (enclosed in brackets
+          # like [subnet-abc123, subnet-def456]) the cloud dispatcher
+          # will detect subnet-related errors and retry using a
+          # different subnet. Most sites specify one subnet.
           SubnetID: ""
-          Region: ""
+
           EBSVolumeType: gp2
           AdminUsername: debian
           # (ec2) name of the IAMInstanceProfile for instances started by