20183: Deduplicate test suite setup.
[arvados.git] / lib / controller / localdb / container_gateway_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package localdb
6
7 import (
8         "context"
9         "crypto/hmac"
10         "crypto/sha256"
11         "fmt"
12         "io"
13         "io/ioutil"
14         "net"
15         "net/http/httptest"
16         "net/url"
17         "strings"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/controller/router"
21         "git.arvados.org/arvados.git/lib/controller/rpc"
22         "git.arvados.org/arvados.git/lib/crunchrun"
23         "git.arvados.org/arvados.git/sdk/go/arvados"
24         "git.arvados.org/arvados.git/sdk/go/arvadostest"
25         "git.arvados.org/arvados.git/sdk/go/auth"
26         "git.arvados.org/arvados.git/sdk/go/ctxlog"
27         "golang.org/x/crypto/ssh"
28         check "gopkg.in/check.v1"
29 )
30
31 var _ = check.Suite(&ContainerGatewaySuite{})
32
33 type ContainerGatewaySuite struct {
34         localdbSuite
35         ctrUUID string
36         gw      *crunchrun.Gateway
37 }
38
39 func (s *ContainerGatewaySuite) SetUpTest(c *check.C) {
40         s.localdbSuite.SetUpTest(c)
41         s.ctx = auth.NewContext(s.ctx, &auth.Credentials{Tokens: []string{arvadostest.ActiveTokenV2}})
42
43         s.ctrUUID = arvadostest.QueuedContainerUUID
44
45         h := hmac.New(sha256.New, []byte(s.cluster.SystemRootToken))
46         fmt.Fprint(h, s.ctrUUID)
47         authKey := fmt.Sprintf("%x", h.Sum(nil))
48
49         rtr := router.New(s.localdb, router.Config{})
50         srv := httptest.NewUnstartedServer(rtr)
51         srv.StartTLS()
52         // the test setup doesn't use lib/service so
53         // service.URLFromContext() returns nothing -- instead, this
54         // is how we advertise our internal URL and enable
55         // proxy-to-other-controller mode,
56         forceInternalURLForTest = &arvados.URL{Scheme: "https", Host: srv.Listener.Addr().String()}
57         ac := &arvados.Client{
58                 APIHost:   srv.Listener.Addr().String(),
59                 AuthToken: arvadostest.Dispatch1Token,
60                 Insecure:  true,
61         }
62         s.gw = &crunchrun.Gateway{
63                 ContainerUUID: s.ctrUUID,
64                 AuthSecret:    authKey,
65                 Address:       "localhost:0",
66                 Log:           ctxlog.TestLogger(c),
67                 Target:        crunchrun.GatewayTargetStub{},
68                 ArvadosClient: ac,
69         }
70         c.Assert(s.gw.Start(), check.IsNil)
71         rootctx := auth.NewContext(s.ctx, &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
72         // OK if this line fails (because state is already Running
73         // from a previous test case) as long as the following line
74         // succeeds:
75         s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
76                 UUID: s.ctrUUID,
77                 Attrs: map[string]interface{}{
78                         "state": arvados.ContainerStateLocked}})
79         _, err := s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
80                 UUID: s.ctrUUID,
81                 Attrs: map[string]interface{}{
82                         "state":           arvados.ContainerStateRunning,
83                         "gateway_address": s.gw.Address}})
84         c.Assert(err, check.IsNil)
85
86         s.cluster.Containers.ShellAccess.Admin = true
87         s.cluster.Containers.ShellAccess.User = true
88         _, err = s.db.Exec(`update containers set interactive_session_started=$1 where uuid=$2`, false, s.ctrUUID)
89         c.Check(err, check.IsNil)
90 }
91
92 func (s *ContainerGatewaySuite) TestConfig(c *check.C) {
93         for _, trial := range []struct {
94                 configAdmin bool
95                 configUser  bool
96                 sendToken   string
97                 errorCode   int
98         }{
99                 {true, true, arvadostest.ActiveTokenV2, 0},
100                 {true, false, arvadostest.ActiveTokenV2, 503},
101                 {false, true, arvadostest.ActiveTokenV2, 0},
102                 {false, false, arvadostest.ActiveTokenV2, 503},
103                 {true, true, arvadostest.AdminToken, 0},
104                 {true, false, arvadostest.AdminToken, 0},
105                 {false, true, arvadostest.AdminToken, 403},
106                 {false, false, arvadostest.AdminToken, 503},
107         } {
108                 c.Logf("trial %#v", trial)
109                 s.cluster.Containers.ShellAccess.Admin = trial.configAdmin
110                 s.cluster.Containers.ShellAccess.User = trial.configUser
111                 ctx := auth.NewContext(s.ctx, &auth.Credentials{Tokens: []string{trial.sendToken}})
112                 sshconn, err := s.localdb.ContainerSSH(ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
113                 if trial.errorCode == 0 {
114                         if !c.Check(err, check.IsNil) {
115                                 continue
116                         }
117                         if !c.Check(sshconn.Conn, check.NotNil) {
118                                 continue
119                         }
120                         sshconn.Conn.Close()
121                 } else {
122                         c.Check(err, check.NotNil)
123                         err, ok := err.(interface{ HTTPStatus() int })
124                         if c.Check(ok, check.Equals, true) {
125                                 c.Check(err.HTTPStatus(), check.Equals, trial.errorCode)
126                         }
127                 }
128         }
129 }
130
131 func (s *ContainerGatewaySuite) TestDirectTCP(c *check.C) {
132         // Set up servers on a few TCP ports
133         var addrs []string
134         for i := 0; i < 3; i++ {
135                 ln, err := net.Listen("tcp", ":0")
136                 c.Assert(err, check.IsNil)
137                 defer ln.Close()
138                 addrs = append(addrs, ln.Addr().String())
139                 go func() {
140                         for {
141                                 conn, err := ln.Accept()
142                                 if err != nil {
143                                         return
144                                 }
145                                 var gotAddr string
146                                 fmt.Fscanf(conn, "%s\n", &gotAddr)
147                                 c.Logf("stub server listening at %s received string %q from remote %s", ln.Addr().String(), gotAddr, conn.RemoteAddr())
148                                 if gotAddr == ln.Addr().String() {
149                                         fmt.Fprintf(conn, "%s\n", ln.Addr().String())
150                                 }
151                                 conn.Close()
152                         }
153                 }()
154         }
155
156         c.Logf("connecting to %s", s.gw.Address)
157         sshconn, err := s.localdb.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
158         c.Assert(err, check.IsNil)
159         c.Assert(sshconn.Conn, check.NotNil)
160         defer sshconn.Conn.Close()
161         conn, chans, reqs, err := ssh.NewClientConn(sshconn.Conn, "zzzz-dz642-abcdeabcdeabcde", &ssh.ClientConfig{
162                 HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil },
163         })
164         c.Assert(err, check.IsNil)
165         client := ssh.NewClient(conn, chans, reqs)
166         for _, expectAddr := range addrs {
167                 _, port, err := net.SplitHostPort(expectAddr)
168                 c.Assert(err, check.IsNil)
169
170                 c.Logf("trying foo:%s", port)
171                 {
172                         conn, err := client.Dial("tcp", "foo:"+port)
173                         c.Assert(err, check.IsNil)
174                         conn.SetDeadline(time.Now().Add(time.Second))
175                         buf, err := ioutil.ReadAll(conn)
176                         c.Check(err, check.IsNil)
177                         c.Check(string(buf), check.Equals, "")
178                 }
179
180                 c.Logf("trying localhost:%s", port)
181                 {
182                         conn, err := client.Dial("tcp", "localhost:"+port)
183                         c.Assert(err, check.IsNil)
184                         conn.SetDeadline(time.Now().Add(time.Second))
185                         conn.Write([]byte(expectAddr + "\n"))
186                         var gotAddr string
187                         fmt.Fscanf(conn, "%s\n", &gotAddr)
188                         c.Check(gotAddr, check.Equals, expectAddr)
189                 }
190         }
191 }
192
193 func (s *ContainerGatewaySuite) TestConnect(c *check.C) {
194         c.Logf("connecting to %s", s.gw.Address)
195         sshconn, err := s.localdb.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
196         c.Assert(err, check.IsNil)
197         c.Assert(sshconn.Conn, check.NotNil)
198         defer sshconn.Conn.Close()
199
200         done := make(chan struct{})
201         go func() {
202                 defer close(done)
203
204                 // Receive text banner
205                 buf := make([]byte, 12)
206                 _, err := io.ReadFull(sshconn.Conn, buf)
207                 c.Check(err, check.IsNil)
208                 c.Check(string(buf), check.Equals, "SSH-2.0-Go\r\n")
209
210                 // Send text banner
211                 _, err = sshconn.Conn.Write([]byte("SSH-2.0-Fake\r\n"))
212                 c.Check(err, check.IsNil)
213
214                 // Receive binary
215                 _, err = io.ReadFull(sshconn.Conn, buf[:4])
216                 c.Check(err, check.IsNil)
217
218                 // If we can get this far into an SSH handshake...
219                 c.Logf("was able to read %x -- success, tunnel is working", buf[:4])
220         }()
221         select {
222         case <-done:
223         case <-time.After(time.Second):
224                 c.Fail()
225         }
226         ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
227         c.Check(err, check.IsNil)
228         c.Check(ctr.InteractiveSessionStarted, check.Equals, true)
229 }
230
231 func (s *ContainerGatewaySuite) TestConnectFail(c *check.C) {
232         c.Log("trying with no token")
233         ctx := auth.NewContext(context.Background(), &auth.Credentials{})
234         _, err := s.localdb.ContainerSSH(ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
235         c.Check(err, check.ErrorMatches, `.* 401 .*`)
236
237         c.Log("trying with anonymous token")
238         ctx = auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{arvadostest.AnonymousToken}})
239         _, err = s.localdb.ContainerSSH(ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
240         c.Check(err, check.ErrorMatches, `.* 404 .*`)
241 }
242
243 func (s *ContainerGatewaySuite) TestCreateTunnel(c *check.C) {
244         // no AuthSecret
245         conn, err := s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
246                 UUID: s.ctrUUID,
247         })
248         c.Check(err, check.ErrorMatches, `authentication error`)
249         c.Check(conn.Conn, check.IsNil)
250
251         // bogus AuthSecret
252         conn, err = s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
253                 UUID:       s.ctrUUID,
254                 AuthSecret: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
255         })
256         c.Check(err, check.ErrorMatches, `authentication error`)
257         c.Check(conn.Conn, check.IsNil)
258
259         // good AuthSecret
260         conn, err = s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
261                 UUID:       s.ctrUUID,
262                 AuthSecret: s.gw.AuthSecret,
263         })
264         c.Check(err, check.IsNil)
265         c.Check(conn.Conn, check.NotNil)
266 }
267
268 func (s *ContainerGatewaySuite) TestConnectThroughTunnelWithProxyOK(c *check.C) {
269         forceProxyForTest = true
270         defer func() { forceProxyForTest = false }()
271         s.cluster.Services.Controller.InternalURLs[*forceInternalURLForTest] = arvados.ServiceInstance{}
272         defer delete(s.cluster.Services.Controller.InternalURLs, *forceInternalURLForTest)
273         s.testConnectThroughTunnel(c, "")
274 }
275
276 func (s *ContainerGatewaySuite) TestConnectThroughTunnelWithProxyError(c *check.C) {
277         forceProxyForTest = true
278         defer func() { forceProxyForTest = false }()
279         // forceInternalURLForTest shouldn't be used because it isn't
280         // listed in s.cluster.Services.Controller.InternalURLs
281         s.testConnectThroughTunnel(c, `.*tunnel endpoint is invalid.*`)
282 }
283
284 func (s *ContainerGatewaySuite) TestConnectThroughTunnelNoProxyOK(c *check.C) {
285         s.testConnectThroughTunnel(c, "")
286 }
287
288 func (s *ContainerGatewaySuite) testConnectThroughTunnel(c *check.C, expectErrorMatch string) {
289         rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
290         // Until the tunnel starts up, set gateway_address to a value
291         // that can't work. We want to ensure the only way we can
292         // reach the gateway is through the tunnel.
293         tungw := &crunchrun.Gateway{
294                 ContainerUUID: s.ctrUUID,
295                 AuthSecret:    s.gw.AuthSecret,
296                 Log:           ctxlog.TestLogger(c),
297                 Target:        crunchrun.GatewayTargetStub{},
298                 ArvadosClient: s.gw.ArvadosClient,
299                 UpdateTunnelURL: func(url string) {
300                         c.Logf("UpdateTunnelURL(%q)", url)
301                         gwaddr := "tunnel " + url
302                         s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
303                                 UUID: s.ctrUUID,
304                                 Attrs: map[string]interface{}{
305                                         "gateway_address": gwaddr}})
306                 },
307         }
308         c.Assert(tungw.Start(), check.IsNil)
309
310         // We didn't supply an external hostname in the Address field,
311         // so Start() should assign a local address.
312         host, _, err := net.SplitHostPort(tungw.Address)
313         c.Assert(err, check.IsNil)
314         c.Check(host, check.Equals, "127.0.0.1")
315
316         _, err = s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
317                 UUID: s.ctrUUID,
318                 Attrs: map[string]interface{}{
319                         "state": arvados.ContainerStateRunning,
320                 }})
321         c.Assert(err, check.IsNil)
322
323         for deadline := time.Now().Add(5 * time.Second); time.Now().Before(deadline); time.Sleep(time.Second / 2) {
324                 ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
325                 c.Assert(err, check.IsNil)
326                 c.Check(ctr.InteractiveSessionStarted, check.Equals, false)
327                 c.Logf("ctr.GatewayAddress == %s", ctr.GatewayAddress)
328                 if strings.HasPrefix(ctr.GatewayAddress, "tunnel ") {
329                         break
330                 }
331         }
332
333         c.Log("connecting to gateway through tunnel")
334         arpc := rpc.NewConn("", &url.URL{Scheme: "https", Host: s.gw.ArvadosClient.APIHost}, true, rpc.PassthroughTokenProvider)
335         sshconn, err := arpc.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
336         if expectErrorMatch != "" {
337                 c.Check(err, check.ErrorMatches, expectErrorMatch)
338                 return
339         }
340         c.Assert(err, check.IsNil)
341         c.Assert(sshconn.Conn, check.NotNil)
342         defer sshconn.Conn.Close()
343
344         done := make(chan struct{})
345         go func() {
346                 defer close(done)
347
348                 // Receive text banner
349                 buf := make([]byte, 12)
350                 _, err := io.ReadFull(sshconn.Conn, buf)
351                 c.Check(err, check.IsNil)
352                 c.Check(string(buf), check.Equals, "SSH-2.0-Go\r\n")
353
354                 // Send text banner
355                 _, err = sshconn.Conn.Write([]byte("SSH-2.0-Fake\r\n"))
356                 c.Check(err, check.IsNil)
357
358                 // Receive binary
359                 _, err = io.ReadFull(sshconn.Conn, buf[:4])
360                 c.Check(err, check.IsNil)
361
362                 // If we can get this far into an SSH handshake...
363                 c.Logf("was able to read %x -- success, tunnel is working", buf[:4])
364         }()
365         select {
366         case <-done:
367         case <-time.After(time.Second):
368                 c.Fail()
369         }
370         ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
371         c.Check(err, check.IsNil)
372         c.Check(ctr.InteractiveSessionStarted, check.Equals, true)
373 }