4 "git.curoverse.com/arvados.git/sdk/go/keepclient"
5 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "github.com/gorilla/mux"
20 // Default TCP address on which to listen for requests.
21 // Initialized by the -listen flag.
22 const DEFAULT_ADDR = ":25107"
24 var listener net.Listener
35 flagset := flag.NewFlagSet("default", flag.ExitOnError)
41 "Interface on which to listen for requests, in the format "+
42 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
43 "to listen on all network interfaces.")
49 "If set, disable GET operations")
55 "If set, disable PUT operations")
61 "Default number of replicas to write if not specified by the client.")
67 "Path to write pid file")
69 flagset.Parse(os.Args[1:])
71 arv, err := arvadosclient.MakeArvadosClient()
73 log.Fatalf("Error setting up arvados client %s", err.Error())
76 kc, err := keepclient.MakeKeepClient(&arv)
78 log.Fatalf("Error setting up keep client %s", err.Error())
82 f, err := os.Create(pidfile)
84 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
86 fmt.Fprint(f, os.Getpid())
88 defer os.Remove(pidfile)
91 kc.Want_replicas = default_replicas
93 listener, err = net.Listen("tcp", listen)
95 log.Fatalf("Could not listen on %v", listen)
98 go RefreshServicesList(&kc)
100 // Shut down the server gracefully (by closing the listener)
101 // if SIGTERM is received.
102 term := make(chan os.Signal, 1)
103 go func(sig <-chan os.Signal) {
105 log.Println("caught signal:", s)
109 signal.Notify(term, syscall.SIGTERM)
110 signal.Notify(term, syscall.SIGINT)
112 log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
114 // Start listening for requests.
115 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
117 log.Println("shutting down")
120 type ApiTokenCache struct {
121 tokens map[string]int64
126 // Refresh the keep service list every five minutes.
127 func RefreshServicesList(kc *keepclient.KeepClient) {
129 time.Sleep(300 * time.Second)
130 oldservices := kc.ServiceRoots()
131 kc.DiscoverKeepServers()
132 newservices := kc.ServiceRoots()
133 s1 := fmt.Sprint(oldservices)
134 s2 := fmt.Sprint(newservices)
136 log.Printf("Updated server list to %v", s2)
141 // Cache the token and set an expire time. If we already have an expire time
142 // on the token, it is not updated.
143 func (this *ApiTokenCache) RememberToken(token string) {
145 defer this.lock.Unlock()
147 now := time.Now().Unix()
148 if this.tokens[token] == 0 {
149 this.tokens[token] = now + this.expireTime
153 // Check if the cached token is known and still believed to be valid.
154 func (this *ApiTokenCache) RecallToken(token string) bool {
156 defer this.lock.Unlock()
158 now := time.Now().Unix()
159 if this.tokens[token] == 0 {
162 } else if now < this.tokens[token] {
163 // Token is known and still valid
167 this.tokens[token] = 0
172 func GetRemoteAddress(req *http.Request) string {
173 if realip := req.Header.Get("X-Real-IP"); realip != "" {
174 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
175 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
180 return req.RemoteAddr
183 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
185 if auth = req.Header.Get("Authorization"); auth == "" {
189 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
195 if cache.RecallToken(tok) {
196 // Valid in the cache, short circut
202 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
203 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
207 // Success! Update cache
208 cache.RememberToken(tok)
213 type GetBlockHandler struct {
214 *keepclient.KeepClient
218 type PutBlockHandler struct {
219 *keepclient.KeepClient
223 type InvalidPathHandler struct{}
226 // Returns a mux.Router that passes GET and PUT requests to the
227 // appropriate handlers.
232 kc *keepclient.KeepClient) *mux.Router {
234 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
236 rest := mux.NewRouter()
239 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
240 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
241 rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
245 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
246 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
249 rest.NotFoundHandler = InvalidPathHandler{}
254 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
255 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
256 http.Error(resp, "Bad request", http.StatusBadRequest)
259 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
261 kc := *this.KeepClient
263 hash := mux.Vars(req)["hash"]
264 hints := mux.Vars(req)["hints"]
266 locator := keepclient.MakeLocator2(hash, hints)
268 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
272 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
273 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
277 // Copy ArvadosClient struct and use the client's API token
278 arvclient := *kc.Arvados
279 arvclient.ApiToken = tok
280 kc.Arvados = &arvclient
282 var reader io.ReadCloser
286 if req.Method == "GET" {
287 reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
289 } else if req.Method == "HEAD" {
290 blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
294 resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
300 n, err2 := io.Copy(resp, reader)
302 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error %v", GetRemoteAddress(req), req.Method, hash, n, blocklen, err2)
303 } else if err2 == nil {
304 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
306 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
309 log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
311 case keepclient.BlockNotFound:
312 http.Error(resp, "Not found", http.StatusNotFound)
314 http.Error(resp, err.Error(), http.StatusBadGateway)
318 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
322 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
324 kc := *this.KeepClient
326 hash := mux.Vars(req)["hash"]
327 hints := mux.Vars(req)["hints"]
329 locator := keepclient.MakeLocator2(hash, hints)
331 var contentLength int64 = -1
332 if req.Header.Get("Content-Length") != "" {
333 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
335 resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
340 log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
342 if contentLength < 1 {
343 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
347 if locator.Size > 0 && int64(locator.Size) != contentLength {
348 http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
354 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
355 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
359 // Copy ArvadosClient struct and use the client's API token
360 arvclient := *kc.Arvados
361 arvclient.ApiToken = tok
362 kc.Arvados = &arvclient
364 // Check if the client specified the number of replicas
365 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
367 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
373 // Now try to put the block through
374 hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
376 // Tell the client how many successful PUTs we accomplished
377 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
381 // Default will return http.StatusOK
382 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
383 n, err2 := io.WriteString(resp, hash)
385 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
388 case keepclient.OversizeBlockError:
390 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
392 case keepclient.InsufficientReplicasError:
394 // At least one write is considered success. The
395 // client can decide if getting less than the number of
396 // replications it asked for is a fatal error.
397 // Default will return http.StatusOK
398 n, err2 := io.WriteString(resp, hash)
400 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
403 http.Error(resp, "", http.StatusServiceUnavailable)
407 http.Error(resp, err.Error(), http.StatusBadGateway)
411 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())