}
instrumented := httpserver.Instrument(reg, log,
- httpserver.HandlerWithContext(ctx,
+ httpserver.HandlerWithDeadline(cluster.API.RequestTimeout.Duration(),
httpserver.AddRequestIDs(
httpserver.LogRequests(
httpserver.NewRequestLimiter(cluster.API.MaxConcurrentRequests, handler, reg)))))
srv := &httpserver.Server{
Server: http.Server{
- Handler: instrumented.ServeAPI(cluster.ManagementToken, instrumented),
+ Handler: instrumented.ServeAPI(cluster.ManagementToken, instrumented),
+ BaseContext: func(net.Listener) context.Context { return ctx },
},
Addr: listenURL.Host,
}
})
}
+// HandlerWithDeadline cancels the request context if the request
+// takes longer than the specified timeout.
+func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithDeadline(r.Context(), time.Now().Add(timeout))
+ defer cancel()
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
+
// LogRequests wraps an http.Handler, logging each request and
// response.
func LogRequests(h http.Handler) http.Handler {
s.ctx = ctxlog.Context(context.Background(), s.log)
}
+func (s *Suite) TestWithDeadline(c *check.C) {
+ req, err := http.NewRequest("GET", "https://foo.example/bar", nil)
+ c.Assert(err, check.IsNil)
+
+ // Short timeout cancels context in <1s
+ resp := httptest.NewRecorder()
+ HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ select {
+ case <-req.Context().Done():
+ w.Write([]byte("ok"))
+ case <-time.After(time.Second):
+ c.Error("timed out")
+ }
+ })).ServeHTTP(resp, req.WithContext(s.ctx))
+ c.Check(resp.Body.String(), check.Equals, "ok")
+
+ // Long timeout does not cancel context in <1ms
+ resp = httptest.NewRecorder()
+ HandlerWithDeadline(time.Second, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ select {
+ case <-req.Context().Done():
+ c.Error("request context done too soon")
+ case <-time.After(time.Millisecond):
+ w.Write([]byte("ok"))
+ }
+ })).ServeHTTP(resp, req.WithContext(s.ctx))
+ c.Check(resp.Body.String(), check.Equals, "ok")
+}
+
func (s *Suite) TestLogRequests(c *check.C) {
h := AddRequestIDs(LogRequests(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
c.Assert(err, check.IsNil)
resp := httptest.NewRecorder()
- HandlerWithContext(s.ctx, h).ServeHTTP(resp, req)
+ h.ServeHTTP(resp, req.WithContext(s.ctx))
dec := json.NewDecoder(s.logdata)
c.Assert(err, check.IsNil)
resp := httptest.NewRecorder()
- HandlerWithContext(s.ctx, LogRequests(
+ LogRequests(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(trial.statusCode)
w.Write([]byte(trial.sentBody))
}),
- )).ServeHTTP(resp, req)
+ ).ServeHTTP(resp, req.WithContext(s.ctx))
gotReq := make(map[string]interface{})
err = dec.Decode(&gotReq)