15954: Remove unused cmdArgs.
[arvados.git] / lib / dispatchcloud / test / ssh_service.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "bytes"
9         "fmt"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "strings"
15         "sync"
16
17         "golang.org/x/crypto/ssh"
18         check "gopkg.in/check.v1"
19 )
20
21 func LoadTestKey(c *check.C, fnm string) (ssh.PublicKey, ssh.Signer) {
22         rawpubkey, err := ioutil.ReadFile(fnm + ".pub")
23         c.Assert(err, check.IsNil)
24         pubkey, _, _, _, err := ssh.ParseAuthorizedKey(rawpubkey)
25         c.Assert(err, check.IsNil)
26         rawprivkey, err := ioutil.ReadFile(fnm)
27         c.Assert(err, check.IsNil)
28         privkey, err := ssh.ParsePrivateKey(rawprivkey)
29         c.Assert(err, check.IsNil)
30         return pubkey, privkey
31 }
32
33 // An SSHExecFunc handles an "exec" session on a multiplexed SSH
34 // connection.
35 type SSHExecFunc func(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
36
37 // An SSHService accepts SSH connections on an available TCP port and
38 // passes clients' "exec" sessions to the provided SSHExecFunc.
39 type SSHService struct {
40         Exec           SSHExecFunc
41         HostKey        ssh.Signer
42         AuthorizedUser string
43         AuthorizedKeys []ssh.PublicKey
44
45         listener net.Listener
46         conn     *ssh.ServerConn
47         setup    sync.Once
48         mtx      sync.Mutex
49         started  chan bool
50         closed   bool
51         err      error
52 }
53
54 // Address returns the host:port where the SSH server is listening. It
55 // returns "" if called before the server is ready to accept
56 // connections.
57 func (ss *SSHService) Address() string {
58         ss.setup.Do(ss.start)
59         ss.mtx.Lock()
60         ln := ss.listener
61         ss.mtx.Unlock()
62         if ln == nil {
63                 return ""
64         }
65         return ln.Addr().String()
66 }
67
68 // RemoteUser returns the username that will be accepted.
69 func (ss *SSHService) RemoteUser() string {
70         return ss.AuthorizedUser
71 }
72
73 // Close shuts down the server and releases resources. Established
74 // connections are unaffected.
75 func (ss *SSHService) Close() {
76         ss.Start()
77         ss.mtx.Lock()
78         ln := ss.listener
79         ss.closed = true
80         ss.mtx.Unlock()
81         if ln != nil {
82                 ln.Close()
83         }
84 }
85
86 // Start returns when the server is ready to accept connections.
87 func (ss *SSHService) Start() error {
88         ss.setup.Do(ss.start)
89         <-ss.started
90         return ss.err
91 }
92
93 func (ss *SSHService) start() {
94         ss.started = make(chan bool)
95         go ss.run()
96 }
97
98 func (ss *SSHService) run() {
99         defer close(ss.started)
100         config := &ssh.ServerConfig{
101                 PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
102                         for _, ak := range ss.AuthorizedKeys {
103                                 if bytes.Equal(ak.Marshal(), pubKey.Marshal()) {
104                                         return &ssh.Permissions{}, nil
105                                 }
106                         }
107                         return nil, fmt.Errorf("unknown public key for %q", c.User())
108                 },
109         }
110         config.AddHostKey(ss.HostKey)
111
112         listener, err := net.Listen("tcp", "127.0.0.1:")
113         if err != nil {
114                 ss.err = err
115                 return
116         }
117
118         ss.mtx.Lock()
119         ss.listener = listener
120         ss.mtx.Unlock()
121
122         go func() {
123                 for {
124                         nConn, err := listener.Accept()
125                         if err != nil && strings.Contains(err.Error(), "use of closed network connection") && ss.closed {
126                                 return
127                         } else if err != nil {
128                                 log.Printf("accept: %s", err)
129                                 return
130                         }
131                         go ss.serveConn(nConn, config)
132                 }
133         }()
134 }
135
136 func (ss *SSHService) serveConn(nConn net.Conn, config *ssh.ServerConfig) {
137         defer nConn.Close()
138         conn, newchans, reqs, err := ssh.NewServerConn(nConn, config)
139         if err != nil {
140                 log.Printf("ssh.NewServerConn: %s", err)
141                 return
142         }
143         defer conn.Close()
144         go ssh.DiscardRequests(reqs)
145         for newch := range newchans {
146                 if newch.ChannelType() != "session" {
147                         newch.Reject(ssh.UnknownChannelType, "unknown channel type")
148                         continue
149                 }
150                 ch, reqs, err := newch.Accept()
151                 if err != nil {
152                         log.Printf("accept channel: %s", err)
153                         return
154                 }
155                 didExec := false
156                 sessionEnv := map[string]string{}
157                 go func() {
158                         for req := range reqs {
159                                 switch {
160                                 case didExec:
161                                         // Reject anything after exec
162                                         req.Reply(false, nil)
163                                 case req.Type == "exec":
164                                         var execReq struct {
165                                                 Command string
166                                         }
167                                         req.Reply(true, nil)
168                                         ssh.Unmarshal(req.Payload, &execReq)
169                                         go func() {
170                                                 var resp struct {
171                                                         Status uint32
172                                                 }
173                                                 resp.Status = ss.Exec(sessionEnv, execReq.Command, ch, ch, ch.Stderr())
174                                                 ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
175                                                 ch.Close()
176                                         }()
177                                         didExec = true
178                                 case req.Type == "env":
179                                         var envReq struct {
180                                                 Name  string
181                                                 Value string
182                                         }
183                                         req.Reply(true, nil)
184                                         ssh.Unmarshal(req.Payload, &envReq)
185                                         sessionEnv[envReq.Name] = envReq.Value
186                                 }
187                         }
188                 }()
189         }
190 }