Merge branch '18184-singularity-374'
[arvados.git] / sdk / go / httpserver / logger.go
index 59981e3e55265be4eed1827d3570391533ac3a30..7eb7f0f03d57b571e314f8d87ca6714cf7d6563f 100644 (file)
@@ -1,11 +1,13 @@
 // Copyright (C) The Arvados Authors. All rights reserved.
 //
-// SPDX-License-Identifier: AGPL-3.0
+// SPDX-License-Identifier: Apache-2.0
 
 package httpserver
 
 import (
+       "bufio"
        "context"
+       "net"
        "net/http"
        "time"
 
@@ -22,11 +24,42 @@ var (
        requestTimeContextKey = contextKey{"requestTime"}
 )
 
-// HandlerWithContext returns an http.Handler that changes the request
-// context to ctx (replacing http.Server's default
-// context.Background()), then calls next.
-func HandlerWithContext(ctx context.Context, next http.Handler) http.Handler {
+type hijacker interface {
+       http.ResponseWriter
+       http.Hijacker
+}
+
+// hijackNotifier wraps a ResponseWriter, calling the provided
+// Notify() func if/when the wrapped Hijacker is hijacked.
+type hijackNotifier struct {
+       hijacker
+       hijacked chan<- bool
+}
+
+func (hn hijackNotifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+       close(hn.hijacked)
+       return hn.hijacker.Hijack()
+}
+
+// HandlerWithDeadline cancels the request context if the request
+// takes longer than the specified timeout without having its
+// connection hijacked.
+func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               ctx, cancel := context.WithCancel(r.Context())
+               defer cancel()
+               nodeadline := make(chan bool)
+               go func() {
+                       select {
+                       case <-nodeadline:
+                       case <-ctx.Done():
+                       case <-time.After(timeout):
+                               cancel()
+                       }
+               }()
+               if hj, ok := w.(hijacker); ok {
+                       w = hijackNotifier{hj, nodeadline}
+               }
                next.ServeHTTP(w, r.WithContext(ctx))
        })
 }
@@ -64,9 +97,8 @@ func rewrapResponseWriter(w http.ResponseWriter, wrapped http.ResponseWriter) ht
                        http.ResponseWriter
                        http.Hijacker
                }{w, hijacker}
-       } else {
-               return w
        }
+       return w
 }
 
 func Logger(req *http.Request) logrus.FieldLogger {