13697: Cancel request context after API.RequestTimeout.
[arvados.git] / sdk / go / httpserver / logger_test.go
index 3b2bc7758069b44345b3da522b8f80cc303c52fe..b623aa4eea33894d3fe01cbfa1f0d54cf57d226b 100644 (file)
@@ -1,6 +1,6 @@
 // Copyright (C) The Arvados Authors. All rights reserved.
 //
-// SPDX-License-Identifier: AGPL-3.0
+// SPDX-License-Identifier: Apache-2.0
 
 package httpserver
 
@@ -8,12 +8,13 @@ import (
        "bytes"
        "context"
        "encoding/json"
+       "fmt"
        "net/http"
        "net/http/httptest"
        "testing"
        "time"
 
-       "git.curoverse.com/arvados.git/sdk/go/ctxlog"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "github.com/sirupsen/logrus"
        check "gopkg.in/check.v1"
 )
@@ -24,17 +25,52 @@ func Test(t *testing.T) {
 
 var _ = check.Suite(&Suite{})
 
-type Suite struct{}
+type Suite struct {
+       ctx     context.Context
+       log     *logrus.Logger
+       logdata *bytes.Buffer
+}
 
-func (s *Suite) TestLogRequests(c *check.C) {
-       captured := &bytes.Buffer{}
-       log := logrus.New()
-       log.Out = captured
-       log.Formatter = &logrus.JSONFormatter{
+func (s *Suite) SetUpTest(c *check.C) {
+       s.logdata = bytes.NewBuffer(nil)
+       s.log = logrus.New()
+       s.log.Out = s.logdata
+       s.log.Formatter = &logrus.JSONFormatter{
                TimestampFormat: time.RFC3339Nano,
        }
-       ctx := ctxlog.Context(context.Background(), log)
+       s.ctx = ctxlog.Context(context.Background(), s.log)
+}
+
+func (s *Suite) TestWithDeadline(c *check.C) {
+       req, err := http.NewRequest("GET", "https://foo.example/bar", nil)
+       c.Assert(err, check.IsNil)
+
+       // Short timeout cancels context in <1s
+       resp := httptest.NewRecorder()
+       HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               select {
+               case <-req.Context().Done():
+                       w.Write([]byte("ok"))
+               case <-time.After(time.Second):
+                       c.Error("timed out")
+               }
+       })).ServeHTTP(resp, req.WithContext(s.ctx))
+       c.Check(resp.Body.String(), check.Equals, "ok")
+
+       // Long timeout does not cancel context in <1ms
+       resp = httptest.NewRecorder()
+       HandlerWithDeadline(time.Second, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               select {
+               case <-req.Context().Done():
+                       c.Error("request context done too soon")
+               case <-time.After(time.Millisecond):
+                       w.Write([]byte("ok"))
+               }
+       })).ServeHTTP(resp, req.WithContext(s.ctx))
+       c.Check(resp.Body.String(), check.Equals, "ok")
+}
 
+func (s *Suite) TestLogRequests(c *check.C) {
        h := AddRequestIDs(LogRequests(
                http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
                        w.Write([]byte("hello world"))
@@ -45,12 +81,13 @@ func (s *Suite) TestLogRequests(c *check.C) {
        c.Assert(err, check.IsNil)
        resp := httptest.NewRecorder()
 
-       HandlerWithContext(ctx, h).ServeHTTP(resp, req)
+       h.ServeHTTP(resp, req.WithContext(s.ctx))
 
-       dec := json.NewDecoder(captured)
+       dec := json.NewDecoder(s.logdata)
 
        gotReq := make(map[string]interface{})
        err = dec.Decode(&gotReq)
+       c.Check(err, check.IsNil)
        c.Logf("%#v", gotReq)
        c.Check(gotReq["RequestID"], check.Matches, "req-[a-z0-9]{20}")
        c.Check(gotReq["reqForwardedFor"], check.Equals, "1.2.3.4:12345")
@@ -58,6 +95,7 @@ func (s *Suite) TestLogRequests(c *check.C) {
 
        gotResp := make(map[string]interface{})
        err = dec.Decode(&gotResp)
+       c.Check(err, check.IsNil)
        c.Logf("%#v", gotResp)
        c.Check(gotResp["RequestID"], check.Equals, gotReq["RequestID"])
        c.Check(gotResp["reqForwardedFor"], check.Equals, "1.2.3.4:12345")
@@ -72,3 +110,48 @@ func (s *Suite) TestLogRequests(c *check.C) {
                c.Check(gotResp[key].(float64), check.Not(check.Equals), float64(0))
        }
 }
+
+func (s *Suite) TestLogErrorBody(c *check.C) {
+       dec := json.NewDecoder(s.logdata)
+
+       for _, trial := range []struct {
+               label      string
+               statusCode int
+               sentBody   string
+               expectLog  bool
+               expectBody string
+       }{
+               {"ok", 200, "hello world", false, ""},
+               {"redir", 302, "<a href='http://foo.example/baz'>redir</a>", false, ""},
+               {"4xx short body", 400, "oops", true, "oops"},
+               {"4xx long body", 400, fmt.Sprintf("%0*d", sniffBytes*2, 1), true, fmt.Sprintf("%0*d", sniffBytes, 0)},
+               {"5xx empty body", 500, "", true, ""},
+       } {
+               comment := check.Commentf("in trial: %q", trial.label)
+
+               req, err := http.NewRequest("GET", "https://foo.example/bar", nil)
+               c.Assert(err, check.IsNil)
+               resp := httptest.NewRecorder()
+
+               LogRequests(
+                       http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+                               w.WriteHeader(trial.statusCode)
+                               w.Write([]byte(trial.sentBody))
+                       }),
+               ).ServeHTTP(resp, req.WithContext(s.ctx))
+
+               gotReq := make(map[string]interface{})
+               err = dec.Decode(&gotReq)
+               c.Check(err, check.IsNil)
+               c.Logf("%#v", gotReq)
+               gotResp := make(map[string]interface{})
+               err = dec.Decode(&gotResp)
+               c.Check(err, check.IsNil)
+               c.Logf("%#v", gotResp)
+               if trial.expectLog {
+                       c.Check(gotResp["respBody"], check.Equals, trial.expectBody, comment)
+               } else {
+                       c.Check(gotResp["respBody"], check.IsNil, comment)
+               }
+       }
+}