Keepproxy use client-supplied token when forwarding GET and PUT requests.
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
index 9e0b2ff90c29fe070ec30e3bd11f96f1038f9a87..367854bed382ec1b17589c825126d5e31a9bf6a6 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "arvados.org/keepclient"
+       "arvados.org/sdk"
        "flag"
        "fmt"
        "github.com/gorilla/mux"
@@ -10,7 +11,9 @@ import (
        "net"
        "net/http"
        "os"
+       "os/signal"
        "sync"
+       "syscall"
        "time"
 )
 
@@ -29,7 +32,9 @@ func main() {
                pidfile          string
        )
 
-       flag.StringVar(
+       flagset := flag.NewFlagSet("default", flag.ExitOnError)
+
+       flagset.StringVar(
                &listen,
                "listen",
                DEFAULT_ADDR,
@@ -37,35 +42,40 @@ func main() {
                        "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
                        "to listen on all network interfaces.")
 
-       flag.BoolVar(
+       flagset.BoolVar(
                &no_get,
                "no-get",
                false,
                "If set, disable GET operations")
 
-       flag.BoolVar(
-               &no_get,
+       flagset.BoolVar(
+               &no_put,
                "no-put",
                false,
                "If set, disable PUT operations")
 
-       flag.IntVar(
+       flagset.IntVar(
                &default_replicas,
                "default-replicas",
                2,
                "Default number of replicas to write if not specified by the client.")
 
-       flag.StringVar(
+       flagset.StringVar(
                &pidfile,
                "pid",
                "",
                "Path to write pid file")
 
-       flag.Parse()
+       flagset.Parse(os.Args[1:])
 
-       kc, err := keepclient.MakeKeepClient()
+       arv, err := sdk.MakeArvadosClient()
        if err != nil {
-               log.Fatal(err)
+               log.Fatalf("Error setting up arvados client %s", err.Error())
+       }
+
+       kc, err := keepclient.MakeKeepClient(&arv)
+       if err != nil {
+               log.Fatalf("Error setting up keep client %s", err.Error())
        }
 
        if pidfile != "" {
@@ -87,8 +97,37 @@ func main() {
 
        go RefreshServicesList(&kc)
 
+       // Shut down the server gracefully (by closing the listener)
+       // if SIGTERM is received.
+       term := make(chan os.Signal, 1)
+       go func(sig <-chan os.Signal) {
+               s := <-sig
+               log.Println("caught signal:", s)
+               listener.Close()
+       }(term)
+       signal.Notify(term, syscall.SIGTERM)
+       signal.Notify(term, syscall.SIGINT)
+
+       if pidfile != "" {
+               f, err := os.Create(pidfile)
+               if err == nil {
+                       fmt.Fprint(f, os.Getpid())
+                       f.Close()
+               } else {
+                       log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
+               }
+       }
+
+       log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
+
        // Start listening for requests.
        http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
+
+       log.Println("shutting down")
+
+       if pidfile != "" {
+               os.Remove(pidfile)
+       }
 }
 
 type ApiTokenCache struct {
@@ -101,7 +140,14 @@ type ApiTokenCache struct {
 func RefreshServicesList(kc *keepclient.KeepClient) {
        for {
                time.Sleep(300 * time.Second)
+               oldservices := kc.ServiceRoots()
                kc.DiscoverKeepServers()
+               newservices := kc.ServiceRoots()
+               s1 := fmt.Sprint(oldservices)
+               s2 := fmt.Sprint(newservices)
+               if s1 != s2 {
+                       log.Printf("Updated server list to %v", s2)
+               }
        }
 }
 
@@ -147,53 +193,34 @@ func GetRemoteAddress(req *http.Request) string {
        return req.RemoteAddr
 }
 
-func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
+func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
        var auth string
        if auth = req.Header.Get("Authorization"); auth == "" {
-               return false
+               return false, ""
        }
 
-       var tok string
        _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
        if err != nil {
                // Scanning error
-               return false
+               return false, ""
        }
 
        if cache.RecallToken(tok) {
                // Valid in the cache, short circut
-               return true
+               return true, tok
        }
 
-       var usersreq *http.Request
-
-       if usersreq, err = http.NewRequest("GET", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
-               // Can't construct the request
+       arv := *kc.Arvados
+       arv.ApiToken = tok
+       if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
                log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
-               return false
-       }
-
-       // Add api token header
-       usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
-
-       // Actually make the request
-       var resp *http.Response
-       if resp, err = kc.Client.Do(usersreq); err != nil {
-               // Something else failed
-               log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
-               return false
-       }
-
-       if resp.StatusCode != http.StatusOK {
-               // Bad status
-               log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
-               return false
+               return false, ""
        }
 
        // Success!  Update cache
        cache.RememberToken(tok)
 
-       return true
+       return true, tok
 }
 
 type GetBlockHandler struct {
@@ -206,6 +233,8 @@ type PutBlockHandler struct {
        *ApiTokenCache
 }
 
+type InvalidPathHandler struct{}
+
 // MakeRESTRouter
 //     Returns a mux.Router that passes GET and PUT requests to the
 //     appropriate handlers.
@@ -218,58 +247,72 @@ func MakeRESTRouter(
        t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
 
        rest := mux.NewRouter()
-       gh := rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t})
-       ghsig := rest.Handle(
-               `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
-               GetBlockHandler{kc, t})
-       ph := rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t})
 
        if enable_get {
-               gh.Methods("GET", "HEAD")
-               ghsig.Methods("GET", "HEAD")
+               rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
+                       GetBlockHandler{kc, t}).Methods("GET", "HEAD")
+               rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
        }
 
        if enable_put {
-               ph.Methods("PUT")
+               rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
+               rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
        }
 
+       rest.NotFoundHandler = InvalidPathHandler{}
+
        return rest
 }
 
+func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+       log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
+       http.Error(resp, "Bad request", http.StatusBadRequest)
+}
+
 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
 
        kc := *this.KeepClient
 
        hash := mux.Vars(req)["hash"]
-       signature := mux.Vars(req)["signature"]
-       timestamp := mux.Vars(req)["timestamp"]
+       hints := mux.Vars(req)["hints"]
+
+       locator := keepclient.MakeLocator2(hash, hints)
 
        log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
 
-       if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
+       var pass bool
+       var tok string
+       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
                http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
                return
        }
 
+       // Copy ArvadosClient struct and use the client's API token
+       arvclient := *kc.Arvados
+       arvclient.ApiToken = tok
+       kc.Arvados = &arvclient
+
        var reader io.ReadCloser
        var err error
        var blocklen int64
 
        if req.Method == "GET" {
-               reader, blocklen, _, err = kc.AuthorizedGet(hash, signature, timestamp)
+               reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
                defer reader.Close()
        } else if req.Method == "HEAD" {
-               blocklen, _, err = kc.AuthorizedAsk(hash, signature, timestamp)
+               blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
        }
 
-       resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
+       if blocklen > 0 {
+               resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
+       }
 
        switch err {
        case nil:
                if reader != nil {
                        n, err2 := io.Copy(resp, reader)
                        if n != blocklen {
-                               log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
+                               log.Printf("%s: %s %s mismatched return %v with Content-Length %v error %v", GetRemoteAddress(req), req.Method, hash, n, blocklen, err2)
                        } else if err2 == nil {
                                log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
                        } else {
@@ -294,6 +337,9 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        kc := *this.KeepClient
 
        hash := mux.Vars(req)["hash"]
+       hints := mux.Vars(req)["hints"]
+
+       locator := keepclient.MakeLocator2(hash, hints)
 
        var contentLength int64 = -1
        if req.Header.Get("Content-Length") != "" {
@@ -311,11 +357,23 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                return
        }
 
-       if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
+       if locator.Size > 0 && int64(locator.Size) != contentLength {
+               http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
+               return
+       }
+
+       var pass bool
+       var tok string
+       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
                http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
                return
        }
 
+       // Copy ArvadosClient struct and use the client's API token
+       arvclient := *kc.Arvados
+       arvclient.ApiToken = tok
+       kc.Arvados = &arvclient
+
        // Check if the client specified the number of replicas
        if req.Header.Get("X-Keep-Desired-Replicas") != "" {
                var r int
@@ -326,7 +384,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        }
 
        // Now try to put the block through
-       replicas, err := kc.PutHR(hash, req.Body, contentLength)
+       hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
 
        // Tell the client how many successful PUTs we accomplished
        resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
@@ -335,6 +393,10 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        case nil:
                // Default will return http.StatusOK
                log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
+               n, err2 := io.WriteString(resp, hash)
+               if err2 != nil {
+                       log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
+               }
 
        case keepclient.OversizeBlockError:
                // Too much data
@@ -346,6 +408,10 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                        // client can decide if getting less than the number of
                        // replications it asked for is a fatal error.
                        // Default will return http.StatusOK
+                       n, err2 := io.WriteString(resp, hash)
+                       if err2 != nil {
+                               log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
+                       }
                } else {
                        http.Error(resp, "", http.StatusServiceUnavailable)
                }