Merge branch 'main' into 22141-picking-dialog-search
[arvados.git] / sdk / go / arvados / client_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "bytes"
9         "context"
10         "fmt"
11         "io/ioutil"
12         "math"
13         "math/rand"
14         "net/http"
15         "net/http/httptest"
16         "net/url"
17         "os"
18         "strings"
19         "sync"
20         "testing/iotest"
21         "time"
22
23         check "gopkg.in/check.v1"
24 )
25
26 type stubTransport struct {
27         Responses map[string]string
28         Requests  []http.Request
29         sync.Mutex
30 }
31
32 func (stub *stubTransport) RoundTrip(req *http.Request) (*http.Response, error) {
33         stub.Lock()
34         stub.Requests = append(stub.Requests, *req)
35         stub.Unlock()
36
37         resp := &http.Response{
38                 Status:     "200 OK",
39                 StatusCode: 200,
40                 Proto:      "HTTP/1.1",
41                 ProtoMajor: 1,
42                 ProtoMinor: 1,
43                 Request:    req,
44         }
45         str := stub.Responses[req.URL.Path]
46         if str == "" {
47                 resp.Status = "404 Not Found"
48                 resp.StatusCode = 404
49                 str = "{}"
50         }
51         buf := bytes.NewBufferString(str)
52         resp.Body = ioutil.NopCloser(buf)
53         resp.ContentLength = int64(buf.Len())
54         return resp, nil
55 }
56
57 type errorTransport struct{}
58
59 func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
60         return nil, fmt.Errorf("something awful happened")
61 }
62
63 type timeoutTransport struct {
64         response []byte
65 }
66
67 func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
68         return &http.Response{
69                 Status:     "200 OK",
70                 StatusCode: 200,
71                 Proto:      "HTTP/1.1",
72                 ProtoMajor: 1,
73                 ProtoMinor: 1,
74                 Request:    req,
75                 Body:       ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))),
76         }, nil
77 }
78
79 var _ = check.Suite(&clientSuite{})
80
81 type clientSuite struct{}
82
83 func (*clientSuite) TestCurrentUser(c *check.C) {
84         stub := &stubTransport{
85                 Responses: map[string]string{
86                         "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`,
87                 },
88         }
89         client := &Client{
90                 Client: &http.Client{
91                         Transport: stub,
92                 },
93                 APIHost:   "zzzzz.arvadosapi.com",
94                 AuthToken: "xyzzy",
95         }
96         u, err := client.CurrentUser()
97         c.Check(err, check.IsNil)
98         c.Check(u.UUID, check.Equals, "zzzzz-abcde-012340123401234")
99         c.Check(stub.Requests, check.Not(check.HasLen), 0)
100         hdr := stub.Requests[len(stub.Requests)-1].Header
101         c.Check(hdr.Get("Authorization"), check.Equals, "Bearer xyzzy")
102
103         client.Client.Transport = &errorTransport{}
104         u, err = client.CurrentUser()
105         c.Check(err, check.NotNil)
106 }
107
108 func (*clientSuite) TestAnythingToValues(c *check.C) {
109         type testCase struct {
110                 in interface{}
111                 // ok==nil means anythingToValues should return an
112                 // error, otherwise it's a func that returns true if
113                 // out is correct
114                 ok func(out url.Values) bool
115         }
116         for _, tc := range []testCase{
117                 {
118                         in: map[string]interface{}{"foo": "bar"},
119                         ok: func(out url.Values) bool {
120                                 return out.Get("foo") == "bar"
121                         },
122                 },
123                 {
124                         in: map[string]interface{}{"foo": 2147483647},
125                         ok: func(out url.Values) bool {
126                                 return out.Get("foo") == "2147483647"
127                         },
128                 },
129                 {
130                         in: map[string]interface{}{"foo": 1.234},
131                         ok: func(out url.Values) bool {
132                                 return out.Get("foo") == "1.234"
133                         },
134                 },
135                 {
136                         in: map[string]interface{}{"foo": "1.234"},
137                         ok: func(out url.Values) bool {
138                                 return out.Get("foo") == "1.234"
139                         },
140                 },
141                 {
142                         in: map[string]interface{}{"foo": map[string]interface{}{"bar": 1.234}},
143                         ok: func(out url.Values) bool {
144                                 return out.Get("foo") == `{"bar":1.234}`
145                         },
146                 },
147                 {
148                         in: url.Values{"foo": {"bar"}},
149                         ok: func(out url.Values) bool {
150                                 return out.Get("foo") == "bar"
151                         },
152                 },
153                 {
154                         in: 1234,
155                         ok: nil,
156                 },
157                 {
158                         in: []string{"foo"},
159                         ok: nil,
160                 },
161         } {
162                 c.Logf("%#v", tc.in)
163                 out, err := anythingToValues(tc.in)
164                 if tc.ok == nil {
165                         c.Check(err, check.NotNil)
166                         continue
167                 }
168                 c.Check(err, check.IsNil)
169                 c.Check(tc.ok(out), check.Equals, true)
170         }
171 }
172
173 // select=["uuid"] is added automatically when RequestAndDecode's
174 // destination argument is nil.
175 func (*clientSuite) TestAutoSelectUUID(c *check.C) {
176         var req *http.Request
177         var err error
178         server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
179                 c.Check(r.ParseForm(), check.IsNil)
180                 req = r
181                 w.Write([]byte("{}"))
182         }))
183         client := Client{
184                 APIHost:   strings.TrimPrefix(server.URL, "https://"),
185                 AuthToken: "zzz",
186                 Insecure:  true,
187                 Timeout:   2 * time.Second,
188         }
189
190         req = nil
191         err = client.RequestAndDecode(nil, http.MethodPost, "test", nil, nil)
192         c.Check(err, check.IsNil)
193         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
194
195         req = nil
196         err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, nil)
197         c.Check(err, check.IsNil)
198         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
199
200         req = nil
201         err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
202         c.Check(err, check.IsNil)
203         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
204
205         req = nil
206         err = client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
207         c.Check(err, check.IsNil)
208         c.Check(req.FormValue("select"), check.Equals, `["blergh"]`)
209 }
210
211 func (*clientSuite) TestLoadConfig(c *check.C) {
212         oldenv := os.Environ()
213         defer func() {
214                 os.Clearenv()
215                 for _, s := range oldenv {
216                         i := strings.IndexRune(s, '=')
217                         os.Setenv(s[:i], s[i+1:])
218                 }
219         }()
220
221         tmp := c.MkDir()
222         os.Setenv("HOME", tmp)
223         for _, s := range os.Environ() {
224                 if strings.HasPrefix(s, "ARVADOS_") {
225                         i := strings.IndexRune(s, '=')
226                         os.Unsetenv(s[:i])
227                 }
228         }
229         os.Mkdir(tmp+"/.config", 0777)
230         os.Mkdir(tmp+"/.config/arvados", 0777)
231
232         // Use $HOME/.config/arvados/settings.conf if no env vars are
233         // set
234         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
235                 ARVADOS_API_HOST = localhost:1
236                 ARVADOS_API_TOKEN = token_from_settings_file1
237         `), 0777)
238         client := NewClientFromEnv()
239         c.Check(client.AuthToken, check.Equals, "token_from_settings_file1")
240         c.Check(client.APIHost, check.Equals, "localhost:1")
241         c.Check(client.Insecure, check.Equals, false)
242
243         // ..._INSECURE=true, comments, ignored lines in settings.conf
244         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
245                 (ignored) = (ignored)
246                 #ARVADOS_API_HOST = localhost:2
247                 ARVADOS_API_TOKEN = token_from_settings_file2
248                 ARVADOS_API_HOST_INSECURE = true
249         `), 0777)
250         client = NewClientFromEnv()
251         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
252         c.Check(client.APIHost, check.Equals, "")
253         c.Check(client.Insecure, check.Equals, true)
254
255         // Environment variables override settings.conf
256         os.Setenv("ARVADOS_API_HOST", "[::]:3")
257         os.Setenv("ARVADOS_API_HOST_INSECURE", "0")
258         os.Setenv("ARVADOS_KEEP_SERVICES", "http://[::]:12345")
259         client = NewClientFromEnv()
260         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
261         c.Check(client.APIHost, check.Equals, "[::]:3")
262         c.Check(client.Insecure, check.Equals, false)
263         c.Check(client.KeepServiceURIs, check.DeepEquals, []string{"http://[::]:12345"})
264
265         // ARVADOS_KEEP_SERVICES environment variable overrides
266         // cluster config, but ARVADOS_API_HOST/TOKEN do not.
267         os.Setenv("ARVADOS_KEEP_SERVICES", "http://[::]:12345")
268         os.Setenv("ARVADOS_API_HOST", "wronghost.example")
269         os.Setenv("ARVADOS_API_TOKEN", "wrongtoken")
270         cfg := Cluster{}
271         cfg.Services.Controller.ExternalURL = URL{Scheme: "https", Host: "ctrl.example:55555", Path: "/"}
272         cfg.Services.Keepstore.InternalURLs = map[URL]ServiceInstance{
273                 URL{Scheme: "https", Host: "keep0.example:55555", Path: "/"}: ServiceInstance{},
274         }
275         client, err := NewClientFromConfig(&cfg)
276         c.Check(err, check.IsNil)
277         c.Check(client.AuthToken, check.Equals, "")
278         c.Check(client.APIHost, check.Equals, "ctrl.example:55555")
279         c.Check(client.Insecure, check.Equals, false)
280         c.Check(client.KeepServiceURIs, check.DeepEquals, []string{"http://[::]:12345"})
281 }
282
283 var _ = check.Suite(&clientRetrySuite{})
284
285 type clientRetrySuite struct {
286         server     *httptest.Server
287         client     Client
288         reqs       []*http.Request
289         respStatus chan int
290         respDelay  time.Duration
291
292         origLimiterQuietPeriod time.Duration
293 }
294
295 func (s *clientRetrySuite) SetUpTest(c *check.C) {
296         // Test server: delay and return errors until a final status
297         // appears on the respStatus channel.
298         s.origLimiterQuietPeriod = requestLimiterQuietPeriod
299         requestLimiterQuietPeriod = time.Second / 100
300         s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
301                 s.reqs = append(s.reqs, r)
302                 delay := s.respDelay
303                 if delay == 0 {
304                         delay = time.Duration(rand.Int63n(int64(time.Second / 10)))
305                 }
306                 timer := time.NewTimer(delay)
307                 defer timer.Stop()
308                 select {
309                 case code, ok := <-s.respStatus:
310                         if !ok {
311                                 code = http.StatusOK
312                         }
313                         w.WriteHeader(code)
314                         w.Write([]byte(`{}`))
315                 case <-timer.C:
316                         w.WriteHeader(http.StatusServiceUnavailable)
317                 }
318         }))
319         s.reqs = nil
320         s.respStatus = make(chan int, 1)
321         s.client = Client{
322                 APIHost:   s.server.URL[8:],
323                 AuthToken: "zzz",
324                 Insecure:  true,
325                 Timeout:   2 * time.Second,
326         }
327 }
328
329 func (s *clientRetrySuite) TearDownTest(c *check.C) {
330         s.server.Close()
331         requestLimiterQuietPeriod = s.origLimiterQuietPeriod
332 }
333
334 func (s *clientRetrySuite) TestOK(c *check.C) {
335         s.respStatus <- http.StatusOK
336         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
337         c.Check(err, check.IsNil)
338         c.Check(s.reqs, check.HasLen, 1)
339 }
340
341 func (s *clientRetrySuite) TestNetworkError(c *check.C) {
342         // Close the stub server to produce a "connection refused" error.
343         s.server.Close()
344
345         start := time.Now()
346         timeout := time.Second
347         ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout))
348         defer cancel()
349         s.client.Timeout = timeout * 2
350         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
351         c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`)
352         delta := time.Since(start)
353         c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout))
354 }
355
356 func (s *clientRetrySuite) TestNonRetryableError(c *check.C) {
357         s.respStatus <- http.StatusBadRequest
358         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
359         c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`)
360         c.Check(s.reqs, check.HasLen, 1)
361 }
362
363 // as of 0.7.2., retryablehttp does not recognize this as a
364 // non-retryable error.
365 func (s *clientRetrySuite) TestNonRetryableStdlibError(c *check.C) {
366         s.respStatus <- http.StatusOK
367         req, err := http.NewRequest(http.MethodGet, "https://"+s.client.APIHost+"/test", nil)
368         c.Assert(err, check.IsNil)
369         req.Header.Set("Good-Header", "T\033rrible header value")
370         err = s.client.DoAndDecode(&struct{}{}, req)
371         c.Check(err, check.ErrorMatches, `.*after 1 attempt.*net/http: invalid header .*`)
372         if !c.Check(s.reqs, check.HasLen, 0) {
373                 c.Logf("%v", s.reqs[0])
374         }
375 }
376
377 func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) {
378         time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound })
379         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
380         c.Check(err, check.ErrorMatches, `.*404 Not Found.*`)
381 }
382
383 func (s *clientRetrySuite) TestOKAfter503s(c *check.C) {
384         start := time.Now()
385         delay := time.Second
386         time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK })
387         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
388         c.Check(err, check.IsNil)
389         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
390         c.Check(time.Since(start) > delay, check.Equals, true)
391 }
392
393 func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) {
394         s.respStatus <- http.StatusServiceUnavailable
395         s.respDelay = time.Second * 2
396         s.client.Timeout = time.Second / 2
397         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
398         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
399         c.Check(s.reqs, check.HasLen, 2)
400 }
401
402 func (s *clientRetrySuite) Test503Forever(c *check.C) {
403         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
404         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
405         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
406 }
407
408 func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) {
409         ctx, cancel := context.WithCancel(context.Background())
410         cancel()
411         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
412         c.Check(err, check.Equals, context.Canceled)
413 }
414
415 func (s *clientRetrySuite) TestExponentialBackoff(c *check.C) {
416         var min, max time.Duration
417         min, max = time.Second, 64*time.Second
418
419         t := exponentialBackoff(min, max, 0, nil)
420         c.Check(t, check.Equals, min)
421
422         for e := float64(1); e < 5; e += 1 {
423                 ok := false
424                 for i := 0; i < 30; i++ {
425                         t = exponentialBackoff(min, max, int(e), nil)
426                         // Every returned value must be between min and min(2^e, max)
427                         c.Check(t >= min, check.Equals, true)
428                         c.Check(t <= min*time.Duration(math.Pow(2, e)), check.Equals, true)
429                         c.Check(t <= max, check.Equals, true)
430                         // Check that jitter is actually happening by
431                         // checking that at least one in 20 trials is
432                         // between min*2^(e-.75) and min*2^(e-.25)
433                         jittermin := time.Duration(float64(min) * math.Pow(2, e-0.75))
434                         jittermax := time.Duration(float64(min) * math.Pow(2, e-0.25))
435                         c.Logf("min %v max %v e %v jittermin %v jittermax %v t %v", min, max, e, jittermin, jittermax, t)
436                         if t > jittermin && t < jittermax {
437                                 ok = true
438                                 break
439                         }
440                 }
441                 c.Check(ok, check.Equals, true)
442         }
443
444         for i := 0; i < 20; i++ {
445                 t := exponentialBackoff(min, max, 100, nil)
446                 c.Check(t < max, check.Equals, true)
447         }
448
449         for _, trial := range []struct {
450                 retryAfter string
451                 expect     time.Duration
452         }{
453                 {"1", time.Second * 4},             // minimum enforced
454                 {"5", time.Second * 5},             // header used
455                 {"55", time.Second * 10},           // maximum enforced
456                 {"eleventy-nine", time.Second * 4}, // invalid header, exponential backoff used
457                 {time.Now().UTC().Add(time.Second).Format(time.RFC1123), time.Second * 4},  // minimum enforced
458                 {time.Now().UTC().Add(time.Minute).Format(time.RFC1123), time.Second * 10}, // maximum enforced
459                 {time.Now().UTC().Add(-time.Minute).Format(time.RFC1123), time.Second * 4}, // minimum enforced
460         } {
461                 c.Logf("trial %+v", trial)
462                 t := exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
463                         StatusCode: http.StatusTooManyRequests,
464                         Header:     http.Header{"Retry-After": {trial.retryAfter}}})
465                 c.Check(t, check.Equals, trial.expect)
466         }
467         t = exponentialBackoff(time.Second*4, time.Second*10, 0, &http.Response{
468                 StatusCode: http.StatusTooManyRequests,
469         })
470         c.Check(t, check.Equals, time.Second*4)
471
472         t = exponentialBackoff(0, max, 0, nil)
473         c.Check(t, check.Equals, time.Duration(0))
474         t = exponentialBackoff(0, max, 1, nil)
475         c.Check(t, check.Not(check.Equals), time.Duration(0))
476 }