Merge branch '14844-cdc-azure-fixes' closes #14844
[arvados.git] / lib / cloud / azure / azure.go
index d745e7e54d27473147e3f214a213a4514a596131..b88962714e709765f1c93e0a6a88dbcf860aabaa 100644 (file)
@@ -47,6 +47,14 @@ type azureInstanceSetConfig struct {
        StorageAccount               string
        BlobContainer                string
        DeleteDanglingResourcesAfter arvados.Duration
+       AdminUsername                string
+}
+
+const tagKeyInstanceSecret = "InstanceSecret"
+
+type containerWrapper interface {
+       GetBlobReference(name string) *storage.Blob
+       ListBlobs(params storage.ListBlobsParameters) (storage.BlobListResponse, error)
 }
 
 type virtualMachinesClientWrapper interface {
@@ -189,20 +197,20 @@ func wrapAzureError(err error) error {
 }
 
 type azureInstanceSet struct {
-       azconfig          azureInstanceSetConfig
-       vmClient          virtualMachinesClientWrapper
-       netClient         interfacesClientWrapper
-       storageAcctClient storageacct.AccountsClient
-       azureEnv          azure.Environment
-       interfaces        map[string]network.Interface
-       dispatcherID      string
-       namePrefix        string
-       ctx               context.Context
-       stopFunc          context.CancelFunc
-       stopWg            sync.WaitGroup
-       deleteNIC         chan string
-       deleteBlob        chan storage.Blob
-       logger            logrus.FieldLogger
+       azconfig     azureInstanceSetConfig
+       vmClient     virtualMachinesClientWrapper
+       netClient    interfacesClientWrapper
+       blobcont     containerWrapper
+       azureEnv     azure.Environment
+       interfaces   map[string]network.Interface
+       dispatcherID string
+       namePrefix   string
+       ctx          context.Context
+       stopFunc     context.CancelFunc
+       stopWg       sync.WaitGroup
+       deleteNIC    chan string
+       deleteBlob   chan storage.Blob
+       logger       logrus.FieldLogger
 }
 
 func newAzureInstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
@@ -248,7 +256,22 @@ func (az *azureInstanceSet) setup(azcfg azureInstanceSetConfig, dispatcherID str
 
        az.vmClient = &virtualMachinesClientImpl{vmClient}
        az.netClient = &interfacesClientImpl{netClient}
-       az.storageAcctClient = storageAcctClient
+
+       result, err := storageAcctClient.ListKeys(az.ctx, az.azconfig.ResourceGroup, az.azconfig.StorageAccount)
+       if err != nil {
+               az.logger.WithError(err).Warn("Couldn't get account keys")
+               return err
+       }
+
+       key1 := *(*result.Keys)[0].Value
+       client, err := storage.NewBasicClientOnSovereignCloud(az.azconfig.StorageAccount, key1, az.azureEnv)
+       if err != nil {
+               az.logger.WithError(err).Warn("Couldn't make client")
+               return err
+       }
+
+       blobsvc := client.GetBlobService()
+       az.blobcont = blobsvc.GetContainerReference(az.azconfig.BlobContainer)
 
        az.dispatcherID = dispatcherID
        az.namePrefix = fmt.Sprintf("compute-%s-", az.dispatcherID)
@@ -311,15 +334,12 @@ func (az *azureInstanceSet) Create(
        instanceType arvados.InstanceType,
        imageID cloud.ImageID,
        newTags cloud.InstanceTags,
+       initCommand cloud.InitCommand,
        publicKey ssh.PublicKey) (cloud.Instance, error) {
 
        az.stopWg.Add(1)
        defer az.stopWg.Done()
 
-       if len(newTags["node-token"]) == 0 {
-               return nil, fmt.Errorf("Must provide tag 'node-token'")
-       }
-
        name, err := randutil.String(15, "abcdefghijklmnopqrstuvwxyz0123456789")
        if err != nil {
                return nil, err
@@ -336,8 +356,6 @@ func (az *azureInstanceSet) Create(
                tags["dispatch-"+k] = &newstr
        }
 
-       tags["dispatch-instance-type"] = &instanceType.Name
-
        nicParameters := network.Interface{
                Location: &az.azconfig.Location,
                Tags:     tags,
@@ -365,14 +383,14 @@ func (az *azureInstanceSet) Create(
                return nil, wrapAzureError(err)
        }
 
-       instanceVhd := fmt.Sprintf("https://%s.blob.%s/%s/%s-os.vhd",
+       blobname := fmt.Sprintf("%s-os.vhd", name)
+       instanceVhd := fmt.Sprintf("https://%s.blob.%s/%s/%s",
                az.azconfig.StorageAccount,
                az.azureEnv.StorageEndpointSuffix,
                az.azconfig.BlobContainer,
-               name)
+               blobname)
 
-       customData := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(`#!/bin/sh
-echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"])))
+       customData := base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))
 
        vmParameters := compute.VirtualMachine{
                Location: &az.azconfig.Location,
@@ -406,13 +424,13 @@ echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"])))
                        },
                        OsProfile: &compute.OSProfile{
                                ComputerName:  &name,
-                               AdminUsername: to.StringPtr("crunch"),
+                               AdminUsername: to.StringPtr(az.azconfig.AdminUsername),
                                LinuxConfiguration: &compute.LinuxConfiguration{
                                        DisablePasswordAuthentication: to.BoolPtr(true),
                                        SSH: &compute.SSHConfiguration{
                                                PublicKeys: &[]compute.SSHPublicKey{
-                                                       compute.SSHPublicKey{
-                                                               Path:    to.StringPtr("/home/crunch/.ssh/authorized_keys"),
+                                                       {
+                                                               Path:    to.StringPtr("/home/" + az.azconfig.AdminUsername + "/.ssh/authorized_keys"),
                                                                KeyData: to.StringPtr(string(ssh.MarshalAuthorizedKey(publicKey))),
                                                        },
                                                },
@@ -425,6 +443,16 @@ echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"])))
 
        vm, err := az.vmClient.createOrUpdate(az.ctx, az.azconfig.ResourceGroup, name, vmParameters)
        if err != nil {
+               _, delerr := az.blobcont.GetBlobReference(blobname).DeleteIfExists(nil)
+               if delerr != nil {
+                       az.logger.WithError(delerr).Warnf("Error cleaning up vhd blob after failed create")
+               }
+
+               _, delerr = az.netClient.delete(context.Background(), az.azconfig.ResourceGroup, *nic.Name)
+               if delerr != nil {
+                       az.logger.WithError(delerr).Warnf("Error cleaning up NIC after failed create")
+               }
+
                return nil, wrapAzureError(err)
        }
 
@@ -494,8 +522,8 @@ func (az *azureInstanceSet) manageNics() (map[string]network.Interface, error) {
                                if result.Value().Tags["created-at"] != nil {
                                        createdAt, err := time.Parse(time.RFC3339Nano, *result.Value().Tags["created-at"])
                                        if err == nil {
-                                               if timestamp.Sub(createdAt).Seconds() > az.azconfig.DeleteDanglingResourcesAfter.Duration().Seconds() {
-                                                       az.logger.Printf("Will delete %v because it is older than %s", *result.Value().Name, az.azconfig.DeleteDanglingResourcesAfter)
+                                               if timestamp.Sub(createdAt) > az.azconfig.DeleteDanglingResourcesAfter.Duration() {
+                                                       az.logger.Printf("Will delete %v because it is older than %s", *result.Value().Name, az.azconfig.DeleteDanglingResourcesAfter)
                                                        az.deleteNIC <- *result.Value().Name
                                                }
                                        }
@@ -512,27 +540,12 @@ func (az *azureInstanceSet) manageNics() (map[string]network.Interface, error) {
 // leased to a VM) and haven't been modified for
 // DeleteDanglingResourcesAfter seconds.
 func (az *azureInstanceSet) manageBlobs() {
-       result, err := az.storageAcctClient.ListKeys(az.ctx, az.azconfig.ResourceGroup, az.azconfig.StorageAccount)
-       if err != nil {
-               az.logger.WithError(err).Warn("Couldn't get account keys")
-               return
-       }
-
-       key1 := *(*result.Keys)[0].Value
-       client, err := storage.NewBasicClientOnSovereignCloud(az.azconfig.StorageAccount, key1, az.azureEnv)
-       if err != nil {
-               az.logger.WithError(err).Warn("Couldn't make client")
-               return
-       }
-
-       blobsvc := client.GetBlobService()
-       blobcont := blobsvc.GetContainerReference(az.azconfig.BlobContainer)
 
        page := storage.ListBlobsParameters{Prefix: az.namePrefix}
        timestamp := time.Now()
 
        for {
-               response, err := blobcont.ListBlobs(page)
+               response, err := az.blobcont.ListBlobs(page)
                if err != nil {
                        az.logger.WithError(err).Warn("Error listing blobs")
                        return
@@ -631,66 +644,19 @@ func (ai *azureInstance) Destroy() error {
 }
 
 func (ai *azureInstance) Address() string {
-       return *(*ai.nic.IPConfigurations)[0].PrivateIPAddress
-}
-
-func (ai *azureInstance) VerifyHostKey(receivedKey ssh.PublicKey, client *ssh.Client) error {
-       ai.provider.stopWg.Add(1)
-       defer ai.provider.stopWg.Done()
-
-       remoteFingerprint := ssh.FingerprintSHA256(receivedKey)
-
-       tags := ai.Tags()
-
-       tg := tags["ssh-pubkey-fingerprint"]
-       if tg != "" {
-               if remoteFingerprint == tg {
-                       return nil
-               }
-               return fmt.Errorf("Key fingerprint did not match, expected %q got %q", tg, remoteFingerprint)
-       }
-
-       nodetokenTag := tags["node-token"]
-       if nodetokenTag == "" {
-               return fmt.Errorf("Missing node token tag")
-       }
-
-       sess, err := client.NewSession()
-       if err != nil {
-               return err
-       }
-
-       nodetokenbytes, err := sess.Output("cat /home/crunch/node-token")
-       if err != nil {
-               return err
-       }
-
-       nodetoken := strings.TrimSpace(string(nodetokenbytes))
+       if ai.nic.IPConfigurations != nil &&
+               len(*ai.nic.IPConfigurations) > 0 &&
+               (*ai.nic.IPConfigurations)[0].PrivateIPAddress != nil {
 
-       expectedToken := fmt.Sprintf("%s-%s", *ai.vm.Name, nodetokenTag)
-
-       if strings.TrimSpace(nodetoken) != expectedToken {
-               return fmt.Errorf("Node token did not match, expected %q got %q", expectedToken, nodetoken)
+               return *(*ai.nic.IPConfigurations)[0].PrivateIPAddress
        }
+       return ""
+}
 
-       sess, err = client.NewSession()
-       if err != nil {
-               return err
-       }
-
-       keyfingerprintbytes, err := sess.Output("ssh-keygen -E sha256 -l -f /etc/ssh/ssh_host_rsa_key.pub")
-       if err != nil {
-               return err
-       }
-
-       sp := strings.Split(string(keyfingerprintbytes), " ")
-
-       if remoteFingerprint != sp[1] {
-               return fmt.Errorf("Key fingerprint did not match, expected %q got %q", sp[1], remoteFingerprint)
-       }
+func (ai *azureInstance) RemoteUser() string {
+       return ai.provider.azconfig.AdminUsername
+}
 
-       tags["ssh-pubkey-fingerprint"] = sp[1]
-       delete(tags, "node-token")
-       ai.SetTags(tags)
-       return nil
+func (ai *azureInstance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
+       return cloud.ErrNotImplemented
 }