From 2b8fc576e242c0b8658eef9f1130143e009efc4d Mon Sep 17 00:00:00 2001 From: Tom Clegg Date: Thu, 23 Sep 2021 09:41:20 -0400 Subject: [PATCH] 13697: Cancel request context after API.RequestTimeout. Arvados-DCO-1.1-Signed-off-by: Tom Clegg --- lib/service/cmd.go | 5 +++-- sdk/go/httpserver/logger.go | 10 +++++++++ sdk/go/httpserver/logger_test.go | 35 +++++++++++++++++++++++++++++--- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/lib/service/cmd.go b/lib/service/cmd.go index 40db4f9c7c..e67c24f65f 100644 --- a/lib/service/cmd.go +++ b/lib/service/cmd.go @@ -126,13 +126,14 @@ func (c *command) RunCommand(prog string, args []string, stdin io.Reader, stdout } instrumented := httpserver.Instrument(reg, log, - httpserver.HandlerWithContext(ctx, + httpserver.HandlerWithDeadline(cluster.API.RequestTimeout.Duration(), httpserver.AddRequestIDs( httpserver.LogRequests( httpserver.NewRequestLimiter(cluster.API.MaxConcurrentRequests, handler, reg))))) srv := &httpserver.Server{ Server: http.Server{ - Handler: instrumented.ServeAPI(cluster.ManagementToken, instrumented), + Handler: instrumented.ServeAPI(cluster.ManagementToken, instrumented), + BaseContext: func(net.Listener) context.Context { return ctx }, }, Addr: listenURL.Host, } diff --git a/sdk/go/httpserver/logger.go b/sdk/go/httpserver/logger.go index 78a1f77adb..1916880963 100644 --- a/sdk/go/httpserver/logger.go +++ b/sdk/go/httpserver/logger.go @@ -31,6 +31,16 @@ func HandlerWithContext(ctx context.Context, next http.Handler) http.Handler { }) } +// HandlerWithDeadline cancels the request context if the request +// takes longer than the specified timeout. +func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithDeadline(r.Context(), time.Now().Add(timeout)) + defer cancel() + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + // LogRequests wraps an http.Handler, logging each request and // response. func LogRequests(h http.Handler) http.Handler { diff --git a/sdk/go/httpserver/logger_test.go b/sdk/go/httpserver/logger_test.go index af45a640ca..b623aa4eea 100644 --- a/sdk/go/httpserver/logger_test.go +++ b/sdk/go/httpserver/logger_test.go @@ -41,6 +41,35 @@ func (s *Suite) SetUpTest(c *check.C) { s.ctx = ctxlog.Context(context.Background(), s.log) } +func (s *Suite) TestWithDeadline(c *check.C) { + req, err := http.NewRequest("GET", "https://foo.example/bar", nil) + c.Assert(err, check.IsNil) + + // Short timeout cancels context in <1s + resp := httptest.NewRecorder() + HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + select { + case <-req.Context().Done(): + w.Write([]byte("ok")) + case <-time.After(time.Second): + c.Error("timed out") + } + })).ServeHTTP(resp, req.WithContext(s.ctx)) + c.Check(resp.Body.String(), check.Equals, "ok") + + // Long timeout does not cancel context in <1ms + resp = httptest.NewRecorder() + HandlerWithDeadline(time.Second, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + select { + case <-req.Context().Done(): + c.Error("request context done too soon") + case <-time.After(time.Millisecond): + w.Write([]byte("ok")) + } + })).ServeHTTP(resp, req.WithContext(s.ctx)) + c.Check(resp.Body.String(), check.Equals, "ok") +} + func (s *Suite) TestLogRequests(c *check.C) { h := AddRequestIDs(LogRequests( http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { @@ -52,7 +81,7 @@ func (s *Suite) TestLogRequests(c *check.C) { c.Assert(err, check.IsNil) resp := httptest.NewRecorder() - HandlerWithContext(s.ctx, h).ServeHTTP(resp, req) + h.ServeHTTP(resp, req.WithContext(s.ctx)) dec := json.NewDecoder(s.logdata) @@ -104,12 +133,12 @@ func (s *Suite) TestLogErrorBody(c *check.C) { c.Assert(err, check.IsNil) resp := httptest.NewRecorder() - HandlerWithContext(s.ctx, LogRequests( + LogRequests( http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(trial.statusCode) w.Write([]byte(trial.sentBody)) }), - )).ServeHTTP(resp, req) + ).ServeHTTP(resp, req.WithContext(s.ctx)) gotReq := make(map[string]interface{}) err = dec.Decode(&gotReq) -- 2.30.2