20485: Test handling of nil publickey argument.
[arvados.git] / lib / cloud / ec2 / ec2.go
index f80e9bd1a52553421befdc4cc504382202328b55..e2cf5e0f1c3f35e881c882e0f005a241bd75ad8c 100644 (file)
@@ -48,6 +48,7 @@ type ec2InstanceSetConfig struct {
        SubnetID                string
        AdminUsername           string
        EBSVolumeType           string
+       EBSPrice                float64
        IAMInstanceProfile      string
        SpotPriceUpdateInterval arvados.Duration
 }
@@ -148,40 +149,6 @@ func (instanceSet *ec2InstanceSet) Create(
        initCommand cloud.InitCommand,
        publicKey ssh.PublicKey) (cloud.Instance, error) {
 
-       md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
-       if err != nil {
-               return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
-       }
-       instanceSet.keysMtx.Lock()
-       var keyname string
-       var ok bool
-       if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
-               keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
-                       Filters: []*ec2.Filter{{
-                               Name:   aws.String("fingerprint"),
-                               Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
-                       }},
-               })
-               if err != nil {
-                       return nil, fmt.Errorf("Could not search for keypair: %v", err)
-               }
-
-               if len(keyout.KeyPairs) > 0 {
-                       keyname = *(keyout.KeyPairs[0].KeyName)
-               } else {
-                       keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
-                       _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
-                               KeyName:           &keyname,
-                               PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
-                       })
-                       if err != nil {
-                               return nil, fmt.Errorf("Could not import keypair: %v", err)
-                       }
-               }
-               instanceSet.keys[md5keyFingerprint] = keyname
-       }
-       instanceSet.keysMtx.Unlock()
-
        ec2tags := []*ec2.Tag{}
        for k, v := range newTags {
                ec2tags = append(ec2tags, &ec2.Tag{
@@ -200,7 +167,6 @@ func (instanceSet *ec2InstanceSet) Create(
                InstanceType: &instanceType.ProviderType,
                MaxCount:     aws.Int64(1),
                MinCount:     aws.Int64(1),
-               KeyName:      &keyname,
 
                NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
                        {
@@ -220,6 +186,14 @@ func (instanceSet *ec2InstanceSet) Create(
                UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
        }
 
+       if publicKey != nil {
+               keyname, err := instanceSet.getKeyName(publicKey)
+               if err != nil {
+                       return nil, err
+               }
+               rii.KeyName = &keyname
+       }
+
        if instanceType.AddedScratch > 0 {
                rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
                        DeviceName: aws.String("/dev/xvdt"),
@@ -245,10 +219,6 @@ func (instanceSet *ec2InstanceSet) Create(
                }
        }
 
-       if instanceSet.ec2config.SpotPriceUpdateInterval <= 0 {
-               instanceSet.ec2config.SpotPriceUpdateInterval = arvados.Duration(24 * time.Hour)
-       }
-
        rsv, err := instanceSet.client.RunInstances(&rii)
        err = wrapError(err, &instanceSet.throttleDelayCreate)
        if err != nil {
@@ -260,6 +230,40 @@ func (instanceSet *ec2InstanceSet) Create(
        }, nil
 }
 
+func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) {
+       instanceSet.keysMtx.Lock()
+       defer instanceSet.keysMtx.Unlock()
+       md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
+       if err != nil {
+               return "", fmt.Errorf("Could not make key fingerprint: %v", err)
+       }
+       if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok {
+               return keyname, nil
+       }
+       keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
+               Filters: []*ec2.Filter{{
+                       Name:   aws.String("fingerprint"),
+                       Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
+               }},
+       })
+       if err != nil {
+               return "", fmt.Errorf("Could not search for keypair: %v", err)
+       }
+       if len(keyout.KeyPairs) > 0 {
+               return *(keyout.KeyPairs[0].KeyName), nil
+       }
+       keyname := "arvados-dispatch-keypair-" + md5keyFingerprint
+       _, err = instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
+               KeyName:           &keyname,
+               PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
+       })
+       if err != nil {
+               return "", fmt.Errorf("Could not import keypair: %v", err)
+       }
+       instanceSet.keys[md5keyFingerprint] = keyname
+       return keyname, nil
+}
+
 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
        var filters []*ec2.Filter
        for k, v := range tags {
@@ -295,14 +299,19 @@ func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances
                }
                dii.NextToken = dio.NextToken
        }
