Merge branch '21026-sanitize-html-doc'
[arvados.git] / sdk / go / arvados / client_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "bytes"
9         "context"
10         "fmt"
11         "io/ioutil"
12         "math"
13         "math/rand"
14         "net/http"
15         "net/http/httptest"
16         "net/url"
17         "os"
18         "strings"
19         "sync"
20         "testing/iotest"
21         "time"
22
23         check "gopkg.in/check.v1"
24 )
25
26 type stubTransport struct {
27         Responses map[string]string
28         Requests  []http.Request
29         sync.Mutex
30 }
31
32 func (stub *stubTransport) RoundTrip(req *http.Request) (*http.Response, error) {
33         stub.Lock()
34         stub.Requests = append(stub.Requests, *req)
35         stub.Unlock()
36
37         resp := &http.Response{
38                 Status:     "200 OK",
39                 StatusCode: 200,
40                 Proto:      "HTTP/1.1",
41                 ProtoMajor: 1,
42                 ProtoMinor: 1,
43                 Request:    req,
44         }
45         str := stub.Responses[req.URL.Path]
46         if str == "" {
47                 resp.Status = "404 Not Found"
48                 resp.StatusCode = 404
49                 str = "{}"
50         }
51         buf := bytes.NewBufferString(str)
52         resp.Body = ioutil.NopCloser(buf)
53         resp.ContentLength = int64(buf.Len())
54         return resp, nil
55 }
56
57 type errorTransport struct{}
58
59 func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
60         return nil, fmt.Errorf("something awful happened")
61 }
62
63 type timeoutTransport struct {
64         response []byte
65 }
66
67 func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
68         return &http.Response{
69                 Status:     "200 OK",
70                 StatusCode: 200,
71                 Proto:      "HTTP/1.1",
72                 ProtoMajor: 1,
73                 ProtoMinor: 1,
74                 Request:    req,
75                 Body:       ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))),
76         }, nil
77 }
78
79 var _ = check.Suite(&clientSuite{})
80
81 type clientSuite struct{}
82
83 func (*clientSuite) TestCurrentUser(c *check.C) {
84         stub := &stubTransport{
85                 Responses: map[string]string{
86                         "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`,
87                 },
88         }
89         client := &Client{
90                 Client: &http.Client{
91                         Transport: stub,
92                 },
93                 APIHost:   "zzzzz.arvadosapi.com",
94                 AuthToken: "xyzzy",
95         }
96         u, err := client.CurrentUser()
97         c.Check(err, check.IsNil)
98         c.Check(u.UUID, check.Equals, "zzzzz-abcde-012340123401234")
99         c.Check(stub.Requests, check.Not(check.HasLen), 0)
100         hdr := stub.Requests[len(stub.Requests)-1].Header
101         c.Check(hdr.Get("Authorization"), check.Equals, "OAuth2 xyzzy")
102
103         client.Client.Transport = &errorTransport{}
104         u, err = client.CurrentUser()
105         c.Check(err, check.NotNil)
106 }
107
108 func (*clientSuite) TestAnythingToValues(c *check.C) {
109         type testCase struct {
110                 in interface{}
111                 // ok==nil means anythingToValues should return an
112                 // error, otherwise it's a func that returns true if
113                 // out is correct
114                 ok func(out url.Values) bool
115         }
116         for _, tc := range []testCase{
117                 {
118                         in: map[string]interface{}{"foo": "bar"},
119                         ok: func(out url.Values) bool {
120                                 return out.Get("foo") == "bar"
121                         },
122                 },
123                 {
124                         in: map[string]interface{}{"foo": 2147483647},
125                         ok: func(out url.Values) bool {
126                                 return out.Get("foo") == "2147483647"
127                         },
128                 },
129                 {
130                         in: map[string]interface{}{"foo": 1.234},
131                         ok: func(out url.Values) bool {
132                                 return out.Get("foo") == "1.234"
133                         },
134                 },
135                 {
136                         in: map[string]interface{}{"foo": "1.234"},
137                         ok: func(out url.Values) bool {
138                                 return out.Get("foo") == "1.234"
139                         },
140                 },
141                 {
142                         in: map[string]interface{}{"foo": map[string]interface{}{"bar": 1.234}},
143                         ok: func(out url.Values) bool {
144                                 return out.Get("foo") == `{"bar":1.234}`
145                         },
146                 },
147                 {
148                         in: url.Values{"foo": {"bar"}},
149                         ok: func(out url.Values) bool {
150                                 return out.Get("foo") == "bar"
151                         },
152                 },
153                 {
154                         in: 1234,
155                         ok: nil,
156                 },
157                 {
158                         in: []string{"foo"},
159                         ok: nil,
160                 },
161         } {
162                 c.Logf("%#v", tc.in)
163                 out, err := anythingToValues(tc.in)
164                 if tc.ok == nil {
165                         c.Check(err, check.NotNil)
166                         continue
167                 }
168                 c.Check(err, check.IsNil)
169                 c.Check(tc.ok(out), check.Equals, true)
170         }
171 }
172
173 // select=["uuid"] is added automatically when RequestAndDecode's
174 // destination argument is nil.
175 func (*clientSuite) TestAutoSelectUUID(c *check.C) {
176         var req *http.Request
177         var err error
178         server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
179                 c.Check(r.ParseForm(), check.IsNil)
180                 req = r
181                 w.Write([]byte("{}"))
182         }))
183         client := Client{
184                 APIHost:   strings.TrimPrefix(server.URL, "https://"),
185                 AuthToken: "zzz",
186                 Insecure:  true,
187                 Timeout:   2 * time.Second,
188         }
189
190         req = nil
191         err = client.RequestAndDecode(nil, http.MethodPost, "test", nil, nil)
192         c.Check(err, check.IsNil)
193         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
194
195         req = nil
196         err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, nil)
197         c.Check(err, check.IsNil)
198         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
199
200         req = nil
201         err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
202         c.Check(err, check.IsNil)
203         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
204
205         req = nil
206         err = client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
207         c.Check(err, check.IsNil)
208         c.Check(req.FormValue("select"), check.Equals, `["blergh"]`)
209 }
210
211 func (*clientSuite) TestLoadConfig(c *check.C) {
212         oldenv := os.Environ()
213         defer func() {
214                 os.Clearenv()
215                 for _, s := range oldenv {
216                         i := strings.IndexRune(s, '=')
217                         os.Setenv(s[:i], s[i+1:])
218                 }
219         }()
220
221         tmp := c.MkDir()
222         os.Setenv("HOME", tmp)
223         for _, s := range os.Environ() {
224                 if strings.HasPrefix(s, "ARVADOS_") {
225                         i := strings.IndexRune(s, '=')
226                         os.Unsetenv(s[:i])
227                 }
228         }
229         os.Mkdir(tmp+"/.config", 0777)
230         os.Mkdir(tmp+"/.config/arvados", 0777)
231
232         // Use $HOME/.config/arvados/settings.conf if no env vars are
233         // set
234         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
235                 ARVADOS_API_HOST = localhost:1
236                 ARVADOS_API_TOKEN = token_from_settings_file1
237         `), 0777)
238         client := NewClientFromEnv()
239         c.Check(client.AuthToken, check.Equals, "token_from_settings_file1")
240         c.Check(client.APIHost, check.Equals, "localhost:1")
241         c.Check(client.Insecure, check.Equals, false)
242
243         // ..._INSECURE=true, comments, ignored lines in settings.conf
244         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
245                 (ignored) = (ignored)
246                 #ARVADOS_API_HOST = localhost:2
247                 ARVADOS_API_TOKEN = token_from_settings_file2
248                 ARVADOS_API_HOST_INSECURE = true
249         `), 0777)
250         client = NewClientFromEnv()
251         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
252         c.Check(client.APIHost, check.Equals, "")
253         c.Check(client.Insecure, check.Equals, true)
254
255         // Environment variables override settings.conf
256         os.Setenv("ARVADOS_API_HOST", "[::]:3")
257         os.Setenv("ARVADOS_API_HOST_INSECURE", "0")
258         client = NewClientFromEnv()
259         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
260         c.Check(client.APIHost, check.Equals, "[::]:3")
261         c.Check(client.Insecure, check.Equals, false)
262 }
263
264 var _ = check.Suite(&clientRetrySuite{})
265
266 type clientRetrySuite struct {
267         server     *httptest.Server
268         client     Client
269         reqs       []*http.Request
270         respStatus chan int
271         respDelay  time.Duration
272
273         origLimiterQuietPeriod time.Duration
274 }
275
276 func (s *clientRetrySuite) SetUpTest(c *check.C) {
277         // Test server: delay and return errors until a final status
278         // appears on the respStatus channel.
279         s.origLimiterQuietPeriod = requestLimiterQuietPeriod
280         requestLimiterQuietPeriod = time.Second / 100
281         s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
282                 s.reqs = append(s.reqs, r)
283                 delay := s.respDelay
284                 if delay == 0 {
285                         delay = time.Duration(rand.Int63n(int64(time.Second / 10)))
286                 }
287                 timer := time.NewTimer(delay)
288                 defer timer.Stop()
289                 select {
290                 case code, ok := <-s.respStatus:
291                         if !ok {
292                                 code = http.StatusOK
293                         }
294                         w.WriteHeader(code)
295                         w.Write([]byte(`{}`))
296                 case <-timer.C:
297                         w.WriteHeader(http.StatusServiceUnavailable)
298                 }
299         }))
300         s.reqs = nil
301         s.respStatus = make(chan int, 1)
302         s.client = Client{
303                 APIHost:   s.server.URL[8:],
304                 AuthToken: "zzz",
305                 Insecure:  true,
306                 Timeout:   2 * time.Second,
307         }
308 }
309
310 func (s *clientRetrySuite) TearDownTest(c *check.C) {
311         s.server.Close()
312         requestLimiterQuietPeriod = s.origLimiterQuietPeriod
313 }
314
315 func (s *clientRetrySuite) TestOK(c *check.C) {
316         s.respStatus <- http.StatusOK
317         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
318         c.Check(err, check.IsNil)
319         c.Check(s.reqs, check.HasLen, 1)
320 }
321
322 func (s *clientRetrySuite) TestNetworkError(c *check.C) {
323         // Close the stub server to produce a "connection refused" error.
324         s.server.Close()
325
326         start := time.Now()
327         timeout := time.Second
328         ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout))
329         defer cancel()
330         s.client.Timeout = timeout * 2
331         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
332         c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`)
333         delta := time.Since(start)
334         c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout))
335 }
336
337 func (s *clientRetrySuite) TestNonRetryableError(c *check.C) {
338         s.respStatus <- http.StatusBadRequest
339         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
340         c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`)
341         c.Check(s.reqs, check.HasLen, 1)
342 }
343
344 func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) {
345         time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound })
346         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
347         c.Check(err, check.ErrorMatches, `.*404 Not Found.*`)
348 }
349
350 func (s *clientRetrySuite) TestOKAfter503s(c *check.C) {
351         start := time.Now()
352         delay := time.Second
353         time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK })
354         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
355         c.Check(err, check.IsNil)
356         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
357         c.Check(time.Since(start) > delay, check.Equals, true)
358 }
359
360 func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) {
361         s.respStatus <- http.StatusServiceUnavailable
362         s.respDelay = time.Second * 2
363         s.client.Timeout = time.Second / 2
364         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
365         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
366         c.Check(s.reqs, check.HasLen, 2)
367 }
368
369 func (s *clientRetrySuite) Test503Forever(c *check.C) {
370         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
371         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
372         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
373 }
374
375 func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) {
376         ctx, cancel := context.WithCancel(context.Background())
377         cancel()
378         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
379         c.Check(err, check.Equals, context.Canceled)
380 }
381
382 func (s *clientRetrySuite) TestExponentialBackoff(c *check.C) {
383         var min, max time.Duration
384         min, max = time.Second, 64*time.Second
385
386         t := exponentialBackoff(min, max, 0, nil)
387         c.Check(t, check.Equals, min)
388
389         for e := float64(1); e < 5; e += 1 {
390                 ok := false
391                 for i := 0; i < 20; i++ {
392                         t = exponentialBackoff(min, max, int(e), nil)
393                         // Every returned value must be between min and min(2^e, max)
394                         c.Check(t >= min, check.Equals, true)
395                         c.Check(t <= min*time.Duration(math.Pow(2, e)), check.Equals, true)
396                         c.Check(t <= max, check.Equals, true)
397                         // Check that jitter is actually happening by
398                         // checking that at least one in 20 trials is
399                         // between min*2^(e-.75) and min*2^(e-.25)
400                         jittermin := time.Duration(float64(min) * math.Pow(2, e-0.75))
401                         jittermax := time.Duration(float64(min) * math.Pow(2, e-0.25))
402                         c.Logf("min %v max %v e %v jittermin %v jittermax %v t %v", min, max, e, jittermin, jittermax, t)
403                         if t > jittermin && t < jittermax {
404                                 ok = true
405                                 break
406                         }
407                 }
408                 c.Check(ok, check.Equals, true)
409         }
410
411         for i := 0; i < 20; i++ {
412                 t := exponentialBackoff(min, max, 100, nil)
413                 c.Check(t < max, check.Equals, true)
414         }
415
416         for _, trial := range []struct {
417                 retryAfter string
418                 expect     time.Duration
419         }{
420                 {"1", time.Second * 4},             // minimum enforced
421                 {"5", time.Second * 5},             // header used
422                 {"55", time.Second * 10},           // maximum enforced
423                 {"eleventy-nine", time.Second * 4}, // invalid header, exponential backoff used
424                 {time.Now().UTC().Add(time.Second).Format(time.RFC1123), time.Second * 4},  // minimum enforced
425                 {time.Now().UTC().Add(time.Minute).Format(time.RFC1123), time.Second * 10}, // maximum enforced
426                 {time.Now().UTC().Add(-time.Minute).Format(time.RFC1123), time.Second * 4}, // minimum enforced
427         } {
428                 c.Logf("trial %+v", trial)
429                 t := exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
430                         StatusCode: http.StatusTooManyRequests,
431                         Header:     http.Header{"Retry-After": {trial.retryAfter}}})
432                 c.Check(t, check.Equals, trial.expect)
433         }
434         t = exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
435                 StatusCode: http.StatusTooManyRequests,
436         })
437         c.Check(t, check.Equals, time.Second*4)
438
439         t = exponentialBackoff(0, max, 0, nil)
440         c.Check(t, check.Equals, time.Duration(0))
441         t = exponentialBackoff(0, max, 1, nil)
442         c.Check(t, check.Not(check.Equals), time.Duration(0))
443 }