10666: Added version number to go sdk and go tools & services
[arvados.git] / services / keepproxy / keepproxy.go
index 76a8a1551fb867f5258221f1f22692656039d97a..ec074cff3f4a772a39e7a345d3e4307be5a04c7f 100644 (file)
@@ -1,3 +1,7 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
@@ -12,6 +16,7 @@ import (
        "os"
        "os/signal"
        "regexp"
+       "strings"
        "sync"
        "syscall"
        "time"
@@ -19,7 +24,9 @@ import (
        "git.curoverse.com/arvados.git/sdk/go/arvados"
        "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
        "git.curoverse.com/arvados.git/sdk/go/config"
+       "git.curoverse.com/arvados.git/sdk/go/health"
        "git.curoverse.com/arvados.git/sdk/go/keepclient"
+       arvadosVersion "git.curoverse.com/arvados.git/sdk/go/version"
        "github.com/coreos/go-systemd/daemon"
        "github.com/ghodss/yaml"
        "github.com/gorilla/mux"
@@ -34,6 +41,7 @@ type Config struct {
        Timeout         arvados.Duration
        PIDFile         string
        Debug           bool
+       ManagementToken string
 }
 
 func DefaultConfig() *Config {
@@ -43,7 +51,10 @@ func DefaultConfig() *Config {
        }
 }
 
-var listener net.Listener
+var (
+       listener net.Listener
+       router   http.Handler
+)
 
 func main() {
        cfg := DefaultConfig()
@@ -58,13 +69,21 @@ func main() {
        flagset.IntVar(&cfg.DefaultReplicas, "default-replicas", cfg.DefaultReplicas, "Default number of replicas to write if not specified by the client. If 0, use site default."+deprecated)
        flagset.StringVar(&cfg.PIDFile, "pid", cfg.PIDFile, "Path to write pid file."+deprecated)
        timeoutSeconds := flagset.Int("timeout", int(time.Duration(cfg.Timeout)/time.Second), "Timeout (in seconds) on requests to internal Keep services."+deprecated)
+       flagset.StringVar(&cfg.ManagementToken, "management-token", cfg.ManagementToken, "Authorization token to be included in all health check requests.")
 
        var cfgPath string
        const defaultCfgPath = "/etc/arvados/keepproxy/keepproxy.yml"
        flagset.StringVar(&cfgPath, "config", defaultCfgPath, "Configuration file `path`")
        dumpConfig := flagset.Bool("dump-config", false, "write current configuration to stdout and exit")
+       getVersion := flagset.Bool("version", false, "Print version information and exit.")
        flagset.Parse(os.Args[1:])
 
+       // Print version information if requested
+       if *getVersion {
+               fmt.Printf("Version: %s\n", arvadosVersion.GetVersion())
+               os.Exit(0)
+       }
+
        err := config.LoadFile(cfg, cfgPath)
        if err != nil {
                h := os.Getenv("ARVADOS_API_HOST")
@@ -88,6 +107,8 @@ func main() {
                log.Fatal(config.DumpAndExit(cfg))
        }
 
+       log.Printf("keepproxy %q started", arvadosVersion.GetVersion())
+
        arv, err := arvadosclient.New(&cfg.Client)
        if err != nil {
                log.Fatalf("Error setting up arvados client %s", err.Error())
@@ -100,6 +121,7 @@ func main() {
        if err != nil {
                log.Fatalf("Error setting up keep client %s", err.Error())
        }
+       keepclient.RefreshServiceDiscoveryOnSIGHUP()
 
        if cfg.PIDFile != "" {
                f, err := os.Create(cfg.PIDFile)
@@ -129,8 +151,6 @@ func main() {
        if cfg.DefaultReplicas > 0 {
                kc.Want_replicas = cfg.DefaultReplicas
        }
-       kc.Client.Timeout = time.Duration(cfg.Timeout)
-       go kc.RefreshServices(5*time.Minute, 3*time.Second)
 
        listener, err = net.Listen("tcp", cfg.Listen)
        if err != nil {
@@ -153,7 +173,8 @@ func main() {
        signal.Notify(term, syscall.SIGINT)
 
        // Start serving requests.
-       http.Serve(listener, MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc))
+       router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc, time.Duration(cfg.Timeout), cfg.ManagementToken)
+       http.Serve(listener, router)
 
        log.Println("shutting down")
 }
@@ -232,61 +253,76 @@ func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, r
        return true, tok
 }
 
-type GetBlockHandler struct {
-       *keepclient.KeepClient
-       *ApiTokenCache
-}
-
-type PutBlockHandler struct {
+type proxyHandler struct {
+       http.Handler
        *keepclient.KeepClient
        *ApiTokenCache
+       timeout   time.Duration
+       transport *http.Transport
 }
 
-type IndexHandler struct {
-       *keepclient.KeepClient
-       *ApiTokenCache
-}
-
-type InvalidPathHandler struct{}
-
-type OptionsHandler struct{}
-
-// MakeRESTRouter
-//     Returns a mux.Router that passes GET and PUT requests to the
-//     appropriate handlers.
-//
-func MakeRESTRouter(
-       enable_get bool,
-       enable_put bool,
-       kc *keepclient.KeepClient) *mux.Router {
-
-       t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
-
+// MakeRESTRouter returns an http.Handler that passes GET and PUT
+// requests to the appropriate handlers.
+func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient, timeout time.Duration, mgmtToken string) http.Handler {
        rest := mux.NewRouter()
 
+       transport := *(http.DefaultTransport.(*http.Transport))
+       transport.DialContext = (&net.Dialer{
+               Timeout:   keepclient.DefaultConnectTimeout,
+               KeepAlive: keepclient.DefaultKeepAlive,
+               DualStack: true,
+       }).DialContext
+       transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure)
+       transport.TLSHandshakeTimeout = keepclient.DefaultTLSHandshakeTimeout
+
+       h := &proxyHandler{
+               Handler:    rest,
+               KeepClient: kc,
+               timeout:    timeout,
+               transport:  &transport,
+               ApiTokenCache: &ApiTokenCache{
+                       tokens:     make(map[string]int64),
+                       expireTime: 300,
+               },
+       }
+
        if enable_get {
-               rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
-                       GetBlockHandler{kc, t}).Methods("GET", "HEAD")
-               rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
+               rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD")
+               rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD")
 
                // List all blocks
-               rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
+               rest.HandleFunc(`/index`, h.Index).Methods("GET")
 
                // List blocks whose hash has the given prefix
-               rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
+               rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET")
        }
 
        if enable_put {
-               rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
-               rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
-               rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
-               rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
-               rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
+               rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT")
+               rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT")
+               rest.HandleFunc(`/`, h.Put).Methods("POST")
+               rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS")
+               rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
        }
 
+       rest.Handle("/_health/{check}", &health.Handler{
+               Token:  mgmtToken,
+               Prefix: "/_health/",
+       }).Methods("GET")
+
        rest.NotFoundHandler = InvalidPathHandler{}
+       return h
+}
 
-       return rest
+var errLoopDetected = errors.New("loop detected")
+
+func (*proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error {
+       if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 {
+               log.Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via)
+               http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError)
+               return errLoopDetected
+       }
+       return nil
 }
 
 func SetCorsHeaders(resp http.ResponseWriter) {
@@ -296,12 +332,14 @@ func SetCorsHeaders(resp http.ResponseWriter) {
        resp.Header().Set("Access-Control-Max-Age", "86486400")
 }
 
-func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+type InvalidPathHandler struct{}
+
+func (InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
        log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
        http.Error(resp, "Bad request", http.StatusBadRequest)
 }
 
-func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) {
        log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
        SetCorsHeaders(resp)
 }
@@ -312,8 +350,12 @@ var MethodNotSupported = errors.New("Method not supported")
 
 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
 
-func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
+       if err := h.checkLoop(resp, req); err != nil {
+               return
+       }
        SetCorsHeaders(resp)
+       resp.Header().Set("Via", req.Proto+" "+viaAlias)
 
        locator := mux.Vars(req)["locator"]
        var err error
@@ -328,11 +370,11 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                }
        }()
 
-       kc := *this.KeepClient
+       kc := h.makeKeepClient(req)
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
@@ -392,10 +434,15 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
 
-func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
+       if err := h.checkLoop(resp, req); err != nil {
+               return
+       }
        SetCorsHeaders(resp)
+       resp.Header().Set("Via", "HTTP/1.1 "+viaAlias)
+
+       kc := h.makeKeepClient(req)
 
-       kc := *this.KeepClient
        var err error
        var expectLength int64
        var status = http.StatusInternalServerError
@@ -432,7 +479,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass {
                err = BadAuthorizationHeader
                status = http.StatusForbidden
                return
@@ -468,7 +515,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        // Tell the client how many successful PUTs we accomplished
        resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
 
-       switch err {
+       switch err.(type) {
        case nil:
                status = http.StatusOK
                _, err = io.WriteString(resp, locatorOut)
@@ -500,7 +547,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 //   Expects "complete" response (terminating with blank new line)
 //   Aborts on any errors
 // Concatenates responses from all those keep servers and returns
-func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
        SetCorsHeaders(resp)
 
        prefix := mux.Vars(req)["prefix"]
@@ -513,9 +560,8 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                }
        }()
 
-       kc := *handler.KeepClient
-
-       ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
+       kc := h.makeKeepClient(req)
+       ok, token := CheckAuthorizationHeader(kc, h.ApiTokenCache, req)
        if !ok {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
@@ -552,3 +598,15 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        status = http.StatusOK
        resp.Write([]byte("\n"))
 }
+
+func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
+       kc := *h.KeepClient
+       kc.HTTPClient = &proxyClient{
+               client: &http.Client{
+                       Timeout:   h.timeout,
+                       Transport: h.transport,
+               },
+               proto: req.Proto,
+       }
+       return &kc
+}