12538: Refactor crunch-run shutdown
[arvados.git] / services / keepproxy / keepproxy.go
index 7a673aeba97b9780d3dcdbac0250d47c4023f905..3d1b4476255d0cadf2274ad12b9aae78f164f72f 100644 (file)
@@ -1,3 +1,7 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
@@ -20,6 +24,7 @@ import (
        "git.curoverse.com/arvados.git/sdk/go/arvados"
        "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
        "git.curoverse.com/arvados.git/sdk/go/config"
+       "git.curoverse.com/arvados.git/sdk/go/health"
        "git.curoverse.com/arvados.git/sdk/go/keepclient"
        "github.com/coreos/go-systemd/daemon"
        "github.com/ghodss/yaml"
@@ -35,6 +40,7 @@ type Config struct {
        Timeout         arvados.Duration
        PIDFile         string
        Debug           bool
+       ManagementToken string
 }
 
 func DefaultConfig() *Config {
@@ -62,6 +68,7 @@ func main() {
        flagset.IntVar(&cfg.DefaultReplicas, "default-replicas", cfg.DefaultReplicas, "Default number of replicas to write if not specified by the client. If 0, use site default."+deprecated)
        flagset.StringVar(&cfg.PIDFile, "pid", cfg.PIDFile, "Path to write pid file."+deprecated)
        timeoutSeconds := flagset.Int("timeout", int(time.Duration(cfg.Timeout)/time.Second), "Timeout (in seconds) on requests to internal Keep services."+deprecated)
+       flagset.StringVar(&cfg.ManagementToken, "management-token", cfg.ManagementToken, "Authorization token to be included in all health check requests.")
 
        var cfgPath string
        const defaultCfgPath = "/etc/arvados/keepproxy/keepproxy.yml"
@@ -104,6 +111,7 @@ func main() {
        if err != nil {
                log.Fatalf("Error setting up keep client %s", err.Error())
        }
+       keepclient.RefreshServiceDiscoveryOnSIGHUP()
 
        if cfg.PIDFile != "" {
                f, err := os.Create(cfg.PIDFile)
@@ -133,8 +141,6 @@ func main() {
        if cfg.DefaultReplicas > 0 {
                kc.Want_replicas = cfg.DefaultReplicas
        }
-       kc.Client.(*http.Client).Timeout = time.Duration(cfg.Timeout)
-       go kc.RefreshServices(5*time.Minute, 3*time.Second)
 
        listener, err = net.Listen("tcp", cfg.Listen)
        if err != nil {
@@ -157,7 +163,7 @@ func main() {
        signal.Notify(term, syscall.SIGINT)
 
        // Start serving requests.
-       router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc)
+       router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc, time.Duration(cfg.Timeout), cfg.ManagementToken)
        http.Serve(listener, router)
 
        log.Println("shutting down")
@@ -241,15 +247,29 @@ type proxyHandler struct {
        http.Handler
        *keepclient.KeepClient
        *ApiTokenCache
+       timeout   time.Duration
+       transport *http.Transport
 }
 
 // MakeRESTRouter returns an http.Handler that passes GET and PUT
 // requests to the appropriate handlers.
-func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient) http.Handler {
+func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient, timeout time.Duration, mgmtToken string) http.Handler {
        rest := mux.NewRouter()
+
+       transport := *(http.DefaultTransport.(*http.Transport))
+       transport.DialContext = (&net.Dialer{
+               Timeout:   keepclient.DefaultConnectTimeout,
+               KeepAlive: keepclient.DefaultKeepAlive,
+               DualStack: true,
+       }).DialContext
+       transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure)
+       transport.TLSHandshakeTimeout = keepclient.DefaultTLSHandshakeTimeout
+
        h := &proxyHandler{
                Handler:    rest,
                KeepClient: kc,
+               timeout:    timeout,
+               transport:  &transport,
                ApiTokenCache: &ApiTokenCache{
                        tokens:     make(map[string]int64),
                        expireTime: 300,
@@ -275,6 +295,11 @@ func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient)
                rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
        }
 
+       rest.Handle("/_health/{check}", &health.Handler{
+               Token:  mgmtToken,
+               Prefix: "/_health/",
+       }).Methods("GET")
+
        rest.NotFoundHandler = InvalidPathHandler{}
        return h
 }
@@ -320,6 +345,7 @@ func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
                return
        }
        SetCorsHeaders(resp)
+       resp.Header().Set("Via", req.Proto+" "+viaAlias)
 
        locator := mux.Vars(req)["locator"]
        var err error
@@ -334,12 +360,11 @@ func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
                }
        }()
 
-       kc := *h.KeepClient
-       kc.Client = &proxyClient{client: kc.Client, proto: req.Proto}
+       kc := h.makeKeepClient(req)
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(&kc, h.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
@@ -404,9 +429,9 @@ func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
                return
        }
        SetCorsHeaders(resp)
+       resp.Header().Set("Via", "HTTP/1.1 "+viaAlias)
 
-       kc := *h.KeepClient
-       kc.Client = &proxyClient{client: kc.Client, proto: req.Proto}
+       kc := h.makeKeepClient(req)
 
        var err error
        var expectLength int64
@@ -444,7 +469,7 @@ func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(&kc, h.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass {
                err = BadAuthorizationHeader
                status = http.StatusForbidden
                return
@@ -525,9 +550,8 @@ func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
                }
        }()
 
-       kc := *h.KeepClient
-
-       ok, token := CheckAuthorizationHeader(&kc, h.ApiTokenCache, req)
+       kc := h.makeKeepClient(req)
+       ok, token := CheckAuthorizationHeader(kc, h.ApiTokenCache, req)
        if !ok {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
@@ -564,3 +588,15 @@ func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
        status = http.StatusOK
        resp.Write([]byte("\n"))
 }
+
+func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
+       kc := *h.KeepClient
+       kc.HTTPClient = &proxyClient{
+               client: &http.Client{
+                       Timeout:   h.timeout,
+                       Transport: h.transport,
+               },
+               proto: req.Proto,
+       }
+       return &kc
+}