19166: Proxy container shell request to other controller instance.
authorTom Clegg <tom@curii.com>
Wed, 22 Jun 2022 01:03:52 +0000 (21:03 -0400)
committerTom Clegg <tom@curii.com>
Fri, 24 Jun 2022 18:23:25 +0000 (14:23 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/controller/federation/conn.go
lib/controller/localdb/container_gateway.go
lib/controller/localdb/container_gateway_test.go
lib/controller/router/request.go
lib/controller/rpc/conn.go
lib/crunchrun/container_gateway.go
lib/crunchrun/crunchrun.go
sdk/go/arvados/api.go
sdk/go/arvados/container_gateway.go
sdk/go/arvadostest/api.go

index 08d3ab1a6ec9f6f46ab4e9b494c81674d4644a17..ffb150bf26aa148b511f4bbde98305469ffef5df 100644 (file)
@@ -375,7 +375,7 @@ func (conn *Conn) ContainerUnlock(ctx context.Context, options arvados.GetOption
        return conn.chooseBackend(options.UUID).ContainerUnlock(ctx, options)
 }
 
-func (conn *Conn) ContainerSSH(ctx context.Context, options arvados.ContainerSSHOptions) (arvados.ContainerSSHConnection, error) {
+func (conn *Conn) ContainerSSH(ctx context.Context, options arvados.ContainerSSHOptions) (arvados.ConnectionResponse, error) {
        return conn.chooseBackend(options.UUID).ContainerSSH(ctx, options)
 }
 
index 79812465477d152c02618b57f57a5d3210d42618..90c95deb35444e6d27efa02958dc71ebe287c779 100644 (file)
@@ -15,11 +15,15 @@ import (
        "crypto/x509"
        "errors"
        "fmt"
+       "io"
+       "io/ioutil"
        "net"
        "net/http"
        "net/url"
        "strings"
 
+       "git.arvados.org/arvados.git/lib/controller/rpc"
+       "git.arvados.org/arvados.git/lib/service"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/auth"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
@@ -27,41 +31,42 @@ import (
        "github.com/hashicorp/yamux"
 )
 
+var (
+       forceProxyForTest       = false
+       forceInternalURLForTest *arvados.URL
+)
+
 // ContainerSSH returns a connection to the SSH server in the
 // appropriate crunch-run process on the worker node where the
 // specified container is running.
 //
 // If the returned error is nil, the caller is responsible for closing
 // sshconn.Conn.
-func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOptions) (sshconn arvados.ContainerSSHConnection, err error) {
+func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOptions) (sshconn arvados.ConnectionResponse, err error) {
        user, err := conn.railsProxy.UserGetCurrent(ctx, arvados.GetOptions{})
        if err != nil {
-               return
+               return sshconn, err
        }
        ctr, err := conn.railsProxy.ContainerGet(ctx, arvados.GetOptions{UUID: opts.UUID})
        if err != nil {
-               return
+               return sshconn, err
        }
        ctxRoot := auth.NewContext(ctx, &auth.Credentials{Tokens: []string{conn.cluster.SystemRootToken}})
        if !user.IsAdmin || !conn.cluster.Containers.ShellAccess.Admin {
                if !conn.cluster.Containers.ShellAccess.User {
-                       err = httpserver.ErrorWithStatus(errors.New("shell access is disabled in config"), http.StatusServiceUnavailable)
-                       return
+                       return sshconn, httpserver.ErrorWithStatus(errors.New("shell access is disabled in config"), http.StatusServiceUnavailable)
                }
-               var crs arvados.ContainerRequestList
-               crs, err = conn.railsProxy.ContainerRequestList(ctxRoot, arvados.ListOptions{Limit: -1, Filters: []arvados.Filter{{"container_uuid", "=", opts.UUID}}})
+               crs, err := conn.railsProxy.ContainerRequestList(ctxRoot, arvados.ListOptions{Limit: -1, Filters: []arvados.Filter{{"container_uuid", "=", opts.UUID}}})
                if err != nil {
-                       return
+                       return sshconn, err
                }
                for _, cr := range crs.Items {
                        if cr.ModifiedByUserUUID != user.UUID {
-                               err = httpserver.ErrorWithStatus(errors.New("permission denied: container is associated with requests submitted by other users"), http.StatusForbidden)
-                               return
+                               return sshconn, httpserver.ErrorWithStatus(errors.New("permission denied: container is associated with requests submitted by other users"), http.StatusForbidden)
                        }
                }
                if crs.ItemsAvailable != len(crs.Items) {
-                       err = httpserver.ErrorWithStatus(errors.New("incomplete response while checking permission"), http.StatusInternalServerError)
-                       return
+                       return sshconn, httpserver.ErrorWithStatus(errors.New("incomplete response while checking permission"), http.StatusInternalServerError)
                }
        }
 
@@ -70,26 +75,77 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
        conn.gwTunnelsLock.Unlock()
 
        if ctr.State == arvados.ContainerStateQueued || ctr.State == arvados.ContainerStateLocked {
-               err = httpserver.ErrorWithStatus(fmt.Errorf("container is not running yet (state is %q)", ctr.State), http.StatusServiceUnavailable)
-               return
+               return sshconn, httpserver.ErrorWithStatus(fmt.Errorf("container is not running yet (state is %q)", ctr.State), http.StatusServiceUnavailable)
        } else if ctr.State != arvados.ContainerStateRunning {
-               err = httpserver.ErrorWithStatus(fmt.Errorf("container has ended (state is %q)", ctr.State), http.StatusGone)
-               return
+               return sshconn, httpserver.ErrorWithStatus(fmt.Errorf("container has ended (state is %q)", ctr.State), http.StatusGone)
        }
 
+       // targetHost is the value we'll use in the Host header in our
+       // "Upgrade: ssh" http request. It's just a placeholder
+       // "localhost", unless we decide to connect directly, in which
+       // case we'll set it to the gateway's external ip:host. (The
+       // gateway doesn't even look at it, but we might as well.)
+       targetHost := "localhost"
+       myURL, _ := service.URLFromContext(ctx)
+
        var rawconn net.Conn
-       if ctr.GatewayAddress != "" && !strings.HasPrefix(ctr.GatewayAddress, "127.0.0.1:") {
+       if host, _, splitErr := net.SplitHostPort(ctr.GatewayAddress); splitErr == nil && host != "" && host != "127.0.0.1" {
+               // If crunch-run provided a GatewayAddress like
+               // "ipaddr:port", that means "ipaddr" is one of the
+               // external interfaces where the gateway is
+               // listening. In that case, it's the most
+               // reliable/direct option, so we use it even if a
+               // tunnel might also be available.
+               targetHost = ctr.GatewayAddress
                rawconn, err = net.Dial("tcp", ctr.GatewayAddress)
-       } else if tunnel != nil {
+               if err != nil {
+                       return sshconn, httpserver.ErrorWithStatus(err, http.StatusServiceUnavailable)
+               }
+       } else if tunnel != nil && !(forceProxyForTest && !opts.NoForward) {
+               // If we can't connect directly, and the gateway has
+               // established a yamux tunnel with us, connect through
+               // the tunnel.
+               //
+               // ...except: forceProxyForTest means we are emulating
+               // a situation where the gateway has established a
+               // yamux tunnel with controller B, and the
+               // ContainerSSH request arrives at controller A. If
+               // opts.NoForward==false then we are acting as A, so
+               // we pretend not to have a tunnel, and fall through
+               // to the "tunurl" case below. If opts.NoForward==true
+               // then the client is A and we are acting as B, so we
+               // connect to our tunnel.
                rawconn, err = tunnel.Open()
+               if err != nil {
+                       return sshconn, httpserver.ErrorWithStatus(err, http.StatusServiceUnavailable)
+               }
        } else if ctr.GatewayAddress == "" {
-               err = errors.New("container is running but gateway is not available")
+               return sshconn, httpserver.ErrorWithStatus(errors.New("container is running but gateway is not available"), http.StatusServiceUnavailable)
+       } else if tunurl := strings.TrimPrefix(ctr.GatewayAddress, "tunnel "); tunurl != ctr.GatewayAddress &&
+               tunurl != "" &&
+               tunurl != myURL.String() &&
+               !opts.NoForward {
+               // If crunch-run provided a GatewayAddress like
+               // "tunnel https://10.0.0.10:1010/", that means the
+               // gateway has established a yamux tunnel with the
+               // controller process at the indicated InternalURL
+               // (which isn't us, otherwise we would have had
+               // "tunnel != nil" above). We need to proxy through to
+               // the other controller process in order to use the
+               // tunnel.
+               for u := range conn.cluster.Services.Controller.InternalURLs {
+                       if u.String() == tunurl {
+                               ctxlog.FromContext(ctx).Debugf("proxying ContainerSSH request to other controller at %s", u)
+                               u := url.URL(u)
+                               arpc := rpc.NewConn(conn.cluster.ClusterID, &u, conn.cluster.TLS.Insecure, rpc.PassthroughTokenProvider)
+                               opts.NoForward = true
+                               return arpc.ContainerSSH(ctx, opts)
+                       }
+               }
+               ctxlog.FromContext(ctx).Warnf("container gateway provided a tunnel endpoint %s that is not one of Services.Controller.InternalURLs", tunurl)
+               return sshconn, httpserver.ErrorWithStatus(errors.New("container gateway is running but tunnel endpoint is invalid"), http.StatusServiceUnavailable)
        } else {
-               err = errors.New("container gateway is running but tunnel is down")
-       }
-       if err != nil {
-               err = httpserver.ErrorWithStatus(err, http.StatusServiceUnavailable)
-               return
+               return sshconn, httpserver.ErrorWithStatus(errors.New("container gateway is running but tunnel is down"), http.StatusServiceUnavailable)
        }
 
        // crunch-run uses a self-signed / unverifiable TLS
@@ -131,27 +187,25 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
        })
        err = tlsconn.HandshakeContext(ctx)
        if err != nil {
-               err = httpserver.ErrorWithStatus(err, http.StatusBadGateway)
-               return
+               return sshconn, httpserver.ErrorWithStatus(err, http.StatusBadGateway)
        }
        if respondAuth == "" {
                tlsconn.Close()
-               err = httpserver.ErrorWithStatus(errors.New("BUG: no respondAuth"), http.StatusInternalServerError)
-               return
+               return sshconn, httpserver.ErrorWithStatus(errors.New("BUG: no respondAuth"), http.StatusInternalServerError)
        }
        bufr := bufio.NewReader(tlsconn)
        bufw := bufio.NewWriter(tlsconn)
 
        u := url.URL{
                Scheme: "http",
-               Host:   ctr.GatewayAddress,
+               Host:   targetHost,
                Path:   "/ssh",
        }
        postform := url.Values{
                "uuid":           {opts.UUID},
                "detach_keys":    {opts.DetachKeys},
                "login_username": {opts.LoginUsername},
-               "no_forward":     {"true"},
+               "no_forward":     {fmt.Sprintf("%v", opts.NoForward)},
        }
        postdata := postform.Encode()
        bufw.WriteString("POST " + u.String() + " HTTP/1.1\r\n")
@@ -163,22 +217,25 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
        bufw.WriteString("\r\n")
        bufw.WriteString(postdata)
        bufw.Flush()
-       resp, err := http.ReadResponse(bufr, &http.Request{Method: "GET"})
+       resp, err := http.ReadResponse(bufr, &http.Request{Method: "POST"})
        if err != nil {
-               err = httpserver.ErrorWithStatus(fmt.Errorf("error reading http response from gateway: %w", err), http.StatusBadGateway)
                tlsconn.Close()
-               return
+               return sshconn, httpserver.ErrorWithStatus(fmt.Errorf("error reading http response from gateway: %w", err), http.StatusBadGateway)
        }
-       if resp.Header.Get("X-Arvados-Authorization-Response") != respondAuth {
-               err = httpserver.ErrorWithStatus(errors.New("bad X-Arvados-Authorization-Response header"), http.StatusBadGateway)
+       defer resp.Body.Close()
+       if resp.StatusCode != http.StatusSwitchingProtocols {
+               body, _ := ioutil.ReadAll(io.LimitReader(resp.Body, 1000))
                tlsconn.Close()
-               return
+               return sshconn, httpserver.ErrorWithStatus(fmt.Errorf("unexpected status %s %q", resp.Status, body), http.StatusBadGateway)
        }
        if strings.ToLower(resp.Header.Get("Upgrade")) != "ssh" ||
                strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
-               err = httpserver.ErrorWithStatus(errors.New("bad upgrade"), http.StatusBadGateway)
                tlsconn.Close()
-               return
+               return sshconn, httpserver.ErrorWithStatus(errors.New("bad upgrade"), http.StatusBadGateway)
+       }
+       if resp.Header.Get("X-Arvados-Authorization-Response") != respondAuth {
+               tlsconn.Close()
+               return sshconn, httpserver.ErrorWithStatus(errors.New("bad X-Arvados-Authorization-Response header"), http.StatusBadGateway)
        }
 
        if !ctr.InteractiveSessionStarted {
@@ -190,15 +247,15 @@ func (conn *Conn) ContainerSSH(ctx context.Context, opts arvados.ContainerSSHOpt
                })
                if err != nil {
                        tlsconn.Close()
-                       return
+                       return sshconn, httpserver.ErrorWithStatus(err, http.StatusInternalServerError)
                }
        }
 
        sshconn.Conn = tlsconn
        sshconn.Bufrw = &bufio.ReadWriter{Reader: bufr, Writer: bufw}
        sshconn.Logger = ctxlog.FromContext(ctx)
-       sshconn.UpgradeHeader = "ssh"
-       return
+       sshconn.Header = http.Header{"Upgrade": {"ssh"}}
+       return sshconn, nil
 }
 
 // ContainerGatewayTunnel sets up a tunnel enabling us (controller) to
