14325: Add worker state diagram.
[arvados.git] / lib / cloud / azure.go
index 2ee66837218ed9e40017b0ccb8d00b96a6c042de..a194b33180b231cfb74964b11253d5aa6f8d0667 100644 (file)
@@ -8,7 +8,6 @@ import (
        "context"
        "encoding/base64"
        "fmt"
-       "log"
        "net/http"
        "regexp"
        "strconv"
@@ -27,23 +26,24 @@ import (
        "github.com/Azure/go-autorest/autorest/to"
        "github.com/jmcvetta/randutil"
        "github.com/mitchellh/mapstructure"
+       "github.com/sirupsen/logrus"
        "golang.org/x/crypto/ssh"
 )
 
 type AzureInstanceSetConfig struct {
-       SubscriptionID               string  `json:"subscription_id"`
-       ClientID                     string  `json:"key"`
-       ClientSecret                 string  `json:"secret"`
-       TenantID                     string  `json:"tenant_id"`
-       CloudEnv                     string  `json:"cloud_environment"`
-       ResourceGroup                string  `json:"resource_group"`
-       Location                     string  `json:"region"`
-       Network                      string  `json:"network"`
-       Subnet                       string  `json:"subnet"`
-       StorageAccount               string  `json:"storage_account"`
-       BlobContainer                string  `json:"blob_container"`
-       Image                        string  `json:"image"`
-       DeleteDanglingResourcesAfter float64 `json:"delete_dangling_resources_after"`
+       SubscriptionID               string  `mapstructure:"subscription_id"`
+       ClientID                     string  `mapstructure:"key"`
+       ClientSecret                 string  `mapstructure:"secret"`
+       TenantID                     string  `mapstructure:"tenant_id"`
+       CloudEnv                     string  `mapstructure:"cloud_environment"`
+       ResourceGroup                string  `mapstructure:"resource_group"`
+       Location                     string  `mapstructure:"region"`
+       Network                      string  `mapstructure:"network"`
+       Subnet                       string  `mapstructure:"subnet"`
+       StorageAccount               string  `mapstructure:"storage_account"`
+       BlobContainer                string  `mapstructure:"blob_container"`
+       Image                        string  `mapstructure:"image"`
+       DeleteDanglingResourcesAfter float64 `mapstructure:"delete_dangling_resources_after"`
 }
 
 type VirtualMachinesClientWrapper interface {
@@ -166,15 +166,14 @@ func WrapAzureError(err error) error {
                if parseErr != nil {
                        // Could not parse as a timestamp, must be number of seconds
                        dur, parseErr := strconv.ParseInt(ra, 10, 64)
-                       if parseErr != nil {
+                       if parseErr == nil {
                                earliestRetry = time.Now().Add(time.Duration(dur) * time.Second)
+                       } else {
+                               // Couldn't make sense of retry-after,
+                               // so set retry to 20 seconds
+                               earliestRetry = time.Now().Add(20 * time.Second)
                        }
                }
-               if parseErr != nil {
-                       // Couldn't make sense of retry-after,
-                       // so set retry to 20 seconds
-                       earliestRetry = time.Now().Add(20 * time.Second)
-               }
                return &AzureRateLimitError{*rq, earliestRetry}
        }
        if rq.ServiceError == nil {
@@ -195,14 +194,20 @@ type AzureInstanceSet struct {
        interfaces        map[string]network.Interface
        dispatcherID      string
        namePrefix        string
+       ctx               context.Context
+       stopFunc          context.CancelFunc
+       stopWg            sync.WaitGroup
+       deleteNIC         chan string
+       deleteBlob        chan storage.Blob
+       logger            logrus.FieldLogger
 }
 
-func NewAzureInstanceSet(config map[string]interface{}, dispatcherID InstanceSetID) (prv InstanceSet, err error) {
+func NewAzureInstanceSet(config map[string]interface{}, dispatcherID InstanceSetID, logger logrus.FieldLogger) (prv InstanceSet, err error) {
        azcfg := AzureInstanceSetConfig{}
        if err = mapstructure.Decode(config, &azcfg); err != nil {
                return nil, err
        }
-       ap := AzureInstanceSet{}
+       ap := AzureInstanceSet{logger: logger}
        err = ap.setup(azcfg, string(dispatcherID))
        if err != nil {
                return nil, err
@@ -243,15 +248,69 @@ func (az *AzureInstanceSet) setup(azcfg AzureInstanceSetConfig, dispatcherID str
        az.dispatcherID = dispatcherID
        az.namePrefix = fmt.Sprintf("compute-%s-", az.dispatcherID)
 
+       az.ctx, az.stopFunc = context.WithCancel(context.Background())
+       go func() {
+               az.stopWg.Add(1)
+               defer az.stopWg.Done()
+
+               tk := time.NewTicker(5 * time.Minute)
+               for {
+                       select {
+                       case <-az.ctx.Done():
+                               tk.Stop()
+                               return
+                       case <-tk.C:
+                               az.ManageBlobs()
+                       }
+               }
+       }()
+
+       az.deleteNIC = make(chan string)
+       az.deleteBlob = make(chan storage.Blob)
+
+       for i := 0; i < 4; i += 1 {
+               go func() {
+                       for {
+                               nicname, ok := <-az.deleteNIC
+                               if !ok {
+                                       return
+                               }
+                               _, delerr := az.netClient.Delete(context.Background(), az.azconfig.ResourceGroup, nicname)
+                               if delerr != nil {
+                                       az.logger.WithError(delerr).Warnf("Error deleting %v", nicname)
+                               } else {
+                                       az.logger.Printf("Deleted NIC %v", nicname)
+                               }
+                       }
+               }()
+               go func() {
+                       for {
+                               blob, ok := <-az.deleteBlob
+                               if !ok {
+                                       return
+                               }
+                               err := blob.Delete(nil)
+                               if err != nil {
+                                       az.logger.WithError(err).Warnf("Error deleting %v", blob.Name)
+                               } else {
+                                       az.logger.Printf("Deleted blob %v", blob.Name)
+                               }
+                       }
+               }()
+       }
+
        return nil
 }
 
-func (az *AzureInstanceSet) Create(ctx context.Context,
+func (az *AzureInstanceSet) Create(
        instanceType arvados.InstanceType,
        imageId ImageID,
        newTags InstanceTags,
        publicKey ssh.PublicKey) (Instance, error) {
 
+       az.stopWg.Add(1)
+       defer az.stopWg.Done()
+
        if len(newTags["node-token"]) == 0 {
                return nil, fmt.Errorf("Must provide tag 'node-token'")
        }
@@ -262,7 +321,6 @@ func (az *AzureInstanceSet) Create(ctx context.Context,
        }
 
        name = az.namePrefix + name
-       log.Printf("name is %v", name)
 
        timestamp := time.Now().Format(time.RFC3339Nano)
 
@@ -297,21 +355,17 @@ func (az *AzureInstanceSet) Create(ctx context.Context,
                        },
                },
        }
-       nic, err := az.netClient.CreateOrUpdate(ctx, az.azconfig.ResourceGroup, name+"-nic", nicParameters)
+       nic, err := az.netClient.CreateOrUpdate(az.ctx, az.azconfig.ResourceGroup, name+"-nic", nicParameters)
        if err != nil {
                return nil, WrapAzureError(err)
        }
 
-       log.Printf("Created NIC %v", *nic.ID)
-
        instance_vhd := fmt.Sprintf("https://%s.blob.%s/%s/%s-os.vhd",
                az.azconfig.StorageAccount,
                az.azureEnv.StorageEndpointSuffix,
                az.azconfig.BlobContainer,
                name)
 
-       log.Printf("URI instance vhd %v", instance_vhd)
-
        customData := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(`#!/bin/sh
 echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"])))
 
@@ -364,7 +418,7 @@ echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"])))
                },
        }
 
-       vm, err := az.vmClient.CreateOrUpdate(ctx, az.azconfig.ResourceGroup, name, vmParameters)
+       vm, err := az.vmClient.CreateOrUpdate(az.ctx, az.azconfig.ResourceGroup, name, vmParameters)
        if err != nil {
                return nil, WrapAzureError(err)
        }
@@ -376,13 +430,16 @@ echo '%s-%s' > /home/crunch/node-token`, name, newTags["node-token"])))
        }, nil
 }
 
-func (az *AzureInstanceSet) Instances(ctx context.Context, _ InstanceTags) ([]Instance, error) {
-       interfaces, err := az.ManageNics(ctx)
+func (az *AzureInstanceSet) Instances(InstanceTags) ([]Instance, error) {
+       az.stopWg.Add(1)
+       defer az.stopWg.Done()
+
+       interfaces, err := az.ManageNics()
        if err != nil {
                return nil, err
        }
 
-       result, err := az.vmClient.ListComplete(ctx, az.azconfig.ResourceGroup)
+       result, err := az.vmClient.ListComplete(az.ctx, az.azconfig.ResourceGroup)
        if err != nil {
                return nil, WrapAzureError(err)
        }
@@ -403,8 +460,16 @@ func (az *AzureInstanceSet) Instances(ctx context.Context, _ InstanceTags) ([]In
        return instances, nil
 }
 
-func (az *AzureInstanceSet) ManageNics(ctx context.Context) (map[string]network.Interface, error) {
-       result, err := az.netClient.ListComplete(ctx, az.azconfig.ResourceGroup)
+// ManageNics returns a list of Azure network interface resources.
+// Also performs garbage collection of NICs which have "namePrefix", are
+// not associated with a virtual machine and have a "create-at" time
+// more than DeleteDanglingResourcesAfter (to prevent racing and
+// deleting newly created NICs) in the past are deleted.
+func (az *AzureInstanceSet) ManageNics() (map[string]network.Interface, error) {
+       az.stopWg.Add(1)
+       defer az.stopWg.Done()
+
+       result, err := az.netClient.ListComplete(az.ctx, az.azconfig.ResourceGroup)
        if err != nil {
                return nil, WrapAzureError(err)
        }
@@ -412,33 +477,9 @@ func (az *AzureInstanceSet) ManageNics(ctx context.Context) (map[string]network.
        interfaces := make(map[string]network.Interface)
 
        timestamp := time.Now()
-       wg := sync.WaitGroup{}
-       deletechannel := make(chan string, 20)
-       defer func() {
-               wg.Wait()
-               close(deletechannel)
-       }()
-       for i := 0; i < 4; i += 1 {
-               go func() {
-                       for {
-                               nicname, ok := <-deletechannel
-                               if !ok {
-                                       return
-                               }
-                               _, delerr := az.netClient.Delete(context.Background(), az.azconfig.ResourceGroup, nicname)
-                               if delerr != nil {
-                                       log.Printf("Error deleting %v: %v", nicname, delerr)
-                               } else {
-                                       log.Printf("Deleted %v", nicname)
-                               }
-                               wg.Done()
-                       }
-               }()
-       }
-
        for ; result.NotDone(); err = result.Next() {
                if err != nil {
-                       log.Printf("Error listing nics: %v", err)
+                       az.logger.WithError(err).Warnf("Error listing nics")
                        return interfaces, nil
                }
                if strings.HasPrefix(*result.Value().Name, az.namePrefix) {
@@ -448,11 +489,9 @@ func (az *AzureInstanceSet) ManageNics(ctx context.Context) (map[string]network.
                                if result.Value().Tags["created-at"] != nil {
                                        created_at, err := time.Parse(time.RFC3339Nano, *result.Value().Tags["created-at"])
                                        if err == nil {
-                                               //log.Printf("found dangling NIC %v created %v seconds ago", *result.Value().Name, timestamp.Sub(created_at).Seconds())
                                                if timestamp.Sub(created_at).Seconds() > az.azconfig.DeleteDanglingResourcesAfter {
-                                                       log.Printf("Will delete %v because it is older than %v s", *result.Value().Name, az.azconfig.DeleteDanglingResourcesAfter)
-                                                       wg.Add(1)
-                                                       deletechannel <- *result.Value().Name
+                                                       az.logger.Printf("Will delete %v because it is older than %v s", *result.Value().Name, az.azconfig.DeleteDanglingResourcesAfter)
+                                                       az.deleteNIC <- *result.Value().Name
                                                }
                                        }
                                }
@@ -462,54 +501,35 @@ func (az *AzureInstanceSet) ManageNics(ctx context.Context) (map[string]network.
        return interfaces, nil
 }
 
-func (az *AzureInstanceSet) ManageBlobs(ctx context.Context) {
-       result, err := az.storageAcctClient.ListKeys(ctx, az.azconfig.ResourceGroup, az.azconfig.StorageAccount)
+// ManageBlobs garbage collects blobs (VM disk images) in the
+// configured storage account container.  It will delete blobs which
+// have "namePrefix", are "available" (which means they are not
+// leased to a VM) and haven't been modified for
+// DeleteDanglingResourcesAfter seconds.
+func (az *AzureInstanceSet) ManageBlobs() {
+       result, err := az.storageAcctClient.ListKeys(az.ctx, az.azconfig.ResourceGroup, az.azconfig.StorageAccount)
        if err != nil {
-               log.Printf("Couldn't get account keys %v", err)
+               az.logger.WithError(err).Warn("Couldn't get account keys")
                return
        }
 
        key1 := *(*result.Keys)[0].Value
        client, err := storage.NewBasicClientOnSovereignCloud(az.azconfig.StorageAccount, key1, az.azureEnv)
        if err != nil {
-               log.Printf("Couldn't make client %v", err)
+               az.logger.WithError(err).Warn("Couldn't make client")
                return
        }
 
        blobsvc := client.GetBlobService()
        blobcont := blobsvc.GetContainerReference(az.azconfig.BlobContainer)
 
-       timestamp := time.Now()
-       wg := sync.WaitGroup{}
-       deletechannel := make(chan storage.Blob, 20)
-       defer func() {
-               wg.Wait()
-               close(deletechannel)
-       }()
-       for i := 0; i < 4; i += 1 {
-               go func() {
-                       for {
-                               blob, ok := <-deletechannel
-                               if !ok {
-                                       return
-                               }
-                               err := blob.Delete(nil)
-                               if err != nil {
-                                       log.Printf("error deleting %v: %v", blob.Name, err)
-                               } else {
-                                       log.Printf("Deleted blob %v", blob.Name)
-                               }
-                               wg.Done()
-                       }
-               }()
-       }
-
        page := storage.ListBlobsParameters{Prefix: az.namePrefix}
+       timestamp := time.Now()
 
        for {
                response, err := blobcont.ListBlobs(page)
                if err != nil {
-                       log.Printf("Error listing blobs %v", err)
+                       az.logger.WithError(err).Warn("Error listing blobs")
                        return
                }
                for _, b := range response.Blobs {
@@ -519,9 +539,8 @@ func (az *AzureInstanceSet) ManageBlobs(ctx context.Context) {
                                b.Properties.LeaseStatus == "unlocked" &&
                                age.Seconds() > az.azconfig.DeleteDanglingResourcesAfter {
 
-                               log.Printf("Blob %v is unlocked and not modified for %v seconds, will delete", b.Name, age.Seconds())
-                               wg.Add(1)
-                               deletechannel <- b
+                               az.logger.Printf("Blob %v is unlocked and not modified for %v seconds, will delete", b.Name, age.Seconds())
+                               az.deleteBlob <- b
                        }
                }
                if response.NextMarker != "" {
@@ -533,6 +552,10 @@ func (az *AzureInstanceSet) ManageBlobs(ctx context.Context) {
 }
 
 func (az *AzureInstanceSet) Stop() {
+       az.stopFunc()
+       az.stopWg.Wait()
+       close(az.deleteNIC)
+       close(az.deleteBlob)
 }
 
 type AzureInstance struct {
@@ -553,7 +576,10 @@ func (ai *AzureInstance) ProviderType() string {
        return string(ai.vm.VirtualMachineProperties.HardwareProfile.VMSize)
 }
 
-func (ai *AzureInstance) SetTags(ctx context.Context, newTags InstanceTags) error {
+func (ai *AzureInstance) SetTags(newTags InstanceTags) error {
+       ai.provider.stopWg.Add(1)
+       defer ai.provider.stopWg.Done()
+
        tags := make(map[string]*string)
 
        for k, v := range ai.vm.Tags {
@@ -570,7 +596,7 @@ func (ai *AzureInstance) SetTags(ctx context.Context, newTags InstanceTags) erro
                Location: &ai.provider.azconfig.Location,
                Tags:     tags,
        }
-       vm, err := ai.provider.vmClient.CreateOrUpdate(ctx, ai.provider.azconfig.ResourceGroup, *ai.vm.Name, vmParameters)
+       vm, err := ai.provider.vmClient.CreateOrUpdate(ai.provider.ctx, ai.provider.azconfig.ResourceGroup, *ai.vm.Name, vmParameters)
        if err != nil {
                return WrapAzureError(err)
        }
@@ -591,8 +617,11 @@ func (ai *AzureInstance) Tags() InstanceTags {
        return tags
 }
 
-func (ai *AzureInstance) Destroy(ctx context.Context) error {
-       _, err := ai.provider.vmClient.Delete(ctx, ai.provider.azconfig.ResourceGroup, *ai.vm.Name)
+func (ai *AzureInstance) Destroy() error {
+       ai.provider.stopWg.Add(1)
+       defer ai.provider.stopWg.Done()
+
+       _, err := ai.provider.vmClient.Delete(ai.provider.ctx, ai.provider.azconfig.ResourceGroup, *ai.vm.Name)
        return WrapAzureError(err)
 }
 
@@ -600,7 +629,10 @@ func (ai *AzureInstance) Address() string {
        return *(*ai.nic.IPConfigurations)[0].PrivateIPAddress
 }
 
-func (ai *AzureInstance) VerifyHostKey(ctx context.Context, receivedKey ssh.PublicKey, client *ssh.Client) error {
+func (ai *AzureInstance) VerifyHostKey(receivedKey ssh.PublicKey, client *ssh.Client) error {
+       ai.provider.stopWg.Add(1)
+       defer ai.provider.stopWg.Done()
+
        remoteFingerprint := ssh.FingerprintSHA256(receivedKey)
 
        tags := ai.Tags()
@@ -632,7 +664,6 @@ func (ai *AzureInstance) VerifyHostKey(ctx context.Context, receivedKey ssh.Publ
        nodetoken := strings.TrimSpace(string(nodetokenbytes))
 
        expectedToken := fmt.Sprintf("%s-%s", *ai.vm.Name, nodetokenTag)
-       log.Printf("%q %q", nodetoken, expectedToken)
 
        if strings.TrimSpace(nodetoken) != expectedToken {
                return fmt.Errorf("Node token did not match, expected %q got %q", expectedToken, nodetoken)
@@ -650,14 +681,12 @@ func (ai *AzureInstance) VerifyHostKey(ctx context.Context, receivedKey ssh.Publ
 
        sp := strings.Split(string(keyfingerprintbytes), " ")
 
-       log.Printf("%q %q", remoteFingerprint, sp[1])
-
        if remoteFingerprint != sp[1] {
                return fmt.Errorf("Key fingerprint did not match, expected %q got %q", sp[1], remoteFingerprint)
        }
 
        tags["ssh-pubkey-fingerprint"] = sp[1]
        delete(tags, "node-token")
-       ai.SetTags(ctx, tags)
+       ai.SetTags(tags)
        return nil
 }