14324: Use logrus in Azure driver. Fix Sirupsen->sirupsen in imports
[arvados.git] / services / keepproxy / keepproxy.go
index 3d1b4476255d0cadf2274ad12b9aae78f164f72f..def97e6e003a9b1d86ed94ea50f48e343cc8b6a7 100644 (file)
@@ -10,7 +10,6 @@ import (
        "fmt"
        "io"
        "io/ioutil"
-       "log"
        "net"
        "net/http"
        "os"
@@ -25,12 +24,16 @@ import (
        "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
        "git.curoverse.com/arvados.git/sdk/go/config"
        "git.curoverse.com/arvados.git/sdk/go/health"
+       "git.curoverse.com/arvados.git/sdk/go/httpserver"
        "git.curoverse.com/arvados.git/sdk/go/keepclient"
+       log "github.com/sirupsen/logrus"
        "github.com/coreos/go-systemd/daemon"
        "github.com/ghodss/yaml"
        "github.com/gorilla/mux"
 )
 
+var version = "dev"
+
 type Config struct {
        Client          arvados.Client
        Listen          string
@@ -55,7 +58,13 @@ var (
        router   http.Handler
 )
 
+const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00"
+
 func main() {
+       log.SetFormatter(&log.JSONFormatter{
+               TimestampFormat: rfc3339NanoFixed,
+       })
+
        cfg := DefaultConfig()
 
        flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
@@ -74,8 +83,15 @@ func main() {
        const defaultCfgPath = "/etc/arvados/keepproxy/keepproxy.yml"
        flagset.StringVar(&cfgPath, "config", defaultCfgPath, "Configuration file `path`")
        dumpConfig := flagset.Bool("dump-config", false, "write current configuration to stdout and exit")
+       getVersion := flagset.Bool("version", false, "Print version information and exit.")
        flagset.Parse(os.Args[1:])
 
+       // Print version information if requested
+       if *getVersion {
+               fmt.Printf("keepproxy %s\n", version)
+               return
+       }
+
        err := config.LoadFile(cfg, cfgPath)
        if err != nil {
                h := os.Getenv("ARVADOS_API_HOST")
@@ -99,6 +115,8 @@ func main() {
                log.Fatal(config.DumpAndExit(cfg))
        }
 
+       log.Printf("keepproxy %s started", version)
+
        arv, err := arvadosclient.New(&cfg.Client)
        if err != nil {
                log.Fatalf("Error setting up arvados client %s", err.Error())
@@ -164,7 +182,7 @@ func main() {
 
        // Start serving requests.
        router = MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc, time.Duration(cfg.Timeout), cfg.ManagementToken)
-       http.Serve(listener, router)
+       http.Serve(listener, httpserver.AddRequestIDs(httpserver.LogRequests(nil, router)))
 
        log.Println("shutting down")
 }
@@ -214,35 +232,56 @@ func GetRemoteAddress(req *http.Request) string {
 }
 
 func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
-       var auth string
-       if auth = req.Header.Get("Authorization"); auth == "" {
+       parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
+       if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 {
                return false, ""
        }
+       tok = parts[1]
 
-       _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
-       if err != nil {
-               // Scanning error
-               return false, ""
+       // Tokens are validated differently depending on what kind of
+       // operation is being performed. For example, tokens in
+       // collection-sharing links permit GET requests, but not
+       // PUT requests.
+       var op string
+       if req.Method == "GET" || req.Method == "HEAD" {
+               op = "read"
+       } else {
+               op = "write"
        }
 
-       if cache.RecallToken(tok) {
+       if cache.RecallToken(op + ":" + tok) {
                // Valid in the cache, short circuit
                return true, tok
        }
 
+       var err error
        arv := *kc.Arvados
        arv.ApiToken = tok
-       if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
+       arv.RequestID = req.Header.Get("X-Request-Id")
+       if op == "read" {
+               err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil)
+       } else {
+               err = arv.Call("HEAD", "users", "", "current", nil, nil)
+       }
+       if err != nil {
                log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
                return false, ""
        }
 
        // Success!  Update cache
-       cache.RememberToken(tok)
+       cache.RememberToken(op + ":" + tok)
 
        return true, tok
 }
 
+// We need to make a private copy of the default http transport early
+// in initialization, then make copies of our private copy later. It
+// won't be safe to copy http.DefaultTransport itself later, because
+// its private mutexes might have already been used. (Without this,
+// the test suite sometimes panics "concurrent map writes" in
+// net/http.(*Transport).removeIdleConnLocked().)
+var defaultTransport = *(http.DefaultTransport.(*http.Transport))
+
 type proxyHandler struct {
        http.Handler
        *keepclient.KeepClient
@@ -256,7 +295,7 @@ type proxyHandler struct {
 func MakeRESTRouter(enable_get bool, enable_put bool, kc *keepclient.KeepClient, timeout time.Duration, mgmtToken string) http.Handler {
        rest := mux.NewRouter()
 
-       transport := *(http.DefaultTransport.(*http.Transport))
+       transport := defaultTransport
        transport.DialContext = (&net.Dialer{
                Timeout:   keepclient.DefaultConnectTimeout,
                KeepAlive: keepclient.DefaultKeepAlive,
@@ -448,6 +487,15 @@ func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
 
        locatorIn := mux.Vars(req)["locator"]
 
+       // Check if the client specified storage classes
+       if req.Header.Get("X-Keep-Storage-Classes") != "" {
+               var scl []string
+               for _, sc := range strings.Split(req.Header.Get("X-Keep-Storage-Classes"), ",") {
+                       scl = append(scl, strings.Trim(sc, " "))
+               }
+               kc.StorageClasses = scl
+       }
+
        _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
        if err != nil || expectLength < 0 {
                err = LengthRequiredError
@@ -491,13 +539,13 @@ func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
 
        // Now try to put the block through
        if locatorIn == "" {
-               if bytes, err := ioutil.ReadAll(req.Body); err != nil {
-                       err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
+               bytes, err2 := ioutil.ReadAll(req.Body)
+               if err2 != nil {
+                       _ = errors.New(fmt.Sprintf("Error reading request body: %s", err2))
                        status = http.StatusInternalServerError
                        return
-               } else {
-                       locatorOut, wroteReplicas, err = kc.PutB(bytes)
                }
+               locatorOut, wroteReplicas, err = kc.PutB(bytes)
        } else {
                locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
        }
@@ -591,6 +639,7 @@ func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
 
 func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
        kc := *h.KeepClient
+       kc.RequestID = req.Header.Get("X-Request-Id")
        kc.HTTPClient = &proxyClient{
                client: &http.Client{
                        Timeout:   h.timeout,