Merge branch '19972-go-client-retry'
[arvados.git] / sdk / go / arvados / client_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "bytes"
9         "context"
10         "fmt"
11         "io/ioutil"
12         "math/rand"
13         "net/http"
14         "net/http/httptest"
15         "net/url"
16         "os"
17         "strings"
18         "sync"
19         "testing/iotest"
20         "time"
21
22         check "gopkg.in/check.v1"
23 )
24
25 type stubTransport struct {
26         Responses map[string]string
27         Requests  []http.Request
28         sync.Mutex
29 }
30
31 func (stub *stubTransport) RoundTrip(req *http.Request) (*http.Response, error) {
32         stub.Lock()
33         stub.Requests = append(stub.Requests, *req)
34         stub.Unlock()
35
36         resp := &http.Response{
37                 Status:     "200 OK",
38                 StatusCode: 200,
39                 Proto:      "HTTP/1.1",
40                 ProtoMajor: 1,
41                 ProtoMinor: 1,
42                 Request:    req,
43         }
44         str := stub.Responses[req.URL.Path]
45         if str == "" {
46                 resp.Status = "404 Not Found"
47                 resp.StatusCode = 404
48                 str = "{}"
49         }
50         buf := bytes.NewBufferString(str)
51         resp.Body = ioutil.NopCloser(buf)
52         resp.ContentLength = int64(buf.Len())
53         return resp, nil
54 }
55
56 type errorTransport struct{}
57
58 func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
59         return nil, fmt.Errorf("something awful happened")
60 }
61
62 type timeoutTransport struct {
63         response []byte
64 }
65
66 func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
67         return &http.Response{
68                 Status:     "200 OK",
69                 StatusCode: 200,
70                 Proto:      "HTTP/1.1",
71                 ProtoMajor: 1,
72                 ProtoMinor: 1,
73                 Request:    req,
74                 Body:       ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))),
75         }, nil
76 }
77
78 var _ = check.Suite(&clientSuite{})
79
80 type clientSuite struct{}
81
82 func (*clientSuite) TestCurrentUser(c *check.C) {
83         stub := &stubTransport{
84                 Responses: map[string]string{
85                         "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`,
86                 },
87         }
88         client := &Client{
89                 Client: &http.Client{
90                         Transport: stub,
91                 },
92                 APIHost:   "zzzzz.arvadosapi.com",
93                 AuthToken: "xyzzy",
94         }
95         u, err := client.CurrentUser()
96         c.Check(err, check.IsNil)
97         c.Check(u.UUID, check.Equals, "zzzzz-abcde-012340123401234")
98         c.Check(stub.Requests, check.Not(check.HasLen), 0)
99         hdr := stub.Requests[len(stub.Requests)-1].Header
100         c.Check(hdr.Get("Authorization"), check.Equals, "OAuth2 xyzzy")
101
102         client.Client.Transport = &errorTransport{}
103         u, err = client.CurrentUser()
104         c.Check(err, check.NotNil)
105 }
106
107 func (*clientSuite) TestAnythingToValues(c *check.C) {
108         type testCase struct {
109                 in interface{}
110                 // ok==nil means anythingToValues should return an
111                 // error, otherwise it's a func that returns true if
112                 // out is correct
113                 ok func(out url.Values) bool
114         }
115         for _, tc := range []testCase{
116                 {
117                         in: map[string]interface{}{"foo": "bar"},
118                         ok: func(out url.Values) bool {
119                                 return out.Get("foo") == "bar"
120                         },
121                 },
122                 {
123                         in: map[string]interface{}{"foo": 2147483647},
124                         ok: func(out url.Values) bool {
125                                 return out.Get("foo") == "2147483647"
126                         },
127                 },
128                 {
129                         in: map[string]interface{}{"foo": 1.234},
130                         ok: func(out url.Values) bool {
131                                 return out.Get("foo") == "1.234"
132                         },
133                 },
134                 {
135                         in: map[string]interface{}{"foo": "1.234"},
136                         ok: func(out url.Values) bool {
137                                 return out.Get("foo") == "1.234"
138                         },
139                 },
140                 {
141                         in: map[string]interface{}{"foo": map[string]interface{}{"bar": 1.234}},
142                         ok: func(out url.Values) bool {
143                                 return out.Get("foo") == `{"bar":1.234}`
144                         },
145                 },
146                 {
147                         in: url.Values{"foo": {"bar"}},
148                         ok: func(out url.Values) bool {
149                                 return out.Get("foo") == "bar"
150                         },
151                 },
152                 {
153                         in: 1234,
154                         ok: nil,
155                 },
156                 {
157                         in: []string{"foo"},
158                         ok: nil,
159                 },
160         } {
161                 c.Logf("%#v", tc.in)
162                 out, err := anythingToValues(tc.in)
163                 if tc.ok == nil {
164                         c.Check(err, check.NotNil)
165                         continue
166                 }
167                 c.Check(err, check.IsNil)
168                 c.Check(tc.ok(out), check.Equals, true)
169         }
170 }
171
172 func (*clientSuite) TestLoadConfig(c *check.C) {
173         oldenv := os.Environ()
174         defer func() {
175                 os.Clearenv()
176                 for _, s := range oldenv {
177                         i := strings.IndexRune(s, '=')
178                         os.Setenv(s[:i], s[i+1:])
179                 }
180         }()
181
182         tmp := c.MkDir()
183         os.Setenv("HOME", tmp)
184         for _, s := range os.Environ() {
185                 if strings.HasPrefix(s, "ARVADOS_") {
186                         i := strings.IndexRune(s, '=')
187                         os.Unsetenv(s[:i])
188                 }
189         }
190         os.Mkdir(tmp+"/.config", 0777)
191         os.Mkdir(tmp+"/.config/arvados", 0777)
192
193         // Use $HOME/.config/arvados/settings.conf if no env vars are
194         // set
195         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
196                 ARVADOS_API_HOST = localhost:1
197                 ARVADOS_API_TOKEN = token_from_settings_file1
198         `), 0777)
199         client := NewClientFromEnv()
200         c.Check(client.AuthToken, check.Equals, "token_from_settings_file1")
201         c.Check(client.APIHost, check.Equals, "localhost:1")
202         c.Check(client.Insecure, check.Equals, false)
203
204         // ..._INSECURE=true, comments, ignored lines in settings.conf
205         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
206                 (ignored) = (ignored)
207                 #ARVADOS_API_HOST = localhost:2
208                 ARVADOS_API_TOKEN = token_from_settings_file2
209                 ARVADOS_API_HOST_INSECURE = true
210         `), 0777)
211         client = NewClientFromEnv()
212         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
213         c.Check(client.APIHost, check.Equals, "")
214         c.Check(client.Insecure, check.Equals, true)
215
216         // Environment variables override settings.conf
217         os.Setenv("ARVADOS_API_HOST", "[::]:3")
218         os.Setenv("ARVADOS_API_HOST_INSECURE", "0")
219         client = NewClientFromEnv()
220         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
221         c.Check(client.APIHost, check.Equals, "[::]:3")
222         c.Check(client.Insecure, check.Equals, false)
223 }
224
225 var _ = check.Suite(&clientRetrySuite{})
226
227 type clientRetrySuite struct {
228         server     *httptest.Server
229         client     Client
230         reqs       []*http.Request
231         respStatus chan int
232         respDelay  time.Duration
233
234         origLimiterQuietPeriod time.Duration
235 }
236
237 func (s *clientRetrySuite) SetUpTest(c *check.C) {
238         // Test server: delay and return errors until a final status
239         // appears on the respStatus channel.
240         s.origLimiterQuietPeriod = requestLimiterQuietPeriod
241         requestLimiterQuietPeriod = time.Second / 100
242         s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
243                 s.reqs = append(s.reqs, r)
244                 delay := s.respDelay
245                 if delay == 0 {
246                         delay = time.Duration(rand.Int63n(int64(time.Second / 10)))
247                 }
248                 timer := time.NewTimer(delay)
249                 defer timer.Stop()
250                 select {
251                 case code, ok := <-s.respStatus:
252                         if !ok {
253                                 code = http.StatusOK
254                         }
255                         w.WriteHeader(code)
256                         w.Write([]byte(`{}`))
257                 case <-timer.C:
258                         w.WriteHeader(http.StatusServiceUnavailable)
259                 }
260         }))
261         s.reqs = nil
262         s.respStatus = make(chan int, 1)
263         s.client = Client{
264                 APIHost:   s.server.URL[8:],
265                 AuthToken: "zzz",
266                 Insecure:  true,
267                 Timeout:   2 * time.Second,
268         }
269 }
270
271 func (s *clientRetrySuite) TearDownTest(c *check.C) {
272         s.server.Close()
273         requestLimiterQuietPeriod = s.origLimiterQuietPeriod
274 }
275
276 func (s *clientRetrySuite) TestOK(c *check.C) {
277         s.respStatus <- http.StatusOK
278         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
279         c.Check(err, check.IsNil)
280         c.Check(s.reqs, check.HasLen, 1)
281 }
282
283 func (s *clientRetrySuite) TestNetworkError(c *check.C) {
284         // Close the stub server to produce a "connection refused" error.
285         s.server.Close()
286
287         start := time.Now()
288         timeout := time.Second
289         ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout))
290         defer cancel()
291         s.client.Timeout = timeout * 2
292         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
293         c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`)
294         delta := time.Since(start)
295         c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout))
296 }
297
298 func (s *clientRetrySuite) TestNonRetryableError(c *check.C) {
299         s.respStatus <- http.StatusBadRequest
300         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
301         c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`)
302         c.Check(s.reqs, check.HasLen, 1)
303 }
304
305 func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) {
306         time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound })
307         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
308         c.Check(err, check.ErrorMatches, `.*404 Not Found.*`)
309 }
310
311 func (s *clientRetrySuite) TestOKAfter503s(c *check.C) {
312         start := time.Now()
313         delay := time.Second
314         time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK })
315         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
316         c.Check(err, check.IsNil)
317         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
318         c.Check(time.Since(start) > delay, check.Equals, true)
319 }
320
321 func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) {
322         s.respStatus <- http.StatusServiceUnavailable
323         s.respDelay = time.Second * 2
324         s.client.Timeout = time.Second / 2
325         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
326         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
327         c.Check(s.reqs, check.HasLen, 2)
328 }
329
330 func (s *clientRetrySuite) Test503Forever(c *check.C) {
331         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
332         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
333         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
334 }
335
336 func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) {
337         ctx, cancel := context.WithCancel(context.Background())
338         cancel()
339         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
340         c.Check(err, check.Equals, context.Canceled)
341 }