4 "arvados.org/keepclient"
8 "github.com/gorilla/mux"
20 // Default TCP address on which to listen for requests.
21 // Initialized by the -listen flag.
22 const DEFAULT_ADDR = ":25107"
24 var listener net.Listener
35 flagset := flag.NewFlagSet("default", flag.ExitOnError)
41 "Interface on which to listen for requests, in the format "+
42 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
43 "to listen on all network interfaces.")
49 "If set, disable GET operations")
55 "If set, disable PUT operations")
61 "Default number of replicas to write if not specified by the client.")
67 "Path to write pid file")
69 flagset.Parse(os.Args[1:])
71 arv, err := sdk.MakeArvadosClient()
73 log.Fatalf("Error setting up arvados client %s", err.Error())
76 kc, err := keepclient.MakeKeepClient(&arv)
78 log.Fatalf("Error setting up keep client %s", err.Error())
82 f, err := os.Create(pidfile)
84 fmt.Fprint(f, os.Getpid())
87 log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
91 kc.Want_replicas = default_replicas
93 listener, err = net.Listen("tcp", listen)
95 log.Fatalf("Could not listen on %v", listen)
98 go RefreshServicesList(&kc)
100 // Shut down the server gracefully (by closing the listener)
101 // if SIGTERM is received.
102 term := make(chan os.Signal, 1)
103 go func(sig <-chan os.Signal) {
105 log.Println("caught signal:", s)
108 signal.Notify(term, syscall.SIGTERM)
109 signal.Notify(term, syscall.SIGINT)
112 f, err := os.Create(pidfile)
114 fmt.Fprint(f, os.Getpid())
117 log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
121 log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
123 // Start listening for requests.
124 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
126 log.Println("shutting down")
133 type ApiTokenCache struct {
134 tokens map[string]int64
139 // Refresh the keep service list every five minutes.
140 func RefreshServicesList(kc *keepclient.KeepClient) {
142 time.Sleep(300 * time.Second)
143 oldservices := kc.ServiceRoots()
144 kc.DiscoverKeepServers()
145 newservices := kc.ServiceRoots()
146 s1 := fmt.Sprint(oldservices)
147 s2 := fmt.Sprint(newservices)
149 log.Printf("Updated server list to %v", s2)
154 // Cache the token and set an expire time. If we already have an expire time
155 // on the token, it is not updated.
156 func (this *ApiTokenCache) RememberToken(token string) {
158 defer this.lock.Unlock()
160 now := time.Now().Unix()
161 if this.tokens[token] == 0 {
162 this.tokens[token] = now + this.expireTime
166 // Check if the cached token is known and still believed to be valid.
167 func (this *ApiTokenCache) RecallToken(token string) bool {
169 defer this.lock.Unlock()
171 now := time.Now().Unix()
172 if this.tokens[token] == 0 {
175 } else if now < this.tokens[token] {
176 // Token is known and still valid
180 this.tokens[token] = 0
185 func GetRemoteAddress(req *http.Request) string {
186 if realip := req.Header.Get("X-Real-IP"); realip != "" {
187 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
188 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
193 return req.RemoteAddr
196 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
198 if auth = req.Header.Get("Authorization"); auth == "" {
203 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
209 if cache.RecallToken(tok) {
210 // Valid in the cache, short circut
216 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
217 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
221 // Success! Update cache
222 cache.RememberToken(tok)
227 type GetBlockHandler struct {
228 *keepclient.KeepClient
232 type PutBlockHandler struct {
233 *keepclient.KeepClient
237 type InvalidPathHandler struct{}
240 // Returns a mux.Router that passes GET and PUT requests to the
241 // appropriate handlers.
246 kc *keepclient.KeepClient) *mux.Router {
248 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
250 rest := mux.NewRouter()
253 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
254 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
255 rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
259 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
260 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
263 rest.NotFoundHandler = InvalidPathHandler{}
268 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
269 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
270 http.Error(resp, "Bad request", http.StatusBadRequest)
273 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
275 kc := *this.KeepClient
277 hash := mux.Vars(req)["hash"]
278 hints := mux.Vars(req)["hints"]
280 locator := keepclient.MakeLocator2(hash, hints)
282 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
284 if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
285 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
289 var reader io.ReadCloser
293 if req.Method == "GET" {
294 reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
296 } else if req.Method == "HEAD" {
297 blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
300 resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
305 n, err2 := io.Copy(resp, reader)
307 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
308 } else if err2 == nil {
309 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
311 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
314 log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
316 case keepclient.BlockNotFound:
317 http.Error(resp, "Not found", http.StatusNotFound)
319 http.Error(resp, err.Error(), http.StatusBadGateway)
323 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
327 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
329 kc := *this.KeepClient
331 hash := mux.Vars(req)["hash"]
332 hints := mux.Vars(req)["hints"]
334 locator := keepclient.MakeLocator2(hash, hints)
336 var contentLength int64 = -1
337 if req.Header.Get("Content-Length") != "" {
338 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
340 resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
345 log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
347 if contentLength < 1 {
348 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
352 if locator.Size > 0 && int64(locator.Size) != contentLength {
353 http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
357 if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
358 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
362 // Check if the client specified the number of replicas
363 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
365 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
371 // Now try to put the block through
372 hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
374 // Tell the client how many successful PUTs we accomplished
375 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
379 // Default will return http.StatusOK
380 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
381 n, err2 := io.WriteString(resp, hash)
383 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
386 case keepclient.OversizeBlockError:
388 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
390 case keepclient.InsufficientReplicasError:
392 // At least one write is considered success. The
393 // client can decide if getting less than the number of
394 // replications it asked for is a fatal error.
395 // Default will return http.StatusOK
396 n, err2 := io.WriteString(resp, hash)
398 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
401 http.Error(resp, "", http.StatusServiceUnavailable)
405 http.Error(resp, err.Error(), http.StatusBadGateway)
409 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())