Merge branch '5824-keep-web' into 5824-keep-web-workbench
[arvados.git] / services / keepproxy / keepproxy.go
index 1a1189658d7ba1473aa33036cce60276932cb9b8..bad0d22bf1a81868799d8a437860f941d6fbe770 100644 (file)
@@ -22,8 +22,8 @@ import (
 )
 
 // Default TCP address on which to listen for requests.
-// Initialized by the -listen flag.
-const DEFAULT_ADDR = ":25107"
+// Override with -listen.
+const DefaultAddr = ":25107"
 
 var listener net.Listener
 
@@ -42,7 +42,7 @@ func main() {
        flagset.StringVar(
                &listen,
                "listen",
-               DEFAULT_ADDR,
+               DefaultAddr,
                "Interface on which to listen for requests, in the format "+
                        "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
                        "to listen on all network interfaces.")
@@ -79,16 +79,6 @@ func main() {
 
        flagset.Parse(os.Args[1:])
 
-       arv, err := arvadosclient.MakeArvadosClient()
-       if err != nil {
-               log.Fatalf("Error setting up arvados client %s", err.Error())
-       }
-
-       kc, err := keepclient.MakeKeepClient(&arv)
-       if err != nil {
-               log.Fatalf("Error setting up keep client %s", err.Error())
-       }
-
        if pidfile != "" {
                f, err := os.Create(pidfile)
                if err != nil {
@@ -99,16 +89,23 @@ func main() {
                defer os.Remove(pidfile)
        }
 
+       arv, err := arvadosclient.MakeArvadosClient()
+       if err != nil {
+               log.Fatalf("setting up arvados client: %v", err)
+       }
+       kc, err := keepclient.MakeKeepClient(&arv)
+       if err != nil {
+               log.Fatalf("setting up keep client: %v", err)
+       }
        kc.Want_replicas = default_replicas
-
        kc.Client.Timeout = time.Duration(timeout) * time.Second
+       go RefreshServicesList(kc, 5*time.Minute, 3*time.Second)
 
        listener, err = net.Listen("tcp", listen)
        if err != nil {
                log.Fatalf("Could not listen on %v", listen)
        }
-
-       go RefreshServicesList(kc)
+       log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
 
        // Shut down the server gracefully (by closing the listener)
        // if SIGTERM is received.
@@ -121,9 +118,7 @@ func main() {
        signal.Notify(term, syscall.SIGTERM)
        signal.Notify(term, syscall.SIGINT)
 
-       log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
-
-       // Start listening for requests.
+       // Start serving requests.
        http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
 
        log.Println("shutting down")
@@ -135,27 +130,39 @@ type ApiTokenCache struct {
        expireTime int64
 }
 
-// Refresh the keep service list every five minutes.
-func RefreshServicesList(kc *keepclient.KeepClient) {
+// Refresh the keep service list on SIGHUP; when the given interval
+// has elapsed since the last refresh; and (if the last refresh
+// failed) the given errInterval has elapsed.
+func RefreshServicesList(kc *keepclient.KeepClient, interval, errInterval time.Duration) {
        var previousRoots = []map[string]string{}
-       var delay time.Duration = 0
+
+       timer := time.NewTimer(interval)
+       gotHUP := make(chan os.Signal, 1)
+       signal.Notify(gotHUP, syscall.SIGHUP)
+
        for {
-               time.Sleep(delay * time.Second)
-               delay = 300
+               select {
+               case <-gotHUP:
+               case <-timer.C:
+               }
+               timer.Reset(interval)
+
                if err := kc.DiscoverKeepServers(); err != nil {
-                       log.Println("Error retrieving services list:", err)
-                       delay = 3
+                       log.Println("Error retrieving services list: %v (retrying in %v)", err, errInterval)
+                       timer.Reset(errInterval)
                        continue
                }
                newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
+
                if !reflect.DeepEqual(previousRoots, newRoots) {
                        log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
+                       previousRoots = newRoots
                }
+
                if len(newRoots[0]) == 0 {
-                       log.Print("WARNING: No local services. Retrying in 3 seconds.")
-                       delay = 3
+                       log.Printf("WARNING: No local services (retrying in %v)", errInterval)
+                       timer.Reset(errInterval)
                }
-               previousRoots = newRoots
        }
 }
 
@@ -191,12 +198,8 @@ func (this *ApiTokenCache) RecallToken(token string) bool {
 }
 
 func GetRemoteAddress(req *http.Request) string {
-       if realip := req.Header.Get("X-Real-IP"); realip != "" {
-               if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
-                       return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
-               } else {
-                       return realip
-               }
+       if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
+               return xff + "," + req.RemoteAddr
        }
        return req.RemoteAddr
 }