7 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "git.curoverse.com/arvados.git/sdk/go/keepclient"
9 "github.com/gorilla/mux"
24 // Default TCP address on which to listen for requests.
25 // Override with -listen.
26 const DefaultAddr = ":25107"
28 var listener net.Listener
40 flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
46 "Interface on which to listen for requests, in the format "+
47 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48 "to listen on all network interfaces.")
54 "If set, disable GET operations")
60 "If set, disable PUT operations")
66 "Default number of replicas to write if not specified by the client.")
72 "Timeout on requests to internal Keep services (default 15 seconds)")
78 "Path to write pid file")
80 flagset.Parse(os.Args[1:])
83 f, err := os.Create(pidfile)
85 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
87 fmt.Fprint(f, os.Getpid())
89 defer os.Remove(pidfile)
92 arv, err := arvadosclient.MakeArvadosClient()
94 log.Fatalf("setting up arvados client: %v", err)
96 kc, err := keepclient.MakeKeepClient(&arv)
98 log.Fatalf("setting up keep client: %v", err)
100 kc.Want_replicas = default_replicas
101 kc.Client.Timeout = time.Duration(timeout) * time.Second
102 go RefreshServicesList(kc, 5*time.Minute, 3*time.Second)
104 listener, err = net.Listen("tcp", listen)
106 log.Fatalf("Could not listen on %v", listen)
108 log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
110 // Shut down the server gracefully (by closing the listener)
111 // if SIGTERM is received.
112 term := make(chan os.Signal, 1)
113 go func(sig <-chan os.Signal) {
115 log.Println("caught signal:", s)
118 signal.Notify(term, syscall.SIGTERM)
119 signal.Notify(term, syscall.SIGINT)
121 // Start serving requests.
122 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
124 log.Println("shutting down")
127 type ApiTokenCache struct {
128 tokens map[string]int64
133 // Refresh the keep service list on SIGHUP; when the given interval
134 // has elapsed since the last refresh; and (if the last refresh
135 // failed) the given errInterval has elapsed.
136 func RefreshServicesList(kc *keepclient.KeepClient, interval, errInterval time.Duration) {
137 var previousRoots = []map[string]string{}
139 timer := time.NewTimer(interval)
140 gotHUP := make(chan os.Signal, 1)
141 signal.Notify(gotHUP, syscall.SIGHUP)
148 timer.Reset(interval)
150 if err := kc.DiscoverKeepServers(); err != nil {
151 log.Println("Error retrieving services list: %v (retrying in %v)", err, errInterval)
152 timer.Reset(errInterval)
155 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
157 if !reflect.DeepEqual(previousRoots, newRoots) {
158 log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
159 previousRoots = newRoots
162 if len(newRoots[0]) == 0 {
163 log.Printf("WARNING: No local services (retrying in %v)", errInterval)
164 timer.Reset(errInterval)
169 // Cache the token and set an expire time. If we already have an expire time
170 // on the token, it is not updated.
171 func (this *ApiTokenCache) RememberToken(token string) {
173 defer this.lock.Unlock()
175 now := time.Now().Unix()
176 if this.tokens[token] == 0 {
177 this.tokens[token] = now + this.expireTime
181 // Check if the cached token is known and still believed to be valid.
182 func (this *ApiTokenCache) RecallToken(token string) bool {
184 defer this.lock.Unlock()
186 now := time.Now().Unix()
187 if this.tokens[token] == 0 {
190 } else if now < this.tokens[token] {
191 // Token is known and still valid
195 this.tokens[token] = 0
200 func GetRemoteAddress(req *http.Request) string {
201 if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
202 return xff + "," + req.RemoteAddr
204 return req.RemoteAddr
207 func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
209 if auth = req.Header.Get("Authorization"); auth == "" {
213 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
219 if cache.RecallToken(tok) {
220 // Valid in the cache, short circut
226 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
227 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
231 // Success! Update cache
232 cache.RememberToken(tok)
237 type GetBlockHandler struct {
238 *keepclient.KeepClient
242 type PutBlockHandler struct {
243 *keepclient.KeepClient
247 type IndexHandler struct {
248 *keepclient.KeepClient
252 type InvalidPathHandler struct{}
254 type OptionsHandler struct{}
257 // Returns a mux.Router that passes GET and PUT requests to the
258 // appropriate handlers.
263 kc *keepclient.KeepClient) *mux.Router {
265 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
267 rest := mux.NewRouter()
270 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
271 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
272 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
275 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
277 // List blocks whose hash has the given prefix
278 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
282 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
283 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
284 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
285 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
286 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
289 rest.NotFoundHandler = InvalidPathHandler{}
294 func SetCorsHeaders(resp http.ResponseWriter) {
295 resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
296 resp.Header().Set("Access-Control-Allow-Origin", "*")
297 resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
298 resp.Header().Set("Access-Control-Max-Age", "86486400")
301 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
302 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
303 http.Error(resp, "Bad request", http.StatusBadRequest)
306 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
307 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
311 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
312 var ContentLengthMismatch = errors.New("Actual length != expected content length")
313 var MethodNotSupported = errors.New("Method not supported")
315 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
317 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
320 locator := mux.Vars(req)["locator"]
323 var expectLength, responseLength int64
327 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
328 if status != http.StatusOK {
329 http.Error(resp, err.Error(), status)
333 kc := *this.KeepClient
337 if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
338 status, err = http.StatusForbidden, BadAuthorizationHeader
342 // Copy ArvadosClient struct and use the client's API token
343 arvclient := *kc.Arvados
344 arvclient.ApiToken = tok
345 kc.Arvados = &arvclient
347 var reader io.ReadCloser
349 locator = removeHint.ReplaceAllString(locator, "$1")
353 expectLength, proxiedURI, err = kc.Ask(locator)
355 reader, expectLength, proxiedURI, err = kc.Get(locator)
360 status, err = http.StatusNotImplemented, MethodNotSupported
364 if expectLength == -1 {
365 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
368 switch respErr := err.(type) {
370 status = http.StatusOK
371 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
376 responseLength, err = io.Copy(resp, reader)
377 if err == nil && expectLength > -1 && responseLength != expectLength {
378 err = ContentLengthMismatch
381 case keepclient.Error:
382 if respErr == keepclient.BlockNotFound {
383 status = http.StatusNotFound
384 } else if respErr.Temporary() {
385 status = http.StatusBadGateway
390 status = http.StatusInternalServerError
394 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
395 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
397 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
400 kc := *this.KeepClient
402 var expectLength int64 = -1
403 var status = http.StatusInternalServerError
404 var wroteReplicas int
405 var locatorOut string = "-"
408 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
409 if status != http.StatusOK {
410 http.Error(resp, err.Error(), status)
414 locatorIn := mux.Vars(req)["locator"]
416 if req.Header.Get("Content-Length") != "" {
417 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
419 resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
424 if expectLength < 0 {
425 err = LengthRequiredError
426 status = http.StatusLengthRequired
431 var loc *keepclient.Locator
432 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
433 status = http.StatusBadRequest
435 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
436 err = LengthMismatchError
437 status = http.StatusBadRequest
444 if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
445 err = BadAuthorizationHeader
446 status = http.StatusForbidden
450 // Copy ArvadosClient struct and use the client's API token
451 arvclient := *kc.Arvados
452 arvclient.ApiToken = tok
453 kc.Arvados = &arvclient
455 // Check if the client specified the number of replicas
456 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
458 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
464 // Now try to put the block through
466 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
467 err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
468 status = http.StatusInternalServerError
471 locatorOut, wroteReplicas, err = kc.PutB(bytes)
474 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
477 // Tell the client how many successful PUTs we accomplished
478 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
482 status = http.StatusOK
483 _, err = io.WriteString(resp, locatorOut)
485 case keepclient.OversizeBlockError:
487 status = http.StatusRequestEntityTooLarge
489 case keepclient.InsufficientReplicasError:
490 if wroteReplicas > 0 {
491 // At least one write is considered success. The
492 // client can decide if getting less than the number of
493 // replications it asked for is a fatal error.
494 status = http.StatusOK
495 _, err = io.WriteString(resp, locatorOut)
497 status = http.StatusServiceUnavailable
501 status = http.StatusBadGateway
505 // ServeHTTP implementation for IndexHandler
506 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
507 // For each keep server found in LocalRoots:
508 // Invokes GetIndex using keepclient
509 // Expects "complete" response (terminating with blank new line)
510 // Aborts on any errors
511 // Concatenates responses from all those keep servers and returns
512 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
515 prefix := mux.Vars(req)["prefix"]
520 if status != http.StatusOK {
521 http.Error(resp, err.Error(), status)
525 kc := *handler.KeepClient
527 ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
529 status, err = http.StatusForbidden, BadAuthorizationHeader
533 // Copy ArvadosClient struct and use the client's API token
534 arvclient := *kc.Arvados
535 arvclient.ApiToken = token
536 kc.Arvados = &arvclient
538 // Only GET method is supported
539 if req.Method != "GET" {
540 status, err = http.StatusNotImplemented, MethodNotSupported
544 // Get index from all LocalRoots and write to resp
546 for uuid := range kc.LocalRoots() {
547 reader, err = kc.GetIndex(uuid, prefix)
549 status = http.StatusBadGateway
553 _, err = io.Copy(resp, reader)
555 status = http.StatusBadGateway
560 // Got index from all the keep servers and wrote to resp
561 status = http.StatusOK
562 resp.Write([]byte("\n"))