// Copyright (C) The Arvados Authors. All rights reserved.
// SPDX-License-Identifier: AGPL-3.0

package test

import (

	check "gopkg.in/check.v1"

// LoadTestKey returns a public/private ssh keypair, read from the files
// identified by the path of the private key.
func LoadTestKey(c *check.C, fnm string) (ssh.PublicKey, ssh.Signer) {
	rawpubkey, err := ioutil.ReadFile(fnm + ".pub")
	c.Assert(err, check.IsNil)
	pubkey, _, _, _, err := ssh.ParseAuthorizedKey(rawpubkey)
	c.Assert(err, check.IsNil)
	rawprivkey, err := ioutil.ReadFile(fnm)
	c.Assert(err, check.IsNil)
	privkey, err := ssh.ParsePrivateKey(rawprivkey)
	c.Assert(err, check.IsNil)
	return pubkey, privkey

// An SSHExecFunc handles an "exec" session on a multiplexed SSH
// connection.
type SSHExecFunc func(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32

// An SSHService accepts SSH connections on an available TCP port and
// passes clients' "exec" sessions to the provided SSHExecFunc.
type SSHService struct {
	Exec           SSHExecFunc
	HostKey        ssh.Signer
	AuthorizedUser string
	AuthorizedKeys []ssh.PublicKey

	listener net.Listener
	conn     *ssh.ServerConn
	setup    sync.Once
	mtx      sync.Mutex
	started  chan bool
	closed   bool
	err      error

// Address returns the host:port where the SSH server is listening. It
// returns "" if called before the server is ready to accept
// connections.
func (ss *SSHService) Address() string {
	ln := ss.listener
	if ln == nil {
		return ""
	return ln.Addr().String()

// RemoteUser returns the username that will be accepted.
func (ss *SSHService) RemoteUser() string {
	return ss.AuthorizedUser

// Close shuts down the server and releases resources. Established
// connections are unaffected.
func (ss *SSHService) Close() {
	ln := ss.listener
	ss.closed = true
	if ln != nil {

// Start returns when the server is ready to accept connections.
func (ss *SSHService) Start() error {
	return ss.err

func (ss *SSHService) start() {
	ss.started = make(chan bool)
	go ss.run()

func (ss *SSHService) run() {
	defer close(ss.started)
	config := &ssh.ServerConfig{
		PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
			for _, ak := range ss.AuthorizedKeys {
				if bytes.Equal(ak.Marshal(), pubKey.Marshal()) {
					return &ssh.Permissions{}, nil
			return nil, fmt.Errorf("unknown public key for %q", c.User())

	listener, err := net.Listen("tcp", "")
	if err != nil {
		ss.err = err

	ss.listener = listener

	go func() {
		for {
			nConn, err := listener.Accept()
			if err != nil && strings.Contains(err.Error(), "use of closed network connection") && ss.closed {
			} else if err != nil {
				log.Printf("accept: %s", err)
			go ss.serveConn(nConn, config)

func (ss *SSHService) serveConn(nConn net.Conn, config *ssh.ServerConfig) {
	defer nConn.Close()
	conn, newchans, reqs, err := ssh.NewServerConn(nConn, config)
	if err != nil {
		log.Printf("ssh.NewServerConn: %s", err)
	defer conn.Close()
	go ssh.DiscardRequests(reqs)
	for newch := range newchans {
		if newch.ChannelType() != "session" {
			newch.Reject(ssh.UnknownChannelType, "unknown channel type")
		ch, reqs, err := newch.Accept()
		if err != nil {
			log.Printf("accept channel: %s", err)
		didExec := false
		sessionEnv := map[string]string{}
		go func() {
			for req := range reqs {
				switch {
				case didExec:
					// Reject anything after exec
					req.Reply(false, nil)
				case req.Type == "exec":
					var execReq struct {
						Command string
					req.Reply(true, nil)
					ssh.Unmarshal(req.Payload, &execReq)
					go func() {
						var resp struct {
							Status uint32
						resp.Status = ss.Exec(sessionEnv, execReq.Command, ch, ch, ch.Stderr())
						ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
					didExec = true
				case req.Type == "env":
					var envReq struct {
						Name  string
						Value string
					req.Reply(true, nil)
					ssh.Unmarshal(req.Payload, &envReq)
					sessionEnv[envReq.Name] = envReq.Value