15275: Avoids running manage_versioning when not needed.
[arvados.git] / lib / dispatchcloud / ssh_executor / executor_test.go
index 8dabfecad86451d6d3b178b50becffff46fa179e..e7c023586b4bb3c09ac8968c35c2cc3f1ed01ee2 100644 (file)
@@ -6,8 +6,10 @@ package ssh_executor
 
 import (
        "bytes"
+       "fmt"
        "io"
        "io/ioutil"
+       "net"
        "sync"
        "testing"
        "time"
@@ -32,17 +34,72 @@ func (*testTarget) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
        return nil
 }
 
+// Address returns the wrapped SSHService's host, with the port
+// stripped. This ensures the executor won't work until
+// SetTargetPort() is called -- see (*testTarget)Port().
+func (tt *testTarget) Address() string {
+       h, _, err := net.SplitHostPort(tt.SSHService.Address())
+       if err != nil {
+               panic(err)
+       }
+       return h
+}
+
+func (tt *testTarget) Port() string {
+       _, p, err := net.SplitHostPort(tt.SSHService.Address())
+       if err != nil {
+               panic(err)
+       }
+       return p
+}
+
+type mitmTarget struct {
+       test.SSHService
+}
+
+func (*mitmTarget) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
+       return fmt.Errorf("host key failed verification: %#v", key)
+}
+
 type ExecutorSuite struct{}
 
+func (s *ExecutorSuite) TestBadHostKey(c *check.C) {
+       _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
+       clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
+       target := &mitmTarget{
+               SSHService: test.SSHService{
+                       Exec: func(map[string]string, string, io.Reader, io.Writer, io.Writer) uint32 {
+                               c.Error("Target Exec func called even though host key verification failed")
+                               return 0
+                       },
+                       HostKey:        hostpriv,
+                       AuthorizedUser: "username",
+                       AuthorizedKeys: []ssh.PublicKey{clientpub},
+               },
+       }
+
+       err := target.Start()
+       c.Check(err, check.IsNil)
+       c.Logf("target address %q", target.Address())
+       defer target.Close()
+
+       exr := New(target)
+       exr.SetSigners(clientpriv)
+
+       _, _, err = exr.Execute(nil, "true", nil)
+       c.Check(err, check.ErrorMatches, "host key failed verification: .*")
+}
+
 func (s *ExecutorSuite) TestExecute(c *check.C) {
        command := `foo 'bar' "baz"`
        stdinData := "foobar\nbaz\n"
        _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
        clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
        for _, exitcode := range []int{0, 1, 2} {
-               srv := &testTarget{
+               target := &testTarget{
                        SSHService: test.SSHService{
-                               Exec: func(cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
+                               Exec: func(env map[string]string, cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
+                                       c.Check(env["TESTVAR"], check.Equals, "test value")
                                        c.Check(cmd, check.Equals, command)
                                        var wg sync.WaitGroup
                                        wg.Add(2)
@@ -65,20 +122,37 @@ func (s *ExecutorSuite) TestExecute(c *check.C) {
                                        return uint32(exitcode)
                                },
                                HostKey:        hostpriv,
+                               AuthorizedUser: "username",
                                AuthorizedKeys: []ssh.PublicKey{clientpub},
                        },
                }
-               err := srv.Start()
+               err := target.Start()
                c.Check(err, check.IsNil)
-               c.Logf("srv address %q", srv.Address())
-               defer srv.Close()
+               c.Logf("target address %q", target.Address())
+               defer target.Close()
 
-               exr := New(srv)
+               exr := New(target)
                exr.SetSigners(clientpriv)
 
+               // Use the default target port (ssh). Execute will
+               // return a connection error or an authentication
+               // error, depending on whether the test host is
+               // running an SSH server.
+               _, _, err = exr.Execute(nil, command, nil)
+               c.Check(err, check.ErrorMatches, `.*(unable to authenticate|connection refused).*`)
+
+               // Use a bogus target port. Execute will return a
+               // connection error.
+               exr.SetTargetPort("0")
+               _, _, err = exr.Execute(nil, command, nil)
+               c.Check(err, check.ErrorMatches, `.*connection refused.*`)
+
+               // Use the test server's listening port.
+               exr.SetTargetPort(target.Port())
+
                done := make(chan bool)
                go func() {
-                       stdout, stderr, err := exr.Execute(command, bytes.NewBufferString(stdinData))
+                       stdout, stderr, err := exr.Execute(map[string]string{"TESTVAR": "test value"}, command, bytes.NewBufferString(stdinData))
                        if exitcode == 0 {
                                c.Check(err, check.IsNil)
                        } else {