20032: Fix unnecessary race in test.
[arvados.git] / lib / controller / localdb / container_gateway_test.go
index 2a77357677b3d7f064832d471257866ef356b07a..3f63e7aa8a694359f7404689aee30aac7b0abc3b 100644 (file)
@@ -12,9 +12,14 @@ import (
        "io"
        "io/ioutil"
        "net"
+       "net/http/httptest"
+       "net/url"
+       "strings"
        "time"
 
        "git.arvados.org/arvados.git/lib/config"
+       "git.arvados.org/arvados.git/lib/controller/router"
+       "git.arvados.org/arvados.git/lib/controller/rpc"
        "git.arvados.org/arvados.git/lib/crunchrun"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/arvadostest"
@@ -55,13 +60,26 @@ func (s *ContainerGatewaySuite) SetUpSuite(c *check.C) {
        fmt.Fprint(h, s.ctrUUID)
        authKey := fmt.Sprintf("%x", h.Sum(nil))
 
+       rtr := router.New(s.localdb, router.Config{})
+       srv := httptest.NewUnstartedServer(rtr)
+       srv.StartTLS()
+       // the test setup doesn't use lib/service so
+       // service.URLFromContext() returns nothing -- instead, this
+       // is how we advertise our internal URL and enable
+       // proxy-to-other-controller mode,
+       forceInternalURLForTest = &arvados.URL{Scheme: "https", Host: srv.Listener.Addr().String()}
+       ac := &arvados.Client{
+               APIHost:   srv.Listener.Addr().String(),
+               AuthToken: arvadostest.Dispatch1Token,
+               Insecure:  true,
+       }
        s.gw = &crunchrun.Gateway{
-               DockerContainerID:  new(string),
-               ContainerUUID:      s.ctrUUID,
-               AuthSecret:         authKey,
-               Address:            "localhost:0",
-               Log:                ctxlog.TestLogger(c),
-               ContainerIPAddress: func() (string, error) { return "localhost", nil },
+               ContainerUUID: s.ctrUUID,
+               AuthSecret:    authKey,
+               Address:       "localhost:0",
+               Log:           ctxlog.TestLogger(c),
+               Target:        crunchrun.GatewayTargetStub{},
+               ArvadosClient: ac,
        }
        c.Assert(s.gw.Start(), check.IsNil)
        rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
@@ -70,18 +88,25 @@ func (s *ContainerGatewaySuite) SetUpSuite(c *check.C) {
                Attrs: map[string]interface{}{
                        "state": arvados.ContainerStateLocked}})
        c.Assert(err, check.IsNil)
-       _, err = s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
+}
+
+func (s *ContainerGatewaySuite) SetUpTest(c *check.C) {
+       // clear any tunnel sessions started by previous test cases
+       s.localdb.gwTunnelsLock.Lock()
+       s.localdb.gwTunnels = nil
+       s.localdb.gwTunnelsLock.Unlock()
+
+       rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
+       _, err := s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
                UUID: s.ctrUUID,
                Attrs: map[string]interface{}{
                        "state":           arvados.ContainerStateRunning,
                        "gateway_address": s.gw.Address}})
        c.Assert(err, check.IsNil)
-}
 
-func (s *ContainerGatewaySuite) SetUpTest(c *check.C) {
        s.cluster.Containers.ShellAccess.Admin = true
        s.cluster.Containers.ShellAccess.User = true
-       _, err := arvadostest.DB(c, s.cluster).Exec(`update containers set interactive_session_started=$1 where uuid=$2`, false, s.ctrUUID)
+       _, err = arvadostest.DB(c, s.cluster).Exec(`update containers set interactive_session_started=$1 where uuid=$2`, false, s.ctrUUID)
        c.Check(err, check.IsNil)
 }
 
