1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
18 "git.arvados.org/arvados.git/lib/dispatchcloud/test"
19 "golang.org/x/crypto/ssh"
20 check "gopkg.in/check.v1"
23 // Gocheck boilerplate
24 func Test(t *testing.T) {
28 var _ = check.Suite(&ExecutorSuite{})
30 type testTarget struct {
34 func (*testTarget) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
38 // Address returns the wrapped SSHService's host, with the port
39 // stripped. This ensures the executor won't work until
40 // SetTargetPort() is called -- see (*testTarget)Port().
41 func (tt *testTarget) Address() string {
42 h, _, err := net.SplitHostPort(tt.SSHService.Address())
49 func (tt *testTarget) Port() string {
50 _, p, err := net.SplitHostPort(tt.SSHService.Address())
57 type mitmTarget struct {
61 func (*mitmTarget) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
62 return fmt.Errorf("host key failed verification: %#v", key)
65 type ExecutorSuite struct{}
67 func (s *ExecutorSuite) TestBadHostKey(c *check.C) {
68 _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
69 clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
70 target := &mitmTarget{
71 SSHService: test.SSHService{
72 Exec: func(map[string]string, string, io.Reader, io.Writer, io.Writer) uint32 {
73 c.Error("Target Exec func called even though host key verification failed")
77 AuthorizedUser: "username",
78 AuthorizedKeys: []ssh.PublicKey{clientpub},
83 c.Check(err, check.IsNil)
84 c.Logf("target address %q", target.Address())
88 exr.SetSigners(clientpriv)
90 _, _, err = exr.Execute(nil, "true", nil)
91 c.Check(err, check.ErrorMatches, "host key failed verification: .*")
94 func (s *ExecutorSuite) TestExecute(c *check.C) {
95 command := `foo 'bar' "baz"`
96 stdinData := "foobar\nbaz\n"
97 _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
98 clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
99 for _, exitcode := range []int{0, 1, 2} {
100 target := &testTarget{
101 SSHService: test.SSHService{
102 Exec: func(env map[string]string, cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
103 c.Check(env["TESTVAR"], check.Equals, "test value")
104 c.Check(cmd, check.Equals, command)
105 var wg sync.WaitGroup
108 io.WriteString(stdout, "stdout\n")
112 io.WriteString(stderr, "stderr\n")
115 buf, err := ioutil.ReadAll(stdin)
117 c.Check(err, check.IsNil)
121 _, err = stdout.Write(buf)
122 c.Check(err, check.IsNil)
123 return uint32(exitcode)
126 AuthorizedUser: "username",
127 AuthorizedKeys: []ssh.PublicKey{clientpub},
130 err := target.Start()
131 c.Check(err, check.IsNil)
132 c.Logf("target address %q", target.Address())
136 exr.SetSigners(clientpriv)
138 // Use the default target port (ssh). Execute will
139 // return a connection error or an authentication
140 // error, depending on whether the test host is
141 // running an SSH server.
142 _, _, err = exr.Execute(nil, command, nil)
143 c.Check(err, check.ErrorMatches, `.*(unable to authenticate|connection refused).*`)
145 // Use a bogus target port. Execute will return a
147 exr.SetTargetPort("0")
148 _, _, err = exr.Execute(nil, command, nil)
149 c.Check(err, check.ErrorMatches, `.*connection refused.*`)
150 c.Check(errors.As(err, new(*net.OpError)), check.Equals, true)
152 // Use the test server's listening port.
153 exr.SetTargetPort(target.Port())
155 done := make(chan bool)
157 stdout, stderr, err := exr.Execute(map[string]string{"TESTVAR": "test value"}, command, bytes.NewBufferString(stdinData))
159 c.Check(err, check.IsNil)
161 c.Check(err, check.NotNil)
162 err, ok := err.(*ssh.ExitError)
163 c.Assert(ok, check.Equals, true)
164 c.Check(err.ExitStatus(), check.Equals, exitcode)
166 c.Check(stdout, check.DeepEquals, []byte("stdout\n"+stdinData))
167 c.Check(stderr, check.DeepEquals, []byte("stderr\n"))
171 timeout := time.NewTimer(time.Second)