5824: Merge branch 'master' into 5824-keep-web
[arvados.git] / services / keepproxy / keepproxy.go
index f2a93f1e3a058fdd9270091aecbc083882654dd3..1a1189658d7ba1473aa33036cce60276932cb9b8 100644 (file)
@@ -37,7 +37,7 @@ func main() {
                pidfile          string
        )
 
-       flagset := flag.NewFlagSet("default", flag.ExitOnError)
+       flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
 
        flagset.StringVar(
                &listen,
@@ -201,7 +201,7 @@ func GetRemoteAddress(req *http.Request) string {
        return req.RemoteAddr
 }
 
-func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
+func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
        var auth string
        if auth = req.Header.Get("Authorization"); auth == "" {
                return false, ""
@@ -331,7 +331,7 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
@@ -362,7 +362,7 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
        }
 
-       switch err {
+       switch respErr := err.(type) {
        case nil:
                status = http.StatusOK
                resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
@@ -375,10 +375,16 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                                err = ContentLengthMismatch
                        }
                }
-       case keepclient.BlockNotFound:
-               status = http.StatusNotFound
+       case keepclient.Error:
+               if respErr == keepclient.BlockNotFound {
+                       status = http.StatusNotFound
+               } else if respErr.Temporary() {
+                       status = http.StatusBadGateway
+               } else {
+                       status = 422
+               }
        default:
-               status = http.StatusBadGateway
+               status = http.StatusInternalServerError
        }
 }
 
@@ -432,7 +438,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
                err = BadAuthorizationHeader
                status = http.StatusForbidden
                return
@@ -515,16 +521,15 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        kc := *handler.KeepClient
 
-       var pass bool
-       var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, handler.ApiTokenCache, req); !pass {
+       ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
+       if !ok {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
 
        // Copy ArvadosClient struct and use the client's API token
        arvclient := *kc.Arvados
-       arvclient.ApiToken = tok
+       arvclient.ApiToken = token
        kc.Arvados = &arvclient
 
        // Only GET method is supported
@@ -533,6 +538,7 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                return
        }
 
+       // Get index from all LocalRoots and write to resp
        var reader io.Reader
        for uuid := range kc.LocalRoots() {
                reader, err = kc.GetIndex(uuid, prefix)
@@ -541,15 +547,7 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                        return
                }
 
-               var readBytes []byte
-               readBytes, err = ioutil.ReadAll(reader)
-               if err != nil {
-                       status = http.StatusBadGateway
-                       return
-               }
-
-               // Got index for this server; write to resp
-               _, err := resp.Write(readBytes)
+               _, err = io.Copy(resp, reader)
                if err != nil {
                        status = http.StatusBadGateway
                        return