import (
"bytes"
+ "context"
"fmt"
"io/ioutil"
+ "math/rand"
"net/http"
+ "net/http/httptest"
"net/url"
"os"
"strings"
"sync"
"testing/iotest"
+ "time"
check "gopkg.in/check.v1"
)
c.Check(client.APIHost, check.Equals, "[::]:3")
c.Check(client.Insecure, check.Equals, false)
}
+
+var _ = check.Suite(&clientRetrySuite{})
+
+type clientRetrySuite struct {
+ server *httptest.Server
+ client Client
+ reqs []*http.Request
+ respStatus chan int
+ respDelay time.Duration
+
+ origLimiterQuietPeriod time.Duration
+}
+
+func (s *clientRetrySuite) SetUpTest(c *check.C) {
+ // Test server: delay and return errors until a final status
+ // appears on the respStatus channel.
+ s.origLimiterQuietPeriod = requestLimiterQuietPeriod
+ requestLimiterQuietPeriod = time.Second / 100
+ s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ s.reqs = append(s.reqs, r)
+ delay := s.respDelay
+ if delay == 0 {
+ delay = time.Duration(rand.Int63n(int64(time.Second / 10)))
+ }
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+ select {
+ case code, ok := <-s.respStatus:
+ if !ok {
+ code = http.StatusOK
+ }
+ w.WriteHeader(code)
+ w.Write([]byte(`{}`))
+ case <-timer.C:
+ w.WriteHeader(http.StatusServiceUnavailable)
+ }
+ }))
+ s.reqs = nil
+ s.respStatus = make(chan int, 1)
+ s.client = Client{
+ APIHost: s.server.URL[8:],
+ AuthToken: "zzz",
+ Insecure: true,
+ Timeout: 2 * time.Second,
+ }
+}
+
+func (s *clientRetrySuite) TearDownTest(c *check.C) {
+ s.server.Close()
+ requestLimiterQuietPeriod = s.origLimiterQuietPeriod
+}
+
+func (s *clientRetrySuite) TestOK(c *check.C) {
+ s.respStatus <- http.StatusOK
+ err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+ c.Check(err, check.IsNil)
+ c.Check(s.reqs, check.HasLen, 1)
+}
+
+func (s *clientRetrySuite) TestNetworkError(c *check.C) {
+ // Close the stub server to produce a "connection refused" error.
+ s.server.Close()
+
+ start := time.Now()
+ timeout := time.Second
+ ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout))
+ defer cancel()
+ s.client.Timeout = timeout * 2
+ err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
+ c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`)
+ delta := time.Since(start)
+ c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout))
+}
+
+func (s *clientRetrySuite) TestNonRetryableError(c *check.C) {
+ s.respStatus <- http.StatusBadRequest
+ err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+ c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`)
+ c.Check(s.reqs, check.HasLen, 1)
+}
+
+func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) {
+ time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound })
+ err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+ c.Check(err, check.ErrorMatches, `.*404 Not Found.*`)
+}
+
+func (s *clientRetrySuite) TestOKAfter503s(c *check.C) {
+ start := time.Now()
+ delay := time.Second
+ time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK })
+ err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+ c.Check(err, check.IsNil)
+ c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
+ c.Check(time.Since(start) > delay, check.Equals, true)
+}
+
+func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) {
+ s.respStatus <- http.StatusServiceUnavailable
+ s.respDelay = time.Second * 2
+ s.client.Timeout = time.Second / 2
+ err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+ c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
+ c.Check(s.reqs, check.HasLen, 2)
+}
+
+func (s *clientRetrySuite) Test503Forever(c *check.C) {
+ err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+ c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
+ c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
+}
+
+func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
+ c.Check(err, check.Equals, context.Canceled)
+}