14360: Call ChooseType just once per container.
[arvados.git] / lib / dispatchcloud / test / ssh_service.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "bytes"
9         "fmt"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "strings"
15         "sync"
16
17         "golang.org/x/crypto/ssh"
18         check "gopkg.in/check.v1"
19 )
20
21 func LoadTestKey(c *check.C, fnm string) (ssh.PublicKey, ssh.Signer) {
22         rawpubkey, err := ioutil.ReadFile(fnm + ".pub")
23         c.Assert(err, check.IsNil)
24         pubkey, _, _, _, err := ssh.ParseAuthorizedKey(rawpubkey)
25         c.Assert(err, check.IsNil)
26         rawprivkey, err := ioutil.ReadFile(fnm)
27         c.Assert(err, check.IsNil)
28         privkey, err := ssh.ParsePrivateKey(rawprivkey)
29         c.Assert(err, check.IsNil)
30         return pubkey, privkey
31 }
32
33 // An SSHExecFunc handles an "exec" session on a multiplexed SSH
34 // connection.
35 type SSHExecFunc func(command string, stdin io.Reader, stdout, stderr io.Writer) uint32
36
37 // An SSHService accepts SSH connections on an available TCP port and
38 // passes clients' "exec" sessions to the provided SSHExecFunc.
39 type SSHService struct {
40         Exec           SSHExecFunc
41         HostKey        ssh.Signer
42         AuthorizedKeys []ssh.PublicKey
43
44         listener net.Listener
45         conn     *ssh.ServerConn
46         setup    sync.Once
47         mtx      sync.Mutex
48         started  chan bool
49         closed   bool
50         err      error
51 }
52
53 // Address returns the host:port where the SSH server is listening. It
54 // returns "" if called before the server is ready to accept
55 // connections.
56 func (ss *SSHService) Address() string {
57         ss.setup.Do(ss.start)
58         ss.mtx.Lock()
59         ln := ss.listener
60         ss.mtx.Unlock()
61         if ln == nil {
62                 return ""
63         }
64         return ln.Addr().String()
65 }
66
67 // Close shuts down the server and releases resources. Established
68 // connections are unaffected.
69 func (ss *SSHService) Close() {
70         ss.Start()
71         ss.mtx.Lock()
72         ln := ss.listener
73         ss.closed = true
74         ss.mtx.Unlock()
75         if ln != nil {
76                 ln.Close()
77         }
78 }
79
80 // Start returns when the server is ready to accept connections.
81 func (ss *SSHService) Start() error {
82         ss.setup.Do(ss.start)
83         <-ss.started
84         return ss.err
85 }
86
87 func (ss *SSHService) start() {
88         ss.started = make(chan bool)
89         go ss.run()
90 }
91
92 func (ss *SSHService) run() {
93         defer close(ss.started)
94         config := &ssh.ServerConfig{
95                 PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
96                         for _, ak := range ss.AuthorizedKeys {
97                                 if bytes.Equal(ak.Marshal(), pubKey.Marshal()) {
98                                         return &ssh.Permissions{}, nil
99                                 }
100                         }
101                         return nil, fmt.Errorf("unknown public key for %q", c.User())
102                 },
103         }
104         config.AddHostKey(ss.HostKey)
105
106         listener, err := net.Listen("tcp", ":")
107         if err != nil {
108                 ss.err = err
109                 return
110         }
111
112         ss.mtx.Lock()
113         ss.listener = listener
114         ss.mtx.Unlock()
115
116         go func() {
117                 for {
118                         nConn, err := listener.Accept()
119                         if err != nil && strings.Contains(err.Error(), "use of closed network connection") && ss.closed {
120                                 return
121                         } else if err != nil {
122                                 log.Printf("accept: %s", err)
123                                 return
124                         }
125                         go ss.serveConn(nConn, config)
126                 }
127         }()
128 }
129
130 func (ss *SSHService) serveConn(nConn net.Conn, config *ssh.ServerConfig) {
131         defer nConn.Close()
132         conn, newchans, reqs, err := ssh.NewServerConn(nConn, config)
133         if err != nil {
134                 log.Printf("ssh.NewServerConn: %s", err)
135                 return
136         }
137         defer conn.Close()
138         go ssh.DiscardRequests(reqs)
139         for newch := range newchans {
140                 if newch.ChannelType() != "session" {
141                         newch.Reject(ssh.UnknownChannelType, "unknown channel type")
142                         continue
143                 }
144                 ch, reqs, err := newch.Accept()
145                 if err != nil {
146                         log.Printf("accept channel: %s", err)
147                         return
148                 }
149                 var execReq struct {
150                         Command string
151                 }
152                 go func() {
153                         for req := range reqs {
154                                 if req.Type == "exec" && execReq.Command == "" {
155                                         req.Reply(true, nil)
156                                         ssh.Unmarshal(req.Payload, &execReq)
157                                         go func() {
158                                                 var resp struct {
159                                                         Status uint32
160                                                 }
161                                                 resp.Status = ss.Exec(execReq.Command, ch, ch, ch.Stderr())
162                                                 ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
163                                                 ch.Close()
164                                         }()
165                                 }
166                         }
167                 }()
168         }
169 }