Keepproxy use client-supplied token when forwarding GET and PUT requests.
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
index 414835c0e249cd9771f475a087793d6f923bd5bb..367854bed382ec1b17589c825126d5e31a9bf6a6 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "arvados.org/keepclient"
+       "arvados.org/sdk"
        "flag"
        "fmt"
        "github.com/gorilla/mux"
@@ -67,7 +68,12 @@ func main() {
 
        flagset.Parse(os.Args[1:])
 
-       kc, err := keepclient.MakeKeepClient()
+       arv, err := sdk.MakeArvadosClient()
+       if err != nil {
+               log.Fatalf("Error setting up arvados client %s", err.Error())
+       }
+
+       kc, err := keepclient.MakeKeepClient(&arv)
        if err != nil {
                log.Fatalf("Error setting up keep client %s", err.Error())
        }
@@ -100,6 +106,7 @@ func main() {
                listener.Close()
        }(term)
        signal.Notify(term, syscall.SIGTERM)
+       signal.Notify(term, syscall.SIGINT)
 
        if pidfile != "" {
                f, err := os.Create(pidfile)
@@ -186,53 +193,34 @@ func GetRemoteAddress(req *http.Request) string {
        return req.RemoteAddr
 }
 
-func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
+func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
        var auth string
        if auth = req.Header.Get("Authorization"); auth == "" {
-               return false
+               return false, ""
        }
 
-       var tok string
        _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
        if err != nil {
                // Scanning error
-               return false
+               return false, ""
        }
 
        if cache.RecallToken(tok) {
                // Valid in the cache, short circut
-               return true
+               return true, tok
        }
 
-       var usersreq *http.Request
-
-       if usersreq, err = http.NewRequest("HEAD", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
-               // Can't construct the request
+       arv := *kc.Arvados
+       arv.ApiToken = tok
+       if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
                log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
-               return false
-       }
-
-       // Add api token header
-       usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
-
-       // Actually make the request
-       var resp *http.Response
-       if resp, err = kc.Client.Do(usersreq); err != nil {
-               // Something else failed
-               log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
-               return false
-       }
-
-       if resp.StatusCode != http.StatusOK {
-               // Bad status
-               log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
-               return false
+               return false, ""
        }
 
        // Success!  Update cache
        cache.RememberToken(tok)
 
-       return true
+       return true, tok
 }
 
 type GetBlockHandler struct {
@@ -292,11 +280,18 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
 
-       if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
+       var pass bool
+       var tok string
+       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
                http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
                return
        }
 
+       // Copy ArvadosClient struct and use the client's API token
+       arvclient := *kc.Arvados
+       arvclient.ApiToken = tok
+       kc.Arvados = &arvclient
+
        var reader io.ReadCloser
        var err error
        var blocklen int64
@@ -308,14 +303,16 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
        }
 
-       resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
+       if blocklen > 0 {
+               resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
+       }
 
        switch err {
        case nil:
                if reader != nil {
                        n, err2 := io.Copy(resp, reader)
                        if n != blocklen {
-                               log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
+                               log.Printf("%s: %s %s mismatched return %v with Content-Length %v error %v", GetRemoteAddress(req), req.Method, hash, n, blocklen, err2)
                        } else if err2 == nil {
                                log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
                        } else {
@@ -365,11 +362,18 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                return
        }
 
-       if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
+       var pass bool
+       var tok string
+       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
                http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
                return
        }
 
+       // Copy ArvadosClient struct and use the client's API token
+       arvclient := *kc.Arvados
+       arvclient.ApiToken = tok
+       kc.Arvados = &arvclient
+
        // Check if the client specified the number of replicas
        if req.Header.Get("X-Keep-Desired-Replicas") != "" {
                var r int