X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/436f5c768dbc97135490b6477efd1ff0482a9dda..9d33e0c4f98da668b23b85c54d20d22fe4b0f342:/services/keepproxy/keepproxy.go diff --git a/services/keepproxy/keepproxy.go b/services/keepproxy/keepproxy.go index 76a8a1551f..65f7a42cd9 100644 --- a/services/keepproxy/keepproxy.go +++ b/services/keepproxy/keepproxy.go @@ -12,6 +12,7 @@ import ( "os" "os/signal" "regexp" + "strings" "sync" "syscall" "time" @@ -43,7 +44,10 @@ func DefaultConfig() *Config { } } -var listener net.Listener +var ( + listener net.Listener + router http.Handler +) func main() { cfg := DefaultConfig() @@ -129,7 +133,7 @@ func main() { if cfg.DefaultReplicas > 0 { kc.Want_replicas = cfg.DefaultReplicas } - kc.Client.Timeout = time.Duration(cfg.Timeout) + kc.Client.(*http.Client).Timeout = time.Duration(cfg.Timeout) go kc.RefreshServices(5*time.Minute, 3*time.Second) listener, err = net.Listen("tcp", cfg.Listen) @@ -153,7 +157,8 @@ func main() { signal.Notify(term, syscall.SIGINT) // Start serving requests. - http.Serve(listener, MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc)) + router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc) + http.Serve(listener, router) log.Println("shutting down") } @@ -232,61 +237,57 @@ func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, r return true, tok } -type GetBlockHandler struct { +type proxyHandler struct { + http.Handler *keepclient.KeepClient *ApiTokenCache } -type PutBlockHandler struct { - *keepclient.KeepClient - *ApiTokenCache -} - -type IndexHandler struct { - *keepclient.KeepClient - *ApiTokenCache -} - -type InvalidPathHandler struct{} - -type OptionsHandler struct{} - -// MakeRESTRouter -// Returns a mux.Router that passes GET and PUT requests to the -// appropriate handlers. -// -func MakeRESTRouter( - enable_get bool, - enable_put bool, - kc *keepclient.KeepClient) *mux.Router { - - t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300} - +// MakeRESTRouter returns an http.Handler that passes GET and PUT +// requests to the appropriate handlers. +func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient) http.Handler { rest := mux.NewRouter() + h := &proxyHandler{ + Handler: rest, + KeepClient: kc, + ApiTokenCache: &ApiTokenCache{ + tokens: make(map[string]int64), + expireTime: 300, + }, + } if enable_get { - rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, - GetBlockHandler{kc, t}).Methods("GET", "HEAD") - rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD") + rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD") + rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD") // List all blocks - rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET") + rest.HandleFunc(`/index`, h.Index).Methods("GET") // List blocks whose hash has the given prefix - rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET") + rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET") } if enable_put { - rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT") - rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT") - rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST") - rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS") - rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS") + rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT") + rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT") + rest.HandleFunc(`/`, h.Put).Methods("POST") + rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS") + rest.HandleFunc(`/`, h.Options).Methods("OPTIONS") } rest.NotFoundHandler = InvalidPathHandler{} + return h +} - return rest +var errLoopDetected = errors.New("loop detected") + +func (*proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error { + if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 { + log.Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via) + http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError) + return errLoopDetected + } + return nil } func SetCorsHeaders(resp http.ResponseWriter) { @@ -296,12 +297,14 @@ func SetCorsHeaders(resp http.ResponseWriter) { resp.Header().Set("Access-Control-Max-Age", "86486400") } -func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +type InvalidPathHandler struct{} + +func (InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path) http.Error(resp, "Bad request", http.StatusBadRequest) } -func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) { log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path) SetCorsHeaders(resp) } @@ -312,8 +315,12 @@ var MethodNotSupported = errors.New("Method not supported") var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)") -func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) { + if err := h.checkLoop(resp, req); err != nil { + return + } SetCorsHeaders(resp) + resp.Header().Set("Via", req.Proto+" "+viaAlias) locator := mux.Vars(req)["locator"] var err error @@ -328,11 +335,12 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques } }() - kc := *this.KeepClient + kc := *h.KeepClient + kc.Client = &proxyClient{client: kc.Client, proto: req.Proto} var pass bool var tok string - if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass { + if pass, tok = CheckAuthorizationHeader(&kc, h.ApiTokenCache, req); !pass { status, err = http.StatusForbidden, BadAuthorizationHeader return } @@ -392,10 +400,16 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired)) var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header") -func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) { + if err := h.checkLoop(resp, req); err != nil { + return + } SetCorsHeaders(resp) + resp.Header().Set("Via", "HTTP/1.1 "+viaAlias) + + kc := *h.KeepClient + kc.Client = &proxyClient{client: kc.Client, proto: req.Proto} - kc := *this.KeepClient var err error var expectLength int64 var status = http.StatusInternalServerError @@ -432,7 +446,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques var pass bool var tok string - if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass { + if pass, tok = CheckAuthorizationHeader(&kc, h.ApiTokenCache, req); !pass { err = BadAuthorizationHeader status = http.StatusForbidden return @@ -468,7 +482,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques // Tell the client how many successful PUTs we accomplished resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas)) - switch err { + switch err.(type) { case nil: status = http.StatusOK _, err = io.WriteString(resp, locatorOut) @@ -500,7 +514,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques // Expects "complete" response (terminating with blank new line) // Aborts on any errors // Concatenates responses from all those keep servers and returns -func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) { SetCorsHeaders(resp) prefix := mux.Vars(req)["prefix"] @@ -513,9 +527,9 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques } }() - kc := *handler.KeepClient + kc := *h.KeepClient - ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req) + ok, token := CheckAuthorizationHeader(&kc, h.ApiTokenCache, req) if !ok { status, err = http.StatusForbidden, BadAuthorizationHeader return