20540: Ensure delay>0 at attempt>0, even if requested min=0.
[arvados.git] / sdk / go / arvados / client_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "bytes"
9         "context"
10         "fmt"
11         "io/ioutil"
12         "math"
13         "math/rand"
14         "net/http"
15         "net/http/httptest"
16         "net/url"
17         "os"
18         "strings"
19         "sync"
20         "testing/iotest"
21         "time"
22
23         check "gopkg.in/check.v1"
24 )
25
26 type stubTransport struct {
27         Responses map[string]string
28         Requests  []http.Request
29         sync.Mutex
30 }
31
32 func (stub *stubTransport) RoundTrip(req *http.Request) (*http.Response, error) {
33         stub.Lock()
34         stub.Requests = append(stub.Requests, *req)
35         stub.Unlock()
36
37         resp := &http.Response{
38                 Status:     "200 OK",
39                 StatusCode: 200,
40                 Proto:      "HTTP/1.1",
41                 ProtoMajor: 1,
42                 ProtoMinor: 1,
43                 Request:    req,
44         }
45         str := stub.Responses[req.URL.Path]
46         if str == "" {
47                 resp.Status = "404 Not Found"
48                 resp.StatusCode = 404
49                 str = "{}"
50         }
51         buf := bytes.NewBufferString(str)
52         resp.Body = ioutil.NopCloser(buf)
53         resp.ContentLength = int64(buf.Len())
54         return resp, nil
55 }
56
57 type errorTransport struct{}
58
59 func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
60         return nil, fmt.Errorf("something awful happened")
61 }
62
63 type timeoutTransport struct {
64         response []byte
65 }
66
67 func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
68         return &http.Response{
69                 Status:     "200 OK",
70                 StatusCode: 200,
71                 Proto:      "HTTP/1.1",
72                 ProtoMajor: 1,
73                 ProtoMinor: 1,
74                 Request:    req,
75                 Body:       ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))),
76         }, nil
77 }
78
79 var _ = check.Suite(&clientSuite{})
80
81 type clientSuite struct{}
82
83 func (*clientSuite) TestCurrentUser(c *check.C) {
84         stub := &stubTransport{
85                 Responses: map[string]string{
86                         "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`,
87                 },
88         }
89         client := &Client{
90                 Client: &http.Client{
91                         Transport: stub,
92                 },
93                 APIHost:   "zzzzz.arvadosapi.com",
94                 AuthToken: "xyzzy",
95         }
96         u, err := client.CurrentUser()
97         c.Check(err, check.IsNil)
98         c.Check(u.UUID, check.Equals, "zzzzz-abcde-012340123401234")
99         c.Check(stub.Requests, check.Not(check.HasLen), 0)
100         hdr := stub.Requests[len(stub.Requests)-1].Header
101         c.Check(hdr.Get("Authorization"), check.Equals, "OAuth2 xyzzy")
102
103         client.Client.Transport = &errorTransport{}
104         u, err = client.CurrentUser()
105         c.Check(err, check.NotNil)
106 }
107
108 func (*clientSuite) TestAnythingToValues(c *check.C) {
109         type testCase struct {
110                 in interface{}
111                 // ok==nil means anythingToValues should return an
112                 // error, otherwise it's a func that returns true if
113                 // out is correct
114                 ok func(out url.Values) bool
115         }
116         for _, tc := range []testCase{
117                 {
118                         in: map[string]interface{}{"foo": "bar"},
119                         ok: func(out url.Values) bool {
120                                 return out.Get("foo") == "bar"
121                         },
122                 },
123                 {
124                         in: map[string]interface{}{"foo": 2147483647},
125                         ok: func(out url.Values) bool {
126                                 return out.Get("foo") == "2147483647"
127                         },
128                 },
129                 {
130                         in: map[string]interface{}{"foo": 1.234},
131                         ok: func(out url.Values) bool {
132                                 return out.Get("foo") == "1.234"
133                         },
134                 },
135                 {
136                         in: map[string]interface{}{"foo": "1.234"},
137                         ok: func(out url.Values) bool {
138                                 return out.Get("foo") == "1.234"
139                         },
140                 },
141                 {
142                         in: map[string]interface{}{"foo": map[string]interface{}{"bar": 1.234}},
143                         ok: func(out url.Values) bool {
144                                 return out.Get("foo") == `{"bar":1.234}`
145                         },
146                 },
147                 {
148                         in: url.Values{"foo": {"bar"}},
149                         ok: func(out url.Values) bool {
150                                 return out.Get("foo") == "bar"
151                         },
152                 },
153                 {
154                         in: 1234,
155                         ok: nil,
156                 },
157                 {
158                         in: []string{"foo"},
159                         ok: nil,
160                 },
161         } {
162                 c.Logf("%#v", tc.in)
163                 out, err := anythingToValues(tc.in)
164                 if tc.ok == nil {
165                         c.Check(err, check.NotNil)
166                         continue
167                 }
168                 c.Check(err, check.IsNil)
169                 c.Check(tc.ok(out), check.Equals, true)
170         }
171 }
172
173 func (*clientSuite) TestLoadConfig(c *check.C) {
174         oldenv := os.Environ()
175         defer func() {
176                 os.Clearenv()
177                 for _, s := range oldenv {
178                         i := strings.IndexRune(s, '=')
179                         os.Setenv(s[:i], s[i+1:])
180                 }
181         }()
182
183         tmp := c.MkDir()
184         os.Setenv("HOME", tmp)
185         for _, s := range os.Environ() {
186                 if strings.HasPrefix(s, "ARVADOS_") {
187                         i := strings.IndexRune(s, '=')
188                         os.Unsetenv(s[:i])
189                 }
190         }
191         os.Mkdir(tmp+"/.config", 0777)
192         os.Mkdir(tmp+"/.config/arvados", 0777)
193
194         // Use $HOME/.config/arvados/settings.conf if no env vars are
195         // set
196         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
197                 ARVADOS_API_HOST = localhost:1
198                 ARVADOS_API_TOKEN = token_from_settings_file1
199         `), 0777)
200         client := NewClientFromEnv()
201         c.Check(client.AuthToken, check.Equals, "token_from_settings_file1")
202         c.Check(client.APIHost, check.Equals, "localhost:1")
203         c.Check(client.Insecure, check.Equals, false)
204
205         // ..._INSECURE=true, comments, ignored lines in settings.conf
206         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
207                 (ignored) = (ignored)
208                 #ARVADOS_API_HOST = localhost:2
209                 ARVADOS_API_TOKEN = token_from_settings_file2
210                 ARVADOS_API_HOST_INSECURE = true
211         `), 0777)
212         client = NewClientFromEnv()
213         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
214         c.Check(client.APIHost, check.Equals, "")
215         c.Check(client.Insecure, check.Equals, true)
216
217         // Environment variables override settings.conf
218         os.Setenv("ARVADOS_API_HOST", "[::]:3")
219         os.Setenv("ARVADOS_API_HOST_INSECURE", "0")
220         client = NewClientFromEnv()
221         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
222         c.Check(client.APIHost, check.Equals, "[::]:3")
223         c.Check(client.Insecure, check.Equals, false)
224 }
225
226 var _ = check.Suite(&clientRetrySuite{})
227
228 type clientRetrySuite struct {
229         server     *httptest.Server
230         client     Client
231         reqs       []*http.Request
232         respStatus chan int
233         respDelay  time.Duration
234
235         origLimiterQuietPeriod time.Duration
236 }
237
238 func (s *clientRetrySuite) SetUpTest(c *check.C) {
239         // Test server: delay and return errors until a final status
240         // appears on the respStatus channel.
241         s.origLimiterQuietPeriod = requestLimiterQuietPeriod
242         requestLimiterQuietPeriod = time.Second / 100
243         s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
244                 s.reqs = append(s.reqs, r)
245                 delay := s.respDelay
246                 if delay == 0 {
247                         delay = time.Duration(rand.Int63n(int64(time.Second / 10)))
248                 }
249                 timer := time.NewTimer(delay)
250                 defer timer.Stop()
251                 select {
252                 case code, ok := <-s.respStatus:
253                         if !ok {
254                                 code = http.StatusOK
255                         }
256                         w.WriteHeader(code)
257                         w.Write([]byte(`{}`))
258                 case <-timer.C:
259                         w.WriteHeader(http.StatusServiceUnavailable)
260                 }
261         }))
262         s.reqs = nil
263         s.respStatus = make(chan int, 1)
264         s.client = Client{
265                 APIHost:   s.server.URL[8:],
266                 AuthToken: "zzz",
267                 Insecure:  true,
268                 Timeout:   2 * time.Second,
269         }
270 }
271
272 func (s *clientRetrySuite) TearDownTest(c *check.C) {
273         s.server.Close()
274         requestLimiterQuietPeriod = s.origLimiterQuietPeriod
275 }
276
277 func (s *clientRetrySuite) TestOK(c *check.C) {
278         s.respStatus <- http.StatusOK
279         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
280         c.Check(err, check.IsNil)
281         c.Check(s.reqs, check.HasLen, 1)
282 }
283
284 func (s *clientRetrySuite) TestNetworkError(c *check.C) {
285         // Close the stub server to produce a "connection refused" error.
286         s.server.Close()
287
288         start := time.Now()
289         timeout := time.Second
290         ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout))
291         defer cancel()
292         s.client.Timeout = timeout * 2
293         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
294         c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`)
295         delta := time.Since(start)
296         c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout))
297 }
298
299 func (s *clientRetrySuite) TestNonRetryableError(c *check.C) {
300         s.respStatus <- http.StatusBadRequest
301         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
302         c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`)
303         c.Check(s.reqs, check.HasLen, 1)
304 }
305
306 func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) {
307         time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound })
308         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
309         c.Check(err, check.ErrorMatches, `.*404 Not Found.*`)
310 }
311
312 func (s *clientRetrySuite) TestOKAfter503s(c *check.C) {
313         start := time.Now()
314         delay := time.Second
315         time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK })
316         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
317         c.Check(err, check.IsNil)
318         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
319         c.Check(time.Since(start) > delay, check.Equals, true)
320 }
321
322 func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) {
323         s.respStatus <- http.StatusServiceUnavailable
324         s.respDelay = time.Second * 2
325         s.client.Timeout = time.Second / 2
326         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
327         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
328         c.Check(s.reqs, check.HasLen, 2)
329 }
330
331 func (s *clientRetrySuite) Test503Forever(c *check.C) {
332         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
333         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
334         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
335 }
336
337 func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) {
338         ctx, cancel := context.WithCancel(context.Background())
339         cancel()
340         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
341         c.Check(err, check.Equals, context.Canceled)
342 }
343
344 func (s *clientRetrySuite) TestExponentialBackoff(c *check.C) {
345         var min, max time.Duration
346         min, max = time.Second, 64*time.Second
347
348         t := exponentialBackoff(min, max, 0, nil)
349         c.Check(t, check.Equals, min)
350
351         for e := float64(1); e < 5; e += 1 {
352                 ok := false
353                 for i := 0; i < 20; i++ {
354                         t = exponentialBackoff(min, max, int(e), nil)
355                         // Every returned value must be between min and min(2^e, max)
356                         c.Check(t >= min, check.Equals, true)
357                         c.Check(t <= min*time.Duration(math.Pow(2, e)), check.Equals, true)
358                         c.Check(t <= max, check.Equals, true)
359                         // Check that jitter is actually happening by
360                         // checking that at least one in 20 trials is
361                         // between min*2^(e-.75) and min*2^(e-.25)
362                         jittermin := time.Duration(float64(min) * math.Pow(2, e-0.75))
363                         jittermax := time.Duration(float64(min) * math.Pow(2, e-0.25))
364                         c.Logf("min %v max %v e %v jittermin %v jittermax %v t %v", min, max, e, jittermin, jittermax, t)
365                         if t > jittermin && t < jittermax {
366                                 ok = true
367                                 break
368                         }
369                 }
370                 c.Check(ok, check.Equals, true)
371         }
372
373         for i := 0; i < 20; i++ {
374                 t := exponentialBackoff(min, max, 100, nil)
375                 c.Check(t < max, check.Equals, true)
376         }
377
378         for _, trial := range []struct {
379                 retryAfter string
380                 expect     time.Duration
381         }{
382                 {"1", time.Second * 4},             // minimum enforced
383                 {"5", time.Second * 5},             // header used
384                 {"55", time.Second * 10},           // maximum enforced
385                 {"eleventy-nine", time.Second * 4}, // invalid header, exponential backoff used
386                 {time.Now().UTC().Add(time.Second).Format(time.RFC1123), time.Second * 4},  // minimum enforced
387                 {time.Now().UTC().Add(time.Minute).Format(time.RFC1123), time.Second * 10}, // maximum enforced
388                 {time.Now().UTC().Add(-time.Minute).Format(time.RFC1123), time.Second * 4}, // minimum enforced
389         } {
390                 c.Logf("trial %+v", trial)
391                 t := exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
392                         StatusCode: http.StatusTooManyRequests,
393                         Header:     http.Header{"Retry-After": {trial.retryAfter}}})
394                 c.Check(t, check.Equals, trial.expect)
395         }
396         t = exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
397                 StatusCode: http.StatusTooManyRequests,
398         })
399         c.Check(t, check.Equals, time.Second*4)
400
401         t = exponentialBackoff(0, max, 0, nil)
402         c.Check(t, check.Equals, time.Duration(0))
403         t = exponentialBackoff(0, max, 1, nil)
404         c.Check(t, check.Not(check.Equals), time.Duration(0))
405 }