import (
"arvados.org/keepclient"
+ "arvados.org/sdk"
"flag"
"fmt"
"github.com/gorilla/mux"
"net"
"net/http"
"os"
+ "os/signal"
"sync"
+ "syscall"
"time"
)
pidfile string
)
- flag.StringVar(
+ flagset := flag.NewFlagSet("default", flag.ExitOnError)
+
+ flagset.StringVar(
&listen,
"listen",
DEFAULT_ADDR,
"ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
"to listen on all network interfaces.")
- flag.BoolVar(
+ flagset.BoolVar(
&no_get,
"no-get",
false,
"If set, disable GET operations")
- flag.BoolVar(
- &no_get,
+ flagset.BoolVar(
+ &no_put,
"no-put",
false,
"If set, disable PUT operations")
- flag.IntVar(
+ flagset.IntVar(
&default_replicas,
"default-replicas",
2,
"Default number of replicas to write if not specified by the client.")
- flag.StringVar(
+ flagset.StringVar(
&pidfile,
"pid",
"",
"Path to write pid file")
- flag.Parse()
+ flagset.Parse(os.Args[1:])
- /*if no_get == false {
- log.Print("Must specify -no-get")
- return
- }*/
+ arv, err := sdk.MakeArvadosClient()
+ if err != nil {
+ log.Fatalf("Error setting up arvados client %s", err.Error())
+ }
- kc, err := keepclient.MakeKeepClient()
+ kc, err := keepclient.MakeKeepClient(&arv)
if err != nil {
- log.Print(err)
- return
+ log.Fatalf("Error setting up keep client %s", err.Error())
}
if pidfile != "" {
listener, err = net.Listen("tcp", listen)
if err != nil {
- log.Printf("Could not listen on %v", listen)
- return
+ log.Fatalf("Could not listen on %v", listen)
}
go RefreshServicesList(&kc)
+ // Shut down the server gracefully (by closing the listener)
+ // if SIGTERM is received.
+ term := make(chan os.Signal, 1)
+ go func(sig <-chan os.Signal) {
+ s := <-sig
+ log.Println("caught signal:", s)
+ listener.Close()
+ }(term)
+ signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, syscall.SIGINT)
+
+ if pidfile != "" {
+ f, err := os.Create(pidfile)
+ if err == nil {
+ fmt.Fprint(f, os.Getpid())
+ f.Close()
+ } else {
+ log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
+ }
+ }
+
+ log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
+
// Start listening for requests.
http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
+
+ log.Println("shutting down")
+
+ if pidfile != "" {
+ os.Remove(pidfile)
+ }
}
type ApiTokenCache struct {
func RefreshServicesList(kc *keepclient.KeepClient) {
for {
time.Sleep(300 * time.Second)
+ oldservices := kc.ServiceRoots()
kc.DiscoverKeepServers()
+ newservices := kc.ServiceRoots()
+ s1 := fmt.Sprint(oldservices)
+ s2 := fmt.Sprint(newservices)
+ if s1 != s2 {
+ log.Printf("Updated server list to %v", s2)
+ }
}
}
}
}
+func GetRemoteAddress(req *http.Request) string {
+ if realip := req.Header.Get("X-Real-IP"); realip != "" {
+ if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
+ return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
+ } else {
+ return realip
+ }
+ }
+ return req.RemoteAddr
+}
+
func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
- if req.Header.Get("Authorization") == "" {
+ var auth string
+ if auth = req.Header.Get("Authorization"); auth == "" {
return false
}
var tok string
- _, err := fmt.Sscanf(req.Header.Get("Authorization"), "OAuth2 %s", &tok)
+ _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
if err != nil {
// Scanning error
return false
return true
}
- var usersreq *http.Request
-
- if usersreq, err = http.NewRequest("GET", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
- // Can't construct the request
- log.Print("CheckAuthorizationHeader error: %v", err)
- return false
- }
-
- // Add api token header
- usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
-
- // Actually make the request
- var resp *http.Response
- if resp, err = kc.Client.Do(usersreq); err != nil {
- // Something else failed
- log.Print("CheckAuthorizationHeader error: %v", err)
- return false
- }
-
- if resp.StatusCode != http.StatusOK {
- // Bad status
+ arv := *kc.Arvados
+ arv.ApiToken = tok
+ if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
+ log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
return false
}
*ApiTokenCache
}
+type InvalidPathHandler struct{}
+
// MakeRESTRouter
// Returns a mux.Router that passes GET and PUT requests to the
// appropriate handlers.
t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
rest := mux.NewRouter()
- gh := rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t})
- ghsig := rest.Handle(
- `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
- GetBlockHandler{kc, t})
- ph := rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t})
if enable_get {
- gh.Methods("GET", "HEAD")
- ghsig.Methods("GET", "HEAD")
+ rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
+ GetBlockHandler{kc, t}).Methods("GET", "HEAD")
+ rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
}
if enable_put {
- ph.Methods("PUT")
+ rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
+ rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
}
+ rest.NotFoundHandler = InvalidPathHandler{}
+
return rest
}
+func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+ log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
+ http.Error(resp, "Bad request", http.StatusBadRequest)
+}
+
func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
kc := *this.KeepClient
+ hash := mux.Vars(req)["hash"]
+ hints := mux.Vars(req)["hints"]
+
+ locator := keepclient.MakeLocator2(hash, hints)
+
+ log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
+
if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
+ return
}
- hash := mux.Vars(req)["hash"]
- signature := mux.Vars(req)["signature"]
- timestamp := mux.Vars(req)["timestamp"]
-
var reader io.ReadCloser
var err error
var blocklen int64
if req.Method == "GET" {
- reader, blocklen, _, err = kc.AuthorizedGet(hash, signature, timestamp)
+ reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
defer reader.Close()
} else if req.Method == "HEAD" {
- blocklen, _, err = kc.AuthorizedAsk(hash, signature, timestamp)
+ blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
}
- resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
+ if blocklen > 0 {
+ resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
+ }
switch err {
case nil:
if reader != nil {
- io.Copy(resp, reader)
+ n, err2 := io.Copy(resp, reader)
+ if n != blocklen {
+ log.Printf("%s: %s %s mismatched return %v with Content-Length %v error %v", GetRemoteAddress(req), req.Method, hash, n, blocklen, err2)
+ } else if err2 == nil {
+ log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
+ } else {
+ log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
+ }
+ } else {
+ log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
}
case keepclient.BlockNotFound:
http.Error(resp, "Not found", http.StatusNotFound)
default:
http.Error(resp, err.Error(), http.StatusBadGateway)
}
+
+ if err != nil {
+ log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
+ }
}
func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
- log.Print("PutBlockHandler start")
-
kc := *this.KeepClient
- if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
- http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
- }
-
hash := mux.Vars(req)["hash"]
+ hints := mux.Vars(req)["hints"]
+
+ locator := keepclient.MakeLocator2(hash, hints)
var contentLength int64 = -1
if req.Header.Get("Content-Length") != "" {
}
+ log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
+
if contentLength < 1 {
http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
return
}
+ if locator.Size > 0 && int64(locator.Size) != contentLength {
+ http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
+ return
+ }
+
+ if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
+ http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
+ return
+ }
+
// Check if the client specified the number of replicas
if req.Header.Get("X-Keep-Desired-Replicas") != "" {
var r int
- _, err := fmt.Sscanf(req.Header.Get("X-Keep-Desired-Replicas"), "%d", &r)
+ _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
if err != nil {
kc.Want_replicas = r
}
}
// Now try to put the block through
- replicas, err := kc.PutHR(hash, req.Body, contentLength)
-
- log.Printf("Replicas stored: %v err: %v", replicas, err)
+ hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
// Tell the client how many successful PUTs we accomplished
- resp.Header().Set("X-Keep-Replicas-Stored", fmt.Sprintf("%d", replicas))
+ resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
switch err {
case nil:
// Default will return http.StatusOK
+ log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
+ n, err2 := io.WriteString(resp, hash)
+ if err2 != nil {
+ log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
+ }
case keepclient.OversizeBlockError:
// Too much data
// client can decide if getting less than the number of
// replications it asked for is a fatal error.
// Default will return http.StatusOK
+ n, err2 := io.WriteString(resp, hash)
+ if err2 != nil {
+ log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
+ }
} else {
http.Error(resp, "", http.StatusServiceUnavailable)
}
http.Error(resp, err.Error(), http.StatusBadGateway)
}
+ if err != nil {
+ log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())
+ }
+
}