18896: controller logs the UUIDs of the tokens used in each request.
[arvados.git] / sdk / go / httpserver / logger.go
index 1916880963494333ce9d4356f1d1c9d2eba55f3b..437429611cb1ee1d0a5db75be353865863985bb6 100644 (file)
@@ -5,7 +5,9 @@
 package httpserver
 
 import (
+       "bufio"
        "context"
+       "net"
        "net/http"
        "time"
 
@@ -19,28 +21,59 @@ type contextKey struct {
 }
 
 var (
-       requestTimeContextKey = contextKey{"requestTime"}
+       requestTimeContextKey       = contextKey{"requestTime"}
+       responseLogFieldsContextKey = contextKey{"responseLogFields"}
 )
 
-// HandlerWithContext returns an http.Handler that changes the request
-// context to ctx (replacing http.Server's default
-// context.Background()), then calls next.
-func HandlerWithContext(ctx context.Context, next http.Handler) http.Handler {
-       return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               next.ServeHTTP(w, r.WithContext(ctx))
-       })
+type hijacker interface {
+       http.ResponseWriter
+       http.Hijacker
+}
+
+// hijackNotifier wraps a ResponseWriter, calling the provided
+// Notify() func if/when the wrapped Hijacker is hijacked.
+type hijackNotifier struct {
+       hijacker
+       hijacked chan<- bool
+}
+
+func (hn hijackNotifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+       close(hn.hijacked)
+       return hn.hijacker.Hijack()
 }
 
 // HandlerWithDeadline cancels the request context if the request
-// takes longer than the specified timeout.
+// takes longer than the specified timeout without having its
+// connection hijacked.
 func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               ctx, cancel := context.WithDeadline(r.Context(), time.Now().Add(timeout))
+               ctx, cancel := context.WithCancel(r.Context())
                defer cancel()
+               nodeadline := make(chan bool)
+               go func() {
+                       select {
+                       case <-nodeadline:
+                       case <-ctx.Done():
+                       case <-time.After(timeout):
+                               cancel()
+                       }
+               }()
+               if hj, ok := w.(hijacker); ok {
+                       w = hijackNotifier{hj, nodeadline}
+               }
                next.ServeHTTP(w, r.WithContext(ctx))
        })
 }
 
+func SetResponseLogFields(ctx context.Context, fields logrus.Fields) {
+       ctxfields := ctx.Value(&responseLogFieldsContextKey)
+       if c, ok := ctxfields.(logrus.Fields); ok {
+               for k, v := range fields {
+                       c[k] = v
+               }
+       }
+}
+
 // LogRequests wraps an http.Handler, logging each request and
 // response.
 func LogRequests(h http.Handler) http.Handler {
@@ -58,6 +91,7 @@ func LogRequests(h http.Handler) http.Handler {
                })
                ctx := req.Context()
                ctx = context.WithValue(ctx, &requestTimeContextKey, time.Now())
+               ctx = context.WithValue(ctx, &responseLogFieldsContextKey, logrus.Fields{})
                ctx = ctxlog.Context(ctx, lgr)
                req = req.WithContext(ctx)
 
@@ -101,6 +135,9 @@ func logResponse(w *responseTimer, req *http.Request, lgr *logrus.Entry) {
                        "timeWriteBody": stats.Duration(tDone.Sub(writeTime)),
                })
        }
+       if responseLogFields, ok := req.Context().Value(&responseLogFieldsContextKey).(logrus.Fields); ok {
+               lgr = lgr.WithFields(responseLogFields)
+       }
        respCode := w.WroteStatus()
        if respCode == 0 {
                respCode = http.StatusOK