20485: Test handling of nil publickey argument.
[arvados.git] / lib / cloud / ec2 / ec2.go
index 2a5eea484590c646e8398762d5692a441f2dc15e..e2cf5e0f1c3f35e881c882e0f005a241bd75ad8c 100644 (file)
@@ -48,6 +48,7 @@ type ec2InstanceSetConfig struct {
        SubnetID                string
        AdminUsername           string
        EBSVolumeType           string
+       EBSPrice                float64
        IAMInstanceProfile      string
        SpotPriceUpdateInterval arvados.Duration
 }
@@ -148,40 +149,6 @@ func (instanceSet *ec2InstanceSet) Create(
        initCommand cloud.InitCommand,
        publicKey ssh.PublicKey) (cloud.Instance, error) {
 
-       md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
-       if err != nil {
-               return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
-       }
-       instanceSet.keysMtx.Lock()
-       var keyname string
-       var ok bool
-       if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
-               keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
-                       Filters: []*ec2.Filter{{
-                               Name:   aws.String("fingerprint"),
-                               Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
-                       }},
-               })
-               if err != nil {
-                       return nil, fmt.Errorf("Could not search for keypair: %v", err)
-               }
-
-               if len(keyout.KeyPairs) > 0 {
-                       keyname = *(keyout.KeyPairs[0].KeyName)
-               } else {
-                       keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
-                       _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
-                               KeyName:           &keyname,
-                               PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
-                       })
-                       if err != nil {
-                               return nil, fmt.Errorf("Could not import keypair: %v", err)
-                       }
-               }
-               instanceSet.keys[md5keyFingerprint] = keyname
-       }
-       instanceSet.keysMtx.Unlock()
-
        ec2tags := []*ec2.Tag{}
        for k, v := range newTags {
                ec2tags = append(ec2tags, &ec2.Tag{
@@ -200,7 +167,6 @@ func (instanceSet *ec2InstanceSet) Create(
                InstanceType: &instanceType.ProviderType,
                MaxCount:     aws.Int64(1),
                MinCount:     aws.Int64(1),
-               KeyName:      &keyname,
 
                NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
                        {
@@ -220,6 +186,14 @@ func (instanceSet *ec2InstanceSet) Create(
                UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
        }
 
+       if publicKey != nil {
+               keyname, err := instanceSet.getKeyName(publicKey)
+               if err != nil {
+                       return nil, err
+               }
+               rii.KeyName = &keyname
+       }
+
        if instanceType.AddedScratch > 0 {
                rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
                        DeviceName: aws.String("/dev/xvdt"),
@@ -245,10 +219,6 @@ func (instanceSet *ec2InstanceSet) Create(
                }
        }
 
-       if instanceSet.ec2config.SpotPriceUpdateInterval <= 0 {
-               instanceSet.ec2config.SpotPriceUpdateInterval = arvados.Duration(24 * time.Hour)
-       }
-
        rsv, err := instanceSet.client.RunInstances(&rii)
        err = wrapError(err, &instanceSet.throttleDelayCreate)
        if err != nil {
@@ -260,6 +230,40 @@ func (instanceSet *ec2InstanceSet) Create(
        }, nil
 }
 
+func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) {
+       instanceSet.keysMtx.Lock()
+       defer instanceSet.keysMtx.Unlock()
+       md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
+       if err != nil {
+               return "", fmt.Errorf("Could not make key fingerprint: %v", err)
+       }
+       if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok {
+               return keyname, nil
+       }
+       keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
+               Filters: []*ec2.Filter{{
+                       Name:   aws.String("fingerprint"),
+                       Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
+               }},
+       })
+       if err != nil {
+               return "", fmt.Errorf("Could not search for keypair: %v", err)
+       }
+       if len(keyout.KeyPairs) > 0 {
+               return *(keyout.KeyPairs[0].KeyName), nil
+       }
+       keyname := "arvados-dispatch-keypair-" + md5keyFingerprint
+       _, err = instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
+               KeyName:           &keyname,
+               PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
+       })
+       if err != nil {
+               return "", fmt.Errorf("Could not import keypair: %v", err)
+       }
+       instanceSet.keys[md5keyFingerprint] = keyname
+       return keyname, nil
+}
+
 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
        var filters []*ec2.Filter
        for k, v := range tags {
@@ -295,7 +299,7 @@ func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances
                }
                dii.NextToken = dio.NextToken
        }
-       if needAZs {
+       if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 {
                az := map[string]string{}
                err := instanceSet.client.DescribeInstanceStatusPages(&ec2.DescribeInstanceStatusInput{
                        IncludeAllInstances: aws.Bool(true),
@@ -340,7 +344,8 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance)
        updateTime := time.Now()
        staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
        needUpdate := false
-       var typeFilterValues []*string
+       allTypes := map[string]bool{}
+
        for _, inst := range instances {
                ec2inst := inst.(*ec2Instance).instance
                if aws.StringValue(ec2inst.InstanceLifecycle) == "spot" {
@@ -352,12 +357,16 @@ func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance)
                        if instanceSet.pricesUpdated[pk].Before(staleTime) {
                                needUpdate = true
                        }
-                       typeFilterValues = append(typeFilterValues, ec2inst.InstanceType)
+                       allTypes[*ec2inst.InstanceType] = true
                }
        }
        if !needUpdate {
                return
        }
+       var typeFilterValues []*string
+       for instanceType := range allTypes {
+               typeFilterValues = append(typeFilterValues, aws.String(instanceType))
+       }
        // Get 3x update interval worth of pricing data. (Ideally the
        // AWS API would tell us "we have shown you all of the price
        // changes up to time T", but it doesn't, so we'll just ask
@@ -506,15 +515,27 @@ func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
 // Spot price that is in effect when your Spot Instance is running."
 // (The use of the phrase "is running", as opposed to "was launched",
 // hints that pricing is dynamic.)
-func (inst *ec2Instance) PriceHistory() []cloud.InstancePrice {
+func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.InstancePrice {
        inst.provider.pricesLock.Lock()
        defer inst.provider.pricesLock.Unlock()
+       // Note updateSpotPrices currently populates
+       // inst.provider.prices only for spot instances, so if
+       // spot==false here, we will return no data.
        pk := priceKey{
                instanceType:     *inst.instance.InstanceType,
                spot:             aws.StringValue(inst.instance.InstanceLifecycle) == "spot",
                availabilityZone: inst.availabilityZone,
        }
-       return inst.provider.prices[pk]
+       var prices []cloud.InstancePrice
+       for _, price := range inst.provider.prices[pk] {
+               // ceil(added scratch space in GiB)
+               gib := (instType.AddedScratch + 1<<30 - 1) >> 30
+               monthly := inst.provider.ec2config.EBSPrice * float64(gib)
+               hourly := monthly / 30 / 24
+               price.Price += hourly
+               prices = append(prices, price)
+       }
+       return prices
 }
 
 type rateLimitError struct {
@@ -527,9 +548,11 @@ func (err rateLimitError) EarliestRetry() time.Time {
 }
 
 var isCodeCapacity = map[string]bool{
-       "InsufficientInstanceCapacity": true,
-       "VcpuLimitExceeded":            true,
-       "MaxSpotInstanceCountExceeded": true,
+       "InsufficientFreeAddressesInSubnet": true,
+       "InsufficientInstanceCapacity":      true,
+       "InsufficientVolumeCapacity":        true,
+       "MaxSpotInstanceCountExceeded":      true,
+       "VcpuLimitExceeded":                 true,
 }
 
 // isErrorCapacity returns whether the error is to be throttled based on its code.