Merge branch '21535-multi-wf-delete'
[arvados.git] / sdk / go / arvados / client_test.go
index b0627fd27a665bf26250892c3fabd3319ff4e489..55e2f998c4a88589efe1d917272e152a29210247 100644 (file)
@@ -6,12 +6,21 @@ package arvados
 
 import (
        "bytes"
+       "context"
        "fmt"
        "io/ioutil"
+       "math"
+       "math/rand"
        "net/http"
+       "net/http/httptest"
        "net/url"
+       "os"
+       "strings"
        "sync"
-       "testing"
+       "testing/iotest"
+       "time"
+
+       check "gopkg.in/check.v1"
 )
 
 type stubTransport struct {
@@ -51,43 +60,52 @@ func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error)
        return nil, fmt.Errorf("something awful happened")
 }
 
-func TestCurrentUser(t *testing.T) {
-       t.Parallel()
+type timeoutTransport struct {
+       response []byte
+}
+
+func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+       return &http.Response{
+               Status:     "200 OK",
+               StatusCode: 200,
+               Proto:      "HTTP/1.1",
+               ProtoMajor: 1,
+               ProtoMinor: 1,
+               Request:    req,
+               Body:       ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))),
+       }, nil
+}
+
+var _ = check.Suite(&clientSuite{})
+
+type clientSuite struct{}
+
+func (*clientSuite) TestCurrentUser(c *check.C) {
        stub := &stubTransport{
                Responses: map[string]string{
                        "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`,
                },
        }
-       c := &Client{
+       client := &Client{
                Client: &http.Client{
                        Transport: stub,
                },
                APIHost:   "zzzzz.arvadosapi.com",
                AuthToken: "xyzzy",
        }
-       u, err := c.CurrentUser()
-       if err != nil {
-               t.Fatal(err)
-       }
-       if x := "zzzzz-abcde-012340123401234"; u.UUID != x {
-               t.Errorf("got uuid %q, expected %q", u.UUID, x)
-       }
-       if len(stub.Requests) < 1 {
-               t.Fatal("empty stub.Requests")
-       }
+       u, err := client.CurrentUser()
+       c.Check(err, check.IsNil)
+       c.Check(u.UUID, check.Equals, "zzzzz-abcde-012340123401234")
+       c.Check(stub.Requests, check.Not(check.HasLen), 0)
        hdr := stub.Requests[len(stub.Requests)-1].Header
-       if hdr.Get("Authorization") != "OAuth2 xyzzy" {
-               t.Errorf("got headers %+q, expected Authorization header", hdr)
-       }
+       c.Check(hdr.Get("Authorization"), check.Equals, "OAuth2 xyzzy")
 
-       c.Client.Transport = &errorTransport{}
-       u, err = c.CurrentUser()
-       if err == nil {
-               t.Errorf("got nil error, expected something awful")
-       }
+       client.Client.Transport = &errorTransport{}
+       u, err = client.CurrentUser()
+       c.Check(err, check.NotNil)
 }
 
-func TestAnythingToValues(t *testing.T) {
+func (*clientSuite) TestAnythingToValues(c *check.C) {
        type testCase struct {
                in interface{}
                // ok==nil means anythingToValues should return an
@@ -141,17 +159,299 @@ func TestAnythingToValues(t *testing.T) {
                        ok: nil,
                },
        } {
-               t.Logf("%#v", tc.in)
+               c.Logf("%#v", tc.in)
                out, err := anythingToValues(tc.in)
-               switch {
-               case tc.ok == nil:
-                       if err == nil {
-                               t.Errorf("got %#v, expected error", out)
+               if tc.ok == nil {
+                       c.Check(err, check.NotNil)
+                       continue
+               }
+               c.Check(err, check.IsNil)
+               c.Check(tc.ok(out), check.Equals, true)
+       }
+}
+
+// select=["uuid"] is added automatically when RequestAndDecode's
+// destination argument is nil.
+func (*clientSuite) TestAutoSelectUUID(c *check.C) {
+       var req *http.Request
+       var err error
+       server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               c.Check(r.ParseForm(), check.IsNil)
+               req = r
+               w.Write([]byte("{}"))
+       }))
+       client := Client{
+               APIHost:   strings.TrimPrefix(server.URL, "https://"),
+               AuthToken: "zzz",
+               Insecure:  true,
+               Timeout:   2 * time.Second,
+       }
+
+       req = nil
+       err = client.RequestAndDecode(nil, http.MethodPost, "test", nil, nil)
+       c.Check(err, check.IsNil)
+       c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
+
+       req = nil
+       err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.IsNil)
+       c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
+
+       req = nil
+       err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
+       c.Check(err, check.IsNil)
+       c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
+
+       req = nil
+       err = client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
+       c.Check(err, check.IsNil)
+       c.Check(req.FormValue("select"), check.Equals, `["blergh"]`)
+}
+
+func (*clientSuite) TestLoadConfig(c *check.C) {
+       oldenv := os.Environ()
+       defer func() {
+               os.Clearenv()
+               for _, s := range oldenv {
+                       i := strings.IndexRune(s, '=')
+                       os.Setenv(s[:i], s[i+1:])
+               }
+       }()
+
+       tmp := c.MkDir()
+       os.Setenv("HOME", tmp)
+       for _, s := range os.Environ() {
+               if strings.HasPrefix(s, "ARVADOS_") {
+                       i := strings.IndexRune(s, '=')
+                       os.Unsetenv(s[:i])
+               }
+       }
+       os.Mkdir(tmp+"/.config", 0777)
+       os.Mkdir(tmp+"/.config/arvados", 0777)
+
+       // Use $HOME/.config/arvados/settings.conf if no env vars are
+       // set
+       os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
+               ARVADOS_API_HOST = localhost:1
+               ARVADOS_API_TOKEN = token_from_settings_file1
+       `), 0777)
+       client := NewClientFromEnv()
+       c.Check(client.AuthToken, check.Equals, "token_from_settings_file1")
+       c.Check(client.APIHost, check.Equals, "localhost:1")
+       c.Check(client.Insecure, check.Equals, false)
+
+       // ..._INSECURE=true, comments, ignored lines in settings.conf
+       os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
+               (ignored) = (ignored)
+               #ARVADOS_API_HOST = localhost:2
+               ARVADOS_API_TOKEN = token_from_settings_file2
+               ARVADOS_API_HOST_INSECURE = true
+       `), 0777)
+       client = NewClientFromEnv()
+       c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
+       c.Check(client.APIHost, check.Equals, "")
+       c.Check(client.Insecure, check.Equals, true)
+
+       // Environment variables override settings.conf
+       os.Setenv("ARVADOS_API_HOST", "[::]:3")
+       os.Setenv("ARVADOS_API_HOST_INSECURE", "0")
+       client = NewClientFromEnv()
+       c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
+       c.Check(client.APIHost, check.Equals, "[::]:3")
+       c.Check(client.Insecure, check.Equals, false)
+}
+
+var _ = check.Suite(&clientRetrySuite{})
+
+type clientRetrySuite struct {
+       server     *httptest.Server
+       client     Client
+       reqs       []*http.Request
+       respStatus chan int
+       respDelay  time.Duration
+
+       origLimiterQuietPeriod time.Duration
+}
+
+func (s *clientRetrySuite) SetUpTest(c *check.C) {
+       // Test server: delay and return errors until a final status
+       // appears on the respStatus channel.
+       s.origLimiterQuietPeriod = requestLimiterQuietPeriod
+       requestLimiterQuietPeriod = time.Second / 100
+       s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               s.reqs = append(s.reqs, r)
+               delay := s.respDelay
+               if delay == 0 {
+                       delay = time.Duration(rand.Int63n(int64(time.Second / 10)))
+               }
+               timer := time.NewTimer(delay)
+               defer timer.Stop()
+               select {
+               case code, ok := <-s.respStatus:
+                       if !ok {
+                               code = http.StatusOK
                        }
-               case err != nil:
-                       t.Errorf("got err %#v, expected nil", err)
-               case !tc.ok(out):
-                       t.Errorf("got %#v but tc.ok() says that is wrong", out)
+                       w.WriteHeader(code)
+                       w.Write([]byte(`{}`))
+               case <-timer.C:
+                       w.WriteHeader(http.StatusServiceUnavailable)
                }
+       }))
+       s.reqs = nil
+       s.respStatus = make(chan int, 1)
+       s.client = Client{
+               APIHost:   s.server.URL[8:],
+               AuthToken: "zzz",
+               Insecure:  true,
+               Timeout:   2 * time.Second,
+       }
+}
+
+func (s *clientRetrySuite) TearDownTest(c *check.C) {
+       s.server.Close()
+       requestLimiterQuietPeriod = s.origLimiterQuietPeriod
+}
+
+func (s *clientRetrySuite) TestOK(c *check.C) {
+       s.respStatus <- http.StatusOK
+       err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.IsNil)
+       c.Check(s.reqs, check.HasLen, 1)
+}
+
+func (s *clientRetrySuite) TestNetworkError(c *check.C) {
+       // Close the stub server to produce a "connection refused" error.
+       s.server.Close()
+
+       start := time.Now()
+       timeout := time.Second
+       ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout))
+       defer cancel()
+       s.client.Timeout = timeout * 2
+       err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`)
+       delta := time.Since(start)
+       c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout))
+}
+
+func (s *clientRetrySuite) TestNonRetryableError(c *check.C) {
+       s.respStatus <- http.StatusBadRequest
+       err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`)
+       c.Check(s.reqs, check.HasLen, 1)
+}
+
+// as of 0.7.2., retryablehttp does not recognize this as a
+// non-retryable error.
+func (s *clientRetrySuite) TestNonRetryableStdlibError(c *check.C) {
+       s.respStatus <- http.StatusOK
+       req, err := http.NewRequest(http.MethodGet, "https://"+s.client.APIHost+"/test", nil)
+       c.Assert(err, check.IsNil)
+       req.Header.Set("Good-Header", "T\033rrible header value")
+       err = s.client.DoAndDecode(&struct{}{}, req)
+       c.Check(err, check.ErrorMatches, `.*after 1 attempt.*net/http: invalid header .*`)
+       if !c.Check(s.reqs, check.HasLen, 0) {
+               c.Logf("%v", s.reqs[0])
        }
 }
