Merge branch '15107-prefer-domain-for-username'
[arvados.git] / sdk / go / arvados / client_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package arvados
6
7 import (
8         "bytes"
9         "fmt"
10         "io/ioutil"
11         "net/http"
12         "net/url"
13         "sync"
14         "testing"
15         "testing/iotest"
16 )
17
18 type stubTransport struct {
19         Responses map[string]string
20         Requests  []http.Request
21         sync.Mutex
22 }
23
24 func (stub *stubTransport) RoundTrip(req *http.Request) (*http.Response, error) {
25         stub.Lock()
26         stub.Requests = append(stub.Requests, *req)
27         stub.Unlock()
28
29         resp := &http.Response{
30                 Status:     "200 OK",
31                 StatusCode: 200,
32                 Proto:      "HTTP/1.1",
33                 ProtoMajor: 1,
34                 ProtoMinor: 1,
35                 Request:    req,
36         }
37         str := stub.Responses[req.URL.Path]
38         if str == "" {
39                 resp.Status = "404 Not Found"
40                 resp.StatusCode = 404
41                 str = "{}"
42         }
43         buf := bytes.NewBufferString(str)
44         resp.Body = ioutil.NopCloser(buf)
45         resp.ContentLength = int64(buf.Len())
46         return resp, nil
47 }
48
49 type errorTransport struct{}
50
51 func (stub *errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
52         return nil, fmt.Errorf("something awful happened")
53 }
54
55 type timeoutTransport struct {
56         response []byte
57 }
58
59 func (stub *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
60         return &http.Response{
61                 Status:     "200 OK",
62                 StatusCode: 200,
63                 Proto:      "HTTP/1.1",
64                 ProtoMajor: 1,
65                 ProtoMinor: 1,
66                 Request:    req,
67                 Body:       ioutil.NopCloser(iotest.TimeoutReader(bytes.NewReader(stub.response))),
68         }, nil
69 }
70
71 func TestCurrentUser(t *testing.T) {
72         t.Parallel()
73         stub := &stubTransport{
74                 Responses: map[string]string{
75                         "/arvados/v1/users/current": `{"uuid":"zzzzz-abcde-012340123401234"}`,
76                 },
77         }
78         c := &Client{
79                 Client: &http.Client{
80                         Transport: stub,
81                 },
82                 APIHost:   "zzzzz.arvadosapi.com",
83                 AuthToken: "xyzzy",
84         }
85         u, err := c.CurrentUser()
86         if err != nil {
87                 t.Fatal(err)
88         }
89         if x := "zzzzz-abcde-012340123401234"; u.UUID != x {
90                 t.Errorf("got uuid %q, expected %q", u.UUID, x)
91         }
92         if len(stub.Requests) < 1 {
93                 t.Fatal("empty stub.Requests")
94         }
95         hdr := stub.Requests[len(stub.Requests)-1].Header
96         if hdr.Get("Authorization") != "OAuth2 xyzzy" {
97                 t.Errorf("got headers %+q, expected Authorization header", hdr)
98         }
99
100         c.Client.Transport = &errorTransport{}
101         u, err = c.CurrentUser()
102         if err == nil {
103                 t.Errorf("got nil error, expected something awful")
104         }
105 }
106
107 func TestAnythingToValues(t *testing.T) {
108         type testCase struct {
109                 in interface{}
110                 // ok==nil means anythingToValues should return an
111                 // error, otherwise it's a func that returns true if
112                 // out is correct
113                 ok func(out url.Values) bool
114         }
115         for _, tc := range []testCase{
116                 {
117                         in: map[string]interface{}{"foo": "bar"},
118                         ok: func(out url.Values) bool {
119                                 return out.Get("foo") == "bar"
120                         },
121                 },
122                 {
123                         in: map[string]interface{}{"foo": 2147483647},
124                         ok: func(out url.Values) bool {
125                                 return out.Get("foo") == "2147483647"
126                         },
127                 },
128                 {
129                         in: map[string]interface{}{"foo": 1.234},
130                         ok: func(out url.Values) bool {
131                                 return out.Get("foo") == "1.234"
132                         },
133                 },
134                 {
135                         in: map[string]interface{}{"foo": "1.234"},
136                         ok: func(out url.Values) bool {
137                                 return out.Get("foo") == "1.234"
138                         },
139                 },
140                 {
141                         in: map[string]interface{}{"foo": map[string]interface{}{"bar": 1.234}},
142                         ok: func(out url.Values) bool {
143                                 return out.Get("foo") == `{"bar":1.234}`
144                         },
145                 },
146                 {
147                         in: url.Values{"foo": {"bar"}},
148                         ok: func(out url.Values) bool {
149                                 return out.Get("foo") == "bar"
150                         },
151                 },
152                 {
153                         in: 1234,
154                         ok: nil,
155                 },
156                 {
157                         in: []string{"foo"},
158                         ok: nil,
159                 },
160         } {
161                 t.Logf("%#v", tc.in)
162                 out, err := anythingToValues(tc.in)
163                 switch {
164                 case tc.ok == nil:
165                         if err == nil {
166                                 t.Errorf("got %#v, expected error", out)
167                         }
168                 case err != nil:
169                         t.Errorf("got err %#v, expected nil", err)
170                 case !tc.ok(out):
171                         t.Errorf("got %#v but tc.ok() says that is wrong", out)
172                 }
173         }
174 }