X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/e8f99cfef7cfbfcf1a1485d69250f24ced3fd609..0df5f0feeced5bff0adfb806dae2d3811257827f:/lib/dispatchcloud/ssh_executor/executor_test.go diff --git a/lib/dispatchcloud/ssh_executor/executor_test.go b/lib/dispatchcloud/ssh_executor/executor_test.go index 8dabfecad8..e7c023586b 100644 --- a/lib/dispatchcloud/ssh_executor/executor_test.go +++ b/lib/dispatchcloud/ssh_executor/executor_test.go @@ -6,8 +6,10 @@ package ssh_executor import ( "bytes" + "fmt" "io" "io/ioutil" + "net" "sync" "testing" "time" @@ -32,17 +34,72 @@ func (*testTarget) VerifyHostKey(ssh.PublicKey, *ssh.Client) error { return nil } +// Address returns the wrapped SSHService's host, with the port +// stripped. This ensures the executor won't work until +// SetTargetPort() is called -- see (*testTarget)Port(). +func (tt *testTarget) Address() string { + h, _, err := net.SplitHostPort(tt.SSHService.Address()) + if err != nil { + panic(err) + } + return h +} + +func (tt *testTarget) Port() string { + _, p, err := net.SplitHostPort(tt.SSHService.Address()) + if err != nil { + panic(err) + } + return p +} + +type mitmTarget struct { + test.SSHService +} + +func (*mitmTarget) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error { + return fmt.Errorf("host key failed verification: %#v", key) +} + type ExecutorSuite struct{} +func (s *ExecutorSuite) TestBadHostKey(c *check.C) { + _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm") + clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch") + target := &mitmTarget{ + SSHService: test.SSHService{ + Exec: func(map[string]string, string, io.Reader, io.Writer, io.Writer) uint32 { + c.Error("Target Exec func called even though host key verification failed") + return 0 + }, + HostKey: hostpriv, + AuthorizedUser: "username", + AuthorizedKeys: []ssh.PublicKey{clientpub}, + }, + } + + err := target.Start() + c.Check(err, check.IsNil) + c.Logf("target address %q", target.Address()) + defer target.Close() + + exr := New(target) + exr.SetSigners(clientpriv) + + _, _, err = exr.Execute(nil, "true", nil) + c.Check(err, check.ErrorMatches, "host key failed verification: .*") +} + func (s *ExecutorSuite) TestExecute(c *check.C) { command := `foo 'bar' "baz"` stdinData := "foobar\nbaz\n" _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm") clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch") for _, exitcode := range []int{0, 1, 2} { - srv := &testTarget{ + target := &testTarget{ SSHService: test.SSHService{ - Exec: func(cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 { + Exec: func(env map[string]string, cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 { + c.Check(env["TESTVAR"], check.Equals, "test value") c.Check(cmd, check.Equals, command) var wg sync.WaitGroup wg.Add(2) @@ -65,20 +122,37 @@ func (s *ExecutorSuite) TestExecute(c *check.C) { return uint32(exitcode) }, HostKey: hostpriv, + AuthorizedUser: "username", AuthorizedKeys: []ssh.PublicKey{clientpub}, }, } - err := srv.Start() + err := target.Start() c.Check(err, check.IsNil) - c.Logf("srv address %q", srv.Address()) - defer srv.Close() + c.Logf("target address %q", target.Address()) + defer target.Close() - exr := New(srv) + exr := New(target) exr.SetSigners(clientpriv) + // Use the default target port (ssh). Execute will + // return a connection error or an authentication + // error, depending on whether the test host is + // running an SSH server. + _, _, err = exr.Execute(nil, command, nil) + c.Check(err, check.ErrorMatches, `.*(unable to authenticate|connection refused).*`) + + // Use a bogus target port. Execute will return a + // connection error. + exr.SetTargetPort("0") + _, _, err = exr.Execute(nil, command, nil) + c.Check(err, check.ErrorMatches, `.*connection refused.*`) + + // Use the test server's listening port. + exr.SetTargetPort(target.Port()) + done := make(chan bool) go func() { - stdout, stderr, err := exr.Execute(command, bytes.NewBufferString(stdinData)) + stdout, stderr, err := exr.Execute(map[string]string{"TESTVAR": "test value"}, command, bytes.NewBufferString(stdinData)) if exitcode == 0 { c.Check(err, check.IsNil) } else {