17170: Dry up close-connection-on-error.
authorTom Clegg <tom@curii.com>
Wed, 13 Jan 2021 22:06:11 +0000 (17:06 -0500)
committerTom Clegg <tom@curii.com>
Wed, 13 Jan 2021 22:06:11 +0000 (17:06 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/controller/rpc/conn.go

index 2accfd8f2f88c4245093e6c1f0b99bdb11a3de00..7dd89452bec7fbeff1ca0c70cdbd3b8e24eaa914 100644 (file)
@@ -302,12 +302,16 @@ func (conn *Conn) ContainerSSH(ctx context.Context, options arvados.ContainerSSH
        if err != nil {
                return
        }
+       defer func() {
+               if err != nil {
+                       netconn.Close()
+               }
+       }()
        bufr := bufio.NewReader(netconn)
        bufw := bufio.NewWriter(netconn)
 
        u, err := conn.baseURL.Parse("/" + strings.Replace(arvados.EndpointContainerSSH.Path, "{uuid}", options.UUID, -1))
        if err != nil {
-               netconn.Close()
                return
        }
        u.RawQuery = url.Values{
@@ -316,11 +320,9 @@ func (conn *Conn) ContainerSSH(ctx context.Context, options arvados.ContainerSSH
        }.Encode()
        tokens, err := conn.tokenProvider(ctx)
        if err != nil {
-               netconn.Close()
                return
        } else if len(tokens) < 1 {
                err = httpserver.ErrorWithStatus(errors.New("unauthorized"), http.StatusUnauthorized)
-               netconn.Close()
                return
        }
        bufw.WriteString("GET " + u.String() + " HTTP/1.1\r\n")
@@ -331,20 +333,17 @@ func (conn *Conn) ContainerSSH(ctx context.Context, options arvados.ContainerSSH
        bufw.Flush()
        resp, err := http.ReadResponse(bufr, &http.Request{Method: "GET"})
        if err != nil {
-               netconn.Close()
                return
        }
        if resp.StatusCode != http.StatusSwitchingProtocols {
                defer resp.Body.Close()
                body, _ := ioutil.ReadAll(resp.Body)
                err = fmt.Errorf("server did not provide a tunnel: %d %q", resp.StatusCode, body)
-               netconn.Close()
                return
        }
        if strings.ToLower(resp.Header.Get("Upgrade")) != "ssh" ||
                strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
                err = fmt.Errorf("bad response from server: Upgrade %q Connection %q", resp.Header.Get("Upgrade"), resp.Header.Get("Connection"))
-               netconn.Close()
                return
        }
        sshconn.Conn = netconn