19166: Set up tunnel for container gateway requests
[arvados.git] / lib / controller / localdb / container_gateway_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package localdb
6
7 import (
8         "context"
9         "crypto/hmac"
10         "crypto/sha256"
11         "fmt"
12         "io"
13         "io/ioutil"
14         "net"
15         "net/http/httptest"
16         "time"
17
18         "git.arvados.org/arvados.git/lib/config"
19         "git.arvados.org/arvados.git/lib/controller/router"
20         "git.arvados.org/arvados.git/lib/crunchrun"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "git.arvados.org/arvados.git/sdk/go/arvadostest"
23         "git.arvados.org/arvados.git/sdk/go/auth"
24         "git.arvados.org/arvados.git/sdk/go/ctxlog"
25         "golang.org/x/crypto/ssh"
26         check "gopkg.in/check.v1"
27 )
28
29 var _ = check.Suite(&ContainerGatewaySuite{})
30
31 type ContainerGatewaySuite struct {
32         cluster *arvados.Cluster
33         localdb *Conn
34         ctx     context.Context
35         ctrUUID string
36         gw      *crunchrun.Gateway
37 }
38
39 func (s *ContainerGatewaySuite) TearDownSuite(c *check.C) {
40         // Undo any changes/additions to the user database so they
41         // don't affect subsequent tests.
42         arvadostest.ResetEnv()
43         c.Check(arvados.NewClientFromEnv().RequestAndDecode(nil, "POST", "database/reset", nil, nil), check.IsNil)
44 }
45
46 func (s *ContainerGatewaySuite) SetUpSuite(c *check.C) {
47         cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
48         c.Assert(err, check.IsNil)
49         s.cluster, err = cfg.GetCluster("")
50         c.Assert(err, check.IsNil)
51         s.localdb = NewConn(s.cluster)
52         s.ctx = auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{arvadostest.ActiveTokenV2}})
53
54         s.ctrUUID = arvadostest.QueuedContainerUUID
55
56         h := hmac.New(sha256.New, []byte(s.cluster.SystemRootToken))
57         fmt.Fprint(h, s.ctrUUID)
58         authKey := fmt.Sprintf("%x", h.Sum(nil))
59
60         rtr := router.New(s.localdb, router.Config{})
61         srv := httptest.NewUnstartedServer(rtr)
62         srv.StartTLS()
63         ac := &arvados.Client{
64                 APIHost:   srv.Listener.Addr().String(),
65                 AuthToken: arvadostest.Dispatch1Token,
66                 Insecure:  true,
67         }
68         s.gw = &crunchrun.Gateway{
69                 ContainerUUID: s.ctrUUID,
70                 AuthSecret:    authKey,
71                 Address:       "localhost:0",
72                 Log:           ctxlog.TestLogger(c),
73                 Target:        crunchrun.GatewayTargetStub{},
74                 ArvadosClient: ac,
75         }
76         c.Assert(s.gw.Start(), check.IsNil)
77         rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
78         _, err = s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
79                 UUID: s.ctrUUID,
80                 Attrs: map[string]interface{}{
81                         "state": arvados.ContainerStateLocked}})
82         c.Assert(err, check.IsNil)
83 }
84
85 func (s *ContainerGatewaySuite) SetUpTest(c *check.C) {
86         // clear any tunnel sessions started by previous test cases
87         s.localdb.gwTunnelsLock.Lock()
88         s.localdb.gwTunnels = nil
89         s.localdb.gwTunnelsLock.Unlock()
90
91         rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
92         _, err := s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
93                 UUID: s.ctrUUID,
94                 Attrs: map[string]interface{}{
95                         "state":           arvados.ContainerStateRunning,
96                         "gateway_address": s.gw.Address}})
97         c.Assert(err, check.IsNil)
98
99         s.cluster.Containers.ShellAccess.Admin = true
100         s.cluster.Containers.ShellAccess.User = true
101         _, err = arvadostest.DB(c, s.cluster).Exec(`update containers set interactive_session_started=$1 where uuid=$2`, false, s.ctrUUID)
102         c.Check(err, check.IsNil)
103 }
104
105 func (s *ContainerGatewaySuite) TestConfig(c *check.C) {
106         for _, trial := range []struct {
107                 configAdmin bool
108                 configUser  bool
109                 sendToken   string
110                 errorCode   int
111         }{
112                 {true, true, arvadostest.ActiveTokenV2, 0},
113                 {true, false, arvadostest.ActiveTokenV2, 503},
114                 {false, true, arvadostest.ActiveTokenV2, 0},
115                 {false, false, arvadostest.ActiveTokenV2, 503},
116                 {true, true, arvadostest.AdminToken, 0},
117                 {true, false, arvadostest.AdminToken, 0},
118                 {false, true, arvadostest.AdminToken, 403},
119                 {false, false, arvadostest.AdminToken, 503},
120         } {
121                 c.Logf("trial %#v", trial)
122                 s.cluster.Containers.ShellAccess.Admin = trial.configAdmin
123                 s.cluster.Containers.ShellAccess.User = trial.configUser
124                 ctx := auth.NewContext(s.ctx, &auth.Credentials{Tokens: []string{trial.sendToken}})
125                 sshconn, err := s.localdb.ContainerSSH(ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
126                 if trial.errorCode == 0 {
127                         if !c.Check(err, check.IsNil) {
128                                 continue
129                         }
130                         if !c.Check(sshconn.Conn, check.NotNil) {
131                                 continue
132                         }
133                         sshconn.Conn.Close()
134                 } else {
135                         c.Check(err, check.NotNil)
136                         err, ok := err.(interface{ HTTPStatus() int })
137                         if c.Check(ok, check.Equals, true) {
138                                 c.Check(err.HTTPStatus(), check.Equals, trial.errorCode)
139                         }
140                 }
141         }
142 }
143
144 func (s *ContainerGatewaySuite) TestDirectTCP(c *check.C) {
145         // Set up servers on a few TCP ports
146         var addrs []string
147         for i := 0; i < 3; i++ {
148                 ln, err := net.Listen("tcp", ":0")
149                 c.Assert(err, check.IsNil)
150                 defer ln.Close()
151                 addrs = append(addrs, ln.Addr().String())
152                 go func() {
153                         for {
154                                 conn, err := ln.Accept()
155                                 if err != nil {
156                                         return
157                                 }
158                                 var gotAddr string
159                                 fmt.Fscanf(conn, "%s\n", &gotAddr)
160                                 c.Logf("stub server listening at %s received string %q from remote %s", ln.Addr().String(), gotAddr, conn.RemoteAddr())
161                                 if gotAddr == ln.Addr().String() {
162                                         fmt.Fprintf(conn, "%s\n", ln.Addr().String())
163                                 }
164                                 conn.Close()
165                         }
166                 }()
167         }
168
169         c.Logf("connecting to %s", s.gw.Address)
170         sshconn, err := s.localdb.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
171         c.Assert(err, check.IsNil)
172         c.Assert(sshconn.Conn, check.NotNil)
173         defer sshconn.Conn.Close()
174         conn, chans, reqs, err := ssh.NewClientConn(sshconn.Conn, "zzzz-dz642-abcdeabcdeabcde", &ssh.ClientConfig{
175                 HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil },
176         })
177         c.Assert(err, check.IsNil)
178         client := ssh.NewClient(conn, chans, reqs)
179         for _, expectAddr := range addrs {
180                 _, port, err := net.SplitHostPort(expectAddr)
181                 c.Assert(err, check.IsNil)
182
183                 c.Logf("trying foo:%s", port)
184                 {
185                         conn, err := client.Dial("tcp", "foo:"+port)
186                         c.Assert(err, check.IsNil)
187                         conn.SetDeadline(time.Now().Add(time.Second))
188                         buf, err := ioutil.ReadAll(conn)
189                         c.Check(err, check.IsNil)
190                         c.Check(string(buf), check.Equals, "")
191                 }
192
193                 c.Logf("trying localhost:%s", port)
194                 {
195                         conn, err := client.Dial("tcp", "localhost:"+port)
196                         c.Assert(err, check.IsNil)
197                         conn.SetDeadline(time.Now().Add(time.Second))
198                         conn.Write([]byte(expectAddr + "\n"))
199                         var gotAddr string
200                         fmt.Fscanf(conn, "%s\n", &gotAddr)
201                         c.Check(gotAddr, check.Equals, expectAddr)
202                 }
203         }
204 }
205
206 func (s *ContainerGatewaySuite) TestConnect(c *check.C) {
207         c.Logf("connecting to %s", s.gw.Address)
208         sshconn, err := s.localdb.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
209         c.Assert(err, check.IsNil)
210         c.Assert(sshconn.Conn, check.NotNil)
211         defer sshconn.Conn.Close()
212
213         done := make(chan struct{})
214         go func() {
215                 defer close(done)
216
217                 // Receive text banner
218                 buf := make([]byte, 12)
219                 _, err := io.ReadFull(sshconn.Conn, buf)
220                 c.Check(err, check.IsNil)
221                 c.Check(string(buf), check.Equals, "SSH-2.0-Go\r\n")
222
223                 // Send text banner
224                 _, err = sshconn.Conn.Write([]byte("SSH-2.0-Fake\r\n"))
225                 c.Check(err, check.IsNil)
226
227                 // Receive binary
228                 _, err = io.ReadFull(sshconn.Conn, buf[:4])
229                 c.Check(err, check.IsNil)
230
231                 // If we can get this far into an SSH handshake...
232                 c.Logf("was able to read %x -- success, tunnel is working", buf[:4])
233         }()
234         select {
235         case <-done:
236         case <-time.After(time.Second):
237                 c.Fail()
238         }
239         ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
240         c.Check(err, check.IsNil)
241         c.Check(ctr.InteractiveSessionStarted, check.Equals, true)
242 }
243
244 func (s *ContainerGatewaySuite) TestConnectFail(c *check.C) {
245         c.Log("trying with no token")
246         ctx := auth.NewContext(context.Background(), &auth.Credentials{})
247         _, err := s.localdb.ContainerSSH(ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
248         c.Check(err, check.ErrorMatches, `.* 401 .*`)
249
250         c.Log("trying with anonymous token")
251         ctx = auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{arvadostest.AnonymousToken}})
252         _, err = s.localdb.ContainerSSH(ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
253         c.Check(err, check.ErrorMatches, `.* 404 .*`)
254 }
255
256 func (s *ContainerGatewaySuite) TestCreateTunnel(c *check.C) {
257         // no AuthSecret
258         conn, err := s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
259                 UUID: s.ctrUUID,
260         })
261         c.Check(err, check.ErrorMatches, `authentication error`)
262         c.Check(conn.Conn, check.IsNil)
263
264         // bogus AuthSecret
265         conn, err = s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
266                 UUID:       s.ctrUUID,
267                 AuthSecret: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
268         })
269         c.Check(err, check.ErrorMatches, `authentication error`)
270         c.Check(conn.Conn, check.IsNil)
271
272         // good AuthSecret
273         conn, err = s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
274                 UUID:       s.ctrUUID,
275                 AuthSecret: s.gw.AuthSecret,
276         })
277         c.Check(err, check.IsNil)
278         c.Check(conn.Conn, check.NotNil)
279 }
280
281 func (s *ContainerGatewaySuite) TestConnectThroughTunnel(c *check.C) {
282         tungw := &crunchrun.Gateway{
283                 ContainerUUID: s.ctrUUID,
284                 AuthSecret:    s.gw.AuthSecret,
285                 Log:           ctxlog.TestLogger(c),
286                 Target:        crunchrun.GatewayTargetStub{},
287                 ArvadosClient: s.gw.ArvadosClient,
288         }
289         c.Assert(tungw.Start(), check.IsNil)
290
291         // We didn't supply an external hostname in the Address field,
292         // so Start() should assign a local address.
293         host, _, err := net.SplitHostPort(tungw.Address)
294         c.Assert(err, check.IsNil)
295         c.Check(host, check.Equals, "127.0.0.1")
296
297         // Set the gateway_address field to 127.0.0.1:badport to
298         // ensure the ContainerSSH() handler connects through the
299         // tunnel, rather than the gateway server on 127.0.0.1 (which
300         // wouldn't work IRL where controller and gateway are on
301         // different hosts, but would allow the test to cheat).
302         rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
303         _, err = s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
304                 UUID: s.ctrUUID,
305                 Attrs: map[string]interface{}{
306                         "state":           arvados.ContainerStateRunning,
307                         "gateway_address": "127.0.0.1:0"}})
308         c.Assert(err, check.IsNil)
309
310         ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
311         c.Check(err, check.IsNil)
312         c.Check(ctr.InteractiveSessionStarted, check.Equals, false)
313         c.Check(ctr.GatewayAddress, check.Equals, "127.0.0.1:0")
314
315         c.Log("connecting to gateway through tunnel")
316         sshconn, err := s.localdb.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
317         c.Assert(err, check.IsNil)
318         c.Assert(sshconn.Conn, check.NotNil)
319         defer sshconn.Conn.Close()
320
321         done := make(chan struct{})
322         go func() {
323                 defer close(done)
324
325                 // Receive text banner
326                 buf := make([]byte, 12)
327                 _, err := io.ReadFull(sshconn.Conn, buf)
328                 c.Check(err, check.IsNil)
329                 c.Check(string(buf), check.Equals, "SSH-2.0-Go\r\n")
330
331                 // Send text banner
332                 _, err = sshconn.Conn.Write([]byte("SSH-2.0-Fake\r\n"))
333                 c.Check(err, check.IsNil)
334
335                 // Receive binary
336                 _, err = io.ReadFull(sshconn.Conn, buf[:4])
337                 c.Check(err, check.IsNil)
338
339                 // If we can get this far into an SSH handshake...
340                 c.Logf("was able to read %x -- success, tunnel is working", buf[:4])
341         }()
342         select {
343         case <-done:
344         case <-time.After(time.Second):
345                 c.Fail()
346         }
347         ctr, err = s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
348         c.Check(err, check.IsNil)
349         c.Check(ctr.InteractiveSessionStarted, check.Equals, true)
350 }