@@ -243,6 +300,11 @@ func (conn *Conn) ContainerGatewayTunnel(ctx context.Context, opts arvados.Conta
        resp.Conn = clientconn
        resp.Bufrw = &bufio.ReadWriter{Reader: bufio.NewReader(&bytes.Buffer{}), Writer: bufio.NewWriter(&bytes.Buffer{})}
        resp.Logger = ctxlog.FromContext(ctx)
-       resp.UpgradeHeader = "tunnel"
+       resp.Header = http.Header{"Upgrade": {"tunnel"}}
+       if u, ok := service.URLFromContext(ctx); ok {
+               resp.Header.Set("X-Arvados-Internal-Url", u.String())
+       } else if forceInternalURLForTest != nil {
+               resp.Header.Set("X-Arvados-Internal-Url", forceInternalURLForTest.String())
+       }
        return
 }
index b3b604e53419f82b6bb475d50b5000c3954f9422..2c882c7852a87b0ed23e2821d3e5fde6ec400439 100644 (file)
@@ -13,10 +13,13 @@ import (
        "io/ioutil"
        "net"
        "net/http/httptest"
+       "net/url"
+       "strings"
        "time"
 
        "git.arvados.org/arvados.git/lib/config"
        "git.arvados.org/arvados.git/lib/controller/router"
+       "git.arvados.org/arvados.git/lib/controller/rpc"
        "git.arvados.org/arvados.git/lib/crunchrun"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/arvadostest"
@@ -60,6 +63,11 @@ func (s *ContainerGatewaySuite) SetUpSuite(c *check.C) {
        rtr := router.New(s.localdb, router.Config{})
        srv := httptest.NewUnstartedServer(rtr)
        srv.StartTLS()
+       // the test setup doesn't use lib/service so
+       // service.URLFromContext() returns nothing -- instead, this
+       // is how we advertise our internal URL and enable
+       // proxy-to-other-controller mode,
+       forceInternalURLForTest = &arvados.URL{Scheme: "https", Host: srv.Listener.Addr().String()}
        ac := &arvados.Client{
                APIHost:   srv.Listener.Addr().String(),
                AuthToken: arvadostest.Dispatch1Token,
@@ -278,13 +286,46 @@ func (s *ContainerGatewaySuite) TestCreateTunnel(c *check.C) {
        c.Check(conn.Conn, check.NotNil)
 }
 
-func (s *ContainerGatewaySuite) TestConnectThroughTunnel(c *check.C) {
+func (s *ContainerGatewaySuite) TestConnectThroughTunnelWithProxyOK(c *check.C) {
+       forceProxyForTest = true
+       defer func() { forceProxyForTest = false }()
+       s.cluster.Services.Controller.InternalURLs[*forceInternalURLForTest] = arvados.ServiceInstance{}
+       defer delete(s.cluster.Services.Controller.InternalURLs, *forceInternalURLForTest)
+       s.testConnectThroughTunnel(c, "")
+}
+
+func (s *ContainerGatewaySuite) TestConnectThroughTunnelWithProxyError(c *check.C) {
+       forceProxyForTest = true
+       defer func() { forceProxyForTest = false }()
+       // forceInternalURLForTest shouldn't be used because it isn't
+       // listed in s.cluster.Services.Controller.InternalURLs
+       s.testConnectThroughTunnel(c, `.*tunnel endpoint is invalid.*`)
+}
+
+func (s *ContainerGatewaySuite) TestConnectThroughTunnelNoProxyOK(c *check.C) {
+       s.testConnectThroughTunnel(c, "")
+}
+
+func (s *ContainerGatewaySuite) testConnectThroughTunnel(c *check.C, expectErrorMatch string) {
+       rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
+       // Until the tunnel starts up, set gateway_address to a value
+       // that can't work. We want to ensure the only way we can
+       // reach the gateway is through the tunnel.
+       gwaddr := "127.0.0.1:0"
        tungw := &crunchrun.Gateway{
                ContainerUUID: s.ctrUUID,
                AuthSecret:    s.gw.AuthSecret,
                Log:           ctxlog.TestLogger(c),
                Target:        crunchrun.GatewayTargetStub{},
                ArvadosClient: s.gw.ArvadosClient,
+               UpdateTunnelURL: func(url string) {
+                       c.Logf("UpdateTunnelURL(%q)", url)
+                       gwaddr = "tunnel " + url
+                       s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
+                               UUID: s.ctrUUID,
+                               Attrs: map[string]interface{}{
+                                       "gateway_address": gwaddr}})
+               },
        }
        c.Assert(tungw.Start(), check.IsNil)
 
@@ -294,26 +335,30 @@ func (s *ContainerGatewaySuite) TestConnectThroughTunnel(c *check.C) {
        c.Assert(err, check.IsNil)
        c.Check(host, check.Equals, "127.0.0.1")
 
-       // Set the gateway_address field to 127.0.0.1:badport to
-       // ensure the ContainerSSH() handler connects through the
-       // tunnel, rather than the gateway server on 127.0.0.1 (which
-       // wouldn't work IRL where controller and gateway are on
-       // different hosts, but would allow the test to cheat).
-       rootctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{s.cluster.SystemRootToken}})
        _, err = s.localdb.ContainerUpdate(rootctx, arvados.UpdateOptions{
                UUID: s.ctrUUID,
                Attrs: map[string]interface{}{
                        "state":           arvados.ContainerStateRunning,
-                       "gateway_address": "127.0.0.1:0"}})
+                       "gateway_address": gwaddr}})
        c.Assert(err, check.IsNil)
 
