12630: Fix test
[arvados.git] / lib / crunchrun / container_gateway.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package crunchrun
6
7 import (
8         "crypto/hmac"
9         "crypto/rand"
10         "crypto/rsa"
11         "crypto/sha256"
12         "crypto/tls"
13         "fmt"
14         "io"
15         "net"
16         "net/http"
17         "os"
18         "os/exec"
19         "sync"
20         "sync/atomic"
21         "syscall"
22         "time"
23
24         "git.arvados.org/arvados.git/lib/selfsigned"
25         "git.arvados.org/arvados.git/sdk/go/ctxlog"
26         "git.arvados.org/arvados.git/sdk/go/httpserver"
27         "github.com/creack/pty"
28         dockerclient "github.com/docker/docker/client"
29         "github.com/google/shlex"
30         "golang.org/x/crypto/ssh"
31         "golang.org/x/net/context"
32 )
33
34 type Gateway struct {
35         DockerContainerID *string
36         ContainerUUID     string
37         Address           string // listen host:port; if port=0, Start() will change it to the selected port
38         AuthSecret        string
39         Log               interface {
40                 Printf(fmt string, args ...interface{})
41         }
42         // return local ip address of running container, or "" if not available
43         ContainerIPAddress func() (string, error)
44
45         sshConfig   ssh.ServerConfig
46         requestAuth string
47         respondAuth string
48 }
49
50 // Start starts an http server that allows authenticated clients to open an
51 // interactive "docker exec" session and (in future) connect to tcp ports
52 // inside the docker container.
53 func (gw *Gateway) Start() error {
54         gw.sshConfig = ssh.ServerConfig{
55                 NoClientAuth: true,
56                 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
57                         if c.User() == "_" {
58                                 return nil, nil
59                         }
60                         return nil, fmt.Errorf("cannot specify user %q via ssh client", c.User())
61                 },
62                 PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
63                         if c.User() == "_" {
64                                 return &ssh.Permissions{
65                                         Extensions: map[string]string{
66                                                 "pubkey-fp": ssh.FingerprintSHA256(pubKey),
67                                         },
68                                 }, nil
69                         }
70                         return nil, fmt.Errorf("cannot specify user %q via ssh client", c.User())
71                 },
72         }
73         pvt, err := rsa.GenerateKey(rand.Reader, 2048)
74         if err != nil {
75                 return err
76         }
77         err = pvt.Validate()
78         if err != nil {
79                 return err
80         }
81         signer, err := ssh.NewSignerFromKey(pvt)
82         if err != nil {
83                 return err
84         }
85         gw.sshConfig.AddHostKey(signer)
86
87         // Address (typically provided by arvados-dispatch-cloud) is
88         // HOST:PORT where HOST is our IP address or hostname as seen
89         // from arvados-controller, and PORT is either the desired
90         // port where we should run our gateway server, or "0" if we
91         // should choose an available port.
92         host, port, err := net.SplitHostPort(gw.Address)
93         if err != nil {
94                 return err
95         }
96         cert, err := selfsigned.CertGenerator{}.Generate()
97         if err != nil {
98                 return err
99         }
100         h := hmac.New(sha256.New, []byte(gw.AuthSecret))
101         h.Write(cert.Certificate[0])
102         gw.requestAuth = fmt.Sprintf("%x", h.Sum(nil))
103         h.Reset()
104         h.Write([]byte(gw.requestAuth))
105         gw.respondAuth = fmt.Sprintf("%x", h.Sum(nil))
106
107         srv := &httpserver.Server{
108                 Server: http.Server{
109                         Handler: http.HandlerFunc(gw.handleSSH),
110                         TLSConfig: &tls.Config{
111                                 Certificates: []tls.Certificate{cert},
112                         },
113                 },
114                 Addr: ":" + port,
115         }
116         err = srv.Start()
117         if err != nil {
118                 return err
119         }
120         // Get the port number we are listening on (the port might be
121         // "0" or a port name, in which case this will be different).
122         _, port, err = net.SplitHostPort(srv.Addr)
123         if err != nil {
124                 return err
125         }
126         // When changing state to Running, we will set
127         // gateway_address to "HOST:PORT" where HOST is our
128         // external hostname/IP as provided by arvados-dispatch-cloud,
129         // and PORT is the port number we ended up listening on.
130         gw.Address = net.JoinHostPort(host, port)
131         return nil
132 }
133
134 // handleSSH connects to an SSH server that allows the caller to run
135 // interactive commands as root (or any other desired user) inside the
136 // container. The tunnel itself can only be created by an
137 // authenticated caller, so the SSH server itself is wide open (any
138 // password or key will be accepted).
139 //
140 // Requests must have path "/ssh" and the following headers:
141 //
142 // Connection: upgrade
143 // Upgrade: ssh
144 // X-Arvados-Target-Uuid: uuid of container
145 // X-Arvados-Authorization: must match
146 // hmac(AuthSecret,certfingerprint) (this prevents other containers
147 // and shell nodes from connecting directly)
148 //
149 // Optional headers:
150 //
151 // X-Arvados-Detach-Keys: argument to "docker exec --detach-keys",
152 // e.g., "ctrl-p,ctrl-q"
153 // X-Arvados-Login-Username: argument to "docker exec --user": account
154 // used to run command(s) inside the container.
155 func (gw *Gateway) handleSSH(w http.ResponseWriter, req *http.Request) {
156         // In future we'll handle browser traffic too, but for now the
157         // only traffic we expect is an SSH tunnel from
158         // (*lib/controller/localdb.Conn)ContainerSSH()
159         if req.Method != "GET" || req.Header.Get("Upgrade") != "ssh" {
160                 http.Error(w, "path not found", http.StatusNotFound)
161                 return
162         }
163         if want := req.Header.Get("X-Arvados-Target-Uuid"); want != gw.ContainerUUID {
164                 http.Error(w, fmt.Sprintf("misdirected request: meant for %q but received by crunch-run %q", want, gw.ContainerUUID), http.StatusBadGateway)
165                 return
166         }
167         if req.Header.Get("X-Arvados-Authorization") != gw.requestAuth {
168                 http.Error(w, "bad X-Arvados-Authorization header", http.StatusUnauthorized)
169                 return
170         }
171         detachKeys := req.Header.Get("X-Arvados-Detach-Keys")
172         username := req.Header.Get("X-Arvados-Login-Username")
173         if username == "" {
174                 username = "root"
175         }
176         hj, ok := w.(http.Hijacker)
177         if !ok {
178                 http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError)
179                 return
180         }
181         netconn, _, err := hj.Hijack()
182         if !ok {
183                 http.Error(w, err.Error(), http.StatusInternalServerError)
184                 return
185         }
186         defer netconn.Close()
187         w.Header().Set("Connection", "upgrade")
188         w.Header().Set("Upgrade", "ssh")
189         w.Header().Set("X-Arvados-Authorization-Response", gw.respondAuth)
190         netconn.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n"))
191         w.Header().Write(netconn)
192         netconn.Write([]byte("\r\n"))
193
194         ctx := req.Context()
195
196         conn, newchans, reqs, err := ssh.NewServerConn(netconn, &gw.sshConfig)
197         if err != nil {
198                 gw.Log.Printf("ssh.NewServerConn: %s", err)
199                 return
200         }
201         defer conn.Close()
202         go ssh.DiscardRequests(reqs)
203         for newch := range newchans {
204                 switch newch.ChannelType() {
205                 case "direct-tcpip":
206                         go gw.handleDirectTCPIP(ctx, newch)
207                 case "session":
208                         go gw.handleSession(ctx, newch, detachKeys, username)
209                 default:
210                         go newch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type %q", newch.ChannelType()))
211                 }
212         }
213 }
214
215 func (gw *Gateway) handleDirectTCPIP(ctx context.Context, newch ssh.NewChannel) {
216         ch, reqs, err := newch.Accept()
217         if err != nil {
218                 gw.Log.Printf("accept direct-tcpip channel: %s", err)
219                 return
220         }
221         defer ch.Close()
222         go ssh.DiscardRequests(reqs)
223
224         // RFC 4254 7.2 (copy of channelOpenDirectMsg in
225         // golang.org/x/crypto/ssh)
226         var msg struct {
227                 Raddr string
228                 Rport uint32
229                 Laddr string
230                 Lport uint32
231         }
232         err = ssh.Unmarshal(newch.ExtraData(), &msg)
233         if err != nil {
234                 fmt.Fprintf(ch.Stderr(), "unmarshal direct-tcpip extradata: %s\n", err)
235                 return
236         }
237         switch msg.Raddr {
238         case "localhost", "0.0.0.0", "127.0.0.1", "::1", "::":
239         default:
240                 fmt.Fprintf(ch.Stderr(), "cannot forward to ports on %q, only localhost\n", msg.Raddr)
241                 return
242         }
243
244         var dstaddr string
245         if gw.ContainerIPAddress != nil {
246                 dstaddr, err = gw.ContainerIPAddress()
247                 if err != nil {
248                         fmt.Fprintf(ch.Stderr(), "container has no IP address: %s\n", err)
249                         return
250                 }
251         }
252         if dstaddr == "" {
253                 fmt.Fprintf(ch.Stderr(), "container has no IP address\n")
254                 return
255         }
256
257         dst := net.JoinHostPort(dstaddr, fmt.Sprintf("%d", msg.Rport))
258         tcpconn, err := net.Dial("tcp", dst)
259         if err != nil {
260                 fmt.Fprintf(ch.Stderr(), "%s: %s\n", dst, err)
261                 return
262         }
263         go func() {
264                 n, _ := io.Copy(ch, tcpconn)
265                 ctxlog.FromContext(ctx).Debugf("tcpip: sent %d bytes\n", n)
266                 ch.CloseWrite()
267         }()
268         n, _ := io.Copy(tcpconn, ch)
269         ctxlog.FromContext(ctx).Debugf("tcpip: received %d bytes\n", n)
270 }
271
272 func (gw *Gateway) handleSession(ctx context.Context, newch ssh.NewChannel, detachKeys, username string) {
273         ch, reqs, err := newch.Accept()
274         if err != nil {
275                 gw.Log.Printf("accept session channel: %s", err)
276                 return
277         }
278         var pty0, tty0 *os.File
279         // Where to send errors/messages for the client to see
280         logw := io.Writer(ch.Stderr())
281         // How to end lines when sending errors/messages to the client
282         // (changes to \r\n when using a pty)
283         eol := "\n"
284         // Env vars to add to child process
285         termEnv := []string(nil)
286         for req := range reqs {
287                 ok := false
288                 switch req.Type {
289                 case "shell", "exec":
290                         ok = true
291                         var payload struct {
292                                 Command string
293                         }
294                         ssh.Unmarshal(req.Payload, &payload)
295                         execargs, err := shlex.Split(payload.Command)
296                         if err != nil {
297                                 fmt.Fprintf(logw, "error parsing supplied command: %s"+eol, err)
298                                 return
299                         }
300                         if len(execargs) == 0 {
301                                 execargs = []string{"/bin/bash", "-login"}
302                         }
303                         go func() {
304                                 cmd := exec.CommandContext(ctx, "docker", "exec", "-i", "--detach-keys="+detachKeys, "--user="+username)
305                                 cmd.Stdin = ch
306                                 cmd.Stdout = ch
307                                 cmd.Stderr = ch.Stderr()
308                                 if tty0 != nil {
309                                         cmd.Args = append(cmd.Args, "-t")
310                                         cmd.Stdin = tty0
311                                         cmd.Stdout = tty0
312                                         cmd.Stderr = tty0
313                                         var wg sync.WaitGroup
314                                         defer wg.Wait()
315                                         wg.Add(2)
316                                         go func() { io.Copy(ch, pty0); wg.Done() }()
317                                         go func() { io.Copy(pty0, ch); wg.Done() }()
318                                         // Send our own debug messages to tty as well.
319                                         logw = tty0
320                                 }
321                                 cmd.Args = append(cmd.Args, *gw.DockerContainerID)
322                                 cmd.Args = append(cmd.Args, execargs...)
323                                 cmd.SysProcAttr = &syscall.SysProcAttr{
324                                         Setctty: tty0 != nil,
325                                         Setsid:  true,
326                                 }
327                                 cmd.Env = append(os.Environ(), termEnv...)
328                                 err := cmd.Run()
329                                 var resp struct {
330                                         Status uint32
331                                 }
332                                 if exiterr, ok := err.(*exec.ExitError); ok {
333                                         if status, ok := exiterr.Sys().(syscall.WaitStatus); ok {
334                                                 resp.Status = uint32(status.ExitStatus())
335                                         }
336                                 } else if err != nil {
337                                         // Propagate errors like `exec: "docker": executable file not found in $PATH`
338                                         fmt.Fprintln(ch.Stderr(), err)
339                                 }
340                                 errClose := ch.CloseWrite()
341                                 if resp.Status == 0 && (err != nil || errClose != nil) {
342                                         resp.Status = 1
343                                 }
344                                 ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
345                                 ch.Close()
346                         }()
347                 case "pty-req":
348                         eol = "\r\n"
349                         p, t, err := pty.Open()
350                         if err != nil {
351                                 fmt.Fprintf(ch.Stderr(), "pty failed: %s"+eol, err)
352                                 break
353                         }
354                         defer p.Close()
355                         defer t.Close()
356                         pty0, tty0 = p, t
357                         ok = true
358                         var payload struct {
359                                 Term string
360                                 Cols uint32
361                                 Rows uint32
362                                 X    uint32
363                                 Y    uint32
364                         }
365                         ssh.Unmarshal(req.Payload, &payload)
366                         termEnv = []string{"TERM=" + payload.Term, "USE_TTY=1"}
367                         err = pty.Setsize(pty0, &pty.Winsize{Rows: uint16(payload.Rows), Cols: uint16(payload.Cols), X: uint16(payload.X), Y: uint16(payload.Y)})
368                         if err != nil {
369                                 fmt.Fprintf(logw, "pty-req: setsize failed: %s"+eol, err)
370                         }
371                 case "window-change":
372                         var payload struct {
373                                 Cols uint32
374                                 Rows uint32
375                                 X    uint32
376                                 Y    uint32
377                         }
378                         ssh.Unmarshal(req.Payload, &payload)
379                         err := pty.Setsize(pty0, &pty.Winsize{Rows: uint16(payload.Rows), Cols: uint16(payload.Cols), X: uint16(payload.X), Y: uint16(payload.Y)})
380                         if err != nil {
381                                 fmt.Fprintf(logw, "window-change: setsize failed: %s"+eol, err)
382                                 break
383                         }
384                         ok = true
385                 case "env":
386                         // TODO: implement "env"
387                         // requests by setting env
388                         // vars in the docker-exec
389                         // command (not docker-exec's
390                         // own environment, which
391                         // would be a gaping security
392                         // hole).
393                 default:
394                         // fmt.Fprintf(logw, "declining %q req"+eol, req.Type)
395                 }
396                 if req.WantReply {
397                         req.Reply(ok, nil)
398                 }
399         }
400 }
401
402 func dockerContainerIPAddress(containerID *string) func() (string, error) {
403         var saved atomic.Value
404         return func() (string, error) {
405                 if ip, ok := saved.Load().(*string); ok {
406                         return *ip, nil
407                 }
408                 docker, err := dockerclient.NewClient(dockerclient.DefaultDockerHost, "1.21", nil, nil)
409                 if err != nil {
410                         return "", fmt.Errorf("cannot create docker client: %s", err)
411                 }
412                 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
413                 defer cancel()
414                 ctr, err := docker.ContainerInspect(ctx, *containerID)
415                 if err != nil {
416                         return "", fmt.Errorf("cannot get docker container info: %s", err)
417                 }
418                 ip := ctr.NetworkSettings.IPAddress
419                 if ip == "" {
420                         // TODO: try to enable networking if it wasn't
421                         // already enabled when the container was
422                         // created.
423                         return "", fmt.Errorf("container has no IP address")
424                 }
425                 saved.Store(&ip)
426                 return ip, nil
427         }
428 }