X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/7db3ceda16742b65d73ebbc05d02351a5e0496bd..refs/heads/18858-sync-users-tool:/sdk/go/httpserver/logger_test.go diff --git a/sdk/go/httpserver/logger_test.go b/sdk/go/httpserver/logger_test.go index eb71fcd814..60768b3fc9 100644 --- a/sdk/go/httpserver/logger_test.go +++ b/sdk/go/httpserver/logger_test.go @@ -1,6 +1,6 @@ // Copyright (C) The Arvados Authors. All rights reserved. // -// SPDX-License-Identifier: AGPL-3.0 +// SPDX-License-Identifier: Apache-2.0 package httpserver @@ -9,12 +9,14 @@ import ( "context" "encoding/json" "fmt" + "io/ioutil" + "net" "net/http" "net/http/httptest" "testing" "time" - "git.curoverse.com/arvados.git/sdk/go/ctxlog" + "git.arvados.org/arvados.git/sdk/go/ctxlog" "github.com/sirupsen/logrus" check "gopkg.in/check.v1" ) @@ -41,6 +43,61 @@ func (s *Suite) SetUpTest(c *check.C) { s.ctx = ctxlog.Context(context.Background(), s.log) } +func (s *Suite) TestWithDeadline(c *check.C) { + req, err := http.NewRequest("GET", "https://foo.example/bar", nil) + c.Assert(err, check.IsNil) + + // Short timeout cancels context in <1s + resp := httptest.NewRecorder() + HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + select { + case <-req.Context().Done(): + w.Write([]byte("ok")) + case <-time.After(time.Second): + c.Error("timed out") + } + })).ServeHTTP(resp, req.WithContext(s.ctx)) + c.Check(resp.Body.String(), check.Equals, "ok") + + // Long timeout does not cancel context in <1ms + resp = httptest.NewRecorder() + HandlerWithDeadline(time.Second, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + select { + case <-req.Context().Done(): + c.Error("request context done too soon") + case <-time.After(time.Millisecond): + w.Write([]byte("ok")) + } + })).ServeHTTP(resp, req.WithContext(s.ctx)) + c.Check(resp.Body.String(), check.Equals, "ok") +} + +func (s *Suite) TestNoDeadlineAfterHijacked(c *check.C) { + srv := Server{ + Addr: ":", + Server: http.Server{ + Handler: HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn, _, err := w.(http.Hijacker).Hijack() + c.Assert(err, check.IsNil) + defer conn.Close() + select { + case <-req.Context().Done(): + c.Error("request context done too soon") + case <-time.After(time.Second / 10): + conn.Write([]byte("HTTP/1.1 200 OK\r\n\r\nok")) + } + })), + BaseContext: func(net.Listener) context.Context { return s.ctx }, + }, + } + srv.Start() + defer srv.Close() + resp, err := http.Get("http://" + srv.Addr) + c.Assert(err, check.IsNil) + body, err := ioutil.ReadAll(resp.Body) + c.Check(string(body), check.Equals, "ok") +} + func (s *Suite) TestLogRequests(c *check.C) { h := AddRequestIDs(LogRequests( http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { @@ -52,12 +109,13 @@ func (s *Suite) TestLogRequests(c *check.C) { c.Assert(err, check.IsNil) resp := httptest.NewRecorder() - HandlerWithContext(s.ctx, h).ServeHTTP(resp, req) + h.ServeHTTP(resp, req.WithContext(s.ctx)) dec := json.NewDecoder(s.logdata) gotReq := make(map[string]interface{}) err = dec.Decode(&gotReq) + c.Check(err, check.IsNil) c.Logf("%#v", gotReq) c.Check(gotReq["RequestID"], check.Matches, "req-[a-z0-9]{20}") c.Check(gotReq["reqForwardedFor"], check.Equals, "1.2.3.4:12345") @@ -65,6 +123,7 @@ func (s *Suite) TestLogRequests(c *check.C) { gotResp := make(map[string]interface{}) err = dec.Decode(&gotResp) + c.Check(err, check.IsNil) c.Logf("%#v", gotResp) c.Check(gotResp["RequestID"], check.Equals, gotReq["RequestID"]) c.Check(gotResp["reqForwardedFor"], check.Equals, "1.2.3.4:12345") @@ -102,18 +161,20 @@ func (s *Suite) TestLogErrorBody(c *check.C) { c.Assert(err, check.IsNil) resp := httptest.NewRecorder() - HandlerWithContext(s.ctx, LogRequests( + LogRequests( http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(trial.statusCode) w.Write([]byte(trial.sentBody)) }), - )).ServeHTTP(resp, req) + ).ServeHTTP(resp, req.WithContext(s.ctx)) gotReq := make(map[string]interface{}) err = dec.Decode(&gotReq) + c.Check(err, check.IsNil) c.Logf("%#v", gotReq) gotResp := make(map[string]interface{}) err = dec.Decode(&gotResp) + c.Check(err, check.IsNil) c.Logf("%#v", gotResp) if trial.expectLog { c.Check(gotResp["respBody"], check.Equals, trial.expectBody, comment)