X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/43ad590772de48fbc3a6a45654445bab79a0bdc1..0b471c74f2cc392a37aa4f8df8ed931bb5969236:/lib/cloud/azure/azure.go diff --git a/lib/cloud/azure/azure.go b/lib/cloud/azure/azure.go index d745e7e54d..ab14d6681e 100644 --- a/lib/cloud/azure/azure.go +++ b/lib/cloud/azure/azure.go @@ -47,6 +47,12 @@ type azureInstanceSetConfig struct { StorageAccount string BlobContainer string DeleteDanglingResourcesAfter arvados.Duration + AdminUsername string +} + +type containerWrapper interface { + GetBlobReference(name string) *storage.Blob + ListBlobs(params storage.ListBlobsParameters) (storage.BlobListResponse, error) } type virtualMachinesClientWrapper interface { @@ -189,35 +195,37 @@ func wrapAzureError(err error) error { } type azureInstanceSet struct { - azconfig azureInstanceSetConfig - vmClient virtualMachinesClientWrapper - netClient interfacesClientWrapper - storageAcctClient storageacct.AccountsClient - azureEnv azure.Environment - interfaces map[string]network.Interface - dispatcherID string - namePrefix string - ctx context.Context - stopFunc context.CancelFunc - stopWg sync.WaitGroup - deleteNIC chan string - deleteBlob chan storage.Blob - logger logrus.FieldLogger -} - -func newAzureInstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) { + azconfig azureInstanceSetConfig + vmClient virtualMachinesClientWrapper + netClient interfacesClientWrapper + blobcont containerWrapper + azureEnv azure.Environment + interfaces map[string]network.Interface + dispatcherID string + namePrefix string + ctx context.Context + stopFunc context.CancelFunc + stopWg sync.WaitGroup + deleteNIC chan string + deleteBlob chan storage.Blob + logger logrus.FieldLogger +} + +func newAzureInstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) { azcfg := azureInstanceSetConfig{} err = json.Unmarshal(config, &azcfg) if err != nil { return nil, err } - ap := azureInstanceSet{logger: logger} - err = ap.setup(azcfg, string(dispatcherID)) + az := azureInstanceSet{logger: logger} + az.ctx, az.stopFunc = context.WithCancel(context.Background()) + err = az.setup(azcfg, string(dispatcherID)) if err != nil { + az.stopFunc() return nil, err } - return &ap, nil + return &az, nil } func (az *azureInstanceSet) setup(azcfg azureInstanceSetConfig, dispatcherID string) (err error) { @@ -248,12 +256,26 @@ func (az *azureInstanceSet) setup(azcfg azureInstanceSetConfig, dispatcherID str az.vmClient = &virtualMachinesClientImpl{vmClient} az.netClient = &interfacesClientImpl{netClient} - az.storageAcctClient = storageAcctClient + + result, err := storageAcctClient.ListKeys(az.ctx, az.azconfig.ResourceGroup, az.azconfig.StorageAccount) + if err != nil { + az.logger.WithError(err).Warn("Couldn't get account keys") + return err + } + + key1 := *(*result.Keys)[0].Value + client, err := storage.NewBasicClientOnSovereignCloud(az.azconfig.StorageAccount, key1, az.azureEnv) + if err != nil { + az.logger.WithError(err).Warn("Couldn't make client") + return err + } + + blobsvc := client.GetBlobService() + az.blobcont = blobsvc.GetContainerReference(az.azconfig.BlobContainer) az.dispatcherID = dispatcherID az.namePrefix = fmt.Sprintf("compute-%s-", az.dispatcherID) - az.ctx, az.stopFunc = context.WithCancel(context.Background()) go func() { az.stopWg.Add(1) defer az.stopWg.Done() @@ -311,13 +333,14 @@ func (az *azureInstanceSet) Create( instanceType arvados.InstanceType, imageID cloud.ImageID, newTags cloud.InstanceTags, + initCommand cloud.InitCommand, publicKey ssh.PublicKey) (cloud.Instance, error) { az.stopWg.Add(1) defer az.stopWg.Done() - if len(newTags["node-token"]) == 0 { - return nil, fmt.Errorf("Must provide tag 'node-token'") + if instanceType.AddedScratch > 0 { + return nil, fmt.Errorf("cannot create instance type %q: driver does not implement non-zero AddedScratch (%d)", instanceType.Name, instanceType.AddedScratch) } name, err := randutil.String(15, "abcdefghijklmnopqrstuvwxyz0123456789") @@ -327,16 +350,11 @@ func (az *azureInstanceSet) Create( name = az.namePrefix + name - timestamp := time.Now().Format(time.RFC3339Nano) - - tags := make(map[string]*string) - tags["created-at"] = ×tamp + tags := map[string]*string{} for k, v := range newTags { - newstr := v - tags["dispatch-"+k] = &newstr + tags[k] = to.StringPtr(v) } - - tags["dispatch-instance-type"] = &instanceType.Name + tags["created-at"] = to.StringPtr(time.Now().Format(time.RFC3339Nano)) nicParameters := network.Interface{ Location: &az.azconfig.Location, @@ -365,14 +383,14 @@ func (az *azureInstanceSet) Create( return nil, wrapAzureError(err) } - instanceVhd := fmt.Sprintf("https://%s.blob.%s/%s/%s-os.vhd", + blobname := fmt.Sprintf("%s-os.vhd", name) + instanceVhd := fmt.Sprintf("https://%s.blob.%s/%s/%s", az.azconfig.StorageAccount, az.azureEnv.StorageEndpointSuffix, az.azconfig.BlobContainer, - name) + blobname) - customData := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(`#!/bin/sh -echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"]))) + customData := base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n")) vmParameters := compute.VirtualMachine{ Location: &az.azconfig.Location, @@ -406,13 +424,13 @@ echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"]))) }, OsProfile: &compute.OSProfile{ ComputerName: &name, - AdminUsername: to.StringPtr("crunch"), + AdminUsername: to.StringPtr(az.azconfig.AdminUsername), LinuxConfiguration: &compute.LinuxConfiguration{ DisablePasswordAuthentication: to.BoolPtr(true), SSH: &compute.SSHConfiguration{ PublicKeys: &[]compute.SSHPublicKey{ - compute.SSHPublicKey{ - Path: to.StringPtr("/home/crunch/.ssh/authorized_keys"), + { + Path: to.StringPtr("/home/" + az.azconfig.AdminUsername + "/.ssh/authorized_keys"), KeyData: to.StringPtr(string(ssh.MarshalAuthorizedKey(publicKey))), }, }, @@ -425,6 +443,16 @@ echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"]))) vm, err := az.vmClient.createOrUpdate(az.ctx, az.azconfig.ResourceGroup, name, vmParameters) if err != nil { + _, delerr := az.blobcont.GetBlobReference(blobname).DeleteIfExists(nil) + if delerr != nil { + az.logger.WithError(delerr).Warnf("Error cleaning up vhd blob after failed create") + } + + _, delerr = az.netClient.delete(context.Background(), az.azconfig.ResourceGroup, *nic.Name) + if delerr != nil { + az.logger.WithError(delerr).Warnf("Error cleaning up NIC after failed create") + } + return nil, wrapAzureError(err) } @@ -449,26 +477,24 @@ func (az *azureInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, err return nil, wrapAzureError(err) } - instances := make([]cloud.Instance, 0) - + var instances []cloud.Instance for ; result.NotDone(); err = result.Next() { if err != nil { return nil, wrapAzureError(err) } - if strings.HasPrefix(*result.Value().Name, az.namePrefix) { - instances = append(instances, &azureInstance{ - provider: az, - vm: result.Value(), - nic: interfaces[*(*result.Value().NetworkProfile.NetworkInterfaces)[0].ID]}) - } + instances = append(instances, &azureInstance{ + provider: az, + vm: result.Value(), + nic: interfaces[*(*result.Value().NetworkProfile.NetworkInterfaces)[0].ID], + }) } return instances, nil } // ManageNics returns a list of Azure network interface resources. -// Also performs garbage collection of NICs which have "namePrefix", are -// not associated with a virtual machine and have a "create-at" time -// more than DeleteDanglingResourcesAfter (to prevent racing and +// Also performs garbage collection of NICs which have "namePrefix", +// are not associated with a virtual machine and have a "created-at" +// time more than DeleteDanglingResourcesAfter (to prevent racing and // deleting newly created NICs) in the past are deleted. func (az *azureInstanceSet) manageNics() (map[string]network.Interface, error) { az.stopWg.Add(1) @@ -494,8 +520,8 @@ func (az *azureInstanceSet) manageNics() (map[string]network.Interface, error) { if result.Value().Tags["created-at"] != nil { createdAt, err := time.Parse(time.RFC3339Nano, *result.Value().Tags["created-at"]) if err == nil { - if timestamp.Sub(createdAt).Seconds() > az.azconfig.DeleteDanglingResourcesAfter.Duration().Seconds() { - az.logger.Printf("Will delete %v because it is older than %v s", *result.Value().Name, az.azconfig.DeleteDanglingResourcesAfter) + if timestamp.Sub(createdAt) > az.azconfig.DeleteDanglingResourcesAfter.Duration() { + az.logger.Printf("Will delete %v because it is older than %s", *result.Value().Name, az.azconfig.DeleteDanglingResourcesAfter) az.deleteNIC <- *result.Value().Name } } @@ -512,27 +538,12 @@ func (az *azureInstanceSet) manageNics() (map[string]network.Interface, error) { // leased to a VM) and haven't been modified for // DeleteDanglingResourcesAfter seconds. func (az *azureInstanceSet) manageBlobs() { - result, err := az.storageAcctClient.ListKeys(az.ctx, az.azconfig.ResourceGroup, az.azconfig.StorageAccount) - if err != nil { - az.logger.WithError(err).Warn("Couldn't get account keys") - return - } - - key1 := *(*result.Keys)[0].Value - client, err := storage.NewBasicClientOnSovereignCloud(az.azconfig.StorageAccount, key1, az.azureEnv) - if err != nil { - az.logger.WithError(err).Warn("Couldn't make client") - return - } - - blobsvc := client.GetBlobService() - blobcont := blobsvc.GetContainerReference(az.azconfig.BlobContainer) page := storage.ListBlobsParameters{Prefix: az.namePrefix} timestamp := time.Now() for { - response, err := blobcont.ListBlobs(page) + response, err := az.blobcont.ListBlobs(page) if err != nil { az.logger.WithError(err).Warn("Error listing blobs") return @@ -585,16 +596,12 @@ func (ai *azureInstance) SetTags(newTags cloud.InstanceTags) error { ai.provider.stopWg.Add(1) defer ai.provider.stopWg.Done() - tags := make(map[string]*string) - + tags := map[string]*string{} for k, v := range ai.vm.Tags { - if !strings.HasPrefix(k, "dispatch-") { - tags[k] = v - } + tags[k] = v } for k, v := range newTags { - newstr := v - tags["dispatch-"+k] = &newstr + tags[k] = to.StringPtr(v) } vmParameters := compute.VirtualMachine{ @@ -611,14 +618,10 @@ func (ai *azureInstance) SetTags(newTags cloud.InstanceTags) error { } func (ai *azureInstance) Tags() cloud.InstanceTags { - tags := make(map[string]string) - + tags := cloud.InstanceTags{} for k, v := range ai.vm.Tags { - if strings.HasPrefix(k, "dispatch-") { - tags[k[9:]] = *v - } + tags[k] = *v } - return tags } @@ -631,66 +634,23 @@ func (ai *azureInstance) Destroy() error { } func (ai *azureInstance) Address() string { - return *(*ai.nic.IPConfigurations)[0].PrivateIPAddress -} - -func (ai *azureInstance) VerifyHostKey(receivedKey ssh.PublicKey, client *ssh.Client) error { - ai.provider.stopWg.Add(1) - defer ai.provider.stopWg.Done() - - remoteFingerprint := ssh.FingerprintSHA256(receivedKey) - - tags := ai.Tags() - - tg := tags["ssh-pubkey-fingerprint"] - if tg != "" { - if remoteFingerprint == tg { - return nil - } - return fmt.Errorf("Key fingerprint did not match, expected %q got %q", tg, remoteFingerprint) - } - - nodetokenTag := tags["node-token"] - if nodetokenTag == "" { - return fmt.Errorf("Missing node token tag") - } - - sess, err := client.NewSession() - if err != nil { - return err + if iprops := ai.nic.InterfacePropertiesFormat; iprops == nil { + return "" + } else if ipconfs := iprops.IPConfigurations; ipconfs == nil || len(*ipconfs) == 0 { + return "" + } else if ipconfprops := (*ipconfs)[0].InterfaceIPConfigurationPropertiesFormat; ipconfprops == nil { + return "" + } else if addr := ipconfprops.PrivateIPAddress; addr == nil { + return "" + } else { + return *addr } +} - nodetokenbytes, err := sess.Output("cat /home/crunch/node-token") - if err != nil { - return err - } - - nodetoken := strings.TrimSpace(string(nodetokenbytes)) - - expectedToken := fmt.Sprintf("%s-%s", *ai.vm.Name, nodetokenTag) - - if strings.TrimSpace(nodetoken) != expectedToken { - return fmt.Errorf("Node token did not match, expected %q got %q", expectedToken, nodetoken) - } - - sess, err = client.NewSession() - if err != nil { - return err - } - - keyfingerprintbytes, err := sess.Output("ssh-keygen -E sha256 -l -f /etc/ssh/ssh_host_rsa_key.pub") - if err != nil { - return err - } - - sp := strings.Split(string(keyfingerprintbytes), " ") - - if remoteFingerprint != sp[1] { - return fmt.Errorf("Key fingerprint did not match, expected %q got %q", sp[1], remoteFingerprint) - } +func (ai *azureInstance) RemoteUser() string { + return ai.provider.azconfig.AdminUsername +} - tags["ssh-pubkey-fingerprint"] = sp[1] - delete(tags, "node-token") - ai.SetTags(tags) - return nil +func (ai *azureInstance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error { + return cloud.ErrNotImplemented }