@@ -210,10 +235,9 @@ func (s *ContainerGatewaySuite) TestConnect(c *check.C) {
                // Receive binary
                _, err = io.ReadFull(sshconn.Conn, buf[:4])
                c.Check(err, check.IsNil)
-               c.Check(buf[:4], check.DeepEquals, []byte{0, 0, 1, 0xfc})
 
                // If we can get this far into an SSH handshake...
-               c.Log("success, tunnel is working")
+               c.Logf("was able to read %x -- success, tunnel is working", buf[:4])
        }()
        select {
        case <-done:
@@ -236,3 +260,135 @@ func (s *ContainerGatewaySuite) TestConnectFail(c *check.C) {
        _, err = s.localdb.ContainerSSH(ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
        c.Check(err, check.ErrorMatches, `.* 404 .*`)
 }
+
+func (s *ContainerGatewaySuite) TestCreateTunnel(c *check.C) {
+       // no AuthSecret
+       conn, err := s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
+               UUID: s.ctrUUID,
+       })
+       c.Check(err, check.ErrorMatches, `authentication error`)
+       c.Check(conn.Conn, check.IsNil)
+
+       // bogus AuthSecret
+       conn, err = s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
+               UUID:       s.ctrUUID,
+               AuthSecret: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
+       })
+       c.Check(err, check.ErrorMatches, `authentication error`)
+       c.Check(conn.Conn, check.IsNil)
+
+       // good AuthSecret
+       conn, err = s.localdb.ContainerGatewayTunnel(s.ctx, arvados.ContainerGatewayTunnelOptions{
+               UUID:       s.ctrUUID,
+               AuthSecret: s.gw.AuthSecret,
+       })
+       c.Check(err, check.IsNil)
+       c.Check(conn.Conn, check.NotNil)
+}
+
+func (s *ContainerGatewaySuite) TestConnectThroughTunnelWithProxyOK(c *check.C) {
+       forceProxyForTest = true
+       defer func() { forceProxyForTest = false }()
+       s.cluster.Services.Controller.InternalURLs[*forceInternalURLForTest] = arvados.ServiceInstance{}
+       defer delete(s.cluster.Services.Controller.InternalURLs, *forceInternalURLForTest)
+       s.testConnectThroughTunnel(c, "")
+}
+
+func (s *ContainerGatewaySuite) TestConnectThroughTunnelWithProxyError(c *check.C) {
+       forceProxyForTest = true
+       defer func() { forceProxyForTest = false }()
+       // forceInternalURLForTest shouldn't be used because it isn't
+       // listed in s.cluster.Services.Controller.InternalURLs
+       s.testConnectThroughTunnel(c, `.*tunnel endpoint is invalid.*`)
+}
+
+func (s *ContainerGatewaySuite) TestConnectThroughTunnelNoProxyOK(c *check.C) {
+       s.testConnectThroughTunnel(c, "")
+}
+
+func (s *ContainerGatewaySuite) testConnectThroughTunnel(c *check.C, expectErrorMatch string) {
+       rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
+       // Until the tunnel starts up, set gateway_address to a value
+       // that can't work. We want to ensure the only way we can
+       // reach the gateway is through the tunnel.
+       tungw := &crunchrun.Gateway{
+               ContainerUUID: s.ctrUUID,
+               AuthSecret:    s.gw.AuthSecret,
+               Log:           ctxlog.TestLogger(c),
+               Target:        crunchrun.GatewayTargetStub{},
+               ArvadosClient: s.gw.ArvadosClient,
+               UpdateTunnelURL: func(url string) {
+                       c.Logf("UpdateTunnelURL(%q)", url)
+                       gwaddr := "tunnel " + url
+                       s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
+                               UUID: s.ctrUUID,
+                               Attrs: map[string]interface{}{
+                                       "gateway_address": gwaddr}})
+               },
+       }
+       c.Assert(tungw.Start(), check.IsNil)
+
+       // We didn't supply an external hostname in the Address field,
+       // so Start() should assign a local address.
+       host, _, err := net.SplitHostPort(tungw.Address)
+       c.Assert(err, check.IsNil)
+       c.Check(host, check.Equals, "127.0.0.1")
+
+       _, err = s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
+               UUID: s.ctrUUID,
+               Attrs: map[string]interface{}{
+                       "state": arvados.ContainerStateRunning,
+               }})
+       c.Assert(err, check.IsNil)
+
+       for deadline := time.Now().Add(5 * time.Second); time.Now().Before(deadline); time.Sleep(time.Second / 2) {
+               ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
+               c.Assert(err, check.IsNil)
+               c.Check(ctr.InteractiveSessionStarted, check.Equals, false)
+               c.Logf("ctr.GatewayAddress == %s", ctr.GatewayAddress)
+               if strings.HasPrefix(ctr.GatewayAddress, "tunnel ") {
+                       break
+               }
+       }
+
+       c.Log("connecting to gateway through tunnel")
+       arpc := rpc.NewConn("", &url.URL{Scheme: "https", Host: s.gw.ArvadosClient.APIHost}, true, rpc.PassthroughTokenProvider)
+       sshconn, err := arpc.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
+       if expectErrorMatch != "" {
+               c.Check(err, check.ErrorMatches, expectErrorMatch)
+               return
+       }
+       c.Assert(err, check.IsNil)
+       c.Assert(sshconn.Conn, check.NotNil)
+       defer sshconn.Conn.Close()
+
+       done := make(chan struct{})
+       go func() {
+               defer close(done)
+
+               // Receive text banner
+               buf := make([]byte, 12)
+               _, err := io.ReadFull(sshconn.Conn, buf)
+               c.Check(err, check.IsNil)
+               c.Check(string(buf), check.Equals, "SSH-2.0-Go\r\n")
+
+               // Send text banner
+               _, err = sshconn.Conn.Write([]byte("SSH-2.0-Fake\r\n"))
+               c.Check(err, check.IsNil)
+
+               // Receive binary
+               _, err = io.ReadFull(sshconn.Conn, buf[:4])
+               c.Check(err, check.IsNil)
+
+               // If we can get this far into an SSH handshake...
+               c.Logf("was able to read %x -- success, tunnel is working", buf[:4])
+       }()
+       select {
+       case <-done:
+       case <-time.After(time.Second):
+               c.Fail()
+       }
+       ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
+       c.Check(err, check.IsNil)
+       c.Check(ctr.InteractiveSessionStarted, check.Equals, true)
+}