19166: Set up tunnel for container gateway requests
[arvados.git] / lib / controller / localdb / container_gateway.go
index 3b40eccaff68c138e498b9ec497f5f704e249f64..fcfa599e4db81c5ae64a8f720a7afd30dcbc3ca3 100644 (file)
@@ -6,13 +6,16 @@ package localdb
 
 import (
        "bufio"
+       "bytes"
        "context"
        "crypto/hmac"
        "crypto/sha256"
+       "crypto/subtle"
        "crypto/tls"
        "crypto/x509"
        "errors"
        "fmt"
+       "net"
        "net/http"
        "net/url"
        "strings"
@@ -21,6 +24,7 @@ import (
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "git.arvados.org/arvados.git/sdk/go/httpserver"
+       "github.com/hashicorp/yamux"
 )
 
 // ContainerSSH returns a connection to the SSH server in the
@@ -61,19 +65,33 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
                }
        }
 
-       switch ctr.State {
-       case arvados.ContainerStateQueued, arvados.ContainerStateLocked:
+       conn.gwTunnelsLock.Lock()
+       tunnel := conn.gwTunnels[opts.UUID]
+       conn.gwTunnelsLock.Unlock()
+
+       if ctr.State == arvados.ContainerStateQueued || ctr.State == arvados.ContainerStateLocked {
                err = httpserver.ErrorWithStatus(fmt.Errorf("container is not running yet (state is %q)", ctr.State), http.StatusServiceUnavailable)
                return
-       case arvados.ContainerStateRunning:
-               if ctr.GatewayAddress == "" {
-                       err = httpserver.ErrorWithStatus(errors.New("container is running but gateway is not available -- installation problem or feature not supported"), http.StatusServiceUnavailable)
-                       return
-               }
-       default:
+       } else if ctr.State != arvados.ContainerStateRunning {
                err = httpserver.ErrorWithStatus(fmt.Errorf("container has ended (state is %q)", ctr.State), http.StatusGone)
                return
        }
+
+       var rawconn net.Conn
+       if ctr.GatewayAddress != "" && !strings.HasPrefix(ctr.GatewayAddress, "127.0.0.1:") {
+               rawconn, err = net.Dial("tcp", ctr.GatewayAddress)
+       } else if tunnel != nil {
+               rawconn, err = tunnel.Open()
+       } else if ctr.GatewayAddress == "" {
+               err = errors.New("container is running but gateway is not available")
+       } else {
+               err = errors.New("container gateway is running but tunnel is down")
+       }
+       if err != nil {
+               err = httpserver.ErrorWithStatus(err, http.StatusServiceUnavailable)
+               return
+       }
+
        // crunch-run uses a self-signed / unverifiable TLS
        // certificate, so we use the following scheme to ensure we're
        // not talking to a MITM.
@@ -93,7 +111,7 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
        // X-Arvados-Authorization-Response header, proving that the
        // server knows ctrKey.
        var requestAuth, respondAuth string
-       netconn, err := tls.Dial("tcp", ctr.GatewayAddress, &tls.Config{
+       tlsconn := tls.Client(rawconn, &tls.Config{
                InsecureSkipVerify: true,
                VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
                        if len(rawCerts) == 0 {
@@ -111,23 +129,25 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
                        return nil
                },
        })
+       err = tlsconn.HandshakeContext(ctx)
        if err != nil {
                err = httpserver.ErrorWithStatus(err, http.StatusBadGateway)
                return
        }
        if respondAuth == "" {
+               tlsconn.Close()
                err = httpserver.ErrorWithStatus(errors.New("BUG: no respondAuth"), http.StatusInternalServerError)
                return
        }
-       bufr := bufio.NewReader(netconn)
-       bufw := bufio.NewWriter(netconn)
+       bufr := bufio.NewReader(tlsconn)
+       bufw := bufio.NewWriter(tlsconn)
 
        u := url.URL{
                Scheme: "http",
                Host:   ctr.GatewayAddress,
                Path:   "/ssh",
        }
-       bufw.WriteString("GET " + u.String() + " HTTP/1.1\r\n")
+       bufw.WriteString("POST " + u.String() + " HTTP/1.1\r\n")
        bufw.WriteString("Host: " + u.Host + "\r\n")
        bufw.WriteString("Upgrade: ssh\r\n")
        bufw.WriteString("X-Arvados-Target-Uuid: " + opts.UUID + "\r\n")
@@ -139,18 +159,18 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
        resp, err := http.ReadResponse(bufr, &http.Request{Method: "GET"})
        if err != nil {
                err = httpserver.ErrorWithStatus(fmt.Errorf("error reading http response from gateway: %w", err), http.StatusBadGateway)
-               netconn.Close()
+               tlsconn.Close()
                return
        }
        if resp.Header.Get("X-Arvados-Authorization-Response") != respondAuth {
                err = httpserver.ErrorWithStatus(errors.New("bad X-Arvados-Authorization-Response header"), http.StatusBadGateway)
-               netconn.Close()
+               tlsconn.Close()
                return
        }
        if strings.ToLower(resp.Header.Get("Upgrade")) != "ssh" ||
                strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
                err = httpserver.ErrorWithStatus(errors.New("bad upgrade"), http.StatusBadGateway)
-               netconn.Close()
+               tlsconn.Close()
                return
        }
 
@@ -162,13 +182,60 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
                        },
                })
                if err != nil {
-                       netconn.Close()
+                       tlsconn.Close()
                        return
                }
        }
 
