4 "arvados.org/keepclient"
7 "github.com/gorilla/mux"
19 // Default TCP address on which to listen for requests.
20 // Initialized by the -listen flag.
21 const DEFAULT_ADDR = ":25107"
23 var listener net.Listener
34 flagset := flag.NewFlagSet("default", flag.ExitOnError)
40 "Interface on which to listen for requests, in the format "+
41 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
42 "to listen on all network interfaces.")
48 "If set, disable GET operations")
54 "If set, disable PUT operations")
60 "Default number of replicas to write if not specified by the client.")
66 "Path to write pid file")
68 flagset.Parse(os.Args[1:])
70 kc, err := keepclient.MakeKeepClient()
72 log.Fatalf("Error setting up keep client %s", err.Error())
76 f, err := os.Create(pidfile)
78 fmt.Fprint(f, os.Getpid())
81 log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
85 kc.Want_replicas = default_replicas
87 listener, err = net.Listen("tcp", listen)
89 log.Fatalf("Could not listen on %v", listen)
92 go RefreshServicesList(&kc)
94 // Shut down the server gracefully (by closing the listener)
95 // if SIGTERM is received.
96 term := make(chan os.Signal, 1)
97 go func(sig <-chan os.Signal) {
99 log.Println("caught signal:", s)
102 signal.Notify(term, syscall.SIGTERM)
105 f, err := os.Create(pidfile)
107 fmt.Fprint(f, os.Getpid())
110 log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
114 log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
116 // Start listening for requests.
117 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
119 log.Println("shutting down")
126 type ApiTokenCache struct {
127 tokens map[string]int64
132 // Refresh the keep service list every five minutes.
133 func RefreshServicesList(kc *keepclient.KeepClient) {
135 time.Sleep(300 * time.Second)
136 oldservices := kc.ServiceRoots()
137 kc.DiscoverKeepServers()
138 newservices := kc.ServiceRoots()
139 s1 := fmt.Sprint(oldservices)
140 s2 := fmt.Sprint(newservices)
142 log.Printf("Updated server list to %v", s2)
147 // Cache the token and set an expire time. If we already have an expire time
148 // on the token, it is not updated.
149 func (this *ApiTokenCache) RememberToken(token string) {
151 defer this.lock.Unlock()
153 now := time.Now().Unix()
154 if this.tokens[token] == 0 {
155 this.tokens[token] = now + this.expireTime
159 // Check if the cached token is known and still believed to be valid.
160 func (this *ApiTokenCache) RecallToken(token string) bool {
162 defer this.lock.Unlock()
164 now := time.Now().Unix()
165 if this.tokens[token] == 0 {
168 } else if now < this.tokens[token] {
169 // Token is known and still valid
173 this.tokens[token] = 0
178 func GetRemoteAddress(req *http.Request) string {
179 if realip := req.Header.Get("X-Real-IP"); realip != "" {
180 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
181 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
186 return req.RemoteAddr
189 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
191 if auth = req.Header.Get("Authorization"); auth == "" {
196 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
202 if cache.RecallToken(tok) {
203 // Valid in the cache, short circut
207 var usersreq *http.Request
209 if usersreq, err = http.NewRequest("HEAD", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
210 // Can't construct the request
211 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
215 // Add api token header
216 usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
218 // Actually make the request
219 var resp *http.Response
220 if resp, err = kc.Client.Do(usersreq); err != nil {
221 // Something else failed
222 log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
226 if resp.StatusCode != http.StatusOK {
228 log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
232 // Success! Update cache
233 cache.RememberToken(tok)
238 type GetBlockHandler struct {
239 *keepclient.KeepClient
243 type PutBlockHandler struct {
244 *keepclient.KeepClient
248 type InvalidPathHandler struct{}
251 // Returns a mux.Router that passes GET and PUT requests to the
252 // appropriate handlers.
257 kc *keepclient.KeepClient) *mux.Router {
259 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
261 rest := mux.NewRouter()
264 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
265 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
266 rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
270 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
271 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
274 rest.NotFoundHandler = InvalidPathHandler{}
279 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
280 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
281 http.Error(resp, "Bad request", http.StatusBadRequest)
284 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
286 kc := *this.KeepClient
288 hash := mux.Vars(req)["hash"]
289 hints := mux.Vars(req)["hints"]
291 locator := keepclient.MakeLocator2(hash, hints)
293 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
295 if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
296 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
300 var reader io.ReadCloser
304 if req.Method == "GET" {
305 reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
307 } else if req.Method == "HEAD" {
308 blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
311 resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
316 n, err2 := io.Copy(resp, reader)
318 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
319 } else if err2 == nil {
320 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
322 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
325 log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
327 case keepclient.BlockNotFound:
328 http.Error(resp, "Not found", http.StatusNotFound)
330 http.Error(resp, err.Error(), http.StatusBadGateway)
334 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
338 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
340 kc := *this.KeepClient
342 hash := mux.Vars(req)["hash"]
343 hints := mux.Vars(req)["hints"]
345 locator := keepclient.MakeLocator2(hash, hints)
347 var contentLength int64 = -1
348 if req.Header.Get("Content-Length") != "" {
349 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
351 resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
356 log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
358 if contentLength < 1 {
359 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
363 if locator.Size > 0 && int64(locator.Size) != contentLength {
364 http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
368 if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
369 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
373 // Check if the client specified the number of replicas
374 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
376 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
382 // Now try to put the block through
383 hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
385 // Tell the client how many successful PUTs we accomplished
386 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
390 // Default will return http.StatusOK
391 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
392 n, err2 := io.WriteString(resp, hash)
394 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
397 case keepclient.OversizeBlockError:
399 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
401 case keepclient.InsufficientReplicasError:
403 // At least one write is considered success. The
404 // client can decide if getting less than the number of
405 // replications it asked for is a fatal error.
406 // Default will return http.StatusOK
407 n, err2 := io.WriteString(resp, hash)
409 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
412 http.Error(resp, "", http.StatusServiceUnavailable)
416 http.Error(resp, err.Error(), http.StatusBadGateway)
420 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())