20541: Auto unselect fields if response body will not be read.
[arvados.git] / sdk / go / arvados / client_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "bytes"
9         "context"
10         "fmt"
11         "io/ioutil"
12         "math/rand"
13         "net/http"
14         "net/http/httptest"
15         "net/url"
16         "os"
17         "strings"
18         "sync"
19         "testing/iotest"
20         "time"
21
22         check "gopkg.in/check.v1"
23 )
24
25 type stubTransport struct {
26         Responses map[string]string
27         Requests  []http.Request
28         sync.Mutex
29 }
30
31 func (stub *stubTransport) RoundTrip(req *http.Request) (*http.Response, error) {
32         stub.Lock()
33         stub.Requests = append(stub.Requests, *req)
34         stub.Unlock()
35
36         resp := &http.Response{
37                 Status:     "200 OK",
38                 StatusCode: 200,
39                 Proto:      "HTTP/1.1",
40                 ProtoMajor: 1,
41                 ProtoMinor: 1,
42                 Request:    req,
43         }
44         str := stub.Responses[req.URL.Path]
45         if str == "" {
46                 resp.Status = "404 Not Found"
47                 resp.StatusCode = 404
48                 str = "{}"
49         }
50         buf := bytes.NewBufferString(str)
51         resp.Body = ioutil.NopCloser(buf)
52         resp.ContentLength = int64(buf.Len())
53         return resp, nil
54 }
55
56 type errorTransport struct{}
57
58 func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
59         return nil, fmt.Errorf("something awful happened")
60 }
61
62 type timeoutTransport struct {
63         response []byte
64 }
65
66 func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
67         return &http.Response{
68                 Status:     "200 OK",
69                 StatusCode: 200,
70                 Proto:      "HTTP/1.1",
71                 ProtoMajor: 1,
72                 ProtoMinor: 1,
73                 Request:    req,
74                 Body:       ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))),
75         }, nil
76 }
77
78 var _ = check.Suite(&clientSuite{})
79
80 type clientSuite struct{}
81
82 func (*clientSuite) TestCurrentUser(c *check.C) {
83         stub := &stubTransport{
84                 Responses: map[string]string{
85                         "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`,
86                 },
87         }
88         client := &Client{
89                 Client: &http.Client{
90                         Transport: stub,
91                 },
92                 APIHost:   "zzzzz.arvadosapi.com",
93                 AuthToken: "xyzzy",
94         }
95         u, err := client.CurrentUser()
96         c.Check(err, check.IsNil)
97         c.Check(u.UUID, check.Equals, "zzzzz-abcde-012340123401234")
98         c.Check(stub.Requests, check.Not(check.HasLen), 0)
99         hdr := stub.Requests[len(stub.Requests)-1].Header
100         c.Check(hdr.Get("Authorization"), check.Equals, "OAuth2 xyzzy")
101
102         client.Client.Transport = &errorTransport{}
103         u, err = client.CurrentUser()
104         c.Check(err, check.NotNil)
105 }
106
107 func (*clientSuite) TestAnythingToValues(c *check.C) {
108         type testCase struct {
109                 in interface{}
110                 // ok==nil means anythingToValues should return an
111                 // error, otherwise it's a func that returns true if
112                 // out is correct
113                 ok func(out url.Values) bool
114         }
115         for _, tc := range []testCase{
116                 {
117                         in: map[string]interface{}{"foo": "bar"},
118                         ok: func(out url.Values) bool {
119                                 return out.Get("foo") == "bar"
120                         },
121                 },
122                 {
123                         in: map[string]interface{}{"foo": 2147483647},
124                         ok: func(out url.Values) bool {
125                                 return out.Get("foo") == "2147483647"
126                         },
127                 },
128                 {
129                         in: map[string]interface{}{"foo": 1.234},
130                         ok: func(out url.Values) bool {
131                                 return out.Get("foo") == "1.234"
132                         },
133                 },
134                 {
135                         in: map[string]interface{}{"foo": "1.234"},
136                         ok: func(out url.Values) bool {
137                                 return out.Get("foo") == "1.234"
138                         },
139                 },
140                 {
141                         in: map[string]interface{}{"foo": map[string]interface{}{"bar": 1.234}},
142                         ok: func(out url.Values) bool {
143                                 return out.Get("foo") == `{"bar":1.234}`
144                         },
145                 },
146                 {
147                         in: url.Values{"foo": {"bar"}},
148                         ok: func(out url.Values) bool {
149                                 return out.Get("foo") == "bar"
150                         },
151                 },
152                 {
153                         in: 1234,
154                         ok: nil,
155                 },
156                 {
157                         in: []string{"foo"},
158                         ok: nil,
159                 },
160         } {
161                 c.Logf("%#v", tc.in)
162                 out, err := anythingToValues(tc.in)
163                 if tc.ok == nil {
164                         c.Check(err, check.NotNil)
165                         continue
166                 }
167                 c.Check(err, check.IsNil)
168                 c.Check(tc.ok(out), check.Equals, true)
169         }
170 }
171
172 // select=["uuid"] is added automatically when RequestAndDecode's
173 // destination argument is nil.
174 func (*clientSuite) TestAutoSelectUUID(c *check.C) {
175         var req *http.Request
176         var err error
177         server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
178                 c.Check(r.ParseForm(), check.IsNil)
179                 req = r
180                 w.Write([]byte("{}"))
181         }))
182         client := Client{
183                 APIHost:   strings.TrimPrefix(server.URL, "https://"),
184                 AuthToken: "zzz",
185                 Insecure:  true,
186                 Timeout:   2 * time.Second,
187         }
188
189         req = nil
190         err = client.RequestAndDecode(nil, http.MethodPost, "test", nil, nil)
191         c.Check(err, check.IsNil)
192         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
193
194         req = nil
195         err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, nil)
196         c.Check(err, check.IsNil)
197         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
198
199         req = nil
200         err = client.RequestAndDecode(nil, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
201         c.Check(err, check.IsNil)
202         c.Check(req.FormValue("select"), check.Equals, `["uuid"]`)
203
204         req = nil
205         err = client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, map[string]interface{}{"select": []string{"blergh"}})
206         c.Check(err, check.IsNil)
207         c.Check(req.FormValue("select"), check.Equals, `["blergh"]`)
208 }
209
210 func (*clientSuite) TestLoadConfig(c *check.C) {
211         oldenv := os.Environ()
212         defer func() {
213                 os.Clearenv()
214                 for _, s := range oldenv {
215                         i := strings.IndexRune(s, '=')
216                         os.Setenv(s[:i], s[i+1:])
217                 }
218         }()
219
220         tmp := c.MkDir()
221         os.Setenv("HOME", tmp)
222         for _, s := range os.Environ() {
223                 if strings.HasPrefix(s, "ARVADOS_") {
224                         i := strings.IndexRune(s, '=')
225                         os.Unsetenv(s[:i])
226                 }
227         }
228         os.Mkdir(tmp+"/.config", 0777)
229         os.Mkdir(tmp+"/.config/arvados", 0777)
230
231         // Use $HOME/.config/arvados/settings.conf if no env vars are
232         // set
233         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
234                 ARVADOS_API_HOST = localhost:1
235                 ARVADOS_API_TOKEN = token_from_settings_file1
236         `), 0777)
237         client := NewClientFromEnv()
238         c.Check(client.AuthToken, check.Equals, "token_from_settings_file1")
239         c.Check(client.APIHost, check.Equals, "localhost:1")
240         c.Check(client.Insecure, check.Equals, false)
241
242         // ..._INSECURE=true, comments, ignored lines in settings.conf
243         os.WriteFile(tmp+"/.config/arvados/settings.conf", []byte(`
244                 (ignored) = (ignored)
245                 #ARVADOS_API_HOST = localhost:2
246                 ARVADOS_API_TOKEN = token_from_settings_file2
247                 ARVADOS_API_HOST_INSECURE = true
248         `), 0777)
249         client = NewClientFromEnv()
250         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
251         c.Check(client.APIHost, check.Equals, "")
252         c.Check(client.Insecure, check.Equals, true)
253
254         // Environment variables override settings.conf
255         os.Setenv("ARVADOS_API_HOST", "[::]:3")
256         os.Setenv("ARVADOS_API_HOST_INSECURE", "0")
257         client = NewClientFromEnv()
258         c.Check(client.AuthToken, check.Equals, "token_from_settings_file2")
259         c.Check(client.APIHost, check.Equals, "[::]:3")
260         c.Check(client.Insecure, check.Equals, false)
261 }
262
263 var _ = check.Suite(&clientRetrySuite{})
264
265 type clientRetrySuite struct {
266         server     *httptest.Server
267         client     Client
268         reqs       []*http.Request
269         respStatus chan int
270         respDelay  time.Duration
271
272         origLimiterQuietPeriod time.Duration
273 }
274
275 func (s *clientRetrySuite) SetUpTest(c *check.C) {
276         // Test server: delay and return errors until a final status
277         // appears on the respStatus channel.
278         s.origLimiterQuietPeriod = requestLimiterQuietPeriod
279         requestLimiterQuietPeriod = time.Second / 100
280         s.server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
281                 s.reqs = append(s.reqs, r)
282                 delay := s.respDelay
283                 if delay == 0 {
284                         delay = time.Duration(rand.Int63n(int64(time.Second / 10)))
285                 }
286                 timer := time.NewTimer(delay)
287                 defer timer.Stop()
288                 select {
289                 case code, ok := <-s.respStatus:
290                         if !ok {
291                                 code = http.StatusOK
292                         }
293                         w.WriteHeader(code)
294                         w.Write([]byte(`{}`))
295                 case <-timer.C:
296                         w.WriteHeader(http.StatusServiceUnavailable)
297                 }
298         }))
299         s.reqs = nil
300         s.respStatus = make(chan int, 1)
301         s.client = Client{
302                 APIHost:   s.server.URL[8:],
303                 AuthToken: "zzz",
304                 Insecure:  true,
305                 Timeout:   2 * time.Second,
306         }
307 }
308
309 func (s *clientRetrySuite) TearDownTest(c *check.C) {
310         s.server.Close()
311         requestLimiterQuietPeriod = s.origLimiterQuietPeriod
312 }
313
314 func (s *clientRetrySuite) TestOK(c *check.C) {
315         s.respStatus <- http.StatusOK
316         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
317         c.Check(err, check.IsNil)
318         c.Check(s.reqs, check.HasLen, 1)
319 }
320
321 func (s *clientRetrySuite) TestNetworkError(c *check.C) {
322         // Close the stub server to produce a "connection refused" error.
323         s.server.Close()
324
325         start := time.Now()
326         timeout := time.Second
327         ctx, cancel := context.WithDeadline(context.Background(), start.Add(timeout))
328         defer cancel()
329         s.client.Timeout = timeout * 2
330         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
331         c.Check(err, check.ErrorMatches, `.*dial tcp .* connection refused.*`)
332         delta := time.Since(start)
333         c.Check(delta > timeout, check.Equals, true, check.Commentf("time.Since(start) == %v, timeout = %v", delta, timeout))
334 }
335
336 func (s *clientRetrySuite) TestNonRetryableError(c *check.C) {
337         s.respStatus <- http.StatusBadRequest
338         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
339         c.Check(err, check.ErrorMatches, `.*400 Bad Request.*`)
340         c.Check(s.reqs, check.HasLen, 1)
341 }
342
343 func (s *clientRetrySuite) TestNonRetryableAfter503s(c *check.C) {
344         time.AfterFunc(time.Second, func() { s.respStatus <- http.StatusNotFound })
345         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
346         c.Check(err, check.ErrorMatches, `.*404 Not Found.*`)
347 }
348
349 func (s *clientRetrySuite) TestOKAfter503s(c *check.C) {
350         start := time.Now()
351         delay := time.Second
352         time.AfterFunc(delay, func() { s.respStatus <- http.StatusOK })
353         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
354         c.Check(err, check.IsNil)
355         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
356         c.Check(time.Since(start) > delay, check.Equals, true)
357 }
358
359 func (s *clientRetrySuite) TestTimeoutAfter503(c *check.C) {
360         s.respStatus <- http.StatusServiceUnavailable
361         s.respDelay = time.Second * 2
362         s.client.Timeout = time.Second / 2
363         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
364         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
365         c.Check(s.reqs, check.HasLen, 2)
366 }
367
368 func (s *clientRetrySuite) Test503Forever(c *check.C) {
369         err := s.client.RequestAndDecode(&struct{}{}, http.MethodGet, "test", nil, nil)
370         c.Check(err, check.ErrorMatches, `.*503 Service Unavailable.*`)
371         c.Check(len(s.reqs) > 1, check.Equals, true, check.Commentf("len(s.reqs) == %d", len(s.reqs)))
372 }
373
374 func (s *clientRetrySuite) TestContextAlreadyCanceled(c *check.C) {
375         ctx, cancel := context.WithCancel(context.Background())
376         cancel()
377         err := s.client.RequestAndDecodeContext(ctx, &struct{}{}, http.MethodGet, "test", nil, nil)
378         c.Check(err, check.Equals, context.Canceled)
379 }