1116c4bb1285d11ae39b4e194f21df6fd01c5d24
[arvados.git] / lib / crunchrun / container_gateway.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package crunchrun
6
7 import (
8         "crypto/hmac"
9         "crypto/rand"
10         "crypto/rsa"
11         "crypto/sha256"
12         "crypto/tls"
13         "fmt"
14         "io"
15         "net"
16         "net/http"
17         "os"
18         "os/exec"
19         "sync"
20         "syscall"
21
22         "git.arvados.org/arvados.git/lib/selfsigned"
23         "git.arvados.org/arvados.git/sdk/go/httpserver"
24         "github.com/creack/pty"
25         "github.com/google/shlex"
26         "golang.org/x/crypto/ssh"
27 )
28
29 type Gateway struct {
30         DockerContainerID *string
31         ContainerUUID     string
32         Address           string // listen host:port; if port=0, Start() will change it to the selected port
33         AuthSecret        string
34         Log               interface {
35                 Printf(fmt string, args ...interface{})
36         }
37
38         sshConfig   ssh.ServerConfig
39         requestAuth string
40         respondAuth string
41 }
42
43 // startGatewayServer starts an http server that allows authenticated
44 // clients to open an interactive "docker exec" session and (in
45 // future) connect to tcp ports inside the docker container.
46 func (gw *Gateway) Start() error {
47         gw.sshConfig = ssh.ServerConfig{
48                 NoClientAuth: true,
49                 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
50                         if c.User() == "_" {
51                                 return nil, nil
52                         } else {
53                                 return nil, fmt.Errorf("cannot specify user %q via ssh client", c.User())
54                         }
55                 },
56                 PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
57                         if c.User() == "_" {
58                                 return &ssh.Permissions{
59                                         Extensions: map[string]string{
60                                                 "pubkey-fp": ssh.FingerprintSHA256(pubKey),
61                                         },
62                                 }, nil
63                         } else {
64                                 return nil, fmt.Errorf("cannot specify user %q via ssh client", c.User())
65                         }
66                 },
67         }
68         pvt, err := rsa.GenerateKey(rand.Reader, 2048)
69         if err != nil {
70                 return err
71         }
72         err = pvt.Validate()
73         if err != nil {
74                 return err
75         }
76         signer, err := ssh.NewSignerFromKey(pvt)
77         if err != nil {
78                 return err
79         }
80         gw.sshConfig.AddHostKey(signer)
81
82         // Address (typically provided by arvados-dispatch-cloud) is
83         // HOST:PORT where HOST is our IP address or hostname as seen
84         // from arvados-controller, and PORT is either the desired
85         // port where we should run our gateway server, or "0" if we
86         // should choose an available port.
87         host, port, err := net.SplitHostPort(gw.Address)
88         if err != nil {
89                 return err
90         }
91         cert, err := selfsigned.CertGenerator{}.Generate()
92         if err != nil {
93                 return err
94         }
95         h := hmac.New(sha256.New, []byte(gw.AuthSecret))
96         h.Write(cert.Certificate[0])
97         gw.requestAuth = fmt.Sprintf("%x", h.Sum(nil))
98         h.Reset()
99         h.Write([]byte(gw.requestAuth))
100         gw.respondAuth = fmt.Sprintf("%x", h.Sum(nil))
101
102         srv := &httpserver.Server{
103                 Server: http.Server{
104                         Handler: http.HandlerFunc(gw.handleSSH),
105                         TLSConfig: &tls.Config{
106                                 Certificates: []tls.Certificate{cert},
107                         },
108                 },
109                 Addr: ":" + port,
110         }
111         err = srv.Start()
112         if err != nil {
113                 return err
114         }
115         // Get the port number we are listening on (the port might be
116         // "0" or a port name, in which case this will be different).
117         _, port, err = net.SplitHostPort(srv.Addr)
118         if err != nil {
119                 return err
120         }
121         // When changing state to Running, we will set
122         // gateway_address to "HOST:PORT" where HOST is our
123         // external hostname/IP as provided by arvados-dispatch-cloud,
124         // and PORT is the port number we ended up listening on.
125         gw.Address = net.JoinHostPort(host, port)
126         return nil
127 }
128
129 // handleSSH connects to an SSH server that allows the caller to run
130 // interactive commands as root (or any other desired user) inside the
131 // container. The tunnel itself can only be created by an
132 // authenticated caller, so the SSH server itself is wide open (any
133 // password or key will be accepted).
134 //
135 // Requests must have path "/ssh" and the following headers:
136 //
137 // Connection: upgrade
138 // Upgrade: ssh
139 // X-Arvados-Target-Uuid: uuid of container
140 // X-Arvados-Authorization: must match
141 // hmac(AuthSecret,certfingerprint) (this prevents other containers
142 // and shell nodes from connecting directly)
143 //
144 // Optional headers:
145 //
146 // X-Arvados-Detach-Keys: argument to "docker exec --detach-keys",
147 // e.g., "ctrl-p,ctrl-q"
148 // X-Arvados-Login-Username: argument to "docker exec --user": account
149 // used to run command(s) inside the container.
150 func (gw *Gateway) handleSSH(w http.ResponseWriter, req *http.Request) {
151         // In future we'll handle browser traffic too, but for now the
152         // only traffic we expect is an SSH tunnel from
153         // (*lib/controller/localdb.Conn)ContainerSSH()
154         if req.Method != "GET" || req.Header.Get("Upgrade") != "ssh" {
155                 http.Error(w, "path not found", http.StatusNotFound)
156                 return
157         }
158         if want := req.Header.Get("X-Arvados-Target-Uuid"); want != gw.ContainerUUID {
159                 http.Error(w, fmt.Sprintf("misdirected request: meant for %q but received by crunch-run %q", want, gw.ContainerUUID), http.StatusBadGateway)
160                 return
161         }
162         if req.Header.Get("X-Arvados-Authorization") != gw.requestAuth {
163                 http.Error(w, "bad X-Arvados-Authorization header", http.StatusUnauthorized)
164                 return
165         }
166         detachKeys := req.Header.Get("X-Arvados-Detach-Keys")
167         username := req.Header.Get("X-Arvados-Login-Username")
168         if username == "" {
169                 username = "root"
170         }
171         hj, ok := w.(http.Hijacker)
172         if !ok {
173                 http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError)
174                 return
175         }
176         netconn, _, err := hj.Hijack()
177         if !ok {
178                 http.Error(w, err.Error(), http.StatusInternalServerError)
179                 return
180         }
181         defer netconn.Close()
182         w.Header().Set("Connection", "upgrade")
183         w.Header().Set("Upgrade", "ssh")
184         w.Header().Set("X-Arvados-Authorization-Response", gw.respondAuth)
185         netconn.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n"))
186         w.Header().Write(netconn)
187         netconn.Write([]byte("\r\n"))
188
189         ctx := req.Context()
190
191         conn, newchans, reqs, err := ssh.NewServerConn(netconn, &gw.sshConfig)
192         if err != nil {
193                 gw.Log.Printf("ssh.NewServerConn: %s", err)
194                 return
195         }
196         defer conn.Close()
197         go ssh.DiscardRequests(reqs)
198         for newch := range newchans {
199                 if newch.ChannelType() != "session" {
200                         newch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type %q", newch.ChannelType()))
201                         continue
202                 }
203                 ch, reqs, err := newch.Accept()
204                 if err != nil {
205                         gw.Log.Printf("accept channel: %s", err)
206                         return
207                 }
208                 var pty0, tty0 *os.File
209                 go func() {
210                         // Where to send errors/messages for the
211                         // client to see
212                         logw := io.Writer(ch.Stderr())
213                         // How to end lines when sending
214                         // errors/messages to the client (changes to
215                         // \r\n when using a pty)
216                         eol := "\n"
217                         // Env vars to add to child process
218                         termEnv := []string(nil)
219                         for req := range reqs {
220                                 ok := false
221                                 switch req.Type {
222                                 case "shell", "exec":
223                                         ok = true
224                                         var payload struct {
225                                                 Command string
226                                         }
227                                         ssh.Unmarshal(req.Payload, &payload)
228                                         execargs, err := shlex.Split(payload.Command)
229                                         if err != nil {
230                                                 fmt.Fprintf(logw, "error parsing supplied command: %s"+eol, err)
231                                                 return
232                                         }
233                                         if len(execargs) == 0 {
234                                                 execargs = []string{"/bin/bash", "-login"}
235                                         }
236                                         go func() {
237                                                 cmd := exec.CommandContext(ctx, "docker", "exec", "-i", "--detach-keys="+detachKeys, "--user="+username)
238                                                 cmd.Stdin = ch
239                                                 cmd.Stdout = ch
240                                                 cmd.Stderr = ch.Stderr()
241                                                 if tty0 != nil {
242                                                         cmd.Args = append(cmd.Args, "-t")
243                                                         cmd.Stdin = tty0
244                                                         cmd.Stdout = tty0
245                                                         cmd.Stderr = tty0
246                                                         var wg sync.WaitGroup
247                                                         defer wg.Wait()
248                                                         wg.Add(2)
249                                                         go func() { io.Copy(ch, pty0); wg.Done() }()
250                                                         go func() { io.Copy(pty0, ch); wg.Done() }()
251                                                         // Send our own debug messages to tty as well.
252                                                         logw = tty0
253                                                 }
254                                                 cmd.Args = append(cmd.Args, *gw.DockerContainerID)
255                                                 cmd.Args = append(cmd.Args, execargs...)
256                                                 cmd.SysProcAttr = &syscall.SysProcAttr{
257                                                         Setctty: tty0 != nil,
258                                                         Setsid:  true,
259                                                 }
260                                                 cmd.Env = append(os.Environ(), termEnv...)
261                                                 err := cmd.Run()
262                                                 var resp struct {
263                                                         Status uint32
264                                                 }
265                                                 if exiterr, ok := err.(*exec.ExitError); ok {
266                                                         if status, ok := exiterr.Sys().(syscall.WaitStatus); ok {
267                                                                 resp.Status = uint32(status.ExitStatus())
268                                                         }
269                                                 } else if err != nil {
270                                                         // Propagate errors like `exec: "docker": executable file not found in $PATH`
271                                                         fmt.Fprintln(ch.Stderr(), err)
272                                                 }
273                                                 errClose := ch.CloseWrite()
274                                                 if resp.Status == 0 && (err != nil || errClose != nil) {
275                                                         resp.Status = 1
276                                                 }
277                                                 ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
278                                                 ch.Close()
279                                         }()
280                                 case "pty-req":
281                                         eol = "\r\n"
282                                         p, t, err := pty.Open()
283                                         if err != nil {
284                                                 fmt.Fprintf(ch.Stderr(), "pty failed: %s"+eol, err)
285                                                 break
286                                         }
287                                         defer p.Close()
288                                         defer t.Close()
289                                         pty0, tty0 = p, t
290                                         ok = true
291                                         var payload struct {
292                                                 Term string
293                                                 Cols uint32
294                                                 Rows uint32
295                                                 X    uint32
296                                                 Y    uint32
297                                         }
298                                         ssh.Unmarshal(req.Payload, &payload)
299                                         termEnv = []string{"TERM=" + payload.Term, "USE_TTY=1"}
300                                         err = pty.Setsize(pty0, &pty.Winsize{Rows: uint16(payload.Rows), Cols: uint16(payload.Cols), X: uint16(payload.X), Y: uint16(payload.Y)})
301                                         if err != nil {
302                                                 fmt.Fprintf(logw, "pty-req: setsize failed: %s"+eol, err)
303                                         }
304                                 case "window-change":
305                                         var payload struct {
306                                                 Cols uint32
307                                                 Rows uint32
308                                                 X    uint32
309                                                 Y    uint32
310                                         }
311                                         ssh.Unmarshal(req.Payload, &payload)
312                                         err := pty.Setsize(pty0, &pty.Winsize{Rows: uint16(payload.Rows), Cols: uint16(payload.Cols), X: uint16(payload.X), Y: uint16(payload.Y)})
313                                         if err != nil {
314                                                 fmt.Fprintf(logw, "window-change: setsize failed: %s"+eol, err)
315                                                 break
316                                         }
317                                         ok = true
318                                 case "env":
319                                         // TODO: implement "env"
320                                         // requests by setting env
321                                         // vars in the docker-exec
322                                         // command (not docker-exec's
323                                         // own environment, which
324                                         // would be a gaping security
325                                         // hole).
326                                 default:
327                                         // fmt.Fprintf(logw, "declining %q req"+eol, req.Type)
328                                 }
329                                 if req.WantReply {
330                                         req.Reply(ok, nil)
331                                 }
332                         }
333                 }()
334         }
335 }