package httpserver
import (
+ "bufio"
"context"
+ "net"
"net/http"
+ "sync"
"time"
"git.arvados.org/arvados.git/sdk/go/ctxlog"
}
var (
- requestTimeContextKey = contextKey{"requestTime"}
+ requestTimeContextKey = contextKey{"requestTime"}
+ responseLogFieldsContextKey = contextKey{"responseLogFields"}
+ mutexContextKey = contextKey{"mutex"}
)
-// HandlerWithContext returns an http.Handler that changes the request
-// context to ctx (replacing http.Server's default
-// context.Background()), then calls next.
-func HandlerWithContext(ctx context.Context, next http.Handler) http.Handler {
+type hijacker interface {
+ http.ResponseWriter
+ http.Hijacker
+}
+
+// hijackNotifier wraps a ResponseWriter, calling the provided
+// Notify() func if/when the wrapped Hijacker is hijacked.
+type hijackNotifier struct {
+ hijacker
+ hijacked chan<- bool
+}
+
+func (hn hijackNotifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ close(hn.hijacked)
+ return hn.hijacker.Hijack()
+}
+
+// HandlerWithDeadline cancels the request context if the request
+// takes longer than the specified timeout without having its
+// connection hijacked.
+func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithCancel(r.Context())
+ defer cancel()
+ nodeadline := make(chan bool)
+ go func() {
+ select {
+ case <-nodeadline:
+ case <-ctx.Done():
+ case <-time.After(timeout):
+ cancel()
+ }
+ }()
+ if hj, ok := w.(hijacker); ok {
+ w = hijackNotifier{hj, nodeadline}
+ }
next.ServeHTTP(w, r.WithContext(ctx))
})
}
+func SetResponseLogFields(ctx context.Context, fields logrus.Fields) {
+ m, _ := ctx.Value(&mutexContextKey).(*sync.Mutex)
+ c, _ := ctx.Value(&responseLogFieldsContextKey).(logrus.Fields)
+ if m == nil || c == nil {
+ return
+ }
+ m.Lock()
+ defer m.Unlock()
+ for k, v := range fields {
+ c[k] = v
+ }
+}
+
// LogRequests wraps an http.Handler, logging each request and
// response.
func LogRequests(h http.Handler) http.Handler {
})
ctx := req.Context()
ctx = context.WithValue(ctx, &requestTimeContextKey, time.Now())
+ ctx = context.WithValue(ctx, &responseLogFieldsContextKey, logrus.Fields{})
+ ctx = context.WithValue(ctx, &mutexContextKey, &sync.Mutex{})
ctx = ctxlog.Context(ctx, lgr)
req = req.WithContext(ctx)
"timeWriteBody": stats.Duration(tDone.Sub(writeTime)),
})
}
+ if responseLogFields, ok := req.Context().Value(&responseLogFieldsContextKey).(logrus.Fields); ok {
+ lgr = lgr.WithFields(responseLogFields)
+ }
respCode := w.WroteStatus()
if respCode == 0 {
respCode = http.StatusOK