"os/signal"
"reflect"
"regexp"
- "strings"
"sync"
"syscall"
"time"
pidfile string
)
- flagset := flag.NewFlagSet("default", flag.ExitOnError)
+ flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
flagset.StringVar(
&listen,
log.Fatalf("Error setting up arvados client %s", err.Error())
}
+ if os.Getenv("ARVADOS_DEBUG") != "" {
+ keepclient.DebugPrintf = log.Printf
+ }
kc, err := keepclient.MakeKeepClient(&arv)
if err != nil {
log.Fatalf("Error setting up keep client %s", err.Error())
return req.RemoteAddr
}
-func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
+func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
var auth string
if auth = req.Header.Get("Authorization"); auth == "" {
return false, ""
var pass bool
var tok string
- if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+ if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
status, err = http.StatusForbidden, BadAuthorizationHeader
return
}
log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
}
- switch err {
+ switch respErr := err.(type) {
case nil:
status = http.StatusOK
resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
err = ContentLengthMismatch
}
}
- case keepclient.BlockNotFound:
- status = http.StatusNotFound
+ case keepclient.Error:
+ if respErr == keepclient.BlockNotFound {
+ status = http.StatusNotFound
+ } else if respErr.Temporary() {
+ status = http.StatusBadGateway
+ } else {
+ status = 422
+ }
default:
- status = http.StatusBadGateway
+ status = http.StatusInternalServerError
}
}
var pass bool
var tok string
- if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+ if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
err = BadAuthorizationHeader
status = http.StatusForbidden
return
}
}
-// ServeHTTP implemenation for IndexHandler
+// ServeHTTP implementation for IndexHandler
// Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
// For each keep server found in LocalRoots:
// Invokes GetIndex using keepclient
kc := *handler.KeepClient
- var pass bool
- var tok string
- if pass, tok = CheckAuthorizationHeader(kc, handler.ApiTokenCache, req); !pass {
+ ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
+ if !ok {
status, err = http.StatusForbidden, BadAuthorizationHeader
return
}
// Copy ArvadosClient struct and use the client's API token
arvclient := *kc.Arvados
- arvclient.ApiToken = tok
+ arvclient.ApiToken = token
kc.Arvados = &arvclient
- var indexResp []byte
- var reader io.Reader
-
- switch req.Method {
- case "GET":
- for uuid := range kc.LocalRoots() {
- reader, err = kc.GetIndex(uuid, prefix)
- if err != nil {
- break
- }
-
- var readBytes []byte
- readBytes, err = ioutil.ReadAll(reader)
- if err != nil {
- break
- }
-
- // Got index; verify that it is complete
- // The response should be "\n" if no locators matched the prefix
- // Else, it should be a list of locators followed by a blank line
- if (!strings.HasSuffix(string(readBytes), "\n\n")) && (string(readBytes) != "\n") {
- err = errors.New("Got incomplete index")
- }
-
- // Trim the extra empty new line found in response from each server
- indexResp = append(indexResp, (readBytes[0 : len(readBytes)-1])...)
- }
-
- // Append empty line at the end of concatenation of all server responses
- indexResp = append(indexResp, ([]byte("\n"))...)
- default:
+ // Only GET method is supported
+ if req.Method != "GET" {
status, err = http.StatusNotImplemented, MethodNotSupported
return
}
- switch err {
- case nil:
- status = http.StatusOK
- resp.Header().Set("Content-Length", fmt.Sprint(len(indexResp)))
- _, err = resp.Write(indexResp)
- default:
- status = http.StatusBadGateway
+ // Get index from all LocalRoots and write to resp
+ var reader io.Reader
+ for uuid := range kc.LocalRoots() {
+ reader, err = kc.GetIndex(uuid, prefix)
+ if err != nil {
+ status = http.StatusBadGateway
+ return
+ }
+
+ _, err = io.Copy(resp, reader)
+ if err != nil {
+ status = http.StatusBadGateway
+ return
+ }
}
+
+ // Got index from all the keep servers and wrote to resp
+ status = http.StatusOK
+ resp.Write([]byte("\n"))
}