15370: Fix flaky test.
[arvados.git] / sdk / go / httpserver / logger_test.go
index 7d5eb2b64f326e2cd10b41bbeacd147a2bff682a..60768b3fc907681c8598a53ae5ff6f9985ef541f 100644 (file)
@@ -1,6 +1,6 @@
 // Copyright (C) The Arvados Authors. All rights reserved.
 //
-// SPDX-License-Identifier: AGPL-3.0
+// SPDX-License-Identifier: Apache-2.0
 
 package httpserver
 
@@ -9,6 +9,8 @@ import (
        "context"
        "encoding/json"
        "fmt"
+       "io/ioutil"
+       "net"
        "net/http"
        "net/http/httptest"
        "testing"
@@ -41,6 +43,61 @@ func (s *Suite) SetUpTest(c *check.C) {
        s.ctx = ctxlog.Context(context.Background(), s.log)
 }
 
+func (s *Suite) TestWithDeadline(c *check.C) {
+       req, err := http.NewRequest("GET", "https://foo.example/bar", nil)
+       c.Assert(err, check.IsNil)
+
+       // Short timeout cancels context in <1s
+       resp := httptest.NewRecorder()
+       HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               select {
+               case <-req.Context().Done():
+                       w.Write([]byte("ok"))
+               case <-time.After(time.Second):
+                       c.Error("timed out")
+               }
+       })).ServeHTTP(resp, req.WithContext(s.ctx))
+       c.Check(resp.Body.String(), check.Equals, "ok")
+
+       // Long timeout does not cancel context in <1ms
+       resp = httptest.NewRecorder()
+       HandlerWithDeadline(time.Second, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               select {
+               case <-req.Context().Done():
+                       c.Error("request context done too soon")
+               case <-time.After(time.Millisecond):
+                       w.Write([]byte("ok"))
+               }
+       })).ServeHTTP(resp, req.WithContext(s.ctx))
+       c.Check(resp.Body.String(), check.Equals, "ok")
+}
+
+func (s *Suite) TestNoDeadlineAfterHijacked(c *check.C) {
+       srv := Server{
+               Addr: ":",
+               Server: http.Server{
+                       Handler: HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+                               conn, _, err := w.(http.Hijacker).Hijack()
+                               c.Assert(err, check.IsNil)
+                               defer conn.Close()
+                               select {
+                               case <-req.Context().Done():
+                                       c.Error("request context done too soon")
+                               case <-time.After(time.Second / 10):
+                                       conn.Write([]byte("HTTP/1.1 200 OK\r\n\r\nok"))
+                               }
+                       })),
+                       BaseContext: func(net.Listener) context.Context { return s.ctx },
+               },
+       }
+       srv.Start()
+       defer srv.Close()
+       resp, err := http.Get("http://" + srv.Addr)
+       c.Assert(err, check.IsNil)
+       body, err := ioutil.ReadAll(resp.Body)
+       c.Check(string(body), check.Equals, "ok")
+}
+
 func (s *Suite) TestLogRequests(c *check.C) {
        h := AddRequestIDs(LogRequests(
                http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
@@ -52,7 +109,7 @@ func (s *Suite) TestLogRequests(c *check.C) {
        c.Assert(err, check.IsNil)
        resp := httptest.NewRecorder()
 
-       HandlerWithContext(s.ctx, h).ServeHTTP(resp, req)
+       h.ServeHTTP(resp, req.WithContext(s.ctx))
 
        dec := json.NewDecoder(s.logdata)
 
@@ -104,12 +161,12 @@ func (s *Suite) TestLogErrorBody(c *check.C) {
                c.Assert(err, check.IsNil)
                resp := httptest.NewRecorder()
 
-               HandlerWithContext(s.ctx, LogRequests(
+               LogRequests(
                        http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
                                w.WriteHeader(trial.statusCode)
                                w.Write([]byte(trial.sentBody))
                        }),
-               )).ServeHTTP(resp, req)
+               ).ServeHTTP(resp, req.WithContext(s.ctx))
 
                gotReq := make(map[string]interface{})
                err = dec.Decode(&gotReq)