20485: Test handling of nil publickey argument.
[arvados.git] / lib / cloud / ec2 / ec2.go
index a74f12561003a6f8763311be4170c3e38e12d8ad..e2cf5e0f1c3f35e881c882e0f005a241bd75ad8c 100644 (file)
@@ -149,40 +149,6 @@ func (instanceSet *ec2InstanceSet) Create(
        initCommand cloud.InitCommand,
        publicKey ssh.PublicKey) (cloud.Instance, error) {
 
-       md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
-       if err != nil {
-               return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
-       }
-       instanceSet.keysMtx.Lock()
-       var keyname string
-       var ok bool
-       if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
-               keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
-                       Filters: []*ec2.Filter{{
-                               Name:   aws.String("fingerprint"),
-                               Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
-                       }},
-               })
-               if err != nil {
-                       return nil, fmt.Errorf("Could not search for keypair: %v", err)
-               }
-
-               if len(keyout.KeyPairs) > 0 {
-                       keyname = *(keyout.KeyPairs[0].KeyName)
-               } else {
-                       keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
-                       _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
-                               KeyName:           &keyname,
-                               PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
-                       })
-                       if err != nil {
-                               return nil, fmt.Errorf("Could not import keypair: %v", err)
-                       }
-               }
-               instanceSet.keys[md5keyFingerprint] = keyname
-       }
-       instanceSet.keysMtx.Unlock()
-
        ec2tags := []*ec2.Tag{}
        for k, v := range newTags {
                ec2tags = append(ec2tags, &ec2.Tag{
@@ -201,7 +167,6 @@ func (instanceSet *ec2InstanceSet) Create(
                InstanceType: &instanceType.ProviderType,
                MaxCount:     aws.Int64(1),
                MinCount:     aws.Int64(1),
-               KeyName:      &keyname,
 
                NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
                        {
@@ -221,6 +186,14 @@ func (instanceSet *ec2InstanceSet) Create(
                UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
        }
 
+       if publicKey != nil {
+               keyname, err := instanceSet.getKeyName(publicKey)
+               if err != nil {
+                       return nil, err
+               }
+               rii.KeyName = &keyname
+       }
+
        if instanceType.AddedScratch > 0 {
                rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
                        DeviceName: aws.String("/dev/xvdt"),
@@ -257,6 +230,40 @@ func (instanceSet *ec2InstanceSet) Create(
        }, nil
 }
 
+func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) {
+       instanceSet.keysMtx.Lock()
+       defer instanceSet.keysMtx.Unlock()
+       md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
+       if err != nil {
+               return "", fmt.Errorf("Could not make key fingerprint: %v", err)
+       }
+       if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok {
+               return keyname, nil
+       }
+       keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
+               Filters: []*ec2.Filter{{
+                       Name:   aws.String("fingerprint"),
+                       Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
+               }},
+       })
+       if err != nil {
+               return "", fmt.Errorf("Could not search for keypair: %v", err)
+       }
+       if len(keyout.KeyPairs) > 0 {
+               return *(keyout.KeyPairs[0].KeyName), nil
+       }
+       keyname := "arvados-dispatch-keypair-" + md5keyFingerprint
+       _, err = instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
+               KeyName:           &keyname,
+               PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
+       })
+       if err != nil {
+               return "", fmt.Errorf("Could not import keypair: %v", err)
+       }
+       instanceSet.keys[md5keyFingerprint] = keyname
+       return keyname, nil
+}
+
 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
        var filters []*ec2.Filter
        for k, v := range tags {