19146: Remove unneeded special case checks, explain the needed one.
[arvados.git] / lib / cloud / ec2 / ec2.go
index fb7afdda4290087e3d680f87f0ef8b2127fdf25d..52b73f781c6bc63c2e4e2d3242ddc33c642157dc 100644 (file)
@@ -40,13 +40,14 @@ const (
 )
 
 type ec2InstanceSetConfig struct {
-       AccessKeyID      string
-       SecretAccessKey  string
-       Region           string
-       SecurityGroupIDs arvados.StringSet
-       SubnetID         string
-       AdminUsername    string
-       EBSVolumeType    string
+       AccessKeyID        string
+       SecretAccessKey    string
+       Region             string
+       SecurityGroupIDs   arvados.StringSet
+       SubnetID           string
+       AdminUsername      string
+       EBSVolumeType      string
+       IAMInstanceProfile string
 }
 
 type ec2Interface interface {
@@ -230,6 +231,12 @@ func (instanceSet *ec2InstanceSet) Create(
                        }}
        }
 
+       if instanceSet.ec2config.IAMInstanceProfile != "" {
+               rii.IamInstanceProfile = &ec2.IamInstanceProfileSpecification{
+                       Name: aws.String(instanceSet.ec2config.IAMInstanceProfile),
+               }
+       }
+
        rsv, err := instanceSet.client.RunInstances(&rii)
        err = wrapError(err, &instanceSet.throttleDelayCreate)
        if err != nil {
@@ -350,28 +357,33 @@ func (err rateLimitError) EarliestRetry() time.Time {
        return err.earliestRetry
 }
 
-var capacityCodes = map[string]struct{}{
-       "InsufficientInstanceCapacity": {},
-       "VcpuLimitExceeded":            {},
-       "MaxSpotInstanceCountExceeded": {},
+var isCodeCapacity = map[string]bool{
+       "InsufficientInstanceCapacity": true,
+       "VcpuLimitExceeded":            true,
+       "MaxSpotInstanceCountExceeded": true,
 }
 
-// IsErrorCapacity returns whether the error is to be throttled based on its code.
+// isErrorCapacity returns whether the error is to be throttled based on its code.
 // Returns false if error is nil.
-func IsErrorCapacity(err error) bool {
+func isErrorCapacity(err error) bool {
        if aerr, ok := err.(awserr.Error); ok && aerr != nil {
-               return isCodeCapacity(aerr.Code())
+               if _, ok := isCodeCapacity[aerr.Code()]; ok {
+                       return true
+               }
        }
        return false
 }
 
-func isCodeCapacity(code string) bool {
-       _, ok := capacityCodes[code]
-       return ok
+type ec2QuotaError struct {
+       error
+}
+
+func (er *ec2QuotaError) IsQuotaError() bool {
+       return true
 }
 
 func wrapError(err error, throttleValue *atomic.Value) error {
-       if request.IsErrorThrottle(err) || IsErrorCapacity(err) {
+       if request.IsErrorThrottle(err) {
                // Back off exponentially until an upstream call
                // either succeeds or returns a non-throttle error.
                d, _ := throttleValue.Load().(time.Duration)
@@ -383,6 +395,8 @@ func wrapError(err error, throttleValue *atomic.Value) error {
                }
                throttleValue.Store(d)
                return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
+       } else if isErrorCapacity(err) {
+               return &ec2QuotaError{err}
        } else if err != nil {
                throttleValue.Store(time.Duration(0))
                return err