19099: Refactor container shell backend so it's not docker-specific.
[arvados.git] / lib / crunchrun / container_gateway.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package crunchrun
6
7 import (
8         "crypto/hmac"
9         "crypto/rand"
10         "crypto/rsa"
11         "crypto/sha256"
12         "crypto/tls"
13         "fmt"
14         "io"
15         "net"
16         "net/http"
17         "os"
18         "os/exec"
19         "sync"
20         "syscall"
21
22         "git.arvados.org/arvados.git/lib/selfsigned"
23         "git.arvados.org/arvados.git/sdk/go/ctxlog"
24         "git.arvados.org/arvados.git/sdk/go/httpserver"
25         "github.com/creack/pty"
26         "github.com/google/shlex"
27         "golang.org/x/crypto/ssh"
28         "golang.org/x/net/context"
29 )
30
31 type GatewayTarget interface {
32         // Command that will execute cmd inside the container
33         InjectCommand(ctx context.Context, detachKeys, username string, usingTTY bool, cmd []string) (*exec.Cmd, error)
34
35         // IP address inside container
36         IPAddress() (string, error)
37 }
38
39 type Gateway struct {
40         ContainerUUID string
41         Address       string // listen host:port; if port=0, Start() will change it to the selected port
42         AuthSecret    string
43         Target        GatewayTarget
44         Log           interface {
45                 Printf(fmt string, args ...interface{})
46         }
47
48         sshConfig   ssh.ServerConfig
49         requestAuth string
50         respondAuth string
51 }
52
53 // Start starts an http server that allows authenticated clients to open an
54 // interactive "docker exec" session and (in future) connect to tcp ports
55 // inside the docker container.
56 func (gw *Gateway) Start() error {
57         gw.sshConfig = ssh.ServerConfig{
58                 NoClientAuth: true,
59                 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
60                         if c.User() == "_" {
61                                 return nil, nil
62                         }
63                         return nil, fmt.Errorf("cannot specify user %q via ssh client", c.User())
64                 },
65                 PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
66                         if c.User() == "_" {
67                                 return &ssh.Permissions{
68                                         Extensions: map[string]string{
69                                                 "pubkey-fp": ssh.FingerprintSHA256(pubKey),
70                                         },
71                                 }, nil
72                         }
73                         return nil, fmt.Errorf("cannot specify user %q via ssh client", c.User())
74                 },
75         }
76         pvt, err := rsa.GenerateKey(rand.Reader, 2048)
77         if err != nil {
78                 return err
79         }
80         err = pvt.Validate()
81         if err != nil {
82                 return err
83         }
84         signer, err := ssh.NewSignerFromKey(pvt)
85         if err != nil {
86                 return err
87         }
88         gw.sshConfig.AddHostKey(signer)
89
90         // Address (typically provided by arvados-dispatch-cloud) is
91         // HOST:PORT where HOST is our IP address or hostname as seen
92         // from arvados-controller, and PORT is either the desired
93         // port where we should run our gateway server, or "0" if we
94         // should choose an available port.
95         host, port, err := net.SplitHostPort(gw.Address)
96         if err != nil {
97                 return err
98         }
99         cert, err := selfsigned.CertGenerator{}.Generate()
100         if err != nil {
101                 return err
102         }
103         h := hmac.New(sha256.New, []byte(gw.AuthSecret))
104         h.Write(cert.Certificate[0])
105         gw.requestAuth = fmt.Sprintf("%x", h.Sum(nil))
106         h.Reset()
107         h.Write([]byte(gw.requestAuth))
108         gw.respondAuth = fmt.Sprintf("%x", h.Sum(nil))
109
110         srv := &httpserver.Server{
111                 Server: http.Server{
112                         Handler: http.HandlerFunc(gw.handleSSH),
113                         TLSConfig: &tls.Config{
114                                 Certificates: []tls.Certificate{cert},
115                         },
116                 },
117                 Addr: ":" + port,
118         }
119         err = srv.Start()
120         if err != nil {
121                 return err
122         }
123         // Get the port number we are listening on (the port might be
124         // "0" or a port name, in which case this will be different).
125         _, port, err = net.SplitHostPort(srv.Addr)
126         if err != nil {
127                 return err
128         }
129         // When changing state to Running, we will set
130         // gateway_address to "HOST:PORT" where HOST is our
131         // external hostname/IP as provided by arvados-dispatch-cloud,
132         // and PORT is the port number we ended up listening on.
133         gw.Address = net.JoinHostPort(host, port)
134         return nil
135 }
136
137 // handleSSH connects to an SSH server that allows the caller to run
138 // interactive commands as root (or any other desired user) inside the
139 // container. The tunnel itself can only be created by an
140 // authenticated caller, so the SSH server itself is wide open (any
141 // password or key will be accepted).
142 //
143 // Requests must have path "/ssh" and the following headers:
144 //
145 // Connection: upgrade
146 // Upgrade: ssh
147 // X-Arvados-Target-Uuid: uuid of container
148 // X-Arvados-Authorization: must match
149 // hmac(AuthSecret,certfingerprint) (this prevents other containers
150 // and shell nodes from connecting directly)
151 //
152 // Optional headers:
153 //
154 // X-Arvados-Detach-Keys: argument to "docker exec --detach-keys",
155 // e.g., "ctrl-p,ctrl-q"
156 // X-Arvados-Login-Username: argument to "docker exec --user": account
157 // used to run command(s) inside the container.
158 func (gw *Gateway) handleSSH(w http.ResponseWriter, req *http.Request) {
159         // In future we'll handle browser traffic too, but for now the
160         // only traffic we expect is an SSH tunnel from
161         // (*lib/controller/localdb.Conn)ContainerSSH()
162         if req.Method != "GET" || req.Header.Get("Upgrade") != "ssh" {
163                 http.Error(w, "path not found", http.StatusNotFound)
164                 return
165         }
166         if want := req.Header.Get("X-Arvados-Target-Uuid"); want != gw.ContainerUUID {
167                 http.Error(w, fmt.Sprintf("misdirected request: meant for %q but received by crunch-run %q", want, gw.ContainerUUID), http.StatusBadGateway)
168                 return
169         }
170         if req.Header.Get("X-Arvados-Authorization") != gw.requestAuth {
171                 http.Error(w, "bad X-Arvados-Authorization header", http.StatusUnauthorized)
172                 return
173         }
174         detachKeys := req.Header.Get("X-Arvados-Detach-Keys")
175         username := req.Header.Get("X-Arvados-Login-Username")
176         if username == "" {
177                 username = "root"
178         }
179         hj, ok := w.(http.Hijacker)
180         if !ok {
181                 http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError)
182                 return
183         }
184         netconn, _, err := hj.Hijack()
185         if !ok {
186                 http.Error(w, err.Error(), http.StatusInternalServerError)
187                 return
188         }
189         defer netconn.Close()
190         w.Header().Set("Connection", "upgrade")
191         w.Header().Set("Upgrade", "ssh")
192         w.Header().Set("X-Arvados-Authorization-Response", gw.respondAuth)
193         netconn.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n"))
194         w.Header().Write(netconn)
195         netconn.Write([]byte("\r\n"))
196
197         ctx := req.Context()
198
199         conn, newchans, reqs, err := ssh.NewServerConn(netconn, &gw.sshConfig)
200         if err != nil {
201                 gw.Log.Printf("ssh.NewServerConn: %s", err)
202                 return
203         }
204         defer conn.Close()
205         go ssh.DiscardRequests(reqs)
206         for newch := range newchans {
207                 switch newch.ChannelType() {
208                 case "direct-tcpip":
209                         go gw.handleDirectTCPIP(ctx, newch)
210                 case "session":
211                         go gw.handleSession(ctx, newch, detachKeys, username)
212                 default:
213                         go newch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type %q", newch.ChannelType()))
214                 }
215         }
216 }
217
218 func (gw *Gateway) handleDirectTCPIP(ctx context.Context, newch ssh.NewChannel) {
219         ch, reqs, err := newch.Accept()
220         if err != nil {
221                 gw.Log.Printf("accept direct-tcpip channel: %s", err)
222                 return
223         }
224         defer ch.Close()
225         go ssh.DiscardRequests(reqs)
226
227         // RFC 4254 7.2 (copy of channelOpenDirectMsg in
228         // golang.org/x/crypto/ssh)
229         var msg struct {
230                 Raddr string
231                 Rport uint32
232                 Laddr string
233                 Lport uint32
234         }
235         err = ssh.Unmarshal(newch.ExtraData(), &msg)
236         if err != nil {
237                 fmt.Fprintf(ch.Stderr(), "unmarshal direct-tcpip extradata: %s\n", err)
238                 return
239         }
240         switch msg.Raddr {
241         case "localhost", "0.0.0.0", "127.0.0.1", "::1", "::":
242         default:
243                 fmt.Fprintf(ch.Stderr(), "cannot forward to ports on %q, only localhost\n", msg.Raddr)
244                 return
245         }
246
247         dstaddr, err := gw.Target.IPAddress()
248         if err != nil {
249                 fmt.Fprintf(ch.Stderr(), "container has no IP address: %s\n", err)
250                 return
251         } else if dstaddr == "" {
252                 fmt.Fprintf(ch.Stderr(), "container has no IP address\n")
253                 return
254         }
255
256         dst := net.JoinHostPort(dstaddr, fmt.Sprintf("%d", msg.Rport))
257         tcpconn, err := net.Dial("tcp", dst)
258         if err != nil {
259                 fmt.Fprintf(ch.Stderr(), "%s: %s\n", dst, err)
260                 return
261         }
262         go func() {
263                 n, _ := io.Copy(ch, tcpconn)
264                 ctxlog.FromContext(ctx).Debugf("tcpip: sent %d bytes\n", n)
265                 ch.CloseWrite()
266         }()
267         n, _ := io.Copy(tcpconn, ch)
268         ctxlog.FromContext(ctx).Debugf("tcpip: received %d bytes\n", n)
269 }
270
271 func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, detachKeys, username string) {
272         ch, reqs, err := newch.Accept()
273         if err != nil {
274                 gw.Log.Printf("accept session channel: %s", err)
275                 return
276         }
277         var pty0, tty0 *os.File
278         // Where to send errors/messages for the client to see
279         logw := io.Writer(ch.Stderr())
280         // How to end lines when sending errors/messages to the client
281         // (changes to \r\n when using a pty)
282         eol := "\n"
283         // Env vars to add to child process
284         termEnv := []string(nil)
285         for req := range reqs {
286                 ok := false
287                 switch req.Type {
288                 case "shell", "exec":
289                         ok = true
290                         var payload struct {
291                                 Command string
292                         }
293                         ssh.Unmarshal(req.Payload, &payload)
294                         execargs, err := shlex.Split(payload.Command)
295                         if err != nil {
296                                 fmt.Fprintf(logw, "error parsing supplied command: %s"+eol, err)
297                                 return
298                         }
299                         if len(execargs) == 0 {
300                                 execargs = []string{"/bin/bash", "-login"}
301                         }
302                         go func() {
303                                 var resp struct {
304                                         Status uint32
305                                 }
306                                 defer func() {
307                                         ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
308                                         ch.Close()
309                                 }()
310
311                                 cmd, err := gw.Target.InjectCommand(ctx, detachKeys, username, tty0 != nil, execargs)
312                                 if err != nil {
313                                         fmt.Fprintln(ch.Stderr(), err)
314                                         ch.CloseWrite()
315                                         resp.Status = 1
316                                         return
317                                 }
318                                 cmd.Stdin = ch
319                                 cmd.Stdout = ch
320                                 cmd.Stderr = ch.Stderr()
321                                 if tty0 != nil {
322                                         cmd.Stdin = tty0
323                                         cmd.Stdout = tty0
324                                         cmd.Stderr = tty0
325                                         var wg sync.WaitGroup
326                                         defer wg.Wait()
327                                         wg.Add(2)
328                                         go func() { io.Copy(ch, pty0); wg.Done() }()
329                                         go func() { io.Copy(pty0, ch); wg.Done() }()
330                                         // Send our own debug messages to tty as well.
331                                         logw = tty0
332                                 }
333                                 cmd.SysProcAttr = &syscall.SysProcAttr{
334                                         Setctty: tty0 != nil,
335                                         Setsid:  true,
336                                 }
337                                 cmd.Env = append(os.Environ(), termEnv...)
338                                 err = cmd.Run()
339                                 if exiterr, ok := err.(*exec.ExitError); ok {
340                                         if status, ok := exiterr.Sys().(syscall.WaitStatus); ok {
341                                                 resp.Status = uint32(status.ExitStatus())
342                                         }
343                                 } else if err != nil {
344                                         // Propagate errors like `exec: "docker": executable file not found in $PATH`
345                                         fmt.Fprintln(ch.Stderr(), err)
346                                 }
347                                 errClose := ch.CloseWrite()
348                                 if resp.Status == 0 && (err != nil || errClose != nil) {
349                                         resp.Status = 1
350                                 }
351                         }()
352                 case "pty-req":
353                         eol = "\r\n"
354                         p, t, err := pty.Open()
355                         if err != nil {
356                                 fmt.Fprintf(ch.Stderr(), "pty failed: %s"+eol, err)
357                                 break
358                         }
359                         defer p.Close()
360                         defer t.Close()
361                         pty0, tty0 = p, t
362                         ok = true
363                         var payload struct {
364                                 Term string
365                                 Cols uint32
366                                 Rows uint32
367                                 X    uint32
368                                 Y    uint32
369                         }
370                         ssh.Unmarshal(req.Payload, &payload)
371                         termEnv = []string{"TERM=" + payload.Term, "USE_TTY=1"}
372                         err = pty.Setsize(pty0, &pty.Winsize{Rows: uint16(payload.Rows), Cols: uint16(payload.Cols), X: uint16(payload.X), Y: uint16(payload.Y)})
373                         if err != nil {
374                                 fmt.Fprintf(logw, "pty-req: setsize failed: %s"+eol, err)
375                         }
376                 case "window-change":
377                         var payload struct {
378                                 Cols uint32
379                                 Rows uint32
380                                 X    uint32
381                                 Y    uint32
382                         }
383                         ssh.Unmarshal(req.Payload, &payload)
384                         err := pty.Setsize(pty0, &pty.Winsize{Rows: uint16(payload.Rows), Cols: uint16(payload.Cols), X: uint16(payload.X), Y: uint16(payload.Y)})
385                         if err != nil {
386                                 fmt.Fprintf(logw, "window-change: setsize failed: %s"+eol, err)
387                                 break
388                         }
389                         ok = true
390                 case "env":
391                         // TODO: implement "env"
392                         // requests by setting env
393                         // vars in the docker-exec
394                         // command (not docker-exec's
395                         // own environment, which
396                         // would be a gaping security
397                         // hole).
398                 default:
399                         // fmt.Fprintf(logw, "declining %q req"+eol, req.Type)
400                 }
401                 if req.WantReply {
402                         req.Reply(ok, nil)
403                 }
404         }
405 }