17344: Check that needed ports are usable before doing any init.
authorTom Clegg <tom@curii.com>
Wed, 13 Jul 2022 02:32:40 +0000 (22:32 -0400)
committerTom Clegg <tom@curii.com>
Thu, 14 Jul 2022 13:21:41 +0000 (09:21 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/install/init.go

index 3eeac5c54984c101587e4df93d72b2a3adf643ab..80fa35cc48c82e0e685fd3c5b166366b3afa8747 100644 (file)
@@ -13,6 +13,8 @@ import (
        "flag"
        "fmt"
        "io"
+       "net"
+       "net/http"
        "net/url"
        "os"
        "os/exec"
@@ -20,7 +22,9 @@ import (
        "regexp"
        "strconv"
        "strings"
+       "sync/atomic"
        "text/template"
+       "time"
 
        "git.arvados.org/arvados.git/lib/cmd"
        "git.arvados.org/arvados.git/lib/config"
@@ -111,6 +115,15 @@ func (initcmd *initCommand) RunCommand(prog string, args []string, stdin io.Read
                return 1
        }
 
+       err = initcmd.checkPort(ctx, "4440")
+       err = initcmd.checkPort(ctx, "443")
+       if initcmd.TLS == "auto" {
+               err = initcmd.checkPort(ctx, "80")
+               if err != nil {
+                       return 1
+               }
+       }
+
        // Do the "create extension" thing early. This way, if there's
        // no local postgresql server (a likely failure mode), we can
        // bail out without any side effects, and the user can start
@@ -417,3 +430,82 @@ func (initcmd *initCommand) createDB(ctx context.Context, dbconn arvados.Postgre
        }
        return nil
 }
+
+// Confirm that http://{initcmd.Domain}:{port} reaches a server that
+// we run on {port}.
+//
+// If port is "80", listening fails, and Nginx appears to be using the
+// debian-packaged default configuration that listens on port 80,
+// disable that Nginx config and try again.
+//
+// (Typically, the reason Nginx is installed is so that Arvados can
+// run an Nginx child process; the default Nginx service using config
+// from /etc/nginx is just an unfortunate side effect of installing
+// Nginx by way of the Debian package.)
+func (initcmd *initCommand) checkPort(ctx context.Context, port string) error {
+       err := initcmd.checkPortOnce(ctx, port)
+       if err == nil || port != "80" {
+               // success, or poking Nginx in the eye won't help
+               return err
+       }
+       d, err2 := os.Open("/etc/nginx/sites-enabled/.")
+       if err2 != nil {
+               return err
+       }
+       fis, err2 := d.Readdir(-1)
+       if err2 != nil || len(fis) != 1 {
+               return err
+       }
+       if target, err2 := os.Readlink("/etc/nginx/sites-enabled/default"); err2 != nil || target != "/etc/nginx/sites-available/default" {
+               return err
+       }
+       err2 = os.Remove("/etc/nginx/sites-enabled/default")
+       if err2 != nil {
+               return err
+       }
+       exec.CommandContext(ctx, "nginx", "-s", "reload").Run()
+       time.Sleep(time.Second)
+       return initcmd.checkPortOnce(ctx, port)
+}
+
+// Start an http server on 0.0.0.0:{port} and confirm that
+// http://{initcmd.Domain}:{port} reaches that server.
+func (initcmd *initCommand) checkPortOnce(ctx context.Context, port string) error {
+       b := make([]byte, 128)
+       _, err := rand.Read(b)
+       if err != nil {
+               return err
+       }
+       token := fmt.Sprintf("%x", b)
+
+       srv := http.Server{
+               Addr: net.JoinHostPort("", port),
+               Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+                       fmt.Fprint(w, token)
+               })}
+       var errServe atomic.Value
+       go func() {
+               errServe.Store(srv.ListenAndServe())
+       }()
+       defer srv.Close()
+       url := "http://" + net.JoinHostPort(initcmd.Domain, port) + "/probe"
+       req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+       if err != nil {
+               return err
+       }
+       resp, err := http.DefaultClient.Do(req)
+       if errServe, _ := errServe.Load().(error); errServe != nil {
+               // If server already exited, return that error
+               // (probably "can't listen"), not the request error.
+               return errServe
+       }
+       if err != nil {
+               return err
+       }
+       buf := make([]byte, len(token))
+       n, err := io.ReadFull(resp.Body, buf)
+       if string(buf[:n]) != token {
+               return fmt.Errorf("listened on port %s but %s connected to something else, returned %q, err %v", port, url, buf[:n], err)
+       }
+       return nil
+}