X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/38fae0458644b89322ddeac125971800b9e452e5..ca06cfbda0e84d469f7810a280cfa4dfa8997260:/services/keepproxy/keepproxy.go diff --git a/services/keepproxy/keepproxy.go b/services/keepproxy/keepproxy.go index 24df531fa4..ec074cff3f 100644 --- a/services/keepproxy/keepproxy.go +++ b/services/keepproxy/keepproxy.go @@ -1,7 +1,10 @@ +// Copyright (C) The Arvados Authors. All rights reserved. +// +// SPDX-License-Identifier: AGPL-3.0 + package main import ( - "encoding/json" "errors" "flag" "fmt" @@ -13,6 +16,7 @@ import ( "os" "os/signal" "regexp" + "strings" "sync" "syscall" "time" @@ -20,8 +24,11 @@ import ( "git.curoverse.com/arvados.git/sdk/go/arvados" "git.curoverse.com/arvados.git/sdk/go/arvadosclient" "git.curoverse.com/arvados.git/sdk/go/config" + "git.curoverse.com/arvados.git/sdk/go/health" "git.curoverse.com/arvados.git/sdk/go/keepclient" + arvadosVersion "git.curoverse.com/arvados.git/sdk/go/version" "github.com/coreos/go-systemd/daemon" + "github.com/ghodss/yaml" "github.com/gorilla/mux" ) @@ -34,6 +41,7 @@ type Config struct { Timeout arvados.Duration PIDFile string Debug bool + ManagementToken string } func DefaultConfig() *Config { @@ -43,7 +51,10 @@ func DefaultConfig() *Config { } } -var listener net.Listener +var ( + listener net.Listener + router http.Handler +) func main() { cfg := DefaultConfig() @@ -58,12 +69,21 @@ func main() { flagset.IntVar(&cfg.DefaultReplicas, "default-replicas", cfg.DefaultReplicas, "Default number of replicas to write if not specified by the client. If 0, use site default."+deprecated) flagset.StringVar(&cfg.PIDFile, "pid", cfg.PIDFile, "Path to write pid file."+deprecated) timeoutSeconds := flagset.Int("timeout", int(time.Duration(cfg.Timeout)/time.Second), "Timeout (in seconds) on requests to internal Keep services."+deprecated) + flagset.StringVar(&cfg.ManagementToken, "management-token", cfg.ManagementToken, "Authorization token to be included in all health check requests.") var cfgPath string const defaultCfgPath = "/etc/arvados/keepproxy/keepproxy.yml" flagset.StringVar(&cfgPath, "config", defaultCfgPath, "Configuration file `path`") + dumpConfig := flagset.Bool("dump-config", false, "write current configuration to stdout and exit") + getVersion := flagset.Bool("version", false, "Print version information and exit.") flagset.Parse(os.Args[1:]) + // Print version information if requested + if *getVersion { + fmt.Printf("Version: %s\n", arvadosVersion.GetVersion()) + os.Exit(0) + } + err := config.LoadFile(cfg, cfgPath) if err != nil { h := os.Getenv("ARVADOS_API_HOST") @@ -77,12 +97,18 @@ func main() { if regexp.MustCompile("^(?i:1|yes|true)$").MatchString(os.Getenv("ARVADOS_API_HOST_INSECURE")) { cfg.Client.Insecure = true } - if j, err := json.MarshalIndent(cfg, "", " "); err == nil { - log.Print("Current configuration:\n", string(j)) + if y, err := yaml.Marshal(cfg); err == nil && !*dumpConfig { + log.Print("Current configuration:\n", string(y)) } cfg.Timeout = arvados.Duration(time.Duration(*timeoutSeconds) * time.Second) } + if *dumpConfig { + log.Fatal(config.DumpAndExit(cfg)) + } + + log.Printf("keepproxy %q started", arvadosVersion.GetVersion()) + arv, err := arvadosclient.New(&cfg.Client) if err != nil { log.Fatalf("Error setting up arvados client %s", err.Error()) @@ -95,6 +121,7 @@ func main() { if err != nil { log.Fatalf("Error setting up keep client %s", err.Error()) } + keepclient.RefreshServiceDiscoveryOnSIGHUP() if cfg.PIDFile != "" { f, err := os.Create(cfg.PIDFile) @@ -124,8 +151,6 @@ func main() { if cfg.DefaultReplicas > 0 { kc.Want_replicas = cfg.DefaultReplicas } - kc.Client.Timeout = time.Duration(cfg.Timeout) - go kc.RefreshServices(5*time.Minute, 3*time.Second) listener, err = net.Listen("tcp", cfg.Listen) if err != nil { @@ -148,7 +173,8 @@ func main() { signal.Notify(term, syscall.SIGINT) // Start serving requests. - http.Serve(listener, MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc)) + router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc, time.Duration(cfg.Timeout), cfg.ManagementToken) + http.Serve(listener, router) log.Println("shutting down") } @@ -227,61 +253,76 @@ func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, r return true, tok } -type GetBlockHandler struct { - *keepclient.KeepClient - *ApiTokenCache -} - -type PutBlockHandler struct { - *keepclient.KeepClient - *ApiTokenCache -} - -type IndexHandler struct { +type proxyHandler struct { + http.Handler *keepclient.KeepClient *ApiTokenCache + timeout time.Duration + transport *http.Transport } -type InvalidPathHandler struct{} - -type OptionsHandler struct{} - -// MakeRESTRouter -// Returns a mux.Router that passes GET and PUT requests to the -// appropriate handlers. -// -func MakeRESTRouter( - enable_get bool, - enable_put bool, - kc *keepclient.KeepClient) *mux.Router { - - t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300} - +// MakeRESTRouter returns an http.Handler that passes GET and PUT +// requests to the appropriate handlers. +func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient, timeout time.Duration, mgmtToken string) http.Handler { rest := mux.NewRouter() + transport := *(http.DefaultTransport.(*http.Transport)) + transport.DialContext = (&net.Dialer{ + Timeout: keepclient.DefaultConnectTimeout, + KeepAlive: keepclient.DefaultKeepAlive, + DualStack: true, + }).DialContext + transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure) + transport.TLSHandshakeTimeout = keepclient.DefaultTLSHandshakeTimeout + + h := &proxyHandler{ + Handler: rest, + KeepClient: kc, + timeout: timeout, + transport: &transport, + ApiTokenCache: &ApiTokenCache{ + tokens: make(map[string]int64), + expireTime: 300, + }, + } + if enable_get { - rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, - GetBlockHandler{kc, t}).Methods("GET", "HEAD") - rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD") + rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD") + rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD") // List all blocks - rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET") + rest.HandleFunc(`/index`, h.Index).Methods("GET") // List blocks whose hash has the given prefix - rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET") + rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET") } if enable_put { - rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT") - rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT") - rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST") - rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS") - rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS") + rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT") + rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT") + rest.HandleFunc(`/`, h.Put).Methods("POST") + rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS") + rest.HandleFunc(`/`, h.Options).Methods("OPTIONS") } + rest.Handle("/_health/{check}", &health.Handler{ + Token: mgmtToken, + Prefix: "/_health/", + }).Methods("GET") + rest.NotFoundHandler = InvalidPathHandler{} + return h +} + +var errLoopDetected = errors.New("loop detected") - return rest +func (*proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error { + if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 { + log.Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via) + http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError) + return errLoopDetected + } + return nil } func SetCorsHeaders(resp http.ResponseWriter) { @@ -291,12 +332,14 @@ func SetCorsHeaders(resp http.ResponseWriter) { resp.Header().Set("Access-Control-Max-Age", "86486400") } -func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +type InvalidPathHandler struct{} + +func (InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path) http.Error(resp, "Bad request", http.StatusBadRequest) } -func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) { log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path) SetCorsHeaders(resp) } @@ -307,8 +350,12 @@ var MethodNotSupported = errors.New("Method not supported") var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)") -func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) { + if err := h.checkLoop(resp, req); err != nil { + return + } SetCorsHeaders(resp) + resp.Header().Set("Via", req.Proto+" "+viaAlias) locator := mux.Vars(req)["locator"] var err error @@ -323,11 +370,11 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques } }() - kc := *this.KeepClient + kc := h.makeKeepClient(req) var pass bool var tok string - if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass { + if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass { status, err = http.StatusForbidden, BadAuthorizationHeader return } @@ -387,10 +434,15 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired)) var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header") -func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) { + if err := h.checkLoop(resp, req); err != nil { + return + } SetCorsHeaders(resp) + resp.Header().Set("Via", "HTTP/1.1 "+viaAlias) + + kc := h.makeKeepClient(req) - kc := *this.KeepClient var err error var expectLength int64 var status = http.StatusInternalServerError @@ -427,7 +479,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques var pass bool var tok string - if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass { + if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass { err = BadAuthorizationHeader status = http.StatusForbidden return @@ -463,7 +515,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques // Tell the client how many successful PUTs we accomplished resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas)) - switch err { + switch err.(type) { case nil: status = http.StatusOK _, err = io.WriteString(resp, locatorOut) @@ -495,7 +547,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques // Expects "complete" response (terminating with blank new line) // Aborts on any errors // Concatenates responses from all those keep servers and returns -func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) { SetCorsHeaders(resp) prefix := mux.Vars(req)["prefix"] @@ -508,9 +560,8 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques } }() - kc := *handler.KeepClient - - ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req) + kc := h.makeKeepClient(req) + ok, token := CheckAuthorizationHeader(kc, h.ApiTokenCache, req) if !ok { status, err = http.StatusForbidden, BadAuthorizationHeader return @@ -547,3 +598,15 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques status = http.StatusOK resp.Write([]byte("\n")) } + +func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient { + kc := *h.KeepClient + kc.HTTPClient = &proxyClient{ + client: &http.Client{ + Timeout: h.timeout, + Transport: h.transport, + }, + proto: req.Proto, + } + return &kc +}