5824: Merge branch 'master' into 5824-keep-web-workbench
[arvados.git] / services / keepproxy / keepproxy.go
index d3dbeaf89e420d4546d0fedab643d9b15e9849c9..2efc9cef5b46fb4c87385b8e799eb355897694f1 100644 (file)
@@ -22,8 +22,8 @@ import (
 )
 
 // Default TCP address on which to listen for requests.
-// Initialized by the -listen flag.
-const DEFAULT_ADDR = ":25107"
+// Override with -listen.
+const DefaultAddr = ":25107"
 
 var listener net.Listener
 
@@ -37,12 +37,12 @@ func main() {
                pidfile          string
        )
 
-       flagset := flag.NewFlagSet("default", flag.ExitOnError)
+       flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
 
        flagset.StringVar(
                &listen,
                "listen",
-               DEFAULT_ADDR,
+               DefaultAddr,
                "Interface on which to listen for requests, in the format "+
                        "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
                        "to listen on all network interfaces.")
@@ -84,6 +84,9 @@ func main() {
                log.Fatalf("Error setting up arvados client %s", err.Error())
        }
 
+       if os.Getenv("ARVADOS_DEBUG") != "" {
+               keepclient.DebugPrintf = log.Printf
+       }
        kc, err := keepclient.MakeKeepClient(&arv)
        if err != nil {
                log.Fatalf("Error setting up keep client %s", err.Error())
@@ -100,15 +103,14 @@ func main() {
        }
 
        kc.Want_replicas = default_replicas
-
        kc.Client.Timeout = time.Duration(timeout) * time.Second
+       go RefreshServicesList(kc, 5*time.Minute, 3*time.Second)
 
        listener, err = net.Listen("tcp", listen)
        if err != nil {
                log.Fatalf("Could not listen on %v", listen)
        }
-
-       go RefreshServicesList(kc)
+       log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
 
        // Shut down the server gracefully (by closing the listener)
        // if SIGTERM is received.
@@ -121,9 +123,7 @@ func main() {
        signal.Notify(term, syscall.SIGTERM)
        signal.Notify(term, syscall.SIGINT)
 
-       log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
-
-       // Start listening for requests.
+       // Start serving requests.
        http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
 
        log.Println("shutting down")
@@ -135,27 +135,39 @@ type ApiTokenCache struct {
        expireTime int64
 }
 
-// Refresh the keep service list every five minutes.
-func RefreshServicesList(kc *keepclient.KeepClient) {
+// Refresh the keep service list on SIGHUP; when the given interval
+// has elapsed since the last refresh; and (if the last refresh
+// failed) the given errInterval has elapsed.
+func RefreshServicesList(kc *keepclient.KeepClient, interval, errInterval time.Duration) {
        var previousRoots = []map[string]string{}
-       var delay time.Duration = 0
+
+       timer := time.NewTimer(interval)
+       gotHUP := make(chan os.Signal, 1)
+       signal.Notify(gotHUP, syscall.SIGHUP)
+
        for {
-               time.Sleep(delay * time.Second)
-               delay = 300
+               select {
+               case <-gotHUP:
+               case <-timer.C:
+               }
+               timer.Reset(interval)
+
                if err := kc.DiscoverKeepServers(); err != nil {
-                       log.Println("Error retrieving services list:", err)
-                       delay = 3
+                       log.Println("Error retrieving services list: %v (retrying in %v)", err, errInterval)
+                       timer.Reset(errInterval)
                        continue
                }
                newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
+
                if !reflect.DeepEqual(previousRoots, newRoots) {
                        log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
+                       previousRoots = newRoots
                }
+
                if len(newRoots[0]) == 0 {
-                       log.Print("WARNING: No local services. Retrying in 3 seconds.")
-                       delay = 3
+                       log.Printf("WARNING: No local services (retrying in %v)", errInterval)
+                       timer.Reset(errInterval)
                }
-               previousRoots = newRoots
        }
 }
 
@@ -191,17 +203,13 @@ func (this *ApiTokenCache) RecallToken(token string) bool {
 }
 
 func GetRemoteAddress(req *http.Request) string {
-       if realip := req.Header.Get("X-Real-IP"); realip != "" {
-               if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
-                       return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
-               } else {
-                       return realip
-               }
+       if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
+               return xff + "," + req.RemoteAddr
        }
        return req.RemoteAddr
 }
 
-func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
+func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
        var auth string
        if auth = req.Header.Get("Authorization"); auth == "" {
                return false, ""
@@ -331,7 +339,7 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
@@ -438,7 +446,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
                err = BadAuthorizationHeader
                status = http.StatusForbidden
                return
@@ -521,7 +529,7 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        kc := *handler.KeepClient
 
-       ok, token := CheckAuthorizationHeader(kc, handler.ApiTokenCache, req)
+       ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
        if !ok {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return