2760: Merge branch '2760-folder-hierarchy' refs #2760
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
index ed33ac9bbd62a79a59e293f9170af9a2b1b3cf2a..38e14fd2832f10d01d2997b6b66dd70f02585459 100644 (file)
@@ -29,7 +29,9 @@ func main() {
                pidfile          string
        )
 
-       flag.StringVar(
+       flagset := flag.NewFlagSet("default", flag.ExitOnError)
+
+       flagset.StringVar(
                &listen,
                "listen",
                DEFAULT_ADDR,
@@ -37,41 +39,35 @@ func main() {
                        "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
                        "to listen on all network interfaces.")
 
-       flag.BoolVar(
+       flagset.BoolVar(
                &no_get,
                "no-get",
                false,
                "If set, disable GET operations")
 
-       flag.BoolVar(
-               &no_get,
+       flagset.BoolVar(
+               &no_put,
                "no-put",
                false,
                "If set, disable PUT operations")
 
-       flag.IntVar(
+       flagset.IntVar(
                &default_replicas,
                "default-replicas",
                2,
                "Default number of replicas to write if not specified by the client.")
 
-       flag.StringVar(
+       flagset.StringVar(
                &pidfile,
                "pid",
                "",
                "Path to write pid file")
 
-       flag.Parse()
-
-       /*if no_get == false {
-               log.Print("Must specify -no-get")
-               return
-       }*/
+       flagset.Parse(os.Args[1:])
 
        kc, err := keepclient.MakeKeepClient()
        if err != nil {
-               log.Print(err)
-               return
+               log.Fatalf("Error setting up keep client %s", err.Error())
        }
 
        if pidfile != "" {
@@ -88,12 +84,15 @@ func main() {
 
        listener, err = net.Listen("tcp", listen)
        if err != nil {
-               log.Printf("Could not listen on %v", listen)
-               return
+               log.Fatalf("Could not listen on %v", listen)
        }
 
+       go RefreshServicesList(&kc)
+
+       log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
+
        // Start listening for requests.
-       http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
+       http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
 }
 
 type ApiTokenCache struct {
@@ -102,6 +101,21 @@ type ApiTokenCache struct {
        expireTime int64
 }
 
+// Refresh the keep service list every five minutes.
+func RefreshServicesList(kc *keepclient.KeepClient) {
+       for {
+               time.Sleep(300 * time.Second)
+               oldservices := kc.ServiceRoots()
+               kc.DiscoverKeepServers()
+               newservices := kc.ServiceRoots()
+               s1 := fmt.Sprint(oldservices)
+               s2 := fmt.Sprint(newservices)
+               if s1 != s2 {
+                       log.Printf("Updated server list to %v", s2)
+               }
+       }
+}
+
 // Cache the token and set an expire time.  If we already have an expire time
 // on the token, it is not updated.
 func (this *ApiTokenCache) RememberToken(token string) {
@@ -133,13 +147,25 @@ func (this *ApiTokenCache) RecallToken(token string) bool {
        }
 }
 
+func GetRemoteAddress(req *http.Request) string {
+       if realip := req.Header.Get("X-Real-IP"); realip != "" {
+               if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
+                       return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
+               } else {
+                       return realip
+               }
+       }
+       return req.RemoteAddr
+}
+
 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
-       if req.Header.Get("Authorization") == "" {
+       var auth string
+       if auth = req.Header.Get("Authorization"); auth == "" {
                return false
        }
 
        var tok string
-       _, err := fmt.Sscanf(req.Header.Get("Authorization"), "OAuth2 %s", &tok)
+       _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
        if err != nil {
                // Scanning error
                return false
@@ -152,9 +178,9 @@ func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, re
 
        var usersreq *http.Request
 
-       if usersreq, err = http.NewRequest("GET", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
+       if usersreq, err = http.NewRequest("HEAD", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
                // Can't construct the request
-               log.Print("CheckAuthorizationHeader error: %v", err)
+               log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
                return false
        }
 
@@ -165,12 +191,13 @@ func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, re
        var resp *http.Response
        if resp, err = kc.Client.Do(usersreq); err != nil {
                // Something else failed
-               log.Print("CheckAuthorizationHeader error: %v", err)
+               log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
                return false
        }
 
        if resp.StatusCode != http.StatusOK {
                // Bad status
+               log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
                return false
        }
 
@@ -181,15 +208,17 @@ func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, re
 }
 
 type GetBlockHandler struct {
-       keepclient.KeepClient
+       *keepclient.KeepClient
        *ApiTokenCache
 }
 
 type PutBlockHandler struct {
-       keepclient.KeepClient
+       *keepclient.KeepClient
        *ApiTokenCache
 }
 
+type InvalidPathHandler struct{}
+
 // MakeRESTRouter
 //     Returns a mux.Router that passes GET and PUT requests to the
 //     appropriate handlers.
@@ -197,48 +226,60 @@ type PutBlockHandler struct {
 func MakeRESTRouter(
        enable_get bool,
        enable_put bool,
-       kc keepclient.KeepClient) *mux.Router {
+       kc *keepclient.KeepClient) *mux.Router {
 
        t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
 
        rest := mux.NewRouter()
-       gh := rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t})
-       ghsig := rest.Handle(
-               `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
-               GetBlockHandler{kc, t})
-       ph := rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t})
 
        if enable_get {
+               gh := rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t})
+               ghsig := rest.Handle(
+                       `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
+                       GetBlockHandler{kc, t})
+
                gh.Methods("GET", "HEAD")
                ghsig.Methods("GET", "HEAD")
        }
 
        if enable_put {
-               ph.Methods("PUT")
+               rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
        }
 
+       rest.NotFoundHandler = InvalidPathHandler{}
+
        return rest
 }
 
+func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+       log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
+       http.Error(resp, "Bad request", http.StatusBadRequest)
+}
+
 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
 
-       if !CheckAuthorizationHeader(this.KeepClient, this.ApiTokenCache, req) {
-               http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
-       }
+       kc := *this.KeepClient
 
        hash := mux.Vars(req)["hash"]
        signature := mux.Vars(req)["signature"]
        timestamp := mux.Vars(req)["timestamp"]
 
+       log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
+
+       if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
+               http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
+               return
+       }
+
        var reader io.ReadCloser
        var err error
        var blocklen int64
 
        if req.Method == "GET" {
-               reader, blocklen, _, err = this.KeepClient.AuthorizedGet(hash, signature, timestamp)
+               reader, blocklen, _, err = kc.AuthorizedGet(hash, signature, timestamp)
                defer reader.Close()
        } else if req.Method == "HEAD" {
-               blocklen, _, err = this.KeepClient.AuthorizedAsk(hash, signature, timestamp)
+               blocklen, _, err = kc.AuthorizedAsk(hash, signature, timestamp)
        }
 
        resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
@@ -246,22 +287,31 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        switch err {
        case nil:
                if reader != nil {
-                       io.Copy(resp, reader)
+                       n, err2 := io.Copy(resp, reader)
+                       if n != blocklen {
+                               log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
+                       } else if err2 == nil {
+                               log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
+                       } else {
+                               log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
+                       }
+               } else {
+                       log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
                }
        case keepclient.BlockNotFound:
                http.Error(resp, "Not found", http.StatusNotFound)
        default:
                http.Error(resp, err.Error(), http.StatusBadGateway)
        }
+
+       if err != nil {
+               log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
+       }
 }
 
 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
 
-       log.Print("PutBlockHandler start")
-
-       if !CheckAuthorizationHeader(this.KeepClient, this.ApiTokenCache, req) {
-               http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
-       }
+       kc := *this.KeepClient
 
        hash := mux.Vars(req)["hash"]
 
@@ -274,31 +324,37 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        }
 
+       log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
+
        if contentLength < 1 {
                http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
                return
        }
 
+       if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
+               http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
+               return
+       }
+
        // Check if the client specified the number of replicas
        if req.Header.Get("X-Keep-Desired-Replicas") != "" {
                var r int
-               _, err := fmt.Sscanf(req.Header.Get("X-Keep-Desired-Replicas"), "%d", &r)
+               _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
                if err != nil {
-                       this.KeepClient.Want_replicas = r
+                       kc.Want_replicas = r
                }
        }
 
        // Now try to put the block through
-       replicas, err := this.KeepClient.PutHR(hash, req.Body, contentLength)
-
-       log.Printf("Replicas stored: %v err: %v", replicas, err)
+       replicas, err := kc.PutHR(hash, req.Body, contentLength)
 
        // Tell the client how many successful PUTs we accomplished
-       resp.Header().Set("X-Keep-Replicas-Stored", fmt.Sprintf("%d", replicas))
+       resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
 
        switch err {
        case nil:
                // Default will return http.StatusOK
+               log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
 
        case keepclient.OversizeBlockError:
                // Too much data
@@ -318,4 +374,8 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                http.Error(resp, err.Error(), http.StatusBadGateway)
        }
 
+       if err != nil {
+               log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())
+       }
+
 }