1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
17 "golang.org/x/crypto/ssh"
18 check "gopkg.in/check.v1"
21 // LoadTestKey returns a public/private ssh keypair, read from the files
22 // identified by the path of the private key.
23 func LoadTestKey(c *check.C, fnm string) (ssh.PublicKey, ssh.Signer) {
24 rawpubkey, err := ioutil.ReadFile(fnm + ".pub")
25 c.Assert(err, check.IsNil)
26 pubkey, _, _, _, err := ssh.ParseAuthorizedKey(rawpubkey)
27 c.Assert(err, check.IsNil)
28 rawprivkey, err := ioutil.ReadFile(fnm)
29 c.Assert(err, check.IsNil)
30 privkey, err := ssh.ParsePrivateKey(rawprivkey)
31 c.Assert(err, check.IsNil)
32 return pubkey, privkey
35 // An SSHExecFunc handles an "exec" session on a multiplexed SSH
37 type SSHExecFunc func(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
39 // An SSHService accepts SSH connections on an available TCP port and
40 // passes clients' "exec" sessions to the provided SSHExecFunc.
41 type SSHService struct {
45 AuthorizedKeys []ssh.PublicKey
56 // Address returns the host:port where the SSH server is listening. It
57 // returns "" if called before the server is ready to accept
59 func (ss *SSHService) Address() string {
67 return ln.Addr().String()
70 // RemoteUser returns the username that will be accepted.
71 func (ss *SSHService) RemoteUser() string {
72 return ss.AuthorizedUser
75 // Close shuts down the server and releases resources. Established
76 // connections are unaffected.
77 func (ss *SSHService) Close() {
88 // Start returns when the server is ready to accept connections.
89 func (ss *SSHService) Start() error {
95 func (ss *SSHService) start() {
96 ss.started = make(chan bool)
100 func (ss *SSHService) run() {
101 defer close(ss.started)
102 config := &ssh.ServerConfig{
103 PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
104 for _, ak := range ss.AuthorizedKeys {
105 if bytes.Equal(ak.Marshal(), pubKey.Marshal()) {
106 return &ssh.Permissions{}, nil
109 return nil, fmt.Errorf("unknown public key for %q", c.User())
112 config.AddHostKey(ss.HostKey)
114 listener, err := net.Listen("tcp", "127.0.0.1:")
121 ss.listener = listener
126 nConn, err := listener.Accept()
127 if err != nil && strings.Contains(err.Error(), "use of closed network connection") && ss.closed {
129 } else if err != nil {
130 log.Printf("accept: %s", err)
133 go ss.serveConn(nConn, config)
138 func (ss *SSHService) serveConn(nConn net.Conn, config *ssh.ServerConfig) {
140 conn, newchans, reqs, err := ssh.NewServerConn(nConn, config)
142 log.Printf("ssh.NewServerConn: %s", err)
146 go ssh.DiscardRequests(reqs)
147 for newch := range newchans {
148 if newch.ChannelType() != "session" {
149 newch.Reject(ssh.UnknownChannelType, "unknown channel type")
152 ch, reqs, err := newch.Accept()
154 log.Printf("accept channel: %s", err)
158 sessionEnv := map[string]string{}
160 for req := range reqs {
163 // Reject anything after exec
164 req.Reply(false, nil)
165 case req.Type == "exec":
170 ssh.Unmarshal(req.Payload, &execReq)
175 resp.Status = ss.Exec(sessionEnv, execReq.Command, ch, ch, ch.Stderr())
176 ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
180 case req.Type == "env":
186 ssh.Unmarshal(req.Payload, &envReq)
187 sessionEnv[envReq.Name] = envReq.Value