1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: Apache-2.0
23 check "gopkg.in/check.v1"
26 type stubTransport struct {
27 Responses map[string]string
28 Requests []http.Request
32 func (stub *stubTransport) RoundTrip(req *http.Request) (*http.Response, error) {
34 stub.Requests = append(stub.Requests, *req)
37 resp := &http.Response{
45 str := stub.Responses[req.URL.Path]
47 resp.Status = "404 Not Found"
51 buf := bytes.NewBufferString(str)
52 resp.Body = ioutil.NopCloser(buf)
53 resp.ContentLength = int64(buf.Len())
57 type errorTransport struct{}
59 func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
60 return nil, fmt.Errorf("something awful happened")
63 type timeoutTransport struct {
67 func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
68 return &http.Response{
75 Body: ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))),
79 var _ = check.Suite(&clientSuite{})
81 type clientSuite struct{}
83 func (*clientSuite) TestCurrentUser(c *check.C) {
84 stub := &stubTransport{
85 Responses: map[string]string{
86 "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`,
93 APIHost: "zzzzz.arvadosapi.com",
96 u, err := client.CurrentUser()
97 c.Check(err, check.IsNil)
98 c.Check(u.UUID, check.Equals, "zzzzz-abcde-012340123401234")
99 c.Check(stub.Requests, check.Not(check.HasLen), 0)
100 hdr := stub.Requests[len(stub.Requests)-1].Header
101 c.Check(hdr.Get("Authorization"), check.Equals, "Bearer xyzzy")
103 client.Client.Transport = &errorTransport{}
104 u, err = client.CurrentUser()
105 c.Check(err, check.NotNil)
108 func (*clientSuite) TestAnythingToValues(c *check.C) {
109 type testCase struct {
111 // ok==nil means anythingToValues should return an
112 // error, otherwise it's a func that returns true if
114 ok func(out url.Values) bool
116 for _, tc := range []testCase{
118 in: map[string]interface{}{"foo": "bar"},
119 ok: func(out url.Values) bool {
120 return out.Get("foo") == "bar"
124 in: map[string]interface{}{"foo": 2147483647},
125 ok: func(out url.Values) bool {
126 return out.Get("foo") == "2147483647"
130 in: map[string]interface{}{"foo": 1.234},
131 ok: func(out url.Values) bool {
132 return out.Get("foo") == "1.234"
136 in: map[string]interface{}{"foo": "1.234"},
137 ok: func(out url.Values) bool {
138 return out.Get("foo") == "1.234"
142 in: map[string]interface{}{"foo": map[string]interface{}{"bar": 1.234}},
143 ok: func(out url.Values) bool {
144 return out.Get("foo") == `{"bar":1.234}`
148 in: url.Values{"foo": {"bar"}},
149 ok: func(out url.Values) bool {
150 return out.Get("foo") == "bar"
163 out, err := anythingToValues(tc.in)
165 c.Check(err, check.NotNil)
168 c.Check(err, check.IsNil)
169 c.Check(tc.ok(out), check.Equals, true)
173 // select=["uuid"] is added automatically when RequestAndDecode's
174 // destination argument is nil.
175 func (*clientSuite) TestAutoSelectUUID(c *check.C) {
176 var req *http.Request
178 server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
179 c.Check(r.ParseForm(), check.IsNil)
181 w.Write([]byte("{}"))
184 APIHost: strings.TrimPrefix(server.URL, "https://"),
187 Timeout: 2 * time.Second,
191 err = client.RequestAndDecode(nil, http.MethodPost, "test", nil, nil)
192 c.Check(err, check.IsNil)
193 c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
196 err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, nil)
197 c.Check(err, check.IsNil)
198 c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
201 err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
202 c.Check(err, check.IsNil)
203 c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
206 err = client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
207 c.Check(err, check.IsNil)
208 c.Check(req.FormValue("select"), check.Equals, `["blergh"]`)
211 func (*clientSuite) TestLoadConfig(c *check.C) {
212 oldenv := os.Environ()
215 for _, s := range oldenv {
216 i := strings.IndexRune(s, '=')
217 os.Setenv(s[:i], s[i+1:])
222 os.Setenv("HOME", tmp)
223 for _, s := range os.Environ() {
224 if strings.HasPrefix(s, "ARVADOS_") {
225 i := strings.IndexRune(s, '=')
229 os.Mkdir(tmp+"/.config", 0777)
230 os.Mkdir(tmp+"/.config/arvados", 0777)
232 // Use $HOME/.config/arvados/settings.conf if no env vars are
234 os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
235 ARVADOS_API_HOST = localhost:1
236 ARVADOS_API_TOKEN = token_from_settings_file1
238 client := NewClientFromEnv()
239 c.Check(client.AuthToken, check.Equals, "token_from_settings_file1")
240 c.Check(client.APIHost, check.Equals, "localhost:1")
241 c.Check(client.Insecure, check.Equals, false)
243 // ..._INSECURE=true, comments, ignored lines in settings.conf
244 os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
245 (ignored) = (ignored)
246 #ARVADOS_API_HOST = localhost:2
247 ARVADOS_API_TOKEN = token_from_settings_file2
248 ARVADOS_API_HOST_INSECURE = true
250 client = NewClientFromEnv()
251 c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
252 c.Check(client.APIHost, check.Equals, "")
253 c.Check(client.Insecure, check.Equals, true)
255 // Environment variables override settings.conf
256 os.Setenv("ARVADOS_API_HOST", "[::]:3")
257 os.Setenv("ARVADOS_API_HOST_INSECURE", "0")
258 os.Setenv("ARVADOS_KEEP_SERVICES", "http://[::]:12345")
259 client = NewClientFromEnv()
260 c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
261 c.Check(client.APIHost, check.Equals, "[::]:3")
262 c.Check(client.Insecure, check.Equals, false)
263 c.Check(client.KeepServiceURIs, check.DeepEquals, []string{"http://[::]:12345"})
265 // ARVADOS_KEEP_SERVICES environment variable overrides
266 // cluster config, but ARVADOS_API_HOST/TOKEN do not.
267 os.Setenv("ARVADOS_KEEP_SERVICES", "http://[::]:12345")
268 os.Setenv("ARVADOS_API_HOST", "wronghost.example")
269 os.Setenv("ARVADOS_API_TOKEN", "wrongtoken")
271 cfg.Services.Controller.ExternalURL = URL{Scheme: "https", Host: "ctrl.example:55555", Path: "/"}
272 cfg.Services.Keepstore.InternalURLs = map[URL]ServiceInstance{
273 URL{Scheme: "https", Host: "keep0.example:55555", Path: "/"}: ServiceInstance{},
275 client, err := NewClientFromConfig(&cfg)
276 c.Check(err, check.IsNil)
277 c.Check(client.AuthToken, check.Equals, "")
278 c.Check(client.APIHost, check.Equals, "ctrl.example:55555")
279 c.Check(client.Insecure, check.Equals, false)
280 c.Check(client.KeepServiceURIs, check.DeepEquals, []string{"http://[::]:12345"})
283 var _ = check.Suite(&clientRetrySuite{})
285 type clientRetrySuite struct {
286 server *httptest.Server
290 respDelay time.Duration
292 origLimiterQuietPeriod time.Duration
295 func (s *clientRetrySuite) SetUpTest(c *check.C) {
296 // Test server: delay and return errors until a final status
297 // appears on the respStatus channel.
298 s.origLimiterQuietPeriod = requestLimiterQuietPeriod
299 requestLimiterQuietPeriod = time.Second / 100
300 s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
301 s.reqs = append(s.reqs, r)
304 delay = time.Duration(rand.Int63n(int64(time.Second / 10)))
306 timer := time.NewTimer(delay)
309 case code, ok := <-s.respStatus:
314 w.Write([]byte(`{}`))
316 w.WriteHeader(http.StatusServiceUnavailable)
320 s.respStatus = make(chan int, 1)
322 APIHost: s.server.URL[8:],
325 Timeout: 2 * time.Second,
329 func (s *clientRetrySuite) TearDownTest(c *check.C) {
331 requestLimiterQuietPeriod = s.origLimiterQuietPeriod
334 func (s *clientRetrySuite) TestOK(c *check.C) {
335 s.respStatus <- http.StatusOK
336 err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
337 c.Check(err, check.IsNil)
338 c.Check(s.reqs, check.HasLen, 1)
341 func (s *clientRetrySuite) TestNetworkError(c *check.C) {
342 // Close the stub server to produce a "connection refused" error.
346 timeout := time.Second
347 ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout))
349 s.client.Timeout = timeout * 2
350 err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
351 c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`)
352 delta := time.Since(start)
353 c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout))
356 func (s *clientRetrySuite) TestNonRetryableError(c *check.C) {
357 s.respStatus <- http.StatusBadRequest
358 err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
359 c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`)
360 c.Check(s.reqs, check.HasLen, 1)
363 // as of 0.7.2., retryablehttp does not recognize this as a
364 // non-retryable error.
365 func (s *clientRetrySuite) TestNonRetryableStdlibError(c *check.C) {
366 s.respStatus <- http.StatusOK
367 req, err := http.NewRequest(http.MethodGet, "https://"+s.client.APIHost+"/test", nil)
368 c.Assert(err, check.IsNil)
369 req.Header.Set("Good-Header", "T\033rrible header value")
370 err = s.client.DoAndDecode(&struct{}{}, req)
371 c.Check(err, check.ErrorMatches, `.*after 1 attempt.*net/http: invalid header .*`)
372 if !c.Check(s.reqs, check.HasLen, 0) {
373 c.Logf("%v", s.reqs[0])
377 func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) {
378 time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound })
379 err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
380 c.Check(err, check.ErrorMatches, `.*404 Not Found.*`)
383 func (s *clientRetrySuite) TestOKAfter503s(c *check.C) {
386 time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK })
387 err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
388 c.Check(err, check.IsNil)
389 c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
390 c.Check(time.Since(start) > delay, check.Equals, true)
393 func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) {
394 s.respStatus <- http.StatusServiceUnavailable
395 s.respDelay = time.Second * 2
396 s.client.Timeout = time.Second / 2
397 err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
398 c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
399 c.Check(s.reqs, check.HasLen, 2)
402 func (s *clientRetrySuite) Test503Forever(c *check.C) {
403 err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
404 c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
405 c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
408 func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) {
409 ctx, cancel := context.WithCancel(context.Background())
411 err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
412 c.Check(err, check.Equals, context.Canceled)
415 func (s *clientRetrySuite) TestExponentialBackoff(c *check.C) {
416 var min, max time.Duration
417 min, max = time.Second, 64*time.Second
419 t := exponentialBackoff(min, max, 0, nil)
420 c.Check(t, check.Equals, min)
422 for e := float64(1); e < 5; e += 1 {
424 for i := 0; i < 30; i++ {
425 t = exponentialBackoff(min, max, int(e), nil)
426 // Every returned value must be between min and min(2^e, max)
427 c.Check(t >= min, check.Equals, true)
428 c.Check(t <= min*time.Duration(math.Pow(2, e)), check.Equals, true)
429 c.Check(t <= max, check.Equals, true)
430 // Check that jitter is actually happening by
431 // checking that at least one in 20 trials is
432 // between min*2^(e-.75) and min*2^(e-.25)
433 jittermin := time.Duration(float64(min) * math.Pow(2, e-0.75))
434 jittermax := time.Duration(float64(min) * math.Pow(2, e-0.25))
435 c.Logf("min %v max %v e %v jittermin %v jittermax %v t %v", min, max, e, jittermin, jittermax, t)
436 if t > jittermin && t < jittermax {
441 c.Check(ok, check.Equals, true)
444 for i := 0; i < 20; i++ {
445 t := exponentialBackoff(min, max, 100, nil)
446 c.Check(t < max, check.Equals, true)
449 for _, trial := range []struct {
453 {"1", time.Second * 4}, // minimum enforced
454 {"5", time.Second * 5}, // header used
455 {"55", time.Second * 10}, // maximum enforced
456 {"eleventy-nine", time.Second * 4}, // invalid header, exponential backoff used
457 {time.Now().UTC().Add(time.Second).Format(time.RFC1123), time.Second * 4}, // minimum enforced
458 {time.Now().UTC().Add(time.Minute).Format(time.RFC1123), time.Second * 10}, // maximum enforced
459 {time.Now().UTC().Add(-time.Minute).Format(time.RFC1123), time.Second * 4}, // minimum enforced
461 c.Logf("trial %+v", trial)
462 t := exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
463 StatusCode: http.StatusTooManyRequests,
464 Header: http.Header{"Retry-After": {trial.retryAfter}}})
465 c.Check(t, check.Equals, trial.expect)
467 t = exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
468 StatusCode: http.StatusTooManyRequests,
470 c.Check(t, check.Equals, time.Second*4)
472 t = exponentialBackoff(0, max, 0, nil)
473 c.Check(t, check.Equals, time.Duration(0))
474 t = exponentialBackoff(0, max, 1, nil)
475 c.Check(t, check.Not(check.Equals), time.Duration(0))