-       if needAZs {
+       if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 {
                az := map[string]string{}
-               instanceSet.client.DescribeInstanceStatusPages(&ec2.DescribeInstanceStatusInput{}, func(page *ec2.DescribeInstanceStatusOutput, lastPage bool) bool {
+               err := instanceSet.client.DescribeInstanceStatusPages(&ec2.DescribeInstanceStatusInput{
+                       IncludeAllInstances: aws.Bool(true),
+               }, func(page *ec2.DescribeInstanceStatusOutput, lastPage bool) bool {
                        for _, ent := range page.InstanceStatuses {
                                az[*ent.InstanceId] = *ent.AvailabilityZone
                        }
                        return true
                })
+               if err != nil {
+                       instanceSet.logger.Warnf("error getting instance statuses: %s", err)
+               }
                for _, inst := range instances {
                        inst := inst.(*ec2Instance)
                        inst.availabilityZone = az[*inst.instance.InstanceId]
@@ -335,7 +344,8 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance)
        updateTime := time.Now()
        staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
        needUpdate := false
-       var typeFilterValues []*string
+       allTypes := map[string]bool{}
+
        for _, inst := range instances {
                ec2inst := inst.(*ec2Instance).instance
                if aws.StringValue(ec2inst.InstanceLifecycle) == "spot" {
@@ -347,12 +357,16 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance)
                        if instanceSet.pricesUpdated[pk].Before(staleTime) {
                                needUpdate = true
                        }
-                       typeFilterValues = append(typeFilterValues, ec2inst.InstanceType)
+                       allTypes[*ec2inst.InstanceType] = true
                }
        }
        if !needUpdate {
                return
        }
+       var typeFilterValues []*string
+       for instanceType := range allTypes {
+               typeFilterValues = append(typeFilterValues, aws.String(instanceType))
+       }
        // Get 3x update interval worth of pricing data. (Ideally the
        // AWS API would tell us "we have shown you all of the price
        // changes up to time T", but it doesn't, so we'll just ask
@@ -363,7 +377,8 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance)
        dsphi := &ec2.DescribeSpotPriceHistoryInput{
                StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())),
                Filters: []*ec2.Filter{
-                       &ec2.Filter{Name: aws.String("InstanceType"), Values: typeFilterValues},
+                       &ec2.Filter{Name: aws.String("instance-type"), Values: typeFilterValues},
+                       &ec2.Filter{Name: aws.String("product-description"), Values: []*string{aws.String("Linux/UNIX")}},
                },
        }
        err := instanceSet.client.DescribeSpotPriceHistoryPages(dsphi, func(page *ec2.DescribeSpotPriceHistoryOutput, lastPage bool) bool {
@@ -500,14 +515,27 @@ func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
 // Spot price that is in effect when your Spot Instance is running."
 // (The use of the phrase "is running", as opposed to "was launched",
 // hints that pricing is dynamic.)
-func (inst *ec2Instance) PriceHistory() []cloud.InstancePrice {
+func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.InstancePrice {
        inst.provider.pricesLock.Lock()
        defer inst.provider.pricesLock.Unlock()
-       return inst.provider.prices[priceKey{
+       // Note updateSpotPrices currently populates
+       // inst.provider.prices only for spot instances, so if
+       // spot==false here, we will return no data.
+       pk := priceKey{
                instanceType:     *inst.instance.InstanceType,
                spot:             aws.StringValue(inst.instance.InstanceLifecycle) == "spot",
                availabilityZone: inst.availabilityZone,
-       }]
+       }
+       var prices []cloud.InstancePrice
+       for _, price := range inst.provider.prices[pk] {
+               // ceil(added scratch space in GiB)
+               gib := (instType.AddedScratch + 1<<30 - 1) >> 30
+               monthly := inst.provider.ec2config.EBSPrice * float64(gib)
+               hourly := monthly / 30 / 24
+               price.Price += hourly
+               prices = append(prices, price)
+       }
+       return prices
 }
 
 type rateLimitError struct {
@@ -520,9 +548,11 @@ func (err rateLimitError) EarliestRetry() time.Time {
 }
 
 var isCodeCapacity = map[string]bool{
-       "InsufficientInstanceCapacity": true,
-       "VcpuLimitExceeded":            true,
-       "MaxSpotInstanceCountExceeded": true,
+       "InsufficientFreeAddressesInSubnet": true,
+       "InsufficientInstanceCapacity":      true,
+       "InsufficientVolumeCapacity":        true,
+       "MaxSpotInstanceCountExceeded":      true,
+       "VcpuLimitExceeded":                 true,
 }
 
 // isErrorCapacity returns whether the error is to be throttled based on its code.