// Copyright (C) The Arvados Authors. All rights reserved.
//
-// SPDX-License-Identifier: AGPL-3.0
+// SPDX-License-Identifier: Apache-2.0
package httpserver
import (
+ "bufio"
"context"
+ "net"
"net/http"
"time"
}
var (
- requestTimeContextKey = contextKey{"requestTime"}
+ requestTimeContextKey = contextKey{"requestTime"}
+ responseLogFieldsContextKey = contextKey{"responseLogFields"}
)
-// HandlerWithContext returns an http.Handler that changes the request
-// context to ctx (replacing http.Server's default
-// context.Background()), then calls next.
-func HandlerWithContext(ctx context.Context, next http.Handler) http.Handler {
+type hijacker interface {
+ http.ResponseWriter
+ http.Hijacker
+}
+
+// hijackNotifier wraps a ResponseWriter, calling the provided
+// Notify() func if/when the wrapped Hijacker is hijacked.
+type hijackNotifier struct {
+ hijacker
+ hijacked chan<- bool
+}
+
+func (hn hijackNotifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ close(hn.hijacked)
+ return hn.hijacker.Hijack()
+}
+
+// HandlerWithDeadline cancels the request context if the request
+// takes longer than the specified timeout without having its
+// connection hijacked.
+func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithCancel(r.Context())
+ defer cancel()
+ nodeadline := make(chan bool)
+ go func() {
+ select {
+ case <-nodeadline:
+ case <-ctx.Done():
+ case <-time.After(timeout):
+ cancel()
+ }
+ }()
+ if hj, ok := w.(hijacker); ok {
+ w = hijackNotifier{hj, nodeadline}
+ }
next.ServeHTTP(w, r.WithContext(ctx))
})
}
+func SetResponseLogFields(ctx context.Context, fields logrus.Fields) {
+ ctxfields := ctx.Value(&responseLogFieldsContextKey)
+ if c, ok := ctxfields.(logrus.Fields); ok {
+ for k, v := range fields {
+ c[k] = v
+ }
+ }
+}
+
// LogRequests wraps an http.Handler, logging each request and
// response.
func LogRequests(h http.Handler) http.Handler {
})
ctx := req.Context()
ctx = context.WithValue(ctx, &requestTimeContextKey, time.Now())
+ ctx = context.WithValue(ctx, &responseLogFieldsContextKey, logrus.Fields{})
ctx = ctxlog.Context(ctx, lgr)
req = req.WithContext(ctx)
logRequest(w, req, lgr)
defer logResponse(w, req, lgr)
- h.ServeHTTP(w, req)
+ h.ServeHTTP(rewrapResponseWriter(w, wrapped), req)
})
}
+// Rewrap w to restore additional interfaces provided by wrapped.
+func rewrapResponseWriter(w http.ResponseWriter, wrapped http.ResponseWriter) http.ResponseWriter {
+ if hijacker, ok := wrapped.(http.Hijacker); ok {
+ return struct {
+ http.ResponseWriter
+ http.Hijacker
+ }{w, hijacker}
+ }
+ return w
+}
+
func Logger(req *http.Request) logrus.FieldLogger {
return ctxlog.FromContext(req.Context())
}
func logResponse(w *responseTimer, req *http.Request, lgr *logrus.Entry) {
if tStart, ok := req.Context().Value(&requestTimeContextKey).(time.Time); ok {
tDone := time.Now()
+ writeTime := w.writeTime
+ if !w.wrote {
+ // Empty response body. Header was sent when
+ // handler exited.
+ writeTime = tDone
+ }
lgr = lgr.WithFields(logrus.Fields{
"timeTotal": stats.Duration(tDone.Sub(tStart)),
- "timeToStatus": stats.Duration(w.writeTime.Sub(tStart)),
- "timeWriteBody": stats.Duration(tDone.Sub(w.writeTime)),
+ "timeToStatus": stats.Duration(writeTime.Sub(tStart)),
+ "timeWriteBody": stats.Duration(tDone.Sub(writeTime)),
})
}
+ if responseLogFields, ok := req.Context().Value(&responseLogFieldsContextKey).(logrus.Fields); ok {
+ lgr = lgr.WithFields(responseLogFields)
+ }
respCode := w.WroteStatus()
if respCode == 0 {
respCode = http.StatusOK