13697: Do not apply RequestTimeout to hijacked connections.
authorTom Clegg <tom@curii.com>
Thu, 23 Sep 2021 14:46:42 +0000 (10:46 -0400)
committerTom Clegg <tom@curii.com>
Thu, 23 Sep 2021 15:12:21 +0000 (11:12 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

sdk/go/httpserver/logger.go
sdk/go/httpserver/logger_test.go

index a0ca6bf28d8d37daf2f8a0406d9ce15e7d4571e6..7eb7f0f03d57b571e314f8d87ca6714cf7d6563f 100644 (file)
@@ -5,7 +5,9 @@
 package httpserver
 
 import (
+       "bufio"
        "context"
+       "net"
        "net/http"
        "time"
 
@@ -22,12 +24,42 @@ var (
        requestTimeContextKey = contextKey{"requestTime"}
 )
 
+type hijacker interface {
+       http.ResponseWriter
+       http.Hijacker
+}
+
+// hijackNotifier wraps a ResponseWriter, calling the provided
+// Notify() func if/when the wrapped Hijacker is hijacked.
+type hijackNotifier struct {
+       hijacker
+       hijacked chan<- bool
+}
+
+func (hn hijackNotifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+       close(hn.hijacked)
+       return hn.hijacker.Hijack()
+}
+
 // HandlerWithDeadline cancels the request context if the request
-// takes longer than the specified timeout.
+// takes longer than the specified timeout without having its
+// connection hijacked.
 func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               ctx, cancel := context.WithDeadline(r.Context(), time.Now().Add(timeout))
+               ctx, cancel := context.WithCancel(r.Context())
                defer cancel()
+               nodeadline := make(chan bool)
+               go func() {
+                       select {
+                       case <-nodeadline:
+                       case <-ctx.Done():
+                       case <-time.After(timeout):
+                               cancel()
+                       }
+               }()
+               if hj, ok := w.(hijacker); ok {
+                       w = hijackNotifier{hj, nodeadline}
+               }
                next.ServeHTTP(w, r.WithContext(ctx))
        })
 }
index b623aa4eea33894d3fe01cbfa1f0d54cf57d226b..60768b3fc907681c8598a53ae5ff6f9985ef541f 100644 (file)
@@ -9,6 +9,8 @@ import (
        "context"
        "encoding/json"
        "fmt"
+       "io/ioutil"
+       "net"
        "net/http"
        "net/http/httptest"
        "testing"
@@ -70,6 +72,32 @@ func (s *Suite) TestWithDeadline(c *check.C) {
        c.Check(resp.Body.String(), check.Equals, "ok")
 }
 
+func (s *Suite) TestNoDeadlineAfterHijacked(c *check.C) {
+       srv := Server{
+               Addr: ":",
+               Server: http.Server{
+                       Handler: HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+                               conn, _, err := w.(http.Hijacker).Hijack()
+                               c.Assert(err, check.IsNil)
+                               defer conn.Close()
+                               select {
+                               case <-req.Context().Done():
+                                       c.Error("request context done too soon")
+                               case <-time.After(time.Second / 10):
+                                       conn.Write([]byte("HTTP/1.1 200 OK\r\n\r\nok"))
+                               }
+                       })),
+                       BaseContext: func(net.Listener) context.Context { return s.ctx },
+               },
+       }
+       srv.Start()
+       defer srv.Close()
+       resp, err := http.Get("http://" + srv.Addr)
+       c.Assert(err, check.IsNil)
+       body, err := ioutil.ReadAll(resp.Body)
+       c.Check(string(body), check.Equals, "ok")
+}
+
 func (s *Suite) TestLogRequests(c *check.C) {
        h := AddRequestIDs(LogRequests(
                http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {