7 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "git.curoverse.com/arvados.git/sdk/go/keepclient"
9 "github.com/gorilla/mux"
24 // Default TCP address on which to listen for requests.
25 // Initialized by the -listen flag.
26 const DEFAULT_ADDR = ":25107"
28 var listener net.Listener
40 flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
46 "Interface on which to listen for requests, in the format "+
47 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48 "to listen on all network interfaces.")
54 "If set, disable GET operations")
60 "If set, disable PUT operations")
66 "Default number of replicas to write if not specified by the client.")
72 "Timeout on requests to internal Keep services (default 15 seconds)")
78 "Path to write pid file")
80 flagset.Parse(os.Args[1:])
82 arv, err := arvadosclient.MakeArvadosClient()
84 log.Fatalf("Error setting up arvados client %s", err.Error())
87 if os.Getenv("ARVADOS_DEBUG") != "" {
88 keepclient.DebugPrintf = log.Printf
90 kc, err := keepclient.MakeKeepClient(&arv)
92 log.Fatalf("Error setting up keep client %s", err.Error())
96 f, err := os.Create(pidfile)
98 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
100 fmt.Fprint(f, os.Getpid())
102 defer os.Remove(pidfile)
105 kc.Want_replicas = default_replicas
107 kc.Client.Timeout = time.Duration(timeout) * time.Second
109 listener, err = net.Listen("tcp", listen)
111 log.Fatalf("Could not listen on %v", listen)
114 go RefreshServicesList(kc)
116 // Shut down the server gracefully (by closing the listener)
117 // if SIGTERM is received.
118 term := make(chan os.Signal, 1)
119 go func(sig <-chan os.Signal) {
121 log.Println("caught signal:", s)
124 signal.Notify(term, syscall.SIGTERM)
125 signal.Notify(term, syscall.SIGINT)
127 log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
129 // Start listening for requests.
130 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
132 log.Println("shutting down")
135 type ApiTokenCache struct {
136 tokens map[string]int64
141 // Refresh the keep service list every five minutes.
142 func RefreshServicesList(kc *keepclient.KeepClient) {
143 var previousRoots = []map[string]string{}
144 var delay time.Duration = 0
146 time.Sleep(delay * time.Second)
148 if err := kc.DiscoverKeepServers(); err != nil {
149 log.Println("Error retrieving services list:", err)
153 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
154 if !reflect.DeepEqual(previousRoots, newRoots) {
155 log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
157 if len(newRoots[0]) == 0 {
158 log.Print("WARNING: No local services. Retrying in 3 seconds.")
161 previousRoots = newRoots
165 // Cache the token and set an expire time. If we already have an expire time
166 // on the token, it is not updated.
167 func (this *ApiTokenCache) RememberToken(token string) {
169 defer this.lock.Unlock()
171 now := time.Now().Unix()
172 if this.tokens[token] == 0 {
173 this.tokens[token] = now + this.expireTime
177 // Check if the cached token is known and still believed to be valid.
178 func (this *ApiTokenCache) RecallToken(token string) bool {
180 defer this.lock.Unlock()
182 now := time.Now().Unix()
183 if this.tokens[token] == 0 {
186 } else if now < this.tokens[token] {
187 // Token is known and still valid
191 this.tokens[token] = 0
196 func GetRemoteAddress(req *http.Request) string {
197 if realip := req.Header.Get("X-Real-IP"); realip != "" {
198 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
199 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
204 return req.RemoteAddr
207 func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
209 if auth = req.Header.Get("Authorization"); auth == "" {
213 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
219 if cache.RecallToken(tok) {
220 // Valid in the cache, short circut
226 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
227 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
231 // Success! Update cache
232 cache.RememberToken(tok)
237 type GetBlockHandler struct {
238 *keepclient.KeepClient
242 type PutBlockHandler struct {
243 *keepclient.KeepClient
247 type IndexHandler struct {
248 *keepclient.KeepClient
252 type InvalidPathHandler struct{}
254 type OptionsHandler struct{}
257 // Returns a mux.Router that passes GET and PUT requests to the
258 // appropriate handlers.
263 kc *keepclient.KeepClient) *mux.Router {
265 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
267 rest := mux.NewRouter()
270 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
271 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
272 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
275 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
277 // List blocks whose hash has the given prefix
278 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
282 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
283 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
284 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
285 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
286 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
289 rest.NotFoundHandler = InvalidPathHandler{}
294 func SetCorsHeaders(resp http.ResponseWriter) {
295 resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
296 resp.Header().Set("Access-Control-Allow-Origin", "*")
297 resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
298 resp.Header().Set("Access-Control-Max-Age", "86486400")
301 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
302 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
303 http.Error(resp, "Bad request", http.StatusBadRequest)
306 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
307 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
311 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
312 var ContentLengthMismatch = errors.New("Actual length != expected content length")
313 var MethodNotSupported = errors.New("Method not supported")
315 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
317 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
320 locator := mux.Vars(req)["locator"]
323 var expectLength, responseLength int64
327 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
328 if status != http.StatusOK {
329 http.Error(resp, err.Error(), status)
333 kc := *this.KeepClient
337 if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
338 status, err = http.StatusForbidden, BadAuthorizationHeader
342 // Copy ArvadosClient struct and use the client's API token
343 arvclient := *kc.Arvados
344 arvclient.ApiToken = tok
345 kc.Arvados = &arvclient
347 var reader io.ReadCloser
349 locator = removeHint.ReplaceAllString(locator, "$1")
353 expectLength, proxiedURI, err = kc.Ask(locator)
355 reader, expectLength, proxiedURI, err = kc.Get(locator)
360 status, err = http.StatusNotImplemented, MethodNotSupported
364 if expectLength == -1 {
365 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
368 switch respErr := err.(type) {
370 status = http.StatusOK
371 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
376 responseLength, err = io.Copy(resp, reader)
377 if err == nil && expectLength > -1 && responseLength != expectLength {
378 err = ContentLengthMismatch
381 case keepclient.Error:
382 if respErr == keepclient.BlockNotFound {
383 status = http.StatusNotFound
384 } else if respErr.Temporary() {
385 status = http.StatusBadGateway
390 status = http.StatusInternalServerError
394 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
395 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
397 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
400 kc := *this.KeepClient
402 var expectLength int64 = -1
403 var status = http.StatusInternalServerError
404 var wroteReplicas int
405 var locatorOut string = "-"
408 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
409 if status != http.StatusOK {
410 http.Error(resp, err.Error(), status)
414 locatorIn := mux.Vars(req)["locator"]
416 if req.Header.Get("Content-Length") != "" {
417 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
419 resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
424 if expectLength < 0 {
425 err = LengthRequiredError
426 status = http.StatusLengthRequired
431 var loc *keepclient.Locator
432 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
433 status = http.StatusBadRequest
435 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
436 err = LengthMismatchError
437 status = http.StatusBadRequest
444 if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
445 err = BadAuthorizationHeader
446 status = http.StatusForbidden
450 // Copy ArvadosClient struct and use the client's API token
451 arvclient := *kc.Arvados
452 arvclient.ApiToken = tok
453 kc.Arvados = &arvclient
455 // Check if the client specified the number of replicas
456 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
458 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
464 // Now try to put the block through
466 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
467 err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
468 status = http.StatusInternalServerError
471 locatorOut, wroteReplicas, err = kc.PutB(bytes)
474 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
477 // Tell the client how many successful PUTs we accomplished
478 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
482 status = http.StatusOK
483 _, err = io.WriteString(resp, locatorOut)
485 case keepclient.OversizeBlockError:
487 status = http.StatusRequestEntityTooLarge
489 case keepclient.InsufficientReplicasError:
490 if wroteReplicas > 0 {
491 // At least one write is considered success. The
492 // client can decide if getting less than the number of
493 // replications it asked for is a fatal error.
494 status = http.StatusOK
495 _, err = io.WriteString(resp, locatorOut)
497 status = http.StatusServiceUnavailable
501 status = http.StatusBadGateway
505 // ServeHTTP implementation for IndexHandler
506 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
507 // For each keep server found in LocalRoots:
508 // Invokes GetIndex using keepclient
509 // Expects "complete" response (terminating with blank new line)
510 // Aborts on any errors
511 // Concatenates responses from all those keep servers and returns
512 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
515 prefix := mux.Vars(req)["prefix"]
520 if status != http.StatusOK {
521 http.Error(resp, err.Error(), status)
525 kc := *handler.KeepClient
527 ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
529 status, err = http.StatusForbidden, BadAuthorizationHeader
533 // Copy ArvadosClient struct and use the client's API token
534 arvclient := *kc.Arvados
535 arvclient.ApiToken = token
536 kc.Arvados = &arvclient
538 // Only GET method is supported
539 if req.Method != "GET" {
540 status, err = http.StatusNotImplemented, MethodNotSupported
544 // Get index from all LocalRoots and write to resp
546 for uuid := range kc.LocalRoots() {
547 reader, err = kc.GetIndex(uuid, prefix)
549 status = http.StatusBadGateway
553 _, err = io.Copy(resp, reader)
555 status = http.StatusBadGateway
560 // Got index from all the keep servers and wrote to resp
561 status = http.StatusOK
562 resp.Write([]byte("\n"))