14360: Merge branch 'master'
[arvados.git] / lib / dispatchcloud / ssh_executor / executor_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ssh_executor
6
7 import (
8         "bytes"
9         "io"
10         "io/ioutil"
11         "sync"
12         "testing"
13         "time"
14
15         "git.curoverse.com/arvados.git/lib/dispatchcloud/test"
16         "golang.org/x/crypto/ssh"
17         check "gopkg.in/check.v1"
18 )
19
20 // Gocheck boilerplate
21 func Test(t *testing.T) {
22         check.TestingT(t)
23 }
24
25 var _ = check.Suite(&ExecutorSuite{})
26
27 type testTarget struct {
28         test.SSHService
29 }
30
31 func (*testTarget) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
32         return nil
33 }
34
35 type ExecutorSuite struct{}
36
37 func (s *ExecutorSuite) TestExecute(c *check.C) {
38         command := `foo 'bar' "baz"`
39         stdinData := "foobar\nbaz\n"
40         _, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
41         clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
42         for _, exitcode := range []int{0, 1, 2} {
43                 srv := &testTarget{
44                         SSHService: test.SSHService{
45                                 Exec: func(cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
46                                         c.Check(cmd, check.Equals, command)
47                                         var wg sync.WaitGroup
48                                         wg.Add(2)
49                                         go func() {
50                                                 io.WriteString(stdout, "stdout\n")
51                                                 wg.Done()
52                                         }()
53                                         go func() {
54                                                 io.WriteString(stderr, "stderr\n")
55                                                 wg.Done()
56                                         }()
57                                         buf, err := ioutil.ReadAll(stdin)
58                                         wg.Wait()
59                                         c.Check(err, check.IsNil)
60                                         if err != nil {
61                                                 return 99
62                                         }
63                                         _, err = stdout.Write(buf)
64                                         c.Check(err, check.IsNil)
65                                         return uint32(exitcode)
66                                 },
67                                 HostKey:        hostpriv,
68                                 AuthorizedKeys: []ssh.PublicKey{clientpub},
69                         },
70                 }
71                 err := srv.Start()
72                 c.Check(err, check.IsNil)
73                 c.Logf("srv address %q", srv.Address())
74                 defer srv.Close()
75
76                 exr := New(srv)
77                 exr.SetSigners(clientpriv)
78
79                 done := make(chan bool)
80                 go func() {
81                         stdout, stderr, err := exr.Execute(command, bytes.NewBufferString(stdinData))
82                         if exitcode == 0 {
83                                 c.Check(err, check.IsNil)
84                         } else {
85                                 c.Check(err, check.NotNil)
86                                 err, ok := err.(*ssh.ExitError)
87                                 c.Assert(ok, check.Equals, true)
88                                 c.Check(err.ExitStatus(), check.Equals, exitcode)
89                         }
90                         c.Check(stdout, check.DeepEquals, []byte("stdout\n"+stdinData))
91                         c.Check(stderr, check.DeepEquals, []byte("stderr\n"))
92                         close(done)
93                 }()
94
95                 timeout := time.NewTimer(time.Second)
96                 select {
97                 case <-done:
98                 case <-timeout.C:
99                         c.Fatal("timed out")
100                 }
101         }
102 }