21189: Revert exit code 64 to 2 for invalid command line argument.
[arvados.git] / lib / dispatchcloud / sshexecutor / executor_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package sshexecutor
6
7 import (
8         "bytes"
9         "errors"
10         "fmt"
11         "io"
12         "io/ioutil"
13         "net"
14         "sync"
15         "testing"
16         "time"
17
18         "git.arvados.org/arvados.git/lib/dispatchcloud/test"
19         "golang.org/x/crypto/ssh"
20         check "gopkg.in/check.v1"
21 )
22
23 // Gocheck boilerplate
24 func Test(t *testing.T) {
25         check.TestingT(t)
26 }
27
28 var _ = check.Suite(&ExecutorSuite{})
29
30 type testTarget struct {
31         test.SSHService
32 }
33
34 func (*testTarget) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
35         return nil
36 }
37
38 // Address returns the wrapped SSHService's host, with the port
39 // stripped. This ensures the executor won't work until
40 // SetTargetPort() is called -- see (*testTarget)Port().
41 func (tt *testTarget) Address() string {
42         h, _, err := net.SplitHostPort(tt.SSHService.Address())
43         if err != nil {
44                 panic(err)
45         }
46         return h
47 }
48
49 func (tt *testTarget) Port() string {
50         _, p, err := net.SplitHostPort(tt.SSHService.Address())
51         if err != nil {
52                 panic(err)
53         }
54         return p
55 }
56
57 type mitmTarget struct {
58         test.SSHService
59 }
60
61 func (*mitmTarget) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
62         return fmt.Errorf("host key failed verification: %#v", key)
63 }
64
65 type ExecutorSuite struct{}
66
67 func (s *ExecutorSuite) TestBadHostKey(c *check.C) {
68         _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
69         clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
70         target := &mitmTarget{
71                 SSHService: test.SSHService{
72                         Exec: func(map[string]string, string, io.Reader, io.Writer, io.Writer) uint32 {
73                                 c.Error("Target Exec func called even though host key verification failed")
74                                 return 0
75                         },
76                         HostKey:        hostpriv,
77                         AuthorizedUser: "username",
78                         AuthorizedKeys: []ssh.PublicKey{clientpub},
79                 },
80         }
81
82         err := target.Start()
83         c.Check(err, check.IsNil)
84         c.Logf("target address %q", target.Address())
85         defer target.Close()
86
87         exr := New(target)
88         exr.SetSigners(clientpriv)
89
90         _, _, err = exr.Execute(nil, "true", nil)
91         c.Check(err, check.ErrorMatches, "host key failed verification: .*")
92 }
93
94 func (s *ExecutorSuite) TestExecute(c *check.C) {
95         command := `foo 'bar' "baz"`
96         stdinData := "foobar\nbaz\n"
97         _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
98         clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
99         for _, exitcode := range []int{0, 1, 2} {
100                 target := &testTarget{
101                         SSHService: test.SSHService{
102                                 Exec: func(env map[string]string, cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
103                                         c.Check(env["TESTVAR"], check.Equals, "test value")
104                                         c.Check(cmd, check.Equals, command)
105                                         var wg sync.WaitGroup
106                                         wg.Add(2)
107                                         go func() {
108                                                 io.WriteString(stdout, "stdout\n")
109                                                 wg.Done()
110                                         }()
111                                         go func() {
112                                                 io.WriteString(stderr, "stderr\n")
113                                                 wg.Done()
114                                         }()
115                                         buf, err := ioutil.ReadAll(stdin)
116                                         wg.Wait()
117                                         c.Check(err, check.IsNil)
118                                         if err != nil {
119                                                 return 99
120                                         }
121                                         _, err = stdout.Write(buf)
122                                         c.Check(err, check.IsNil)
123                                         return uint32(exitcode)
124                                 },
125                                 HostKey:        hostpriv,
126                                 AuthorizedUser: "username",
127                                 AuthorizedKeys: []ssh.PublicKey{clientpub},
128                         },
129                 }
130                 err := target.Start()
131                 c.Check(err, check.IsNil)
132                 c.Logf("target address %q", target.Address())
133                 defer target.Close()
134
135                 exr := New(target)
136                 exr.SetSigners(clientpriv)
137
138                 // Use the default target port (ssh). Execute will
139                 // return a connection error or an authentication
140                 // error, depending on whether the test host is
141                 // running an SSH server.
142                 _, _, err = exr.Execute(nil, command, nil)
143                 c.Check(err, check.ErrorMatches, `.*(unable to authenticate|connection refused).*`)
144
145                 // Use a bogus target port. Execute will return a
146                 // connection error.
147                 exr.SetTargetPort("0")
148                 _, _, err = exr.Execute(nil, command, nil)
149                 c.Check(err, check.ErrorMatches, `.*connection refused.*`)
150                 c.Check(errors.As(err, new(*net.OpError)), check.Equals, true)
151
152                 // Use the test server's listening port.
153                 exr.SetTargetPort(target.Port())
154
155                 done := make(chan bool)
156                 go func() {
157                         stdout, stderr, err := exr.Execute(map[string]string{"TESTVAR": "test value"}, command, bytes.NewBufferString(stdinData))
158                         if exitcode == 0 {
159                                 c.Check(err, check.IsNil)
160                         } else {
161                                 c.Check(err, check.NotNil)
162                                 err, ok := err.(*ssh.ExitError)
163                                 c.Assert(ok, check.Equals, true)
164                                 c.Check(err.ExitStatus(), check.Equals, exitcode)
165                         }
166                         c.Check(stdout, check.DeepEquals, []byte("stdout\n"+stdinData))
167                         c.Check(stderr, check.DeepEquals, []byte("stderr\n"))
168                         close(done)
169                 }()
170
171                 timeout := time.NewTimer(time.Second)
172                 select {
173                 case <-done:
174                 case <-timeout.C:
175                         c.Fatal("timed out")
176                 }
177         }
178 }