19099: Update tests to new crunchrun.Gateway fields.
[arvados.git] / lib / controller / localdb / container_gateway_test.go
index aff569b0988d177b1be42eca9fdffc33a40d55d0..271760420153481daac1f0f129a63c684591b94b 100644 (file)
@@ -10,6 +10,8 @@ import (
        "crypto/sha256"
        "fmt"
        "io"
+       "io/ioutil"
+       "net"
        "time"
 
        "git.arvados.org/arvados.git/lib/config"
@@ -18,6 +20,7 @@ import (
        "git.arvados.org/arvados.git/sdk/go/arvadostest"
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "golang.org/x/crypto/ssh"
        check "gopkg.in/check.v1"
 )
 
@@ -53,11 +56,11 @@ func (s *ContainerGatewaySuite) SetUpSuite(c *check.C) {
        authKey := fmt.Sprintf("%x", h.Sum(nil))
 
        s.gw = &crunchrun.Gateway{
-               DockerContainerID: new(string),
-               ContainerUUID:     s.ctrUUID,
-               AuthSecret:        authKey,
-               Address:           "localhost:0",
-               Log:               ctxlog.TestLogger(c),
+               ContainerUUID: s.ctrUUID,
+               AuthSecret:    authKey,
+               Address:       "localhost:0",
+               Log:           ctxlog.TestLogger(c),
+               Target:        crunchrun.GatewayTargetStub{},
        }
        c.Assert(s.gw.Start(), check.IsNil)
        rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
@@ -120,6 +123,68 @@ func (s *ContainerGatewaySuite) TestConfig(c *check.C) {
        }
 }
 
+func (s *ContainerGatewaySuite) TestDirectTCP(c *check.C) {
+       // Set up servers on a few TCP ports
+       var addrs []string
+       for i := 0; i < 3; i++ {
+               ln, err := net.Listen("tcp", ":0")
+               c.Assert(err, check.IsNil)
+               defer ln.Close()
+               addrs = append(addrs, ln.Addr().String())
+               go func() {
+                       for {
+                               conn, err := ln.Accept()
+                               if err != nil {
+                                       return
+                               }
+                               var gotAddr string
+                               fmt.Fscanf(conn, "%s\n", &gotAddr)
+                               c.Logf("stub server listening at %s received string %q from remote %s", ln.Addr().String(), gotAddr, conn.RemoteAddr())
+                               if gotAddr == ln.Addr().String() {
+                                       fmt.Fprintf(conn, "%s\n", ln.Addr().String())
+                               }
+                               conn.Close()
+                       }
+               }()
+       }
+
+       c.Logf("connecting to %s", s.gw.Address)
+       sshconn, err := s.localdb.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
+       c.Assert(err, check.IsNil)
+       c.Assert(sshconn.Conn, check.NotNil)
+       defer sshconn.Conn.Close()
+       conn, chans, reqs, err := ssh.NewClientConn(sshconn.Conn, "zzzz-dz642-abcdeabcdeabcde", &ssh.ClientConfig{
+               HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil },
+       })
+       c.Assert(err, check.IsNil)
+       client := ssh.NewClient(conn, chans, reqs)
+       for _, expectAddr := range addrs {
+               _, port, err := net.SplitHostPort(expectAddr)
+               c.Assert(err, check.IsNil)
+
+               c.Logf("trying foo:%s", port)
+               {
+                       conn, err := client.Dial("tcp", "foo:"+port)
+                       c.Assert(err, check.IsNil)
+                       conn.SetDeadline(time.Now().Add(time.Second))
+                       buf, err := ioutil.ReadAll(conn)
+                       c.Check(err, check.IsNil)
+                       c.Check(string(buf), check.Equals, "")
+               }
+
+               c.Logf("trying localhost:%s", port)
+               {
+                       conn, err := client.Dial("tcp", "localhost:"+port)
+                       c.Assert(err, check.IsNil)
+                       conn.SetDeadline(time.Now().Add(time.Second))
+                       conn.Write([]byte(expectAddr + "\n"))
+                       var gotAddr string
+                       fmt.Fscanf(conn, "%s\n", &gotAddr)
+                       c.Check(gotAddr, check.Equals, expectAddr)
+               }
+       }
+}
+
 func (s *ContainerGatewaySuite) TestConnect(c *check.C) {
        c.Logf("connecting to %s", s.gw.Address)
        sshconn, err := s.localdb.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
@@ -144,10 +209,9 @@ func (s *ContainerGatewaySuite) TestConnect(c *check.C) {
                // Receive binary
                _, err = io.ReadFull(sshconn.Conn, buf[:4])
                c.Check(err, check.IsNil)
-               c.Check(buf[:4], check.DeepEquals, []byte{0, 0, 1, 0xfc})
 
                // If we can get this far into an SSH handshake...
-               c.Log("success, tunnel is working")
+               c.Logf("was able to read %x -- success, tunnel is working", buf[:4])
        }()
        select {
        case <-done: