4 "arvados.org/keepclient"
7 "github.com/gorilla/mux"
19 // Default TCP address on which to listen for requests.
20 // Initialized by the -listen flag.
21 const DEFAULT_ADDR = ":25107"
23 var listener net.Listener
34 flagset := flag.NewFlagSet("default", flag.ExitOnError)
40 "Interface on which to listen for requests, in the format "+
41 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
42 "to listen on all network interfaces.")
48 "If set, disable GET operations")
54 "If set, disable PUT operations")
60 "Default number of replicas to write if not specified by the client.")
66 "Path to write pid file")
68 flagset.Parse(os.Args[1:])
70 kc, err := keepclient.MakeKeepClient()
72 log.Fatalf("Error setting up keep client %s", err.Error())
76 f, err := os.Create(pidfile)
78 fmt.Fprint(f, os.Getpid())
81 log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
85 kc.Want_replicas = default_replicas
87 listener, err = net.Listen("tcp", listen)
89 log.Fatalf("Could not listen on %v", listen)
92 go RefreshServicesList(&kc)
94 // Shut down the server gracefully (by closing the listener)
95 // if SIGTERM is received.
96 term := make(chan os.Signal, 1)
97 go func(sig <-chan os.Signal) {
99 log.Println("caught signal:", s)
102 signal.Notify(term, syscall.SIGTERM)
103 signal.Notify(term, syscall.SIGINT)
106 f, err := os.Create(pidfile)
108 fmt.Fprint(f, os.Getpid())
111 log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
115 log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
117 // Start listening for requests.
118 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
120 log.Println("shutting down")
127 type ApiTokenCache struct {
128 tokens map[string]int64
133 // Refresh the keep service list every five minutes.
134 func RefreshServicesList(kc *keepclient.KeepClient) {
136 time.Sleep(300 * time.Second)
137 oldservices := kc.ServiceRoots()
138 kc.DiscoverKeepServers()
139 newservices := kc.ServiceRoots()
140 s1 := fmt.Sprint(oldservices)
141 s2 := fmt.Sprint(newservices)
143 log.Printf("Updated server list to %v", s2)
148 // Cache the token and set an expire time. If we already have an expire time
149 // on the token, it is not updated.
150 func (this *ApiTokenCache) RememberToken(token string) {
152 defer this.lock.Unlock()
154 now := time.Now().Unix()
155 if this.tokens[token] == 0 {
156 this.tokens[token] = now + this.expireTime
160 // Check if the cached token is known and still believed to be valid.
161 func (this *ApiTokenCache) RecallToken(token string) bool {
163 defer this.lock.Unlock()
165 now := time.Now().Unix()
166 if this.tokens[token] == 0 {
169 } else if now < this.tokens[token] {
170 // Token is known and still valid
174 this.tokens[token] = 0
179 func GetRemoteAddress(req *http.Request) string {
180 if realip := req.Header.Get("X-Real-IP"); realip != "" {
181 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
182 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
187 return req.RemoteAddr
190 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
192 if auth = req.Header.Get("Authorization"); auth == "" {
197 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
203 if cache.RecallToken(tok) {
204 // Valid in the cache, short circut
208 var usersreq *http.Request
210 if usersreq, err = http.NewRequest("HEAD", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
211 // Can't construct the request
212 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
216 // Add api token header
217 usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
219 // Actually make the request
220 var resp *http.Response
221 if resp, err = kc.Client.Do(usersreq); err != nil {
222 // Something else failed
223 log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
227 if resp.StatusCode != http.StatusOK {
229 log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
233 // Success! Update cache
234 cache.RememberToken(tok)
239 type GetBlockHandler struct {
240 *keepclient.KeepClient
244 type PutBlockHandler struct {
245 *keepclient.KeepClient
249 type InvalidPathHandler struct{}
252 // Returns a mux.Router that passes GET and PUT requests to the
253 // appropriate handlers.
258 kc *keepclient.KeepClient) *mux.Router {
260 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
262 rest := mux.NewRouter()
265 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
266 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
267 rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
271 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
272 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
275 rest.NotFoundHandler = InvalidPathHandler{}
280 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
281 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
282 http.Error(resp, "Bad request", http.StatusBadRequest)
285 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
287 kc := *this.KeepClient
289 hash := mux.Vars(req)["hash"]
290 hints := mux.Vars(req)["hints"]
292 locator := keepclient.MakeLocator2(hash, hints)
294 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
296 if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
297 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
301 var reader io.ReadCloser
305 if req.Method == "GET" {
306 reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
308 } else if req.Method == "HEAD" {
309 blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
312 resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
317 n, err2 := io.Copy(resp, reader)
319 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
320 } else if err2 == nil {
321 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
323 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
326 log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
328 case keepclient.BlockNotFound:
329 http.Error(resp, "Not found", http.StatusNotFound)
331 http.Error(resp, err.Error(), http.StatusBadGateway)
335 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
339 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
341 kc := *this.KeepClient
343 hash := mux.Vars(req)["hash"]
344 hints := mux.Vars(req)["hints"]
346 locator := keepclient.MakeLocator2(hash, hints)
348 var contentLength int64 = -1
349 if req.Header.Get("Content-Length") != "" {
350 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
352 resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
357 log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
359 if contentLength < 1 {
360 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
364 if locator.Size > 0 && int64(locator.Size) != contentLength {
365 http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
369 if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
370 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
374 // Check if the client specified the number of replicas
375 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
377 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
383 // Now try to put the block through
384 hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
386 // Tell the client how many successful PUTs we accomplished
387 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
391 // Default will return http.StatusOK
392 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
393 n, err2 := io.WriteString(resp, hash)
395 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
398 case keepclient.OversizeBlockError:
400 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
402 case keepclient.InsufficientReplicasError:
404 // At least one write is considered success. The
405 // client can decide if getting less than the number of
406 // replications it asked for is a fatal error.
407 // Default will return http.StatusOK
408 n, err2 := io.WriteString(resp, hash)
410 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
413 http.Error(resp, "", http.StatusServiceUnavailable)
417 http.Error(resp, err.Error(), http.StatusBadGateway)
421 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())