4 "git.curoverse.com/arvados.git/sdk/go/keepclient"
5 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "github.com/gorilla/mux"
20 // Default TCP address on which to listen for requests.
21 // Initialized by the -listen flag.
22 const DEFAULT_ADDR = ":25107"
24 var listener net.Listener
35 flagset := flag.NewFlagSet("default", flag.ExitOnError)
41 "Interface on which to listen for requests, in the format "+
42 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
43 "to listen on all network interfaces.")
49 "If set, disable GET operations")
55 "If set, disable PUT operations")
61 "Default number of replicas to write if not specified by the client.")
67 "Path to write pid file")
69 flagset.Parse(os.Args[1:])
71 arv, err := arvadosclient.MakeArvadosClient()
73 log.Fatalf("Error setting up arvados client %s", err.Error())
76 kc, err := keepclient.MakeKeepClient(&arv)
78 log.Fatalf("Error setting up keep client %s", err.Error())
82 f, err := os.Create(pidfile)
84 fmt.Fprint(f, os.Getpid())
87 log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
91 kc.Want_replicas = default_replicas
93 listener, err = net.Listen("tcp", listen)
95 log.Fatalf("Could not listen on %v", listen)
98 go RefreshServicesList(&kc)
100 // Shut down the server gracefully (by closing the listener)
101 // if SIGTERM is received.
102 term := make(chan os.Signal, 1)
103 go func(sig <-chan os.Signal) {
105 log.Println("caught signal:", s)
108 signal.Notify(term, syscall.SIGTERM)
109 signal.Notify(term, syscall.SIGINT)
112 f, err := os.Create(pidfile)
114 fmt.Fprint(f, os.Getpid())
117 log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
121 log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
123 // Start listening for requests.
124 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
126 log.Println("shutting down")
133 type ApiTokenCache struct {
134 tokens map[string]int64
139 // Refresh the keep service list every five minutes.
140 func RefreshServicesList(kc *keepclient.KeepClient) {
142 time.Sleep(300 * time.Second)
143 oldservices := kc.ServiceRoots()
144 kc.DiscoverKeepServers()
145 newservices := kc.ServiceRoots()
146 s1 := fmt.Sprint(oldservices)
147 s2 := fmt.Sprint(newservices)
149 log.Printf("Updated server list to %v", s2)
154 // Cache the token and set an expire time. If we already have an expire time
155 // on the token, it is not updated.
156 func (this *ApiTokenCache) RememberToken(token string) {
158 defer this.lock.Unlock()
160 now := time.Now().Unix()
161 if this.tokens[token] == 0 {
162 this.tokens[token] = now + this.expireTime
166 // Check if the cached token is known and still believed to be valid.
167 func (this *ApiTokenCache) RecallToken(token string) bool {
169 defer this.lock.Unlock()
171 now := time.Now().Unix()
172 if this.tokens[token] == 0 {
175 } else if now < this.tokens[token] {
176 // Token is known and still valid
180 this.tokens[token] = 0
185 func GetRemoteAddress(req *http.Request) string {
186 if realip := req.Header.Get("X-Real-IP"); realip != "" {
187 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
188 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
193 return req.RemoteAddr
196 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
198 if auth = req.Header.Get("Authorization"); auth == "" {
202 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
208 if cache.RecallToken(tok) {
209 // Valid in the cache, short circut
215 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
216 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
220 // Success! Update cache
221 cache.RememberToken(tok)
226 type GetBlockHandler struct {
227 *keepclient.KeepClient
231 type PutBlockHandler struct {
232 *keepclient.KeepClient
236 type InvalidPathHandler struct{}
239 // Returns a mux.Router that passes GET and PUT requests to the
240 // appropriate handlers.
245 kc *keepclient.KeepClient) *mux.Router {
247 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
249 rest := mux.NewRouter()
252 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
253 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
254 rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
258 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
259 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
262 rest.NotFoundHandler = InvalidPathHandler{}
267 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
268 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
269 http.Error(resp, "Bad request", http.StatusBadRequest)
272 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
274 kc := *this.KeepClient
276 hash := mux.Vars(req)["hash"]
277 hints := mux.Vars(req)["hints"]
279 locator := keepclient.MakeLocator2(hash, hints)
281 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
285 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
286 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
290 // Copy ArvadosClient struct and use the client's API token
291 arvclient := *kc.Arvados
292 arvclient.ApiToken = tok
293 kc.Arvados = &arvclient
295 var reader io.ReadCloser
299 if req.Method == "GET" {
300 reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
302 } else if req.Method == "HEAD" {
303 blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
307 resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
313 n, err2 := io.Copy(resp, reader)
315 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error %v", GetRemoteAddress(req), req.Method, hash, n, blocklen, err2)
316 } else if err2 == nil {
317 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
319 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
322 log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
324 case keepclient.BlockNotFound:
325 http.Error(resp, "Not found", http.StatusNotFound)
327 http.Error(resp, err.Error(), http.StatusBadGateway)
331 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
335 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
337 kc := *this.KeepClient
339 hash := mux.Vars(req)["hash"]
340 hints := mux.Vars(req)["hints"]
342 locator := keepclient.MakeLocator2(hash, hints)
344 var contentLength int64 = -1
345 if req.Header.Get("Content-Length") != "" {
346 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
348 resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
353 log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
355 if contentLength < 1 {
356 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
360 if locator.Size > 0 && int64(locator.Size) != contentLength {
361 http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
367 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
368 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
372 // Copy ArvadosClient struct and use the client's API token
373 arvclient := *kc.Arvados
374 arvclient.ApiToken = tok
375 kc.Arvados = &arvclient
377 // Check if the client specified the number of replicas
378 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
380 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
386 // Now try to put the block through
387 hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
389 // Tell the client how many successful PUTs we accomplished
390 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
394 // Default will return http.StatusOK
395 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
396 n, err2 := io.WriteString(resp, hash)
398 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
401 case keepclient.OversizeBlockError:
403 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
405 case keepclient.InsufficientReplicasError:
407 // At least one write is considered success. The
408 // client can decide if getting less than the number of
409 // replications it asked for is a fatal error.
410 // Default will return http.StatusOK
411 n, err2 := io.WriteString(resp, hash)
413 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
416 http.Error(resp, "", http.StatusServiceUnavailable)
420 http.Error(resp, err.Error(), http.StatusBadGateway)
424 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())