14807: Allow driver to specify SSH username.
[arvados.git] / lib / cloud / azure / azure.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package azure
6
7 import (
8         "context"
9         "encoding/base64"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "regexp"
14         "strconv"
15         "strings"
16         "sync"
17         "time"
18
19         "git.curoverse.com/arvados.git/lib/cloud"
20         "git.curoverse.com/arvados.git/sdk/go/arvados"
21         "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2018-06-01/compute"
22         "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2018-06-01/network"
23         storageacct "github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2018-02-01/storage"
24         "github.com/Azure/azure-sdk-for-go/storage"
25         "github.com/Azure/go-autorest/autorest"
26         "github.com/Azure/go-autorest/autorest/azure"
27         "github.com/Azure/go-autorest/autorest/azure/auth"
28         "github.com/Azure/go-autorest/autorest/to"
29         "github.com/jmcvetta/randutil"
30         "github.com/sirupsen/logrus"
31         "golang.org/x/crypto/ssh"
32 )
33
34 // Driver is the azure implementation of the cloud.Driver interface.
35 var Driver = cloud.DriverFunc(newAzureInstanceSet)
36
37 type azureInstanceSetConfig struct {
38         SubscriptionID               string
39         ClientID                     string
40         ClientSecret                 string
41         TenantID                     string
42         CloudEnvironment             string
43         ResourceGroup                string
44         Location                     string
45         Network                      string
46         Subnet                       string
47         StorageAccount               string
48         BlobContainer                string
49         DeleteDanglingResourcesAfter arvados.Duration
50         AdminUsername                string
51 }
52
53 type virtualMachinesClientWrapper interface {
54         createOrUpdate(ctx context.Context,
55                 resourceGroupName string,
56                 VMName string,
57                 parameters compute.VirtualMachine) (result compute.VirtualMachine, err error)
58         delete(ctx context.Context, resourceGroupName string, VMName string) (result *http.Response, err error)
59         listComplete(ctx context.Context, resourceGroupName string) (result compute.VirtualMachineListResultIterator, err error)
60 }
61
62 type virtualMachinesClientImpl struct {
63         inner compute.VirtualMachinesClient
64 }
65
66 func (cl *virtualMachinesClientImpl) createOrUpdate(ctx context.Context,
67         resourceGroupName string,
68         VMName string,
69         parameters compute.VirtualMachine) (result compute.VirtualMachine, err error) {
70
71         future, err := cl.inner.CreateOrUpdate(ctx, resourceGroupName, VMName, parameters)
72         if err != nil {
73                 return compute.VirtualMachine{}, wrapAzureError(err)
74         }
75         future.WaitForCompletionRef(ctx, cl.inner.Client)
76         r, err := future.Result(cl.inner)
77         return r, wrapAzureError(err)
78 }
79
80 func (cl *virtualMachinesClientImpl) delete(ctx context.Context, resourceGroupName string, VMName string) (result *http.Response, err error) {
81         future, err := cl.inner.Delete(ctx, resourceGroupName, VMName)
82         if err != nil {
83                 return nil, wrapAzureError(err)
84         }
85         err = future.WaitForCompletionRef(ctx, cl.inner.Client)
86         return future.Response(), wrapAzureError(err)
87 }
88
89 func (cl *virtualMachinesClientImpl) listComplete(ctx context.Context, resourceGroupName string) (result compute.VirtualMachineListResultIterator, err error) {
90         r, err := cl.inner.ListComplete(ctx, resourceGroupName)
91         return r, wrapAzureError(err)
92 }
93
94 type interfacesClientWrapper interface {
95         createOrUpdate(ctx context.Context,
96                 resourceGroupName string,
97                 networkInterfaceName string,
98                 parameters network.Interface) (result network.Interface, err error)
99         delete(ctx context.Context, resourceGroupName string, networkInterfaceName string) (result *http.Response, err error)
100         listComplete(ctx context.Context, resourceGroupName string) (result network.InterfaceListResultIterator, err error)
101 }
102
103 type interfacesClientImpl struct {
104         inner network.InterfacesClient
105 }
106
107 func (cl *interfacesClientImpl) delete(ctx context.Context, resourceGroupName string, VMName string) (result *http.Response, err error) {
108         future, err := cl.inner.Delete(ctx, resourceGroupName, VMName)
109         if err != nil {
110                 return nil, wrapAzureError(err)
111         }
112         err = future.WaitForCompletionRef(ctx, cl.inner.Client)
113         return future.Response(), wrapAzureError(err)
114 }
115
116 func (cl *interfacesClientImpl) createOrUpdate(ctx context.Context,
117         resourceGroupName string,
118         networkInterfaceName string,
119         parameters network.Interface) (result network.Interface, err error) {
120
121         future, err := cl.inner.CreateOrUpdate(ctx, resourceGroupName, networkInterfaceName, parameters)
122         if err != nil {
123                 return network.Interface{}, wrapAzureError(err)
124         }
125         future.WaitForCompletionRef(ctx, cl.inner.Client)
126         r, err := future.Result(cl.inner)
127         return r, wrapAzureError(err)
128 }
129
130 func (cl *interfacesClientImpl) listComplete(ctx context.Context, resourceGroupName string) (result network.InterfaceListResultIterator, err error) {
131         r, err := cl.inner.ListComplete(ctx, resourceGroupName)
132         return r, wrapAzureError(err)
133 }
134
135 var quotaRe = regexp.MustCompile(`(?i:exceed|quota|limit)`)
136
137 type azureRateLimitError struct {
138         azure.RequestError
139         firstRetry time.Time
140 }
141
142 func (ar *azureRateLimitError) EarliestRetry() time.Time {
143         return ar.firstRetry
144 }
145
146 type azureQuotaError struct {
147         azure.RequestError
148 }
149
150 func (ar *azureQuotaError) IsQuotaError() bool {
151         return true
152 }
153
154 func wrapAzureError(err error) error {
155         de, ok := err.(autorest.DetailedError)
156         if !ok {
157                 return err
158         }
159         rq, ok := de.Original.(*azure.RequestError)
160         if !ok {
161                 return err
162         }
163         if rq.Response == nil {
164                 return err
165         }
166         if rq.Response.StatusCode == 429 || len(rq.Response.Header["Retry-After"]) >= 1 {
167                 // API throttling
168                 ra := rq.Response.Header["Retry-After"][0]
169                 earliestRetry, parseErr := http.ParseTime(ra)
170                 if parseErr != nil {
171                         // Could not parse as a timestamp, must be number of seconds
172                         dur, parseErr := strconv.ParseInt(ra, 10, 64)
173                         if parseErr == nil {
174                                 earliestRetry = time.Now().Add(time.Duration(dur) * time.Second)
175                         } else {
176                                 // Couldn't make sense of retry-after,
177                                 // so set retry to 20 seconds
178                                 earliestRetry = time.Now().Add(20 * time.Second)
179                         }
180                 }
181                 return &azureRateLimitError{*rq, earliestRetry}
182         }
183         if rq.ServiceError == nil {
184                 return err
185         }
186         if quotaRe.FindString(rq.ServiceError.Code) != "" || quotaRe.FindString(rq.ServiceError.Message) != "" {
187                 return &azureQuotaError{*rq}
188         }
189         return err
190 }
191
192 type azureInstanceSet struct {
193         azconfig          azureInstanceSetConfig
194         vmClient          virtualMachinesClientWrapper
195         netClient         interfacesClientWrapper
196         storageAcctClient storageacct.AccountsClient
197         azureEnv          azure.Environment
198         interfaces        map[string]network.Interface
199         dispatcherID      string
200         namePrefix        string
201         ctx               context.Context
202         stopFunc          context.CancelFunc
203         stopWg            sync.WaitGroup
204         deleteNIC         chan string
205         deleteBlob        chan storage.Blob
206         logger            logrus.FieldLogger
207 }
208
209 func newAzureInstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
210         azcfg := azureInstanceSetConfig{}
211         err = json.Unmarshal(config, &azcfg)
212         if err != nil {
213                 return nil, err
214         }
215
216         ap := azureInstanceSet{logger: logger}
217         err = ap.setup(azcfg, string(dispatcherID))
218         if err != nil {
219                 return nil, err
220         }
221         return &ap, nil
222 }
223
224 func (az *azureInstanceSet) setup(azcfg azureInstanceSetConfig, dispatcherID string) (err error) {
225         az.azconfig = azcfg
226         vmClient := compute.NewVirtualMachinesClient(az.azconfig.SubscriptionID)
227         netClient := network.NewInterfacesClient(az.azconfig.SubscriptionID)
228         storageAcctClient := storageacct.NewAccountsClient(az.azconfig.SubscriptionID)
229
230         az.azureEnv, err = azure.EnvironmentFromName(az.azconfig.CloudEnvironment)
231         if err != nil {
232                 return err
233         }
234
235         authorizer, err := auth.ClientCredentialsConfig{
236                 ClientID:     az.azconfig.ClientID,
237                 ClientSecret: az.azconfig.ClientSecret,
238                 TenantID:     az.azconfig.TenantID,
239                 Resource:     az.azureEnv.ResourceManagerEndpoint,
240                 AADEndpoint:  az.azureEnv.ActiveDirectoryEndpoint,
241         }.Authorizer()
242         if err != nil {
243                 return err
244         }
245
246         vmClient.Authorizer = authorizer
247         netClient.Authorizer = authorizer
248         storageAcctClient.Authorizer = authorizer
249
250         az.vmClient = &virtualMachinesClientImpl{vmClient}
251         az.netClient = &interfacesClientImpl{netClient}
252         az.storageAcctClient = storageAcctClient
253
254         az.dispatcherID = dispatcherID
255         az.namePrefix = fmt.Sprintf("compute-%s-", az.dispatcherID)
256
257         az.ctx, az.stopFunc = context.WithCancel(context.Background())
258         go func() {
259                 az.stopWg.Add(1)
260                 defer az.stopWg.Done()
261
262                 tk := time.NewTicker(5 * time.Minute)
263                 for {
264                         select {
265                         case <-az.ctx.Done():
266                                 tk.Stop()
267                                 return
268                         case <-tk.C:
269                                 az.manageBlobs()
270                         }
271                 }
272         }()
273
274         az.deleteNIC = make(chan string)
275         az.deleteBlob = make(chan storage.Blob)
276
277         for i := 0; i < 4; i++ {
278                 go func() {
279                         for {
280                                 nicname, ok := <-az.deleteNIC
281                                 if !ok {
282                                         return
283                                 }
284                                 _, delerr := az.netClient.delete(context.Background(), az.azconfig.ResourceGroup, nicname)
285                                 if delerr != nil {
286                                         az.logger.WithError(delerr).Warnf("Error deleting %v", nicname)
287                                 } else {
288                                         az.logger.Printf("Deleted NIC %v", nicname)
289                                 }
290                         }
291                 }()
292                 go func() {
293                         for {
294                                 blob, ok := <-az.deleteBlob
295                                 if !ok {
296                                         return
297                                 }
298                                 err := blob.Delete(nil)
299                                 if err != nil {
300                                         az.logger.WithError(err).Warnf("Error deleting %v", blob.Name)
301                                 } else {
302                                         az.logger.Printf("Deleted blob %v", blob.Name)
303                                 }
304                         }
305                 }()
306         }
307
308         return nil
309 }
310
311 func (az *azureInstanceSet) Create(
312         instanceType arvados.InstanceType,
313         imageID cloud.ImageID,
314         newTags cloud.InstanceTags,
315         publicKey ssh.PublicKey) (cloud.Instance, error) {
316
317         az.stopWg.Add(1)
318         defer az.stopWg.Done()
319
320         if len(newTags["node-token"]) == 0 {
321                 return nil, fmt.Errorf("Must provide tag 'node-token'")
322         }
323
324         name, err := randutil.String(15, "abcdefghijklmnopqrstuvwxyz0123456789")
325         if err != nil {
326                 return nil, err
327         }
328
329         name = az.namePrefix + name
330
331         timestamp := time.Now().Format(time.RFC3339Nano)
332
333         tags := make(map[string]*string)
334         tags["created-at"] = &timestamp
335         for k, v := range newTags {
336                 newstr := v
337                 tags["dispatch-"+k] = &newstr
338         }
339
340         tags["dispatch-instance-type"] = &instanceType.Name
341
342         nicParameters := network.Interface{
343                 Location: &az.azconfig.Location,
344                 Tags:     tags,
345                 InterfacePropertiesFormat: &network.InterfacePropertiesFormat{
346                         IPConfigurations: &[]network.InterfaceIPConfiguration{
347                                 network.InterfaceIPConfiguration{
348                                         Name: to.StringPtr("ip1"),
349                                         InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{
350                                                 Subnet: &network.Subnet{
351                                                         ID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers"+
352                                                                 "/Microsoft.Network/virtualnetworks/%s/subnets/%s",
353                                                                 az.azconfig.SubscriptionID,
354                                                                 az.azconfig.ResourceGroup,
355                                                                 az.azconfig.Network,
356                                                                 az.azconfig.Subnet)),
357                                                 },
358                                                 PrivateIPAllocationMethod: network.Dynamic,
359                                         },
360                                 },
361                         },
362                 },
363         }
364         nic, err := az.netClient.createOrUpdate(az.ctx, az.azconfig.ResourceGroup, name+"-nic", nicParameters)
365         if err != nil {
366                 return nil, wrapAzureError(err)
367         }
368
369         instanceVhd := fmt.Sprintf("https://%s.blob.%s/%s/%s-os.vhd",
370                 az.azconfig.StorageAccount,
371                 az.azureEnv.StorageEndpointSuffix,
372                 az.azconfig.BlobContainer,
373                 name)
374
375         customData := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(`#!/bin/sh
376 echo '%s-%s' > '/home/%s/node-token'`, name, newTags["node-token"], az.azconfig.AdminUsername)))
377
378         vmParameters := compute.VirtualMachine{
379                 Location: &az.azconfig.Location,
380                 Tags:     tags,
381                 VirtualMachineProperties: &compute.VirtualMachineProperties{
382                         HardwareProfile: &compute.HardwareProfile{
383                                 VMSize: compute.VirtualMachineSizeTypes(instanceType.ProviderType),
384                         },
385                         StorageProfile: &compute.StorageProfile{
386                                 OsDisk: &compute.OSDisk{
387                                         OsType:       compute.Linux,
388                                         Name:         to.StringPtr(name + "-os"),
389                                         CreateOption: compute.FromImage,
390                                         Image: &compute.VirtualHardDisk{
391                                                 URI: to.StringPtr(string(imageID)),
392                                         },
393                                         Vhd: &compute.VirtualHardDisk{
394                                                 URI: &instanceVhd,
395                                         },
396                                 },
397                         },
398                         NetworkProfile: &compute.NetworkProfile{
399                                 NetworkInterfaces: &[]compute.NetworkInterfaceReference{
400                                         compute.NetworkInterfaceReference{
401                                                 ID: nic.ID,
402                                                 NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{
403                                                         Primary: to.BoolPtr(true),
404                                                 },
405                                         },
406                                 },
407                         },
408                         OsProfile: &compute.OSProfile{
409                                 ComputerName:  &name,
410                                 AdminUsername: to.StringPtr(az.azconfig.AdminUsername),
411                                 LinuxConfiguration: &compute.LinuxConfiguration{
412                                         DisablePasswordAuthentication: to.BoolPtr(true),
413                                         SSH: &compute.SSHConfiguration{
414                                                 PublicKeys: &[]compute.SSHPublicKey{
415                                                         {
416                                                                 Path:    to.StringPtr("/home/" + az.azconfig.AdminUsername + "/.ssh/authorized_keys"),
417                                                                 KeyData: to.StringPtr(string(ssh.MarshalAuthorizedKey(publicKey))),
418                                                         },
419                                                 },
420                                         },
421                                 },
422                                 CustomData: &customData,
423                         },
424                 },
425         }
426
427         vm, err := az.vmClient.createOrUpdate(az.ctx, az.azconfig.ResourceGroup, name, vmParameters)
428         if err != nil {
429                 return nil, wrapAzureError(err)
430         }
431
432         return &azureInstance{
433                 provider: az,
434                 nic:      nic,
435                 vm:       vm,
436         }, nil
437 }
438
439 func (az *azureInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
440         az.stopWg.Add(1)
441         defer az.stopWg.Done()
442
443         interfaces, err := az.manageNics()
444         if err != nil {
445                 return nil, err
446         }
447
448         result, err := az.vmClient.listComplete(az.ctx, az.azconfig.ResourceGroup)
449         if err != nil {
450                 return nil, wrapAzureError(err)
451         }
452
453         instances := make([]cloud.Instance, 0)
454
455         for ; result.NotDone(); err = result.Next() {
456                 if err != nil {
457                         return nil, wrapAzureError(err)
458                 }
459                 if strings.HasPrefix(*result.Value().Name, az.namePrefix) {
460                         instances = append(instances, &azureInstance{
461                                 provider: az,
462                                 vm:       result.Value(),
463                                 nic:      interfaces[*(*result.Value().NetworkProfile.NetworkInterfaces)[0].ID]})
464                 }
465         }
466         return instances, nil
467 }
468
469 // ManageNics returns a list of Azure network interface resources.
470 // Also performs garbage collection of NICs which have "namePrefix", are
471 // not associated with a virtual machine and have a "create-at" time
472 // more than DeleteDanglingResourcesAfter (to prevent racing and
473 // deleting newly created NICs) in the past are deleted.
474 func (az *azureInstanceSet) manageNics() (map[string]network.Interface, error) {
475         az.stopWg.Add(1)
476         defer az.stopWg.Done()
477
478         result, err := az.netClient.listComplete(az.ctx, az.azconfig.ResourceGroup)
479         if err != nil {
480                 return nil, wrapAzureError(err)
481         }
482
483         interfaces := make(map[string]network.Interface)
484
485         timestamp := time.Now()
486         for ; result.NotDone(); err = result.Next() {
487                 if err != nil {
488                         az.logger.WithError(err).Warnf("Error listing nics")
489                         return interfaces, nil
490                 }
491                 if strings.HasPrefix(*result.Value().Name, az.namePrefix) {
492                         if result.Value().VirtualMachine != nil {
493                                 interfaces[*result.Value().ID] = result.Value()
494                         } else {
495                                 if result.Value().Tags["created-at"] != nil {
496                                         createdAt, err := time.Parse(time.RFC3339Nano, *result.Value().Tags["created-at"])
497                                         if err == nil {
498                                                 if timestamp.Sub(createdAt).Seconds() > az.azconfig.DeleteDanglingResourcesAfter.Duration().Seconds() {
499                                                         az.logger.Printf("Will delete %v because it is older than %v s", *result.Value().Name, az.azconfig.DeleteDanglingResourcesAfter)
500                                                         az.deleteNIC <- *result.Value().Name
501                                                 }
502                                         }
503                                 }
504                         }
505                 }
506         }
507         return interfaces, nil
508 }
509
510 // ManageBlobs garbage collects blobs (VM disk images) in the
511 // configured storage account container.  It will delete blobs which
512 // have "namePrefix", are "available" (which means they are not
513 // leased to a VM) and haven't been modified for
514 // DeleteDanglingResourcesAfter seconds.
515 func (az *azureInstanceSet) manageBlobs() {
516         result, err := az.storageAcctClient.ListKeys(az.ctx, az.azconfig.ResourceGroup, az.azconfig.StorageAccount)
517         if err != nil {
518                 az.logger.WithError(err).Warn("Couldn't get account keys")
519                 return
520         }
521
522         key1 := *(*result.Keys)[0].Value
523         client, err := storage.NewBasicClientOnSovereignCloud(az.azconfig.StorageAccount, key1, az.azureEnv)
524         if err != nil {
525                 az.logger.WithError(err).Warn("Couldn't make client")
526                 return
527         }
528
529         blobsvc := client.GetBlobService()
530         blobcont := blobsvc.GetContainerReference(az.azconfig.BlobContainer)
531
532         page := storage.ListBlobsParameters{Prefix: az.namePrefix}
533         timestamp := time.Now()
534
535         for {
536                 response, err := blobcont.ListBlobs(page)
537                 if err != nil {
538                         az.logger.WithError(err).Warn("Error listing blobs")
539                         return
540                 }
541                 for _, b := range response.Blobs {
542                         age := timestamp.Sub(time.Time(b.Properties.LastModified))
543                         if b.Properties.BlobType == storage.BlobTypePage &&
544                                 b.Properties.LeaseState == "available" &&
545                                 b.Properties.LeaseStatus == "unlocked" &&
546                                 age.Seconds() > az.azconfig.DeleteDanglingResourcesAfter.Duration().Seconds() {
547
548                                 az.logger.Printf("Blob %v is unlocked and not modified for %v seconds, will delete", b.Name, age.Seconds())
549                                 az.deleteBlob <- b
550                         }
551                 }
552                 if response.NextMarker != "" {
553                         page.Marker = response.NextMarker
554                 } else {
555                         break
556                 }
557         }
558 }
559
560 func (az *azureInstanceSet) Stop() {
561         az.stopFunc()
562         az.stopWg.Wait()
563         close(az.deleteNIC)
564         close(az.deleteBlob)
565 }
566
567 type azureInstance struct {
568         provider *azureInstanceSet
569         nic      network.Interface
570         vm       compute.VirtualMachine
571 }
572
573 func (ai *azureInstance) ID() cloud.InstanceID {
574         return cloud.InstanceID(*ai.vm.ID)
575 }
576
577 func (ai *azureInstance) String() string {
578         return *ai.vm.Name
579 }
580
581 func (ai *azureInstance) ProviderType() string {
582         return string(ai.vm.VirtualMachineProperties.HardwareProfile.VMSize)
583 }
584
585 func (ai *azureInstance) SetTags(newTags cloud.InstanceTags) error {
586         ai.provider.stopWg.Add(1)
587         defer ai.provider.stopWg.Done()
588
589         tags := make(map[string]*string)
590
591         for k, v := range ai.vm.Tags {
592                 if !strings.HasPrefix(k, "dispatch-") {
593                         tags[k] = v
594                 }
595         }
596         for k, v := range newTags {
597                 newstr := v
598                 tags["dispatch-"+k] = &newstr
599         }
600
601         vmParameters := compute.VirtualMachine{
602                 Location: &ai.provider.azconfig.Location,
603                 Tags:     tags,
604         }
605         vm, err := ai.provider.vmClient.createOrUpdate(ai.provider.ctx, ai.provider.azconfig.ResourceGroup, *ai.vm.Name, vmParameters)
606         if err != nil {
607                 return wrapAzureError(err)
608         }
609         ai.vm = vm
610
611         return nil
612 }
613
614 func (ai *azureInstance) Tags() cloud.InstanceTags {
615         tags := make(map[string]string)
616
617         for k, v := range ai.vm.Tags {
618                 if strings.HasPrefix(k, "dispatch-") {
619                         tags[k[9:]] = *v
620                 }
621         }
622
623         return tags
624 }
625
626 func (ai *azureInstance) Destroy() error {
627         ai.provider.stopWg.Add(1)
628         defer ai.provider.stopWg.Done()
629
630         _, err := ai.provider.vmClient.delete(ai.provider.ctx, ai.provider.azconfig.ResourceGroup, *ai.vm.Name)
631         return wrapAzureError(err)
632 }
633
634 func (ai *azureInstance) Address() string {
635         return *(*ai.nic.IPConfigurations)[0].PrivateIPAddress
636 }
637
638 func (ai *azureInstance) RemoteUser() string {
639         return ai.provider.azconfig.AdminUsername
640 }
641
642 func (ai *azureInstance) VerifyHostKey(receivedKey ssh.PublicKey, client *ssh.Client) error {
643         ai.provider.stopWg.Add(1)
644         defer ai.provider.stopWg.Done()
645
646         remoteFingerprint := ssh.FingerprintSHA256(receivedKey)
647
648         tags := ai.Tags()
649
650         tg := tags["ssh-pubkey-fingerprint"]
651         if tg != "" {
652                 if remoteFingerprint == tg {
653                         return nil
654                 }
655                 return fmt.Errorf("Key fingerprint did not match, expected %q got %q", tg, remoteFingerprint)
656         }
657
658         nodetokenTag := tags["node-token"]
659         if nodetokenTag == "" {
660                 return fmt.Errorf("Missing node token tag")
661         }
662
663         sess, err := client.NewSession()
664         if err != nil {
665                 return err
666         }
667
668         nodetokenbytes, err := sess.Output("cat /home/" + ai.provider.azconfig.AdminUsername + "/node-token")
669         if err != nil {
670                 return err
671         }
672
673         nodetoken := strings.TrimSpace(string(nodetokenbytes))
674
675         expectedToken := fmt.Sprintf("%s-%s", *ai.vm.Name, nodetokenTag)
676
677         if strings.TrimSpace(nodetoken) != expectedToken {
678                 return fmt.Errorf("Node token did not match, expected %q got %q", expectedToken, nodetoken)
679         }
680
681         sess, err = client.NewSession()
682         if err != nil {
683                 return err
684         }
685
686         keyfingerprintbytes, err := sess.Output("ssh-keygen -E sha256 -l -f /etc/ssh/ssh_host_rsa_key.pub")
687         if err != nil {
688                 return err
689         }
690
691         sp := strings.Split(string(keyfingerprintbytes), " ")
692
693         if remoteFingerprint != sp[1] {
694                 return fmt.Errorf("Key fingerprint did not match, expected %q got %q", sp[1], remoteFingerprint)
695         }
696
697         tags["ssh-pubkey-fingerprint"] = sp[1]
698         delete(tags, "node-token")
699         ai.SetTags(tags)
700         return nil
701 }