+
+func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) {
+       time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound })
+       err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.ErrorMatches, `.*404 Not Found.*`)
+}
+
+func (s *clientRetrySuite) TestOKAfter503s(c *check.C) {
+       start := time.Now()
+       delay := time.Second
+       time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK })
+       err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.IsNil)
+       c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
+       c.Check(time.Since(start) > delay, check.Equals, true)
+}
+
+func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) {
+       s.respStatus <- http.StatusServiceUnavailable
+       s.respDelay = time.Second * 2
+       s.client.Timeout = time.Second / 2
+       err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
+       c.Check(s.reqs, check.HasLen, 2)
+}
+
+func (s *clientRetrySuite) Test503Forever(c *check.C) {
+       err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
+       c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
+}
+
+func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) {
+       ctx, cancel := context.WithCancel(context.Background())
+       cancel()
+       err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
+       c.Check(err, check.Equals, context.Canceled)
+}
+
+func (s *clientRetrySuite) TestExponentialBackoff(c *check.C) {
+       var min, max time.Duration
+       min, max = time.Second, 64*time.Second
+
+       t := exponentialBackoff(min, max, 0, nil)
+       c.Check(t, check.Equals, min)
+
+       for e := float64(1); e < 5; e += 1 {
+               ok := false
+               for i := 0; i < 20; i++ {
+                       t = exponentialBackoff(min, max, int(e), nil)
+                       // Every returned value must be between min and min(2^e, max)
+                       c.Check(t >= min, check.Equals, true)
+                       c.Check(t <= min*time.Duration(math.Pow(2, e)), check.Equals, true)
+                       c.Check(t <= max, check.Equals, true)
+                       // Check that jitter is actually happening by
+                       // checking that at least one in 20 trials is
+                       // between min*2^(e-.75) and min*2^(e-.25)
+                       jittermin := time.Duration(float64(min) * math.Pow(2, e-0.75))
+                       jittermax := time.Duration(float64(min) * math.Pow(2, e-0.25))
+                       c.Logf("min %v max %v e %v jittermin %v jittermax %v t %v", min, max, e, jittermin, jittermax, t)
+                       if t > jittermin && t < jittermax {
+                               ok = true
+                               break
+                       }
+               }
+               c.Check(ok, check.Equals, true)
+       }
+
+       for i := 0; i < 20; i++ {
+               t := exponentialBackoff(min, max, 100, nil)
+               c.Check(t < max, check.Equals, true)
+       }
+
+       for _, trial := range []struct {
+               retryAfter string
+               expect     time.Duration
+       }{
+               {"1", time.Second * 4},             // minimum enforced
+               {"5", time.Second * 5},             // header used
+               {"55", time.Second * 10},           // maximum enforced
+               {"eleventy-nine", time.Second * 4}, // invalid header, exponential backoff used
+               {time.Now().UTC().Add(time.Second).Format(time.RFC1123), time.Second * 4},  // minimum enforced
+               {time.Now().UTC().Add(time.Minute).Format(time.RFC1123), time.Second * 10}, // maximum enforced
+               {time.Now().UTC().Add(-time.Minute).Format(time.RFC1123), time.Second * 4}, // minimum enforced
+       } {
+               c.Logf("trial %+v", trial)
+               t := exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
+                       StatusCode: http.StatusTooManyRequests,
+                       Header:     http.Header{"Retry-After": {trial.retryAfter}}})
+               c.Check(t, check.Equals, trial.expect)
+       }
+       t = exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
+               StatusCode: http.StatusTooManyRequests,
+       })
+       c.Check(t, check.Equals, time.Second*4)
+
+       t = exponentialBackoff(0, max, 0, nil)
+       c.Check(t, check.Equals, time.Duration(0))
+       t = exponentialBackoff(0, max, 1, nil)
+       c.Check(t, check.Not(check.Equals), time.Duration(0))
+}