13697: Cancel request context after API.RequestTimeout.
authorTom Clegg <tom@curii.com>
Thu, 23 Sep 2021 13:41:20 +0000 (09:41 -0400)
committerTom Clegg <tom@curii.com>
Thu, 23 Sep 2021 15:12:21 +0000 (11:12 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/service/cmd.go
sdk/go/httpserver/logger.go
sdk/go/httpserver/logger_test.go

index 40db4f9c7c7f80f744ab3c44da874794d925c9c1..e67c24f65f39cea4929c95fe30abbdc5ab98a901 100644 (file)
@@ -126,13 +126,14 @@ func (c *command) RunCommand(prog string, args []string, stdin io.Reader, stdout
        }
 
        instrumented := httpserver.Instrument(reg, log,
-               httpserver.HandlerWithContext(ctx,
+               httpserver.HandlerWithDeadline(cluster.API.RequestTimeout.Duration(),
                        httpserver.AddRequestIDs(
                                httpserver.LogRequests(
                                        httpserver.NewRequestLimiter(cluster.API.MaxConcurrentRequests, handler, reg)))))
        srv := &httpserver.Server{
                Server: http.Server{
-                       Handler: instrumented.ServeAPI(cluster.ManagementToken, instrumented),
+                       Handler:     instrumented.ServeAPI(cluster.ManagementToken, instrumented),
+                       BaseContext: func(net.Listener) context.Context { return ctx },
                },
                Addr: listenURL.Host,
        }
index 78a1f77adb9ca9942d00fe559c3192b658e8e5fe..1916880963494333ce9d4356f1d1c9d2eba55f3b 100644 (file)
@@ -31,6 +31,16 @@ func HandlerWithContext(ctx context.Context, next http.Handler) http.Handler {
        })
 }
 
+// HandlerWithDeadline cancels the request context if the request
+// takes longer than the specified timeout.
+func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler {
+       return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               ctx, cancel := context.WithDeadline(r.Context(), time.Now().Add(timeout))
+               defer cancel()
+               next.ServeHTTP(w, r.WithContext(ctx))
+       })
+}
+
 // LogRequests wraps an http.Handler, logging each request and
 // response.
 func LogRequests(h http.Handler) http.Handler {
index af45a640ca38e2bb8baa76244070d4f5753324b0..b623aa4eea33894d3fe01cbfa1f0d54cf57d226b 100644 (file)
@@ -41,6 +41,35 @@ func (s *Suite) SetUpTest(c *check.C) {
        s.ctx = ctxlog.Context(context.Background(), s.log)
 }
 
+func (s *Suite) TestWithDeadline(c *check.C) {
+       req, err := http.NewRequest("GET", "https://foo.example/bar", nil)
+       c.Assert(err, check.IsNil)
+
+       // Short timeout cancels context in <1s
+       resp := httptest.NewRecorder()
+       HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               select {
+               case <-req.Context().Done():
+                       w.Write([]byte("ok"))
+               case <-time.After(time.Second):
+                       c.Error("timed out")
+               }
+       })).ServeHTTP(resp, req.WithContext(s.ctx))
+       c.Check(resp.Body.String(), check.Equals, "ok")
+
+       // Long timeout does not cancel context in <1ms
+       resp = httptest.NewRecorder()
+       HandlerWithDeadline(time.Second, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               select {
+               case <-req.Context().Done():
+                       c.Error("request context done too soon")
+               case <-time.After(time.Millisecond):
+                       w.Write([]byte("ok"))
+               }
+       })).ServeHTTP(resp, req.WithContext(s.ctx))
+       c.Check(resp.Body.String(), check.Equals, "ok")
+}
+
 func (s *Suite) TestLogRequests(c *check.C) {
        h := AddRequestIDs(LogRequests(
                http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
@@ -52,7 +81,7 @@ func (s *Suite) TestLogRequests(c *check.C) {
        c.Assert(err, check.IsNil)
        resp := httptest.NewRecorder()
 
-       HandlerWithContext(s.ctx, h).ServeHTTP(resp, req)
+       h.ServeHTTP(resp, req.WithContext(s.ctx))
 
        dec := json.NewDecoder(s.logdata)
 
@@ -104,12 +133,12 @@ func (s *Suite) TestLogErrorBody(c *check.C) {
                c.Assert(err, check.IsNil)
                resp := httptest.NewRecorder()
 
-               HandlerWithContext(s.ctx, LogRequests(
+               LogRequests(
                        http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
                                w.WriteHeader(trial.statusCode)
                                w.Write([]byte(trial.sentBody))
                        }),
-               )).ServeHTTP(resp, req)
+               ).ServeHTTP(resp, req.WithContext(s.ctx))
 
                gotReq := make(map[string]interface{})
                err = dec.Decode(&gotReq)