Merge branch 'master' into 7661-fuse-by-pdh
[arvados.git] / services / keepproxy / keepproxy.go
index 7ba24809e066a62412996fe2fd2929e81e245b9a..79ed51eb0e00f57eb38d4a31251a87c2bc5e866c 100644 (file)
@@ -14,6 +14,8 @@ import (
        "net/http"
        "os"
        "os/signal"
+       "reflect"
+       "regexp"
        "sync"
        "syscall"
        "time"
@@ -35,7 +37,7 @@ func main() {
                pidfile          string
        )
 
-       flagset := flag.NewFlagSet("default", flag.ExitOnError)
+       flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
 
        flagset.StringVar(
                &listen,
@@ -82,6 +84,9 @@ func main() {
                log.Fatalf("Error setting up arvados client %s", err.Error())
        }
 
+       if os.Getenv("ARVADOS_DEBUG") != "" {
+               keepclient.DebugPrintf = log.Printf
+       }
        kc, err := keepclient.MakeKeepClient(&arv)
        if err != nil {
                log.Fatalf("Error setting up keep client %s", err.Error())
@@ -135,24 +140,25 @@ type ApiTokenCache struct {
 
 // Refresh the keep service list every five minutes.
 func RefreshServicesList(kc *keepclient.KeepClient) {
-       previousRoots := ""
+       var previousRoots = []map[string]string{}
+       var delay time.Duration = 0
        for {
+               time.Sleep(delay * time.Second)
+               delay = 300
                if err := kc.DiscoverKeepServers(); err != nil {
                        log.Println("Error retrieving services list:", err)
-                       time.Sleep(3*time.Second)
-                       previousRoots = ""
-               } else if len(kc.LocalRoots()) == 0 {
-                       log.Println("Received empty services list")
-                       time.Sleep(3*time.Second)
-                       previousRoots = ""
-               } else {
-                       newRoots := fmt.Sprint("Locals ", kc.LocalRoots(), ", gateways ", kc.GatewayRoots())
-                       if newRoots != previousRoots {
-                               log.Println("Updated services list:", newRoots)
-                               previousRoots = newRoots
-                       }
-                       time.Sleep(300*time.Second)
+                       delay = 3
+                       continue
+               }
+               newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
+               if !reflect.DeepEqual(previousRoots, newRoots) {
+                       log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
                }
+               if len(newRoots[0]) == 0 {
+                       log.Print("WARNING: No local services. Retrying in 3 seconds.")
+                       delay = 3
+               }
+               previousRoots = newRoots
        }
 }
 
@@ -198,7 +204,7 @@ func GetRemoteAddress(req *http.Request) string {
        return req.RemoteAddr
 }
 
-func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
+func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
        var auth string
        if auth = req.Header.Get("Authorization"); auth == "" {
                return false, ""
@@ -238,6 +244,11 @@ type PutBlockHandler struct {
        *ApiTokenCache
 }
 
+type IndexHandler struct {
+       *keepclient.KeepClient
+       *ApiTokenCache
+}
+
 type InvalidPathHandler struct{}
 
 type OptionsHandler struct{}
@@ -259,6 +270,12 @@ func MakeRESTRouter(
                rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
                        GetBlockHandler{kc, t}).Methods("GET", "HEAD")
                rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
+
+               // List all blocks
+               rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
+
+               // List blocks whose hash has the given prefix
+               rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
        }
 
        if enable_put {
@@ -295,6 +312,8 @@ var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header
 var ContentLengthMismatch = errors.New("Actual length != expected content length")
 var MethodNotSupported = errors.New("Method not supported")
 
+var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
+
 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
        SetCorsHeaders(resp)
 
@@ -315,7 +334,7 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
@@ -327,6 +346,8 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var reader io.ReadCloser
 
+       locator = removeHint.ReplaceAllString(locator, "$1")
+
        switch req.Method {
        case "HEAD":
                expectLength, proxiedURI, err = kc.Ask(locator)
@@ -344,7 +365,7 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
        }
 
-       switch err {
+       switch respErr := err.(type) {
        case nil:
                status = http.StatusOK
                resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
@@ -357,10 +378,16 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                                err = ContentLengthMismatch
                        }
                }
-       case keepclient.BlockNotFound:
-               status = http.StatusNotFound
+       case keepclient.Error:
+               if respErr == keepclient.BlockNotFound {
+                       status = http.StatusNotFound
+               } else if respErr.Temporary() {
+                       status = http.StatusBadGateway
+               } else {
+                       status = 422
+               }
        default:
-               status = http.StatusBadGateway
+               status = http.StatusInternalServerError
        }
 }
 
@@ -414,7 +441,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
                err = BadAuthorizationHeader
                status = http.StatusForbidden
                return
@@ -474,3 +501,63 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                status = http.StatusBadGateway
        }
 }
+
+// ServeHTTP implementation for IndexHandler
+// Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
+// For each keep server found in LocalRoots:
+//   Invokes GetIndex using keepclient
+//   Expects "complete" response (terminating with blank new line)
+//   Aborts on any errors
+// Concatenates responses from all those keep servers and returns
+func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+       SetCorsHeaders(resp)
+
+       prefix := mux.Vars(req)["prefix"]
+       var err error
+       var status int
+
+       defer func() {
+               if status != http.StatusOK {
+                       http.Error(resp, err.Error(), status)
+               }
+       }()
+
+       kc := *handler.KeepClient
+
+       ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
+       if !ok {
+               status, err = http.StatusForbidden, BadAuthorizationHeader
+               return
+       }
+
+       // Copy ArvadosClient struct and use the client's API token
+       arvclient := *kc.Arvados
+       arvclient.ApiToken = token
+       kc.Arvados = &arvclient
+
+       // Only GET method is supported
+       if req.Method != "GET" {
+               status, err = http.StatusNotImplemented, MethodNotSupported
+               return
+       }
+
+       // Get index from all LocalRoots and write to resp
+       var reader io.Reader
+       for uuid := range kc.LocalRoots() {
+               reader, err = kc.GetIndex(uuid, prefix)
+               if err != nil {
+                       status = http.StatusBadGateway
+                       return
+               }
+
+               _, err = io.Copy(resp, reader)
+               if err != nil {
+                       status = http.StatusBadGateway
+                       return
+               }
+       }
+
+       // Got index from all the keep servers and wrote to resp
+       status = http.StatusOK
+       resp.Write([]byte("\n"))
+}