19961: Reuse metadata token within expected TTL, adjust retry logic.
authorTom Clegg <tom@curii.com>
Tue, 21 Feb 2023 21:14:04 +0000 (16:14 -0500)
committerTom Clegg <tom@curii.com>
Tue, 21 Feb 2023 21:14:04 +0000 (16:14 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/crunchrun/crunchrun.go
lib/crunchrun/crunchrun_test.go

index ea48512181b0134d441e15a4f048afcbb0726fdd..a9c65cca422922dac2c5a43a658c7c874ff0ce58 100644 (file)
@@ -1203,6 +1203,8 @@ func (runner *ContainerRunner) updateLogs() {
 var spotInterruptionCheckInterval = 5 * time.Second
 var ec2MetadataBaseURL = "http://169.254.169.254"
 
+const ec2TokenTTL = time.Second * 21600
+
 func (runner *ContainerRunner) checkSpotInterruptionNotices() {
        type ec2metadata struct {
                Action string    `json:"action"`
@@ -1210,38 +1212,47 @@ func (runner *ContainerRunner) checkSpotInterruptionNotices() {
        }
        runner.CrunchLog.Printf("Checking for spot interruptions every %v using instance metadata at %s", spotInterruptionCheckInterval, ec2MetadataBaseURL)
        var metadata ec2metadata
+       var token string
+       var tokenExp time.Time
        check := func() error {
                ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
                defer cancel()
-               req, err := http.NewRequestWithContext(ctx, http.MethodPut, ec2MetadataBaseURL+"/latest/api/token", nil)
-               if err != nil {
-                       return err
-               }
-               req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600")
-               resp, err := http.DefaultClient.Do(req)
-               if err != nil {
-                       return err
-               }
-               defer resp.Body.Close()
-               if resp.StatusCode != http.StatusOK {
-                       return fmt.Errorf("%s", resp.Status)
-               }
-               token, err := ioutil.ReadAll(resp.Body)
-               if err != nil {
-                       return err
+               if token == "" || tokenExp.Sub(time.Now()) < time.Minute {
+                       req, err := http.NewRequestWithContext(ctx, http.MethodPut, ec2MetadataBaseURL+"/latest/api/token", nil)
+                       if err != nil {
+                               return err
+                       }
+                       req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", fmt.Sprintf("%d", int(ec2TokenTTL/time.Second)))
+                       resp, err := http.DefaultClient.Do(req)
+                       if err != nil {
+                               return err
+                       }
+                       defer resp.Body.Close()
+                       if resp.StatusCode != http.StatusOK {
+                               return fmt.Errorf("%s", resp.Status)
+                       }
+                       newtoken, err := ioutil.ReadAll(resp.Body)
+                       if err != nil {
+                               return err
+                       }
+                       token = strings.TrimSpace(string(newtoken))
+                       tokenExp = time.Now().Add(ec2TokenTTL)
                }
-               req, err = http.NewRequestWithContext(ctx, http.MethodGet, ec2MetadataBaseURL+"/latest/meta-data/spot/instance-action", nil)
+               req, err := http.NewRequestWithContext(ctx, http.MethodGet, ec2MetadataBaseURL+"/latest/meta-data/spot/instance-action", nil)
                if err != nil {
                        return err
                }
-               req.Header.Set("X-aws-ec2-metadata-token", strings.TrimSpace(string(token)))
-               resp, err = http.DefaultClient.Do(req)
+               req.Header.Set("X-aws-ec2-metadata-token", token)
+               resp, err := http.DefaultClient.Do(req)
                if err != nil {
                        return err
                }
                defer resp.Body.Close()
                metadata = ec2metadata{}
-               if resp.StatusCode == http.StatusNotFound {
+               switch resp.StatusCode {
+               case http.StatusOK:
+                       break
+               case http.StatusNotFound:
                        // "If Amazon EC2 is not preparing to stop or
                        // terminate the instance, or if you
                        // terminated the instance yourself,
@@ -1249,7 +1260,10 @@ func (runner *ContainerRunner) checkSpotInterruptionNotices() {
                        // instance metadata and you receive an HTTP
                        // 404 error when you try to retrieve it."
                        return nil
-               } else if resp.StatusCode != http.StatusOK {
+               case http.StatusUnauthorized:
+                       token = ""
+                       return fmt.Errorf("%s", resp.Status)
+               default:
                        return fmt.Errorf("%s", resp.Status)
                }
                err = json.NewDecoder(resp.Body).Decode(&metadata)
@@ -1265,12 +1279,13 @@ func (runner *ContainerRunner) checkSpotInterruptionNotices() {
                if err != nil {
                        runner.CrunchLog.Printf("Error checking spot interruptions: %s", err)
                        failures++
-                       if failures > 3 {
-                               runner.CrunchLog.Printf("Giving up on checking spot interruptions after too many errors")
+                       if failures > 5 {
+                               runner.CrunchLog.Printf("Giving up on checking spot interruptions after too many consecutive failures")
                                return
                        }
                        continue
                }
+               failures = 0
                if metadata != lastmetadata {
                        lastmetadata = metadata
                        text := fmt.Sprintf("Cloud provider indicates instance action %q scheduled for time %q", metadata.Action, metadata.Time.UTC().Format(time.RFC3339))
index 786f9410a8ac77783edc8e04652960ba56a1217e..5b4c6827b965feba6456e3f4f47c5d82288a9d64 100644 (file)
@@ -13,6 +13,7 @@ import (
        "io"
        "io/ioutil"
        "log"
+       "math/rand"
        "net/http"
        "net/http/httptest"
        "os"
@@ -21,6 +22,7 @@ import (
        "runtime/pprof"
        "strings"
        "sync"
+       "sync/atomic"
        "syscall"
        "testing"
        "time"
@@ -777,38 +779,50 @@ func (s *TestSuite) TestRunAlreadyRunning(c *C) {
        c.Check(ran, Equals, false)
 }
 
-func (s *TestSuite) TestSpotInterruptionNotice(c *C) {
-       var failedOnce bool
-       var stoptime time.Time
-       token := "fake-ec2-metadata-token"
-       stub := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               if !failedOnce {
+func ec2MetadataServerStub(c *C, token *string, failureRate float64, stoptime *atomic.Value) *httptest.Server {
+       failedOnce := false
+       return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               if !failedOnce || rand.Float64() < failureRate {
                        w.WriteHeader(http.StatusServiceUnavailable)
                        failedOnce = true
                        return
                }
                switch r.URL.Path {
                case "/latest/api/token":
-                       fmt.Fprintln(w, token)
+                       fmt.Fprintln(w, *token)
                case "/latest/meta-data/spot/instance-action":
-                       if r.Header.Get("X-aws-ec2-metadata-token") != token {
+                       if r.Header.Get("X-aws-ec2-metadata-token") != *token {
                                w.WriteHeader(http.StatusUnauthorized)
-                       } else if stoptime.IsZero() {
+                       } else if t, _ := stoptime.Load().(time.Time); t.IsZero() {
                                w.WriteHeader(http.StatusNotFound)
                        } else {
-                               fmt.Fprintf(w, `{"action":"stop","time":"%s"}`, stoptime.Format(time.RFC3339))
+                               fmt.Fprintf(w, `{"action":"stop","time":"%s"}`, t.Format(time.RFC3339))
                        }
                default:
                        w.WriteHeader(http.StatusNotFound)
                }
        }))
+}
+
+func (s *TestSuite) TestSpotInterruptionNotice(c *C) {
+       s.testSpotInterruptionNotice(c, 0.1)
+}
+
+func (s *TestSuite) TestSpotInterruptionNoticeNotAvailable(c *C) {
+       s.testSpotInterruptionNotice(c, 1)
+}
+
+func (s *TestSuite) testSpotInterruptionNotice(c *C, failureRate float64) {
+       var stoptime atomic.Value
+       token := "fake-ec2-metadata-token"
+       stub := ec2MetadataServerStub(c, &token, failureRate, &stoptime)
        defer stub.Close()
 
        defer func(i time.Duration, u string) {
                spotInterruptionCheckInterval = i
                ec2MetadataBaseURL = u
        }(spotInterruptionCheckInterval, ec2MetadataBaseURL)
-       spotInterruptionCheckInterval = time.Second / 4
+       spotInterruptionCheckInterval = time.Second / 8
        ec2MetadataBaseURL = stub.URL
 
        go s.runner.checkSpotInterruptionNotices()
@@ -824,13 +838,18 @@ func (s *TestSuite) TestSpotInterruptionNotice(c *C) {
     "state": "Locked"
 }`, nil, func() int {
                time.Sleep(time.Second)
-               stoptime = time.Now().Add(time.Minute).UTC()
+               stoptime.Store(time.Now().Add(time.Minute).UTC())
+               token = "different-fake-ec2-metadata-token"
                time.Sleep(time.Second)
                return 0
        })
-       c.Check(s.api.Logs["crunch-run"].String(), Matches, `(?ms).*Checking for spot interruptions every 250ms using instance metadata at http://.*`)
+       c.Check(s.api.Logs["crunch-run"].String(), Matches, `(?ms).*Checking for spot interruptions every 125ms using instance metadata at http://.*`)
        c.Check(s.api.Logs["crunch-run"].String(), Matches, `(?ms).*Error checking spot interruptions: 503 Service Unavailable.*`)
-       c.Check(s.api.Logs["crunch-run"].String(), Matches, `(?ms).*Cloud provider indicates instance action "stop" scheduled for time "`+stoptime.Format(time.RFC3339)+`".*`)
+       if failureRate == 1 {
+               c.Check(s.api.Logs["crunch-run"].String(), Matches, `(?ms).*Giving up on checking spot interruptions after too many consecutive failures.*`)
+       } else {
+               c.Check(s.api.Logs["crunch-run"].String(), Matches, `(?ms).*Cloud provider indicates instance action "stop" scheduled for time "`+stoptime.Load().(time.Time).Format(time.RFC3339)+`".*`)
+       }
 }
 
 func (s *TestSuite) TestRunTimeExceeded(c *C) {