14324: Use logrus in Azure driver. Fix Sirupsen->sirupsen in imports
[arvados.git] / services / keepproxy / keepproxy.go
index 7dfd01ad41e7fd96f74b763d6129cc1792448538..def97e6e003a9b1d86ed94ea50f48e343cc8b6a7 100644 (file)
@@ -1,3 +1,7 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
@@ -6,7 +10,6 @@ import (
        "fmt"
        "io"
        "io/ioutil"
-       "log"
        "net"
        "net/http"
        "os"
@@ -20,12 +23,17 @@ import (
        "git.curoverse.com/arvados.git/sdk/go/arvados"
        "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
        "git.curoverse.com/arvados.git/sdk/go/config"
+       "git.curoverse.com/arvados.git/sdk/go/health"
+       "git.curoverse.com/arvados.git/sdk/go/httpserver"
        "git.curoverse.com/arvados.git/sdk/go/keepclient"
+       log "github.com/sirupsen/logrus"
        "github.com/coreos/go-systemd/daemon"
        "github.com/ghodss/yaml"
        "github.com/gorilla/mux"
 )
 
+var version = "dev"
+
 type Config struct {
        Client          arvados.Client
        Listen          string
@@ -35,6 +43,7 @@ type Config struct {
        Timeout         arvados.Duration
        PIDFile         string
        Debug           bool
+       ManagementToken string
 }
 
 func DefaultConfig() *Config {
@@ -49,7 +58,13 @@ var (
        router   http.Handler
 )
 
+const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00"
+
 func main() {
+       log.SetFormatter(&log.JSONFormatter{
+               TimestampFormat: rfc3339NanoFixed,
+       })
+
        cfg := DefaultConfig()
 
        flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
@@ -62,13 +77,21 @@ func main() {
        flagset.IntVar(&cfg.DefaultReplicas, "default-replicas", cfg.DefaultReplicas, "Default number of replicas to write if not specified by the client. If 0, use site default."+deprecated)
        flagset.StringVar(&cfg.PIDFile, "pid", cfg.PIDFile, "Path to write pid file."+deprecated)
        timeoutSeconds := flagset.Int("timeout", int(time.Duration(cfg.Timeout)/time.Second), "Timeout (in seconds) on requests to internal Keep services."+deprecated)
+       flagset.StringVar(&cfg.ManagementToken, "management-token", cfg.ManagementToken, "Authorization token to be included in all health check requests.")
 
        var cfgPath string
        const defaultCfgPath = "/etc/arvados/keepproxy/keepproxy.yml"
        flagset.StringVar(&cfgPath, "config", defaultCfgPath, "Configuration file `path`")
        dumpConfig := flagset.Bool("dump-config", false, "write current configuration to stdout and exit")
+       getVersion := flagset.Bool("version", false, "Print version information and exit.")
        flagset.Parse(os.Args[1:])
 
+       // Print version information if requested
+       if *getVersion {
+               fmt.Printf("keepproxy %s\n", version)
+               return
+       }
+
        err := config.LoadFile(cfg, cfgPath)
        if err != nil {
                h := os.Getenv("ARVADOS_API_HOST")
@@ -92,6 +115,8 @@ func main() {
                log.Fatal(config.DumpAndExit(cfg))
        }
 
+       log.Printf("keepproxy %s started", version)
+
        arv, err := arvadosclient.New(&cfg.Client)
        if err != nil {
                log.Fatalf("Error setting up arvados client %s", err.Error())
@@ -156,8 +181,8 @@ func main() {
        signal.Notify(term, syscall.SIGINT)
 
        // Start serving requests.
-       router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc, time.Duration(cfg.Timeout))
-       http.Serve(listener, router)
+       router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc, time.Duration(cfg.Timeout), cfg.ManagementToken)
+       http.Serve(listener, httpserver.AddRequestIDs(httpserver.LogRequests(nil, router)))
 
        log.Println("shutting down")
 }
@@ -207,35 +232,56 @@ func GetRemoteAddress(req *http.Request) string {
 }
 
 func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
-       var auth string
-       if auth = req.Header.Get("Authorization"); auth == "" {
+       parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
+       if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 {
                return false, ""
        }
+       tok = parts[1]
 
-       _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
-       if err != nil {
-               // Scanning error
-               return false, ""
+       // Tokens are validated differently depending on what kind of
+       // operation is being performed. For example, tokens in
+       // collection-sharing links permit GET requests, but not
+       // PUT requests.
+       var op string
+       if req.Method == "GET" || req.Method == "HEAD" {
+               op = "read"
+       } else {
+               op = "write"
        }
 
-       if cache.RecallToken(tok) {
+       if cache.RecallToken(op + ":" + tok) {
                // Valid in the cache, short circuit
                return true, tok
        }
 
+       var err error
        arv := *kc.Arvados
        arv.ApiToken = tok
-       if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
+       arv.RequestID = req.Header.Get("X-Request-Id")
+       if op == "read" {
+               err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil)
+       } else {
+               err = arv.Call("HEAD", "users", "", "current", nil, nil)
+       }
+       if err != nil {
                log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
                return false, ""
        }
 
        // Success!  Update cache
-       cache.RememberToken(tok)
+       cache.RememberToken(op + ":" + tok)
 
        return true, tok
 }
 
+// We need to make a private copy of the default http transport early
+// in initialization, then make copies of our private copy later. It
+// won't be safe to copy http.DefaultTransport itself later, because
+// its private mutexes might have already been used. (Without this,
+// the test suite sometimes panics "concurrent map writes" in
+// net/http.(*Transport).removeIdleConnLocked().)
+var defaultTransport = *(http.DefaultTransport.(*http.Transport))
+
 type proxyHandler struct {
        http.Handler
        *keepclient.KeepClient
@@ -246,10 +292,10 @@ type proxyHandler struct {
 
 // MakeRESTRouter returns an http.Handler that passes GET and PUT
 // requests to the appropriate handlers.
-func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient, timeout time.Duration) http.Handler {
+func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient, timeout time.Duration, mgmtToken string) http.Handler {
        rest := mux.NewRouter()
 
-       transport := *(http.DefaultTransport.(*http.Transport))
+       transport := defaultTransport
        transport.DialContext = (&net.Dialer{
                Timeout:   keepclient.DefaultConnectTimeout,
                KeepAlive: keepclient.DefaultKeepAlive,
@@ -288,6 +334,11 @@ func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient,
                rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
        }
 
+       rest.Handle("/_health/{check}", &health.Handler{
+               Token:  mgmtToken,
+               Prefix: "/_health/",
+       }).Methods("GET")
+
        rest.NotFoundHandler = InvalidPathHandler{}
        return h
 }
@@ -436,6 +487,15 @@ func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
 
        locatorIn := mux.Vars(req)["locator"]
 
+       // Check if the client specified storage classes
+       if req.Header.Get("X-Keep-Storage-Classes") != "" {
+               var scl []string
+               for _, sc := range strings.Split(req.Header.Get("X-Keep-Storage-Classes"), ",") {
+                       scl = append(scl, strings.Trim(sc, " "))
+               }
+               kc.StorageClasses = scl
+       }
+
        _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
        if err != nil || expectLength < 0 {
                err = LengthRequiredError
@@ -479,13 +539,13 @@ func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
 
        // Now try to put the block through
        if locatorIn == "" {
-               if bytes, err := ioutil.ReadAll(req.Body); err != nil {
-                       err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
+               bytes, err2 := ioutil.ReadAll(req.Body)
+               if err2 != nil {
+                       _ = errors.New(fmt.Sprintf("Error reading request body: %s", err2))
                        status = http.StatusInternalServerError
                        return
-               } else {
-                       locatorOut, wroteReplicas, err = kc.PutB(bytes)
                }
+               locatorOut, wroteReplicas, err = kc.PutB(bytes)
        } else {
                locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
        }
@@ -579,6 +639,7 @@ func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
 
 func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
        kc := *h.KeepClient
+       kc.RequestID = req.Header.Get("X-Request-Id")
        kc.HTTPClient = &proxyClient{
                client: &http.Client{
                        Timeout:   h.timeout,