16217: Refactor ws to use lib/service.
[arvados.git] / services / ws / service_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ws
6
7 import (
8         "bytes"
9         "context"
10         "encoding/json"
11         "flag"
12         "io/ioutil"
13         "net/http"
14         "net/http/httptest"
15         "os"
16         "sync"
17         "time"
18
19         "git.arvados.org/arvados.git/lib/config"
20         "git.arvados.org/arvados.git/lib/service"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "git.arvados.org/arvados.git/sdk/go/arvadostest"
23         "git.arvados.org/arvados.git/sdk/go/ctxlog"
24         "github.com/prometheus/client_golang/prometheus"
25         "github.com/sirupsen/logrus"
26         check "gopkg.in/check.v1"
27 )
28
29 var _ = check.Suite(&serviceSuite{})
30
31 type serviceSuite struct {
32         handler service.Handler
33         srv     *httptest.Server
34         cluster *arvados.Cluster
35         wg      sync.WaitGroup
36 }
37
38 func (s *serviceSuite) SetUpTest(c *check.C) {
39         var err error
40         s.cluster, err = s.testConfig(c)
41         c.Assert(err, check.IsNil)
42 }
43
44 func (s *serviceSuite) start() {
45         s.handler = newHandler(context.Background(), s.cluster, "", prometheus.NewRegistry())
46         s.srv = httptest.NewServer(s.handler)
47 }
48
49 func (s *serviceSuite) TearDownTest(c *check.C) {
50         if s.srv != nil {
51                 s.srv.Close()
52         }
53 }
54
55 func (*serviceSuite) testConfig(c *check.C) (*arvados.Cluster, error) {
56         ldr := config.NewLoader(nil, ctxlog.TestLogger(c))
57         cfg, err := ldr.Load()
58         if err != nil {
59                 return nil, err
60         }
61         cluster, err := cfg.GetCluster("")
62         if err != nil {
63                 return nil, err
64         }
65         client := arvados.NewClientFromEnv()
66         cluster.Services.Controller.ExternalURL.Host = client.APIHost
67         cluster.SystemRootToken = client.AuthToken
68         cluster.TLS.Insecure = client.Insecure
69         cluster.PostgreSQL.Connection = testDBConfig()
70         cluster.Services.Websocket.InternalURLs = map[arvados.URL]arvados.ServiceInstance{arvados.URL{Host: ":"}: arvados.ServiceInstance{}}
71         cluster.ManagementToken = arvadostest.ManagementToken
72         return cluster, nil
73 }
74
75 // TestBadDB ensures the server returns an error (instead of panicking
76 // or deadlocking) if it can't connect to the database server at
77 // startup.
78 func (s *serviceSuite) TestBadDB(c *check.C) {
79         s.cluster.PostgreSQL.Connection["password"] = "1234"
80         s.start()
81         resp, err := http.Get(s.srv.URL)
82         c.Check(err, check.IsNil)
83         c.Check(resp.StatusCode, check.Equals, http.StatusInternalServerError)
84         c.Check(s.handler.CheckHealth(), check.ErrorMatches, "database not connected")
85         c.Check(err, check.IsNil)
86         c.Check(resp.StatusCode, check.Equals, http.StatusInternalServerError)
87 }
88
89 func (s *serviceSuite) TestHealth(c *check.C) {
90         s.start()
91         for _, token := range []string{"", "foo", s.cluster.ManagementToken} {
92                 req, err := http.NewRequest("GET", s.srv.URL+"/_health/ping", nil)
93                 c.Assert(err, check.IsNil)
94                 if token != "" {
95                         req.Header.Add("Authorization", "Bearer "+token)
96                 }
97                 resp, err := http.DefaultClient.Do(req)
98                 c.Check(err, check.IsNil)
99                 if token == s.cluster.ManagementToken {
100                         c.Check(resp.StatusCode, check.Equals, http.StatusOK)
101                         buf, err := ioutil.ReadAll(resp.Body)
102                         c.Check(err, check.IsNil)
103                         c.Check(string(buf), check.Equals, `{"health":"OK"}`+"\n")
104                 } else {
105                         c.Check(resp.StatusCode, check.Not(check.Equals), http.StatusOK)
106                 }
107         }
108 }
109
110 func (s *serviceSuite) TestStatus(c *check.C) {
111         s.start()
112         req, err := http.NewRequest("GET", s.srv.URL+"/status.json", nil)
113         c.Assert(err, check.IsNil)
114         resp, err := http.DefaultClient.Do(req)
115         c.Check(err, check.IsNil)
116         c.Check(resp.StatusCode, check.Equals, http.StatusOK)
117         var status map[string]interface{}
118         err = json.NewDecoder(resp.Body).Decode(&status)
119         c.Check(err, check.IsNil)
120         c.Check(status["Version"], check.Not(check.Equals), "")
121 }
122
123 func (s *serviceSuite) TestHealthDisabled(c *check.C) {
124         s.cluster.ManagementToken = ""
125         s.start()
126         for _, token := range []string{"", "foo", arvadostest.ManagementToken} {
127                 req, err := http.NewRequest("GET", s.srv.URL+"/_health/ping", nil)
128                 c.Assert(err, check.IsNil)
129                 req.Header.Add("Authorization", "Bearer "+token)
130                 resp, err := http.DefaultClient.Do(req)
131                 c.Check(err, check.IsNil)
132                 c.Check(resp.StatusCode, check.Equals, http.StatusNotFound)
133         }
134 }
135
136 func (s *serviceSuite) TestLoadLegacyConfig(c *check.C) {
137         content := []byte(`
138 Client:
139   APIHost: example.com
140   AuthToken: abcdefg
141 Postgres:
142   "dbname": "arvados_production"
143   "user": "arvados"
144   "password": "xyzzy"
145   "host": "localhost"
146   "connect_timeout": "30"
147   "sslmode": "require"
148   "fallback_application_name": "arvados-ws"
149 PostgresPool: 63
150 Listen: ":8765"
151 LogLevel: "debug"
152 LogFormat: "text"
153 PingTimeout: 61s
154 ClientEventQueue: 62
155 ServerEventQueue:  5
156 ManagementToken: qqqqq
157 `)
158         tmpfile, err := ioutil.TempFile("", "example")
159         if err != nil {
160                 c.Error(err)
161         }
162
163         defer os.Remove(tmpfile.Name()) // clean up
164
165         if _, err := tmpfile.Write(content); err != nil {
166                 c.Error(err)
167         }
168         if err := tmpfile.Close(); err != nil {
169                 c.Error(err)
170
171         }
172         ldr := config.NewLoader(&bytes.Buffer{}, logrus.New())
173         flagset := flag.NewFlagSet("", flag.ContinueOnError)
174         ldr.SetupFlags(flagset)
175         flagset.Parse(ldr.MungeLegacyConfigArgs(ctxlog.TestLogger(c), []string{"-config", tmpfile.Name()}, "-legacy-ws-config"))
176         cfg, err := ldr.Load()
177         c.Check(err, check.IsNil)
178         cluster, err := cfg.GetCluster("")
179         c.Check(err, check.IsNil)
180         c.Check(cluster, check.NotNil)
181
182         c.Check(cluster.Services.Controller.ExternalURL, check.Equals, arvados.URL{Scheme: "https", Host: "example.com"})
183         c.Check(cluster.SystemRootToken, check.Equals, "abcdefg")
184
185         c.Check(cluster.PostgreSQL.Connection, check.DeepEquals, arvados.PostgreSQLConnection{
186                 "connect_timeout":           "30",
187                 "dbname":                    "arvados_production",
188                 "fallback_application_name": "arvados-ws",
189                 "host":                      "localhost",
190                 "password":                  "xyzzy",
191                 "sslmode":                   "require",
192                 "user":                      "arvados"})
193         c.Check(cluster.PostgreSQL.ConnectionPool, check.Equals, 63)
194         c.Check(cluster.Services.Websocket.InternalURLs[arvados.URL{Host: ":8765"}], check.NotNil)
195         c.Check(cluster.SystemLogs.LogLevel, check.Equals, "debug")
196         c.Check(cluster.SystemLogs.Format, check.Equals, "text")
197         c.Check(cluster.API.SendTimeout, check.Equals, arvados.Duration(61*time.Second))
198         c.Check(cluster.API.WebsocketClientEventQueue, check.Equals, 62)
199         c.Check(cluster.API.WebsocketServerEventQueue, check.Equals, 5)
200         c.Check(cluster.ManagementToken, check.Equals, "qqqqq")
201 }