-       sshconn.Conn = netconn
+       sshconn.Conn = tlsconn
        sshconn.Bufrw = &bufio.ReadWriter{Reader: bufr, Writer: bufw}
        sshconn.Logger = ctxlog.FromContext(ctx)
+       sshconn.UpgradeHeader = "ssh"
+       return
+}
+
+// ContainerGatewayTunnel sets up a tunnel enabling us (controller) to
+// connect to the caller's (crunch-run's) gateway server.
+func (conn *Conn) ContainerGatewayTunnel(ctx context.Context, opts arvados.ContainerGatewayTunnelOptions) (resp arvados.ConnectionResponse, err error) {
+       h := hmac.New(sha256.New, []byte(conn.cluster.SystemRootToken))
+       fmt.Fprint(h, opts.UUID)
+       authSecret := fmt.Sprintf("%x", h.Sum(nil))
+       if subtle.ConstantTimeCompare([]byte(authSecret), []byte(opts.AuthSecret)) != 1 {
+               ctxlog.FromContext(ctx).Info("received incorrect auth_secret")
+               return resp, httpserver.ErrorWithStatus(errors.New("authentication error"), http.StatusUnauthorized)
+       }
+
+       muxconn, clientconn := net.Pipe()
+       tunnel, err := yamux.Server(muxconn, nil)
+       if err != nil {
+               clientconn.Close()
+               return resp, httpserver.ErrorWithStatus(err, http.StatusInternalServerError)
+       }
+
+       conn.gwTunnelsLock.Lock()
+       if conn.gwTunnels == nil {
+               conn.gwTunnels = map[string]*yamux.Session{opts.UUID: tunnel}
+       } else {
+               conn.gwTunnels[opts.UUID] = tunnel
+       }
+       conn.gwTunnelsLock.Unlock()
+
+       go func() {
+               <-tunnel.CloseChan()
+               conn.gwTunnelsLock.Lock()
+               if conn.gwTunnels[opts.UUID] == tunnel {
+                       delete(conn.gwTunnels, opts.UUID)
+               }
+               conn.gwTunnelsLock.Unlock()
+       }()
+
+       // Assuming we're acting as the backend of an http server,
+       // lib/controller/router will call resp's ServeHTTP handler,
+       // which upgrades the incoming http connection to a raw socket
+       // and connects it to our yamux.Server through our net.Pipe().
+       resp.Conn = clientconn
+       resp.Bufrw = &bufio.ReadWriter{Reader: bufio.NewReader(&bytes.Buffer{}), Writer: bufio.NewWriter(&bytes.Buffer{})}
+       resp.Logger = ctxlog.FromContext(ctx)
+       resp.UpgradeHeader = "tunnel"
        return
 }