X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/6de67b677d281b99b53ee9a25c9228523fdf7ee2..8a2035547ad8bf6abad6a4a03bb0b59211a00932:/services/keepproxy/keepproxy.go diff --git a/services/keepproxy/keepproxy.go b/services/keepproxy/keepproxy.go index 46664931fd..0c0c08fe4d 100644 --- a/services/keepproxy/keepproxy.go +++ b/services/keepproxy/keepproxy.go @@ -1,18 +1,21 @@ +// Copyright (C) The Arvados Authors. All rights reserved. +// +// SPDX-License-Identifier: AGPL-3.0 + package main import ( - "encoding/json" "errors" "flag" "fmt" "io" "io/ioutil" - "log" "net" "net/http" "os" "os/signal" "regexp" + "strings" "sync" "syscall" "time" @@ -20,11 +23,17 @@ import ( "git.curoverse.com/arvados.git/sdk/go/arvados" "git.curoverse.com/arvados.git/sdk/go/arvadosclient" "git.curoverse.com/arvados.git/sdk/go/config" + "git.curoverse.com/arvados.git/sdk/go/health" + "git.curoverse.com/arvados.git/sdk/go/httpserver" "git.curoverse.com/arvados.git/sdk/go/keepclient" + log "github.com/Sirupsen/logrus" "github.com/coreos/go-systemd/daemon" + "github.com/ghodss/yaml" "github.com/gorilla/mux" ) +var version = "dev" + type Config struct { Client arvados.Client Listen string @@ -34,6 +43,7 @@ type Config struct { Timeout arvados.Duration PIDFile string Debug bool + ManagementToken string } func DefaultConfig() *Config { @@ -43,9 +53,18 @@ func DefaultConfig() *Config { } } -var listener net.Listener +var ( + listener net.Listener + router http.Handler +) + +const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00" func main() { + log.SetFormatter(&log.JSONFormatter{ + TimestampFormat: rfc3339NanoFixed, + }) + cfg := DefaultConfig() flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError) @@ -58,12 +77,21 @@ func main() { flagset.IntVar(&cfg.DefaultReplicas, "default-replicas", cfg.DefaultReplicas, "Default number of replicas to write if not specified by the client. If 0, use site default."+deprecated) flagset.StringVar(&cfg.PIDFile, "pid", cfg.PIDFile, "Path to write pid file."+deprecated) timeoutSeconds := flagset.Int("timeout", int(time.Duration(cfg.Timeout)/time.Second), "Timeout (in seconds) on requests to internal Keep services."+deprecated) + flagset.StringVar(&cfg.ManagementToken, "management-token", cfg.ManagementToken, "Authorization token to be included in all health check requests.") var cfgPath string - const defaultCfgPath = "/etc/arvados/keepproxy/config.json" + const defaultCfgPath = "/etc/arvados/keepproxy/keepproxy.yml" flagset.StringVar(&cfgPath, "config", defaultCfgPath, "Configuration file `path`") + dumpConfig := flagset.Bool("dump-config", false, "write current configuration to stdout and exit") + getVersion := flagset.Bool("version", false, "Print version information and exit.") flagset.Parse(os.Args[1:]) + // Print version information if requested + if *getVersion { + fmt.Printf("keepproxy %s\n", version) + return + } + err := config.LoadFile(cfg, cfgPath) if err != nil { h := os.Getenv("ARVADOS_API_HOST") @@ -77,12 +105,18 @@ func main() { if regexp.MustCompile("^(?i:1|yes|true)$").MatchString(os.Getenv("ARVADOS_API_HOST_INSECURE")) { cfg.Client.Insecure = true } - if j, err := json.MarshalIndent(cfg, "", " "); err == nil { - log.Print("Current configuration:\n", string(j)) + if y, err := yaml.Marshal(cfg); err == nil && !*dumpConfig { + log.Print("Current configuration:\n", string(y)) } cfg.Timeout = arvados.Duration(time.Duration(*timeoutSeconds) * time.Second) } + if *dumpConfig { + log.Fatal(config.DumpAndExit(cfg)) + } + + log.Printf("keepproxy %s started", version) + arv, err := arvadosclient.New(&cfg.Client) if err != nil { log.Fatalf("Error setting up arvados client %s", err.Error()) @@ -95,6 +129,7 @@ func main() { if err != nil { log.Fatalf("Error setting up keep client %s", err.Error()) } + keepclient.RefreshServiceDiscoveryOnSIGHUP() if cfg.PIDFile != "" { f, err := os.Create(cfg.PIDFile) @@ -124,14 +159,12 @@ func main() { if cfg.DefaultReplicas > 0 { kc.Want_replicas = cfg.DefaultReplicas } - kc.Client.Timeout = time.Duration(cfg.Timeout) - go kc.RefreshServices(5*time.Minute, 3*time.Second) listener, err = net.Listen("tcp", cfg.Listen) if err != nil { log.Fatalf("listen(%s): %s", cfg.Listen, err) } - if _, err := daemon.SdNotify("READY=1"); err != nil { + if _, err := daemon.SdNotify(false, "READY=1"); err != nil { log.Printf("Error notifying init daemon: %v", err) } log.Println("Listening at", listener.Addr()) @@ -148,7 +181,8 @@ func main() { signal.Notify(term, syscall.SIGINT) // Start serving requests. - http.Serve(listener, MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc)) + router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc, time.Duration(cfg.Timeout), cfg.ManagementToken) + http.Serve(listener, httpserver.AddRequestIDs(httpserver.LogRequests(router))) log.Println("shutting down") } @@ -198,90 +232,117 @@ func GetRemoteAddress(req *http.Request) string { } func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) { - var auth string - if auth = req.Header.Get("Authorization"); auth == "" { + parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2) + if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 { return false, "" } - - _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok) - if err != nil { - // Scanning error - return false, "" + tok = parts[1] + + // Tokens are validated differently depending on what kind of + // operation is being performed. For example, tokens in + // collection-sharing links permit GET requests, but not + // PUT requests. + var op string + if req.Method == "GET" || req.Method == "HEAD" { + op = "read" + } else { + op = "write" } - if cache.RecallToken(tok) { + if cache.RecallToken(op + ":" + tok) { // Valid in the cache, short circuit return true, tok } + var err error arv := *kc.Arvados arv.ApiToken = tok - if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil { + if op == "read" { + err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil) + } else { + err = arv.Call("HEAD", "users", "", "current", nil, nil) + } + if err != nil { log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err) return false, "" } // Success! Update cache - cache.RememberToken(tok) + cache.RememberToken(op + ":" + tok) return true, tok } -type GetBlockHandler struct { +type proxyHandler struct { + http.Handler *keepclient.KeepClient *ApiTokenCache + timeout time.Duration + transport *http.Transport } -type PutBlockHandler struct { - *keepclient.KeepClient - *ApiTokenCache -} - -type IndexHandler struct { - *keepclient.KeepClient - *ApiTokenCache -} - -type InvalidPathHandler struct{} - -type OptionsHandler struct{} - -// MakeRESTRouter -// Returns a mux.Router that passes GET and PUT requests to the -// appropriate handlers. -// -func MakeRESTRouter( - enable_get bool, - enable_put bool, - kc *keepclient.KeepClient) *mux.Router { - - t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300} - +// MakeRESTRouter returns an http.Handler that passes GET and PUT +// requests to the appropriate handlers. +func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient, timeout time.Duration, mgmtToken string) http.Handler { rest := mux.NewRouter() + transport := *(http.DefaultTransport.(*http.Transport)) + transport.DialContext = (&net.Dialer{ + Timeout: keepclient.DefaultConnectTimeout, + KeepAlive: keepclient.DefaultKeepAlive, + DualStack: true, + }).DialContext + transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure) + transport.TLSHandshakeTimeout = keepclient.DefaultTLSHandshakeTimeout + + h := &proxyHandler{ + Handler: rest, + KeepClient: kc, + timeout: timeout, + transport: &transport, + ApiTokenCache: &ApiTokenCache{ + tokens: make(map[string]int64), + expireTime: 300, + }, + } + if enable_get { - rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, - GetBlockHandler{kc, t}).Methods("GET", "HEAD") - rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD") + rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD") + rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD") // List all blocks - rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET") + rest.HandleFunc(`/index`, h.Index).Methods("GET") // List blocks whose hash has the given prefix - rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET") + rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET") } if enable_put { - rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT") - rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT") - rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST") - rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS") - rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS") + rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT") + rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT") + rest.HandleFunc(`/`, h.Put).Methods("POST") + rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS") + rest.HandleFunc(`/`, h.Options).Methods("OPTIONS") } + rest.Handle("/_health/{check}", &health.Handler{ + Token: mgmtToken, + Prefix: "/_health/", + }).Methods("GET") + rest.NotFoundHandler = InvalidPathHandler{} + return h +} + +var errLoopDetected = errors.New("loop detected") - return rest +func (*proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error { + if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 { + log.Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via) + http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError) + return errLoopDetected + } + return nil } func SetCorsHeaders(resp http.ResponseWriter) { @@ -291,12 +352,14 @@ func SetCorsHeaders(resp http.ResponseWriter) { resp.Header().Set("Access-Control-Max-Age", "86486400") } -func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +type InvalidPathHandler struct{} + +func (InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path) http.Error(resp, "Bad request", http.StatusBadRequest) } -func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) { log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path) SetCorsHeaders(resp) } @@ -307,8 +370,12 @@ var MethodNotSupported = errors.New("Method not supported") var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)") -func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) { + if err := h.checkLoop(resp, req); err != nil { + return + } SetCorsHeaders(resp) + resp.Header().Set("Via", req.Proto+" "+viaAlias) locator := mux.Vars(req)["locator"] var err error @@ -323,11 +390,11 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques } }() - kc := *this.KeepClient + kc := h.makeKeepClient(req) var pass bool var tok string - if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass { + if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass { status, err = http.StatusForbidden, BadAuthorizationHeader return } @@ -387,10 +454,15 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired)) var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header") -func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) { + if err := h.checkLoop(resp, req); err != nil { + return + } SetCorsHeaders(resp) + resp.Header().Set("Via", "HTTP/1.1 "+viaAlias) + + kc := h.makeKeepClient(req) - kc := *this.KeepClient var err error var expectLength int64 var status = http.StatusInternalServerError @@ -427,7 +499,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques var pass bool var tok string - if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass { + if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass { err = BadAuthorizationHeader status = http.StatusForbidden return @@ -463,7 +535,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques // Tell the client how many successful PUTs we accomplished resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas)) - switch err { + switch err.(type) { case nil: status = http.StatusOK _, err = io.WriteString(resp, locatorOut) @@ -495,7 +567,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques // Expects "complete" response (terminating with blank new line) // Aborts on any errors // Concatenates responses from all those keep servers and returns -func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) { SetCorsHeaders(resp) prefix := mux.Vars(req)["prefix"] @@ -508,9 +580,8 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques } }() - kc := *handler.KeepClient - - ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req) + kc := h.makeKeepClient(req) + ok, token := CheckAuthorizationHeader(kc, h.ApiTokenCache, req) if !ok { status, err = http.StatusForbidden, BadAuthorizationHeader return @@ -547,3 +618,16 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques status = http.StatusOK resp.Write([]byte("\n")) } + +func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient { + kc := *h.KeepClient + kc.HTTPClient = &proxyClient{ + client: &http.Client{ + Timeout: h.timeout, + Transport: h.transport, + }, + proto: req.Proto, + requestID: req.Header.Get("X-Request-Id"), + } + return &kc +}