-       ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
-       c.Check(err, check.IsNil)
-       c.Check(ctr.InteractiveSessionStarted, check.Equals, false)
-       c.Check(ctr.GatewayAddress, check.Equals, "127.0.0.1:0")
+       for deadline := time.Now().Add(5 * time.Second); time.Now().Before(deadline); time.Sleep(time.Second / 2) {
+               ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
+               c.Assert(err, check.IsNil)
+               c.Check(ctr.InteractiveSessionStarted, check.Equals, false)
+               c.Logf("ctr.GatewayAddress == %s", ctr.GatewayAddress)
+               if strings.HasPrefix(ctr.GatewayAddress, "tunnel ") {
+                       break
+               }
+       }
 
        c.Log("connecting to gateway through tunnel")
-       sshconn, err := s.localdb.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
+       arpc := rpc.NewConn("", &url.URL{Scheme: "https", Host: s.gw.ArvadosClient.APIHost}, true, rpc.PassthroughTokenProvider)
+       sshconn, err := arpc.ContainerSSH(s.ctx, arvados.ContainerSSHOptions{UUID: s.ctrUUID})
+       if expectErrorMatch != "" {
+               c.Check(err, check.ErrorMatches, expectErrorMatch)
+               return
+       }
        c.Assert(err, check.IsNil)
        c.Assert(sshconn.Conn, check.NotNil)
        defer sshconn.Conn.Close()
