11537: Add Via header to proxied keepstore requests.
[arvados.git] / services / keepproxy / keepproxy.go
index 76a8a1551fb867f5258221f1f22692656039d97a..7a673aeba97b9780d3dcdbac0250d47c4023f905 100644 (file)
@@ -12,6 +12,7 @@ import (
        "os"
        "os/signal"
        "regexp"
+       "strings"
        "sync"
        "syscall"
        "time"
@@ -43,7 +44,10 @@ func DefaultConfig() *Config {
        }
 }
 
-var listener net.Listener
+var (
+       listener net.Listener
+       router   http.Handler
+)
 
 func main() {
        cfg := DefaultConfig()
@@ -129,7 +133,7 @@ func main() {
        if cfg.DefaultReplicas > 0 {
                kc.Want_replicas = cfg.DefaultReplicas
        }
-       kc.Client.Timeout = time.Duration(cfg.Timeout)
+       kc.Client.(*http.Client).Timeout = time.Duration(cfg.Timeout)
        go kc.RefreshServices(5*time.Minute, 3*time.Second)
 
        listener, err = net.Listen("tcp", cfg.Listen)
@@ -153,7 +157,8 @@ func main() {
        signal.Notify(term, syscall.SIGINT)
 
        // Start serving requests.
-       http.Serve(listener, MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc))
+       router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc)
+       http.Serve(listener, router)
 
        log.Println("shutting down")
 }
@@ -232,61 +237,57 @@ func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, r
        return true, tok
 }
 
-type GetBlockHandler struct {
+type proxyHandler struct {
+       http.Handler
        *keepclient.KeepClient
        *ApiTokenCache
 }
 
-type PutBlockHandler struct {
-       *keepclient.KeepClient
-       *ApiTokenCache
-}
-
-type IndexHandler struct {
-       *keepclient.KeepClient
-       *ApiTokenCache
-}
-
-type InvalidPathHandler struct{}
-
-type OptionsHandler struct{}
-
-// MakeRESTRouter
-//     Returns a mux.Router that passes GET and PUT requests to the
-//     appropriate handlers.
-//
-func MakeRESTRouter(
-       enable_get bool,
-       enable_put bool,
-       kc *keepclient.KeepClient) *mux.Router {
-
-       t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
-
+// MakeRESTRouter returns an http.Handler that passes GET and PUT
+// requests to the appropriate handlers.
+func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient) http.Handler {
        rest := mux.NewRouter()
+       h := &proxyHandler{
+               Handler:    rest,
+               KeepClient: kc,
+               ApiTokenCache: &ApiTokenCache{
+                       tokens:     make(map[string]int64),
+                       expireTime: 300,
+               },
+       }
 
        if enable_get {
-               rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
-                       GetBlockHandler{kc, t}).Methods("GET", "HEAD")
-               rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
+               rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD")
+               rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD")
 
                // List all blocks
-               rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
+               rest.HandleFunc(`/index`, h.Index).Methods("GET")
 
                // List blocks whose hash has the given prefix
-               rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
+               rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET")
        }
 
        if enable_put {
-               rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
-               rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
-               rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
-               rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
-               rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
+               rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT")
+               rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT")
+               rest.HandleFunc(`/`, h.Put).Methods("POST")
+               rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS")
+               rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
        }
 
        rest.NotFoundHandler = InvalidPathHandler{}
+       return h
+}
 
-       return rest
+var errLoopDetected = errors.New("loop detected")
+
+func (*proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error {
+       if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 {
+               log.Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via)
+               http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError)
+               return errLoopDetected
+       }
+       return nil
 }
 
 func SetCorsHeaders(resp http.ResponseWriter) {
@@ -296,12 +297,14 @@ func SetCorsHeaders(resp http.ResponseWriter) {
        resp.Header().Set("Access-Control-Max-Age", "86486400")
 }
 
-func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+type InvalidPathHandler struct{}
+
+func (InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
        log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
        http.Error(resp, "Bad request", http.StatusBadRequest)
 }
 
-func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) {
        log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
        SetCorsHeaders(resp)
 }
@@ -312,7 +315,10 @@ var MethodNotSupported = errors.New("Method not supported")
 
 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
 
-func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
+       if err := h.checkLoop(resp, req); err != nil {
+               return
+       }
        SetCorsHeaders(resp)
 
        locator := mux.Vars(req)["locator"]
@@ -328,11 +334,12 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                }
        }()
 
-       kc := *this.KeepClient
+       kc := *h.KeepClient
+       kc.Client = &proxyClient{client: kc.Client, proto: req.Proto}
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, h.ApiTokenCache, req); !pass {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
@@ -392,10 +399,15 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
 
-func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
+       if err := h.checkLoop(resp, req); err != nil {
+               return
+       }
        SetCorsHeaders(resp)
 
-       kc := *this.KeepClient
+       kc := *h.KeepClient
+       kc.Client = &proxyClient{client: kc.Client, proto: req.Proto}
+
        var err error
        var expectLength int64
        var status = http.StatusInternalServerError
@@ -432,7 +444,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, h.ApiTokenCache, req); !pass {
                err = BadAuthorizationHeader
                status = http.StatusForbidden
                return
@@ -468,7 +480,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        // Tell the client how many successful PUTs we accomplished
        resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
 
-       switch err {
+       switch err.(type) {
        case nil:
                status = http.StatusOK
                _, err = io.WriteString(resp, locatorOut)
@@ -500,7 +512,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 //   Expects "complete" response (terminating with blank new line)
 //   Aborts on any errors
 // Concatenates responses from all those keep servers and returns
-func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
        SetCorsHeaders(resp)
 
        prefix := mux.Vars(req)["prefix"]
@@ -513,9 +525,9 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                }
        }()
 
-       kc := *handler.KeepClient
+       kc := *h.KeepClient
 
-       ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
+       ok, token := CheckAuthorizationHeader(&kc, h.ApiTokenCache, req)
        if !ok {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return