21720:
[arvados.git] / lib / dispatchcloud / test / ssh_service.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "bytes"
9         "fmt"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "strings"
15         "sync"
16
17         "golang.org/x/crypto/ssh"
18         check "gopkg.in/check.v1"
19 )
20
21 // LoadTestKey returns a public/private ssh keypair, read from the files
22 // identified by the path of the private key.
23 func LoadTestKey(c *check.C, fnm string) (ssh.PublicKey, ssh.Signer) {
24         rawpubkey, err := ioutil.ReadFile(fnm + ".pub")
25         c.Assert(err, check.IsNil)
26         pubkey, _, _, _, err := ssh.ParseAuthorizedKey(rawpubkey)
27         c.Assert(err, check.IsNil)
28         rawprivkey, err := ioutil.ReadFile(fnm)
29         c.Assert(err, check.IsNil)
30         privkey, err := ssh.ParsePrivateKey(rawprivkey)
31         c.Assert(err, check.IsNil)
32         return pubkey, privkey
33 }
34
35 // An SSHExecFunc handles an "exec" session on a multiplexed SSH
36 // connection.
37 type SSHExecFunc func(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
38
39 // An SSHService accepts SSH connections on an available TCP port and
40 // passes clients' "exec" sessions to the provided SSHExecFunc.
41 type SSHService struct {
42         Exec           SSHExecFunc
43         HostKey        ssh.Signer
44         AuthorizedUser string
45         AuthorizedKeys []ssh.PublicKey
46
47         listener net.Listener
48         conn     *ssh.ServerConn
49         setup    sync.Once
50         mtx      sync.Mutex
51         started  chan bool
52         closed   bool
53         err      error
54 }
55
56 // Address returns the host:port where the SSH server is listening. It
57 // returns "" if called before the server is ready to accept
58 // connections.
59 func (ss *SSHService) Address() string {
60         ss.setup.Do(ss.start)
61         ss.mtx.Lock()
62         ln := ss.listener
63         ss.mtx.Unlock()
64         if ln == nil {
65                 return ""
66         }
67         return ln.Addr().String()
68 }
69
70 // RemoteUser returns the username that will be accepted.
71 func (ss *SSHService) RemoteUser() string {
72         return ss.AuthorizedUser
73 }
74
75 // Close shuts down the server and releases resources. Established
76 // connections are unaffected.
77 func (ss *SSHService) Close() {
78         ss.Start()
79         ss.mtx.Lock()
80         ln := ss.listener
81         ss.closed = true
82         ss.mtx.Unlock()
83         if ln != nil {
84                 ln.Close()
85         }
86 }
87
88 // Start returns when the server is ready to accept connections.
89 func (ss *SSHService) Start() error {
90         ss.setup.Do(ss.start)
91         <-ss.started
92         return ss.err
93 }
94
95 func (ss *SSHService) start() {
96         ss.started = make(chan bool)
97         go ss.run()
98 }
99
100 func (ss *SSHService) run() {
101         defer close(ss.started)
102         config := &ssh.ServerConfig{
103                 PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
104                         for _, ak := range ss.AuthorizedKeys {
105                                 if bytes.Equal(ak.Marshal(), pubKey.Marshal()) {
106                                         return &ssh.Permissions{}, nil
107                                 }
108                         }
109                         return nil, fmt.Errorf("unknown public key for %q", c.User())
110                 },
111         }
112         config.AddHostKey(ss.HostKey)
113
114         listener, err := net.Listen("tcp", "127.0.0.1:")
115         if err != nil {
116                 ss.err = err
117                 return
118         }
119
120         ss.mtx.Lock()
121         ss.listener = listener
122         ss.mtx.Unlock()
123
124         go func() {
125                 for {
126                         nConn, err := listener.Accept()
127                         if err != nil && strings.Contains(err.Error(), "use of closed network connection") && ss.closed {
128                                 return
129                         } else if err != nil {
130                                 log.Printf("accept: %s", err)
131                                 return
132                         }
133                         go ss.serveConn(nConn, config)
134                 }
135         }()
136 }
137
138 func (ss *SSHService) serveConn(nConn net.Conn, config *ssh.ServerConfig) {
139         defer nConn.Close()
140         conn, newchans, reqs, err := ssh.NewServerConn(nConn, config)
141         if err != nil {
142                 log.Printf("ssh.NewServerConn: %s", err)
143                 return
144         }
145         defer conn.Close()
146         go ssh.DiscardRequests(reqs)
147         for newch := range newchans {
148                 if newch.ChannelType() != "session" {
149                         newch.Reject(ssh.UnknownChannelType, "unknown channel type")
150                         continue
151                 }
152                 ch, reqs, err := newch.Accept()
153                 if err != nil {
154                         log.Printf("accept channel: %s", err)
155                         return
156                 }
157                 didExec := false
158                 sessionEnv := map[string]string{}
159                 go func() {
160                         for req := range reqs {
161                                 switch {
162                                 case didExec:
163                                         // Reject anything after exec
164                                         req.Reply(false, nil)
165                                 case req.Type == "exec":
166                                         var execReq struct {
167                                                 Command string
168                                         }
169                                         req.Reply(true, nil)
170                                         ssh.Unmarshal(req.Payload, &execReq)
171                                         go func() {
172                                                 var resp struct {
173                                                         Status uint32
174                                                 }
175                                                 resp.Status = ss.Exec(sessionEnv, execReq.Command, ch, ch, ch.Stderr())
176                                                 ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
177                                                 ch.Close()
178                                         }()
179                                         didExec = true
180                                 case req.Type == "env":
181                                         var envReq struct {
182                                                 Name  string
183                                                 Value string
184                                         }
185                                         req.Reply(true, nil)
186                                         ssh.Unmarshal(req.Payload, &envReq)
187                                         sessionEnv[envReq.Name] = envReq.Value
188                                 }
189                         }
190                 }()
191         }
192 }