import (
"bytes"
+ "fmt"
"io"
"io/ioutil"
+ "net"
"sync"
"testing"
"time"
return nil
}
+// Address returns the wrapped SSHService's host, with the port
+// stripped. This ensures the executor won't work until
+// SetTargetPort() is called -- see (*testTarget)Port().
+func (tt *testTarget) Address() string {
+ h, _, err := net.SplitHostPort(tt.SSHService.Address())
+ if err != nil {
+ panic(err)
+ }
+ return h
+}
+
+func (tt *testTarget) Port() string {
+ _, p, err := net.SplitHostPort(tt.SSHService.Address())
+ if err != nil {
+ panic(err)
+ }
+ return p
+}
+
+type mitmTarget struct {
+ test.SSHService
+}
+
+func (*mitmTarget) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
+ return fmt.Errorf("host key failed verification: %#v", key)
+}
+
type ExecutorSuite struct{}
+func (s *ExecutorSuite) TestBadHostKey(c *check.C) {
+ _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
+ clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
+ target := &mitmTarget{
+ SSHService: test.SSHService{
+ Exec: func(map[string]string, string, io.Reader, io.Writer, io.Writer) uint32 {
+ c.Error("Target Exec func called even though host key verification failed")
+ return 0
+ },
+ HostKey: hostpriv,
+ AuthorizedUser: "username",
+ AuthorizedKeys: []ssh.PublicKey{clientpub},
+ },
+ }
+
+ err := target.Start()
+ c.Check(err, check.IsNil)
+ c.Logf("target address %q", target.Address())
+ defer target.Close()
+
+ exr := New(target)
+ exr.SetSigners(clientpriv)
+
+ _, _, err = exr.Execute(nil, "true", nil)
+ c.Check(err, check.ErrorMatches, "host key failed verification: .*")
+}
+
func (s *ExecutorSuite) TestExecute(c *check.C) {
command := `foo 'bar' "baz"`
stdinData := "foobar\nbaz\n"
_, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
for _, exitcode := range []int{0, 1, 2} {
- srv := &testTarget{
+ target := &testTarget{
SSHService: test.SSHService{
- Exec: func(cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
+ Exec: func(env map[string]string, cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
+ c.Check(env["TESTVAR"], check.Equals, "test value")
c.Check(cmd, check.Equals, command)
var wg sync.WaitGroup
wg.Add(2)
return uint32(exitcode)
},
HostKey: hostpriv,
+ AuthorizedUser: "username",
AuthorizedKeys: []ssh.PublicKey{clientpub},
},
}
- err := srv.Start()
+ err := target.Start()
c.Check(err, check.IsNil)
- c.Logf("srv address %q", srv.Address())
- defer srv.Close()
+ c.Logf("target address %q", target.Address())
+ defer target.Close()
- exr := New(srv)
+ exr := New(target)
exr.SetSigners(clientpriv)
+ // Use the default target port (ssh). Execute will
+ // return a connection error or an authentication
+ // error, depending on whether the test host is
+ // running an SSH server.
+ _, _, err = exr.Execute(nil, command, nil)
+ c.Check(err, check.ErrorMatches, `.*(unable to authenticate|connection refused).*`)
+
+ // Use a bogus target port. Execute will return a
+ // connection error.
+ exr.SetTargetPort("0")
+ _, _, err = exr.Execute(nil, command, nil)
+ c.Check(err, check.ErrorMatches, `.*connection refused.*`)
+
+ // Use the test server's listening port.
+ exr.SetTargetPort(target.Port())
+
done := make(chan bool)
go func() {
- stdout, stderr, err := exr.Execute(command, bytes.NewBufferString(stdinData))
+ stdout, stderr, err := exr.Execute(map[string]string{"TESTVAR": "test value"}, command, bytes.NewBufferString(stdinData))
if exitcode == 0 {
c.Check(err, check.IsNil)
} else {