Merge branch 'master' into 16811-public-favs
[arvados.git] / lib / dispatchcloud / sshexecutor / executor_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package sshexecutor
6
7 import (
8         "bytes"
9         "fmt"
10         "io"
11         "io/ioutil"
12         "net"
13         "sync"
14         "testing"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/dispatchcloud/test"
18         "golang.org/x/crypto/ssh"
19         check "gopkg.in/check.v1"
20 )
21
22 // Gocheck boilerplate
23 func Test(t *testing.T) {
24         check.TestingT(t)
25 }
26
27 var _ = check.Suite(&ExecutorSuite{})
28
29 type testTarget struct {
30         test.SSHService
31 }
32
33 func (*testTarget) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
34         return nil
35 }
36
37 // Address returns the wrapped SSHService's host, with the port
38 // stripped. This ensures the executor won't work until
39 // SetTargetPort() is called -- see (*testTarget)Port().
40 func (tt *testTarget) Address() string {
41         h, _, err := net.SplitHostPort(tt.SSHService.Address())
42         if err != nil {
43                 panic(err)
44         }
45         return h
46 }
47
48 func (tt *testTarget) Port() string {
49         _, p, err := net.SplitHostPort(tt.SSHService.Address())
50         if err != nil {
51                 panic(err)
52         }
53         return p
54 }
55
56 type mitmTarget struct {
57         test.SSHService
58 }
59
60 func (*mitmTarget) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
61         return fmt.Errorf("host key failed verification: %#v", key)
62 }
63
64 type ExecutorSuite struct{}
65
66 func (s *ExecutorSuite) TestBadHostKey(c *check.C) {
67         _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
68         clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
69         target := &mitmTarget{
70                 SSHService: test.SSHService{
71                         Exec: func(map[string]string, string, io.Reader, io.Writer, io.Writer) uint32 {
72                                 c.Error("Target Exec func called even though host key verification failed")
73                                 return 0
74                         },
75                         HostKey:        hostpriv,
76                         AuthorizedUser: "username",
77                         AuthorizedKeys: []ssh.PublicKey{clientpub},
78                 },
79         }
80
81         err := target.Start()
82         c.Check(err, check.IsNil)
83         c.Logf("target address %q", target.Address())
84         defer target.Close()
85
86         exr := New(target)
87         exr.SetSigners(clientpriv)
88
89         _, _, err = exr.Execute(nil, "true", nil)
90         c.Check(err, check.ErrorMatches, "host key failed verification: .*")
91 }
92
93 func (s *ExecutorSuite) TestExecute(c *check.C) {
94         command := `foo 'bar' "baz"`
95         stdinData := "foobar\nbaz\n"
96         _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
97         clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
98         for _, exitcode := range []int{0, 1, 2} {
99                 target := &testTarget{
100                         SSHService: test.SSHService{
101                                 Exec: func(env map[string]string, cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
102                                         c.Check(env["TESTVAR"], check.Equals, "test value")
103                                         c.Check(cmd, check.Equals, command)
104                                         var wg sync.WaitGroup
105                                         wg.Add(2)
106                                         go func() {
107                                                 io.WriteString(stdout, "stdout\n")
108                                                 wg.Done()
109                                         }()
110                                         go func() {
111                                                 io.WriteString(stderr, "stderr\n")
112                                                 wg.Done()
113                                         }()
114                                         buf, err := ioutil.ReadAll(stdin)
115                                         wg.Wait()
116                                         c.Check(err, check.IsNil)
117                                         if err != nil {
118                                                 return 99
119                                         }
120                                         _, err = stdout.Write(buf)
121                                         c.Check(err, check.IsNil)
122                                         return uint32(exitcode)
123                                 },
124                                 HostKey:        hostpriv,
125                                 AuthorizedUser: "username",
126                                 AuthorizedKeys: []ssh.PublicKey{clientpub},
127                         },
128                 }
129                 err := target.Start()
130                 c.Check(err, check.IsNil)
131                 c.Logf("target address %q", target.Address())
132                 defer target.Close()
133
134                 exr := New(target)
135                 exr.SetSigners(clientpriv)
136
137                 // Use the default target port (ssh). Execute will
138                 // return a connection error or an authentication
139                 // error, depending on whether the test host is
140                 // running an SSH server.
141                 _, _, err = exr.Execute(nil, command, nil)
142                 c.Check(err, check.ErrorMatches, `.*(unable to authenticate|connection refused).*`)
143
144                 // Use a bogus target port. Execute will return a
145                 // connection error.
146                 exr.SetTargetPort("0")
147                 _, _, err = exr.Execute(nil, command, nil)
148                 c.Check(err, check.ErrorMatches, `.*connection refused.*`)
149
150                 // Use the test server's listening port.
151                 exr.SetTargetPort(target.Port())
152
153                 done := make(chan bool)
154                 go func() {
155                         stdout, stderr, err := exr.Execute(map[string]string{"TESTVAR": "test value"}, command, bytes.NewBufferString(stdinData))
156                         if exitcode == 0 {
157                                 c.Check(err, check.IsNil)
158                         } else {
159                                 c.Check(err, check.NotNil)
160                                 err, ok := err.(*ssh.ExitError)
161                                 c.Assert(ok, check.Equals, true)
162                                 c.Check(err.ExitStatus(), check.Equals, exitcode)
163                         }
164                         c.Check(stdout, check.DeepEquals, []byte("stdout\n"+stdinData))
165                         c.Check(stderr, check.DeepEquals, []byte("stderr\n"))
166                         close(done)
167                 }()
168
169                 timeout := time.NewTimer(time.Second)
170                 select {
171                 case <-done:
172                 case <-timeout.C:
173                         c.Fatal("timed out")
174                 }
175         }
176 }