X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/257d60253246952b435cea23b1912af80ea2c6d6..d6446b03e2f5d5079a870bdd7b963456dc12b485:/sdk/go/arvados/client_test.go diff --git a/sdk/go/arvados/client_test.go b/sdk/go/arvados/client_test.go index 5011aa81f6..a196003b8f 100644 --- a/sdk/go/arvados/client_test.go +++ b/sdk/go/arvados/client_test.go @@ -1,13 +1,26 @@ +// Copyright (C) The Arvados Authors. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + package arvados import ( "bytes" + "context" "fmt" "io/ioutil" + "math" + "math/rand" "net/http" + "net/http/httptest" "net/url" + "os" + "strings" "sync" - "testing" + "testing/iotest" + "time" + + check "gopkg.in/check.v1" ) type stubTransport struct { @@ -47,43 +60,52 @@ func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error) return nil, fmt.Errorf("something awful happened") } -func TestCurrentUser(t *testing.T) { - t.Parallel() +type timeoutTransport struct { + response []byte +} + +func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: req, + Body: ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))), + }, nil +} + +var _ = check.Suite(&clientSuite{}) + +type clientSuite struct{} + +func (*clientSuite) TestCurrentUser(c *check.C) { stub := &stubTransport{ Responses: map[string]string{ "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`, }, } - c := &Client{ + client := &Client{ Client: &http.Client{ Transport: stub, }, APIHost: "zzzzz.arvadosapi.com", AuthToken: "xyzzy", } - u, err := c.CurrentUser() - if err != nil { - t.Fatal(err) - } - if x := "zzzzz-abcde-012340123401234"; u.UUID != x { - t.Errorf("got uuid %q, expected %q", u.UUID, x) - } - if len(stub.Requests) < 1 { - t.Fatal("empty stub.Requests") - } + u, err := client.CurrentUser() + c.Check(err, check.IsNil) + c.Check(u.UUID, check.Equals, "zzzzz-abcde-012340123401234") + c.Check(stub.Requests, check.Not(check.HasLen), 0) hdr := stub.Requests[len(stub.Requests)-1].Header - if hdr.Get("Authorization") != "OAuth2 xyzzy" { - t.Errorf("got headers %+q, expected Authorization header", hdr) - } + c.Check(hdr.Get("Authorization"), check.Equals, "OAuth2 xyzzy") - c.Client.Transport = &errorTransport{} - u, err = c.CurrentUser() - if err == nil { - t.Errorf("got nil error, expected something awful") - } + client.Client.Transport = &errorTransport{} + u, err = client.CurrentUser() + c.Check(err, check.NotNil) } -func TestAnythingToValues(t *testing.T) { +func (*clientSuite) TestAnythingToValues(c *check.C) { type testCase struct { in interface{} // ok==nil means anythingToValues should return an @@ -137,17 +159,285 @@ func TestAnythingToValues(t *testing.T) { ok: nil, }, } { - t.Logf("%#v", tc.in) + c.Logf("%#v", tc.in) out, err := anythingToValues(tc.in) - switch { - case tc.ok == nil: - if err == nil { - t.Errorf("got %#v, expected error", out) + if tc.ok == nil { + c.Check(err, check.NotNil) + continue + } + c.Check(err, check.IsNil) + c.Check(tc.ok(out), check.Equals, true) + } +} + +// select=["uuid"] is added automatically when RequestAndDecode's +// destination argument is nil. +func (*clientSuite) TestAutoSelectUUID(c *check.C) { + var req *http.Request + var err error + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Check(r.ParseForm(), check.IsNil) + req = r + w.Write([]byte("{}")) + })) + client := Client{ + APIHost: strings.TrimPrefix(server.URL, "https://"), + AuthToken: "zzz", + Insecure: true, + Timeout: 2 * time.Second, + } + + req = nil + err = client.RequestAndDecode(nil, http.MethodPost, "test", nil, nil) + c.Check(err, check.IsNil) + c.Check(req.FormValue("select"), check.Equals, `["uuid"]`) + + req = nil + err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, nil) + c.Check(err, check.IsNil) + c.Check(req.FormValue("select"), check.Equals, `["uuid"]`) + + req = nil + err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}}) + c.Check(err, check.IsNil) + c.Check(req.FormValue("select"), check.Equals, `["uuid"]`) + + req = nil + err = client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}}) + c.Check(err, check.IsNil) + c.Check(req.FormValue("select"), check.Equals, `["blergh"]`) +} + +func (*clientSuite) TestLoadConfig(c *check.C) { + oldenv := os.Environ() + defer func() { + os.Clearenv() + for _, s := range oldenv { + i := strings.IndexRune(s, '=') + os.Setenv(s[:i], s[i+1:]) + } + }() + + tmp := c.MkDir() + os.Setenv("HOME", tmp) + for _, s := range os.Environ() { + if strings.HasPrefix(s, "ARVADOS_") { + i := strings.IndexRune(s, '=') + os.Unsetenv(s[:i]) + } + } + os.Mkdir(tmp+"/.config", 0777) + os.Mkdir(tmp+"/.config/arvados", 0777) + + // Use $HOME/.config/arvados/settings.conf if no env vars are + // set + os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(` + ARVADOS_API_HOST = localhost:1 + ARVADOS_API_TOKEN = token_from_settings_file1 + `), 0777) + client := NewClientFromEnv() + c.Check(client.AuthToken, check.Equals, "token_from_settings_file1") + c.Check(client.APIHost, check.Equals, "localhost:1") + c.Check(client.Insecure, check.Equals, false) + + // ..._INSECURE=true, comments, ignored lines in settings.conf + os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(` + (ignored) = (ignored) + #ARVADOS_API_HOST = localhost:2 + ARVADOS_API_TOKEN = token_from_settings_file2 + ARVADOS_API_HOST_INSECURE = true + `), 0777) + client = NewClientFromEnv() + c.Check(client.AuthToken, check.Equals, "token_from_settings_file2") + c.Check(client.APIHost, check.Equals, "") + c.Check(client.Insecure, check.Equals, true) + + // Environment variables override settings.conf + os.Setenv("ARVADOS_API_HOST", "[::]:3") + os.Setenv("ARVADOS_API_HOST_INSECURE", "0") + client = NewClientFromEnv() + c.Check(client.AuthToken, check.Equals, "token_from_settings_file2") + c.Check(client.APIHost, check.Equals, "[::]:3") + c.Check(client.Insecure, check.Equals, false) +} + +var _ = check.Suite(&clientRetrySuite{}) + +type clientRetrySuite struct { + server *httptest.Server + client Client + reqs []*http.Request + respStatus chan int + respDelay time.Duration + + origLimiterQuietPeriod time.Duration +} + +func (s *clientRetrySuite) SetUpTest(c *check.C) { + // Test server: delay and return errors until a final status + // appears on the respStatus channel. + s.origLimiterQuietPeriod = requestLimiterQuietPeriod + requestLimiterQuietPeriod = time.Second / 100 + s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.reqs = append(s.reqs, r) + delay := s.respDelay + if delay == 0 { + delay = time.Duration(rand.Int63n(int64(time.Second / 10))) + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case code, ok := <-s.respStatus: + if !ok { + code = http.StatusOK + } + w.WriteHeader(code) + w.Write([]byte(`{}`)) + case <-timer.C: + w.WriteHeader(http.StatusServiceUnavailable) + } + })) + s.reqs = nil + s.respStatus = make(chan int, 1) + s.client = Client{ + APIHost: s.server.URL[8:], + AuthToken: "zzz", + Insecure: true, + Timeout: 2 * time.Second, + } +} + +func (s *clientRetrySuite) TearDownTest(c *check.C) { + s.server.Close() + requestLimiterQuietPeriod = s.origLimiterQuietPeriod +} + +func (s *clientRetrySuite) TestOK(c *check.C) { + s.respStatus <- http.StatusOK + err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil) + c.Check(err, check.IsNil) + c.Check(s.reqs, check.HasLen, 1) +} + +func (s *clientRetrySuite) TestNetworkError(c *check.C) { + // Close the stub server to produce a "connection refused" error. + s.server.Close() + + start := time.Now() + timeout := time.Second + ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout)) + defer cancel() + s.client.Timeout = timeout * 2 + err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil) + c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`) + delta := time.Since(start) + c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout)) +} + +func (s *clientRetrySuite) TestNonRetryableError(c *check.C) { + s.respStatus <- http.StatusBadRequest + err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil) + c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`) + c.Check(s.reqs, check.HasLen, 1) +} + +func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) { + time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound }) + err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil) + c.Check(err, check.ErrorMatches, `.*404 Not Found.*`) +} + +func (s *clientRetrySuite) TestOKAfter503s(c *check.C) { + start := time.Now() + delay := time.Second + time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK }) + err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil) + c.Check(err, check.IsNil) + c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs))) + c.Check(time.Since(start) > delay, check.Equals, true) +} + +func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) { + s.respStatus <- http.StatusServiceUnavailable + s.respDelay = time.Second * 2 + s.client.Timeout = time.Second / 2 + err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil) + c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`) + c.Check(s.reqs, check.HasLen, 2) +} + +func (s *clientRetrySuite) Test503Forever(c *check.C) { + err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil) + c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`) + c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs))) +} + +func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil) + c.Check(err, check.Equals, context.Canceled) +} + +func (s *clientRetrySuite) TestExponentialBackoff(c *check.C) { + var min, max time.Duration + min, max = time.Second, 64*time.Second + + t := exponentialBackoff(min, max, 0, nil) + c.Check(t, check.Equals, min) + + for e := float64(1); e < 5; e += 1 { + ok := false + for i := 0; i < 20; i++ { + t = exponentialBackoff(min, max, int(e), nil) + // Every returned value must be between min and min(2^e, max) + c.Check(t >= min, check.Equals, true) + c.Check(t <= min*time.Duration(math.Pow(2, e)), check.Equals, true) + c.Check(t <= max, check.Equals, true) + // Check that jitter is actually happening by + // checking that at least one in 20 trials is + // between min*2^(e-.75) and min*2^(e-.25) + jittermin := time.Duration(float64(min) * math.Pow(2, e-0.75)) + jittermax := time.Duration(float64(min) * math.Pow(2, e-0.25)) + c.Logf("min %v max %v e %v jittermin %v jittermax %v t %v", min, max, e, jittermin, jittermax, t) + if t > jittermin && t < jittermax { + ok = true + break } - case err != nil: - t.Errorf("got err %#v, expected nil", err) - case !tc.ok(out): - t.Errorf("got %#v but tc.ok() says that is wrong", out) } + c.Check(ok, check.Equals, true) + } + + for i := 0; i < 20; i++ { + t := exponentialBackoff(min, max, 100, nil) + c.Check(t < max, check.Equals, true) } + + for _, trial := range []struct { + retryAfter string + expect time.Duration + }{ + {"1", time.Second * 4}, // minimum enforced + {"5", time.Second * 5}, // header used + {"55", time.Second * 10}, // maximum enforced + {"eleventy-nine", time.Second * 4}, // invalid header, exponential backoff used + {time.Now().UTC().Add(time.Second).Format(time.RFC1123), time.Second * 4}, // minimum enforced + {time.Now().UTC().Add(time.Minute).Format(time.RFC1123), time.Second * 10}, // maximum enforced + {time.Now().UTC().Add(-time.Minute).Format(time.RFC1123), time.Second * 4}, // minimum enforced + } { + c.Logf("trial %+v", trial) + t := exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{"Retry-After": {trial.retryAfter}}}) + c.Check(t, check.Equals, trial.expect) + } + t = exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{ + StatusCode: http.StatusTooManyRequests, + }) + c.Check(t, check.Equals, time.Second*4) + + t = exponentialBackoff(0, max, 0, nil) + c.Check(t, check.Equals, time.Duration(0)) + t = exponentialBackoff(0, max, 1, nil) + c.Check(t, check.Not(check.Equals), time.Duration(0)) }