// Copyright (C) The Arvados Authors. All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0

package ssh_executor

import (
	"bytes"
	"fmt"
	"io"
	"io/ioutil"
	"net"
	"sync"
	"testing"
	"time"

	"git.curoverse.com/arvados.git/lib/dispatchcloud/test"
	"golang.org/x/crypto/ssh"
	check "gopkg.in/check.v1"
)

// Gocheck boilerplate
func Test(t *testing.T) {
	check.TestingT(t)
}

var _ = check.Suite(&ExecutorSuite{})

type testTarget struct {
	test.SSHService
}

func (*testTarget) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
	return nil
}

// Address returns the wrapped SSHService's host, with the port
// stripped. This ensures the executor won't work until
// SetTargetPort() is called -- see (*testTarget)Port().
func (tt *testTarget) Address() string {
	h, _, err := net.SplitHostPort(tt.SSHService.Address())
	if err != nil {
		panic(err)
	}
	return h
}

func (tt *testTarget) Port() string {
	_, p, err := net.SplitHostPort(tt.SSHService.Address())
	if err != nil {
		panic(err)
	}
	return p
}

type mitmTarget struct {
	test.SSHService
}

func (*mitmTarget) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
	return fmt.Errorf("host key failed verification: %#v", key)
}

type ExecutorSuite struct{}

func (s *ExecutorSuite) TestBadHostKey(c *check.C) {
	_, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
	clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
	target := &mitmTarget{
		SSHService: test.SSHService{
			Exec: func(map[string]string, string, io.Reader, io.Writer, io.Writer) uint32 {
				c.Error("Target Exec func called even though host key verification failed")
				return 0
			},
			HostKey:        hostpriv,
			AuthorizedUser: "username",
			AuthorizedKeys: []ssh.PublicKey{clientpub},
		},
	}

	err := target.Start()
	c.Check(err, check.IsNil)
	c.Logf("target address %q", target.Address())
	defer target.Close()

	exr := New(target)
	exr.SetSigners(clientpriv)

	_, _, err = exr.Execute(nil, "true", nil)
	c.Check(err, check.ErrorMatches, "host key failed verification: .*")
}

func (s *ExecutorSuite) TestExecute(c *check.C) {
	command := `foo 'bar' "baz"`
	stdinData := "foobar\nbaz\n"
	_, hostpriv := test.LoadTestKey(c, "../test/sshkey_vm")
	clientpub, clientpriv := test.LoadTestKey(c, "../test/sshkey_dispatch")
	for _, exitcode := range []int{0, 1, 2} {
		target := &testTarget{
			SSHService: test.SSHService{
				Exec: func(env map[string]string, cmd string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
					c.Check(env["TESTVAR"], check.Equals, "test value")
					c.Check(cmd, check.Equals, command)
					var wg sync.WaitGroup
					wg.Add(2)
					go func() {
						io.WriteString(stdout, "stdout\n")
						wg.Done()
					}()
					go func() {
						io.WriteString(stderr, "stderr\n")
						wg.Done()
					}()
					buf, err := ioutil.ReadAll(stdin)
					wg.Wait()
					c.Check(err, check.IsNil)
					if err != nil {
						return 99
					}
					_, err = stdout.Write(buf)
					c.Check(err, check.IsNil)
					return uint32(exitcode)
				},
				HostKey:        hostpriv,
				AuthorizedUser: "username",
				AuthorizedKeys: []ssh.PublicKey{clientpub},
			},
		}
		err := target.Start()
		c.Check(err, check.IsNil)
		c.Logf("target address %q", target.Address())
		defer target.Close()

		exr := New(target)
		exr.SetSigners(clientpriv)

		// Use the default target port (ssh). Execute will
		// return a connection error or an authentication
		// error, depending on whether the test host is
		// running an SSH server.
		_, _, err = exr.Execute(nil, command, nil)
		c.Check(err, check.ErrorMatches, `.*(unable to authenticate|connection refused).*`)

		// Use a bogus target port. Execute will return a
		// connection error.
		exr.SetTargetPort("0")
		_, _, err = exr.Execute(nil, command, nil)
		c.Check(err, check.ErrorMatches, `.*connection refused.*`)

		// Use the test server's listening port.
		exr.SetTargetPort(target.Port())

		done := make(chan bool)
		go func() {
			stdout, stderr, err := exr.Execute(map[string]string{"TESTVAR": "test value"}, command, bytes.NewBufferString(stdinData))
			if exitcode == 0 {
				c.Check(err, check.IsNil)
			} else {
				c.Check(err, check.NotNil)
				err, ok := err.(*ssh.ExitError)
				c.Assert(ok, check.Equals, true)
				c.Check(err.ExitStatus(), check.Equals, exitcode)
			}
			c.Check(stdout, check.DeepEquals, []byte("stdout\n"+stdinData))
			c.Check(stderr, check.DeepEquals, []byte("stderr\n"))
			close(done)
		}()

		timeout := time.NewTimer(time.Second)
		select {
		case <-done:
		case <-timeout.C:
			c.Fatal("timed out")
		}
	}
}