@@ -344,7 +389,7 @@ func (s *ContainerGatewaySuite) TestConnectThroughTunnel(c *check.C) {
        case <-time.After(time.Second):
                c.Fail()
        }
-       ctr, err = s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
+       ctr, err := s.localdb.ContainerGet(s.ctx, arvados.GetOptions{UUID: s.ctrUUID})
        c.Check(err, check.IsNil)
        c.Check(ctr.InteractiveSessionStarted, check.Equals, true)
 }
index 06141b1033e3f0034e003eab07da11c17153496e..31f2e1d7baf5098a377ffe9d1acd7b737958231d 100644 (file)
@@ -176,6 +176,7 @@ var boolParams = map[string]bool{
        "bypass_federation":       true,
        "recursive":               true,
        "exclude_home_project":    true,
+       "no_forward":              true,
 }
 
 func stringToBool(s string) bool {
index 8e25ca0d05b741e963addc987359fe8575404388..1475a5e01fb0546c54fddf4856759ed587719255 100644 (file)
@@ -23,6 +23,7 @@ import (
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/auth"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "git.arvados.org/arvados.git/sdk/go/httpserver"
 )
 
@@ -331,21 +332,17 @@ func (conn *Conn) ContainerUnlock(ctx context.Context, options arvados.GetOption
 // ContainerSSH returns a connection to the out-of-band SSH server for
 // a running container. If the returned error is nil, the caller is
 // responsible for closing sshconn.Conn.
-func (conn *Conn) ContainerSSH(ctx context.Context, options arvados.ContainerSSHOptions) (sshconn arvados.ContainerSSHConnection, err error) {
+func (conn *Conn) ContainerSSH(ctx context.Context, options arvados.ContainerSSHOptions) (sshconn arvados.ConnectionResponse, err error) {
        u, err := conn.baseURL.Parse("/" + strings.Replace(arvados.EndpointContainerSSH.Path, "{uuid}", options.UUID, -1))
        if err != nil {
                err = fmt.Errorf("url.Parse: %w", err)
                return
        }
-       u.RawQuery = url.Values{
+       return conn.socket(ctx, u, "ssh", url.Values{
                "detach_keys":    {options.DetachKeys},
                "login_username": {options.LoginUsername},
-       }.Encode()
-       resp, err := conn.socket(ctx, u, "ssh", nil)
-       if err != nil {
-               return
-       }
-       return arvados.ContainerSSHConnection(resp), nil
+               "no_forward":     {fmt.Sprintf("%v", options.NoForward)},
+       })
 }
 
 // ContainerGatewayTunnel returns a connection to a yamux session on
@@ -376,8 +373,7 @@ func (conn *Conn) socket(ctx context.Context, u *url.URL, upgradeHeader string,
        }
        netconn, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: insecure})
        if err != nil {
-               err = fmt.Errorf("tls.Dial: %w", err)
-               return
+               return connresp, fmt.Errorf("tls.Dial: %w", err)
        }
        defer func() {
                if err != nil {
@@ -389,10 +385,9 @@ func (conn *Conn) socket(ctx context.Context, u *url.URL, upgradeHeader string,
 
        tokens, err := conn.tokenProvider(ctx)
        if err != nil {
-               return
+               return connresp, err
        } else if len(tokens) < 1 {
-               err = httpserver.ErrorWithStatus(errors.New("unauthorized"), http.StatusUnauthorized)
-               return
+               return connresp, httpserver.ErrorWithStatus(errors.New("unauthorized"), http.StatusUnauthorized)
        }
        postdata := postform.Encode()
        bufw.WriteString("POST " + u.String() + " HTTP/1.1\r\n")
@@ -402,18 +397,16 @@ func (conn *Conn) socket(ctx context.Context, u *url.URL, upgradeHeader string,
        bufw.WriteString("Content-Type: application/x-www-form-urlencoded\r\n")
        fmt.Fprintf(bufw, "Content-Length: %d\r\n", len(postdata))
        bufw.WriteString("\r\n")
-       if len(postdata) > 0 {
-               bufw.WriteString(postdata)
-       }
+       bufw.WriteString(postdata)
        bufw.Flush()
-       resp, err := http.ReadResponse(bufr, &http.Request{Method: "GET"})
+       resp, err := http.ReadResponse(bufr, &http.Request{Method: "POST"})
        if err != nil {
-               err = fmt.Errorf("http.ReadResponse: %w", err)
-               return
+               return connresp, fmt.Errorf("http.ReadResponse: %w", err)
        }
+       defer resp.Body.Close()
        if resp.StatusCode != http.StatusSwitchingProtocols {
-               defer resp.Body.Close()
-               body, _ := ioutil.ReadAll(resp.Body)
+               ctxlog.FromContext(ctx).Infof("rpc.Conn.socket: server %s did not switch protocols, got status %s", u.String(), resp.Status)
+               body, _ := ioutil.ReadAll(io.LimitReader(resp.Body, 10000))
                var message string
                var errDoc httpserver.ErrorResponse
                if err := json.Unmarshal(body, &errDoc); err == nil {
@@ -421,17 +414,16 @@ func (conn *Conn) socket(ctx context.Context, u *url.URL, upgradeHeader string,
                } else {
                        message = fmt.Sprintf("%q", body)
                }
-               err = fmt.Errorf("server did not provide a tunnel: %s (HTTP %d)", message, resp.StatusCode)
-               return
+               return connresp, fmt.Errorf("server did not provide a tunnel: %s %s", resp.Status, message)
        }
        if strings.ToLower(resp.Header.Get("Upgrade")) != upgradeHeader ||
                strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
-               err = fmt.Errorf("bad response from server: Upgrade %q Connection %q", resp.Header.Get("Upgrade"), resp.Header.Get("Connection"))
-               return
+               return connresp, fmt.Errorf("bad response from server: Upgrade %q Connection %q", resp.Header.Get("Upgrade"), resp.Header.Get("Connection"))
        }
        connresp.Conn = netconn
        connresp.Bufrw = &bufio.ReadWriter{Reader: bufr, Writer: bufw}
-       return
+       connresp.Header = resp.Header
+       return connresp, nil
 }
 
 func (conn *Conn) ContainerRequestCreate(ctx context.Context, options arvados.CreateOptions) (arvados.ContainerRequest, error) {
index ba52f8ab43cd6f49a107777b0d28f9c14fec92ab..02df06cf2159fe6c16802485c8215c37468dd0e1 100644 (file)
@@ -73,6 +73,11 @@ type Gateway struct {
        // address is unknown or not routable from controller.
        ArvadosClient *arvados.Client
 
+       // When a tunnel is connected or reconnected, this func (if
+       // not nil) will be called with the InternalURL of the
+       // controller process at the other end of the tunnel.
+       UpdateTunnelURL func(url string)
+
        sshConfig   ssh.ServerConfig
        requestAuth string
        respondAuth string
@@ -209,6 +214,9 @@ func (gw *Gateway) runTunnel(addr string) error {
        if err != nil {
                return fmt.Errorf("error setting up mux client end: %s", err)
        }
+       if url := tun.Header.Get("X-Arvados-Internal-Url"); url != "" && gw.UpdateTunnelURL != nil {
+               gw.UpdateTunnelURL(url)
+       }
        for {
                muxconn, err := mux.Accept()
                if err != nil {
index c2ed37e75a43a72d6f71d0b0083e65720ad9ee0b..df3abe630900fdc32cdb383239569855879e1c83 100644 (file)
@@ -1917,6 +1917,20 @@ func (command) RunCommand(prog string, args []string, stdin io.Reader, stdout, s
                        Target:        cr.executor,
                        Log:           cr.CrunchLog,
                        ArvadosClient: cr.dispatcherClient,
+                       UpdateTunnelURL: func(url string) {
+                               if gwListen != "" {
+                                       // prefer connecting directly
+                                       return
+                               }
+                               // direct connection won't work, so we
+                               // use the gateway_address field to
+                               // indicate the internalURL of the
+                               // controller process that has the
+                               // current tunnel connection.
+                               cr.gateway.Address = "tunnel " + url
+                               cr.DispatcherArvClient.Update("containers", containerUUID,
+                                       arvadosclient.Dict{"container": arvadosclient.Dict{"gateway_address": cr.gateway.Address}}, nil)
+                       },
                }
                err = cr.gateway.Start()
                if err != nil {
index 8a41cb851c059b8ae498c647093a100910bff0a1..3797a17f50d504ae2894ac4c6a68f598b4e37564 100644 (file)
@@ -10,6 +10,7 @@ import (
        "encoding/json"
        "io"
        "net"
+       "net/http"
 
        "github.com/sirupsen/logrus"
 )
@@ -100,13 +101,11 @@ type ContainerSSHOptions struct {
        NoForward     bool   `json:"no_forward"`
 }
 
-type ContainerSSHConnection ConnectionResponse
-
 type ConnectionResponse struct {
-       Conn          net.Conn           `json:"-"`
-       Bufrw         *bufio.ReadWriter  `json:"-"`
-       Logger        logrus.FieldLogger `json:"-"`
-       UpgradeHeader string             `json:"-"`
+       Conn   net.Conn           `json:"-"`
+       Bufrw  *bufio.ReadWriter  `json:"-"`
+       Logger logrus.FieldLogger `json:"-"`
+       Header http.Header        `json:"-"`
 }
 
 type ContainerGatewayTunnelOptions struct {
@@ -264,7 +263,7 @@ type API interface {
        ContainerDelete(ctx context.Context, options DeleteOptions) (Container, error)
        ContainerLock(ctx context.Context, options GetOptions) (Container, error)
        ContainerUnlock(ctx context.Context, options GetOptions) (Container, error)
-       ContainerSSH(ctx context.Context, options ContainerSSHOptions) (ContainerSSHConnection, error)
+       ContainerSSH(ctx context.Context, options ContainerSSHOptions) (ConnectionResponse, error)
        ContainerGatewayTunnel(ctx context.Context, options ContainerGatewayTunnelOptions) (ConnectionResponse, error)
        ContainerRequestCreate(ctx context.Context, options CreateOptions) (ContainerRequest, error)
        ContainerRequestUpdate(ctx context.Context, options UpdateOptions) (ContainerRequest, error)
index d1d512856ac9744f091c2152985a7ebc5fd81651..ce33fb3105a218a537fe8b0c28cf122041b962a5 100644 (file)
@@ -21,7 +21,9 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque
                return
        }
        w.Header().Set("Connection", "upgrade")
-       w.Header().Set("Upgrade", cresp.UpgradeHeader)
+       for k, v := range cresp.Header {
+               w.Header()[k] = v
+       }
        w.WriteHeader(http.StatusSwitchingProtocols)
        conn, bufrw, err := hj.Hijack()
        if err != nil {
@@ -32,7 +34,7 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque
 
        var bytesIn, bytesOut int64
        var wg sync.WaitGroup
-       ctx, cancel := context.WithCancel(context.Background())
+       ctx, cancel := context.WithCancel(req.Context())
        wg.Add(1)
        go func() {
                defer wg.Done()
@@ -44,7 +46,7 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque
                        bytesOut += n
                }
                if err != nil {
-                       ctxlog.FromContext(req.Context()).WithError(err).Error("error copying downstream")
+                       ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream")
                }
        }()
        wg.Add(1)
@@ -58,17 +60,14 @@ func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Reque
                        bytesIn += n
                }
                if err != nil {
-                       ctxlog.FromContext(req.Context()).WithError(err).Error("error copying upstream")
+                       ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream")
                }
        }()
-       <-ctx.Done()
+       wg.Wait()
        if cresp.Logger != nil {
-               go func() {
-                       wg.Wait()
-                       cresp.Logger.WithFields(logrus.Fields{
-                               "bytesIn":  bytesIn,
-                               "bytesOut": bytesOut,
-                       }).Info("closed connection")
-               }()
+               cresp.Logger.WithFields(logrus.Fields{
+                       "bytesIn":  bytesIn,
+                       "bytesOut": bytesOut,
+               }).Info("closed connection")
        }
 }
index d784abf6719c26c855da0d24017444e57d962328..d6da579d6b9ce1323dfbeb9b50f993232822379a 100644 (file)
@@ -109,9 +109,9 @@ func (as *APIStub) ContainerUnlock(ctx context.Context, options arvados.GetOptio
        as.appendCall(ctx, as.ContainerUnlock, options)
        return arvados.Container{}, as.Error
 }
-func (as *APIStub) ContainerSSH(ctx context.Context, options arvados.ContainerSSHOptions) (arvados.ContainerSSHConnection, error) {
+func (as *APIStub) ContainerSSH(ctx context.Context, options arvados.ContainerSSHOptions) (arvados.ConnectionResponse, error) {
        as.appendCall(ctx, as.ContainerSSH, options)
-       return arvados.ContainerSSHConnection{}, as.Error
+       return arvados.ConnectionResponse{}, as.Error
 }
 func (as *APIStub) ContainerGatewayTunnel(ctx context.Context, options arvados.ContainerGatewayTunnelOptions) (arvados.ConnectionResponse, error) {
        as.appendCall(ctx, as.ContainerGatewayTunnel, options)