7 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "git.curoverse.com/arvados.git/sdk/go/keepclient"
9 "github.com/gorilla/mux"
24 // Default TCP address on which to listen for requests.
25 // Initialized by the -listen flag.
26 const DEFAULT_ADDR = ":25107"
28 var listener net.Listener
40 flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
46 "Interface on which to listen for requests, in the format "+
47 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48 "to listen on all network interfaces.")
54 "If set, disable GET operations")
60 "If set, disable PUT operations")
66 "Default number of replicas to write if not specified by the client.")
72 "Timeout on requests to internal Keep services (default 15 seconds)")
78 "Path to write pid file")
80 flagset.Parse(os.Args[1:])
82 arv, err := arvadosclient.MakeArvadosClient()
84 log.Fatalf("Error setting up arvados client %s", err.Error())
87 kc, err := keepclient.MakeKeepClient(&arv)
89 log.Fatalf("Error setting up keep client %s", err.Error())
93 f, err := os.Create(pidfile)
95 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
97 fmt.Fprint(f, os.Getpid())
99 defer os.Remove(pidfile)
102 kc.Want_replicas = default_replicas
104 kc.Client.Timeout = time.Duration(timeout) * time.Second
106 listener, err = net.Listen("tcp", listen)
108 log.Fatalf("Could not listen on %v", listen)
111 go RefreshServicesList(kc, 5*time.Minute, 3*time.Second)
113 // Shut down the server gracefully (by closing the listener)
114 // if SIGTERM is received.
115 term := make(chan os.Signal, 1)
116 go func(sig <-chan os.Signal) {
118 log.Println("caught signal:", s)
121 signal.Notify(term, syscall.SIGTERM)
122 signal.Notify(term, syscall.SIGINT)
124 log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
126 // Start listening for requests.
127 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
129 log.Println("shutting down")
132 type ApiTokenCache struct {
133 tokens map[string]int64
138 // Refresh the keep service list on SIGHUP; when the given interval
139 // has elapsed since the last refresh; and (if the last refresh
140 // failed) the given errInterval has elapsed.
141 func RefreshServicesList(kc *keepclient.KeepClient, interval, errInterval time.Duration) {
142 var previousRoots = []map[string]string{}
144 timer := time.NewTimer(interval)
145 gotHUP := make(chan os.Signal, 1)
146 signal.Notify(gotHUP, syscall.SIGHUP)
153 timer.Reset(interval)
155 if err := kc.DiscoverKeepServers(); err != nil {
156 log.Println("Error retrieving services list: %v (retrying in %v)", err, errInterval)
157 timer.Reset(errInterval)
160 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
162 if !reflect.DeepEqual(previousRoots, newRoots) {
163 log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
164 previousRoots = newRoots
167 if len(newRoots[0]) == 0 {
168 log.Printf("WARNING: No local services (retrying in %v)", errInterval)
169 timer.Reset(errInterval)
174 // Cache the token and set an expire time. If we already have an expire time
175 // on the token, it is not updated.
176 func (this *ApiTokenCache) RememberToken(token string) {
178 defer this.lock.Unlock()
180 now := time.Now().Unix()
181 if this.tokens[token] == 0 {
182 this.tokens[token] = now + this.expireTime
186 // Check if the cached token is known and still believed to be valid.
187 func (this *ApiTokenCache) RecallToken(token string) bool {
189 defer this.lock.Unlock()
191 now := time.Now().Unix()
192 if this.tokens[token] == 0 {
195 } else if now < this.tokens[token] {
196 // Token is known and still valid
200 this.tokens[token] = 0
205 func GetRemoteAddress(req *http.Request) string {
206 if realip := req.Header.Get("X-Real-IP"); realip != "" {
207 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
208 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
213 return req.RemoteAddr
216 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
218 if auth = req.Header.Get("Authorization"); auth == "" {
222 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
228 if cache.RecallToken(tok) {
229 // Valid in the cache, short circut
235 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
236 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
240 // Success! Update cache
241 cache.RememberToken(tok)
246 type GetBlockHandler struct {
247 *keepclient.KeepClient
251 type PutBlockHandler struct {
252 *keepclient.KeepClient
256 type IndexHandler struct {
257 *keepclient.KeepClient
261 type InvalidPathHandler struct{}
263 type OptionsHandler struct{}
266 // Returns a mux.Router that passes GET and PUT requests to the
267 // appropriate handlers.
272 kc *keepclient.KeepClient) *mux.Router {
274 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
276 rest := mux.NewRouter()
279 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
280 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
281 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
284 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
286 // List blocks whose hash has the given prefix
287 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
291 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
292 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
293 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
294 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
295 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
298 rest.NotFoundHandler = InvalidPathHandler{}
303 func SetCorsHeaders(resp http.ResponseWriter) {
304 resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
305 resp.Header().Set("Access-Control-Allow-Origin", "*")
306 resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
307 resp.Header().Set("Access-Control-Max-Age", "86486400")
310 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
311 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
312 http.Error(resp, "Bad request", http.StatusBadRequest)
315 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
316 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
320 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
321 var ContentLengthMismatch = errors.New("Actual length != expected content length")
322 var MethodNotSupported = errors.New("Method not supported")
324 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
326 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
329 locator := mux.Vars(req)["locator"]
332 var expectLength, responseLength int64
336 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
337 if status != http.StatusOK {
338 http.Error(resp, err.Error(), status)
342 kc := *this.KeepClient
346 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
347 status, err = http.StatusForbidden, BadAuthorizationHeader
351 // Copy ArvadosClient struct and use the client's API token
352 arvclient := *kc.Arvados
353 arvclient.ApiToken = tok
354 kc.Arvados = &arvclient
356 var reader io.ReadCloser
358 locator = removeHint.ReplaceAllString(locator, "$1")
362 expectLength, proxiedURI, err = kc.Ask(locator)
364 reader, expectLength, proxiedURI, err = kc.Get(locator)
369 status, err = http.StatusNotImplemented, MethodNotSupported
373 if expectLength == -1 {
374 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
379 status = http.StatusOK
380 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
385 responseLength, err = io.Copy(resp, reader)
386 if err == nil && expectLength > -1 && responseLength != expectLength {
387 err = ContentLengthMismatch
390 case keepclient.BlockNotFound:
391 status = http.StatusNotFound
393 status = http.StatusBadGateway
397 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
398 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
400 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
403 kc := *this.KeepClient
405 var expectLength int64 = -1
406 var status = http.StatusInternalServerError
407 var wroteReplicas int
408 var locatorOut string = "-"
411 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
412 if status != http.StatusOK {
413 http.Error(resp, err.Error(), status)
417 locatorIn := mux.Vars(req)["locator"]
419 if req.Header.Get("Content-Length") != "" {
420 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
422 resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
427 if expectLength < 0 {
428 err = LengthRequiredError
429 status = http.StatusLengthRequired
434 var loc *keepclient.Locator
435 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
436 status = http.StatusBadRequest
438 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
439 err = LengthMismatchError
440 status = http.StatusBadRequest
447 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
448 err = BadAuthorizationHeader
449 status = http.StatusForbidden
453 // Copy ArvadosClient struct and use the client's API token
454 arvclient := *kc.Arvados
455 arvclient.ApiToken = tok
456 kc.Arvados = &arvclient
458 // Check if the client specified the number of replicas
459 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
461 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
467 // Now try to put the block through
469 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
470 err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
471 status = http.StatusInternalServerError
474 locatorOut, wroteReplicas, err = kc.PutB(bytes)
477 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
480 // Tell the client how many successful PUTs we accomplished
481 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
485 status = http.StatusOK
486 _, err = io.WriteString(resp, locatorOut)
488 case keepclient.OversizeBlockError:
490 status = http.StatusRequestEntityTooLarge
492 case keepclient.InsufficientReplicasError:
493 if wroteReplicas > 0 {
494 // At least one write is considered success. The
495 // client can decide if getting less than the number of
496 // replications it asked for is a fatal error.
497 status = http.StatusOK
498 _, err = io.WriteString(resp, locatorOut)
500 status = http.StatusServiceUnavailable
504 status = http.StatusBadGateway
508 // ServeHTTP implementation for IndexHandler
509 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
510 // For each keep server found in LocalRoots:
511 // Invokes GetIndex using keepclient
512 // Expects "complete" response (terminating with blank new line)
513 // Aborts on any errors
514 // Concatenates responses from all those keep servers and returns
515 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
518 prefix := mux.Vars(req)["prefix"]
523 if status != http.StatusOK {
524 http.Error(resp, err.Error(), status)
528 kc := *handler.KeepClient
530 ok, token := CheckAuthorizationHeader(kc, handler.ApiTokenCache, req)
532 status, err = http.StatusForbidden, BadAuthorizationHeader
536 // Copy ArvadosClient struct and use the client's API token
537 arvclient := *kc.Arvados
538 arvclient.ApiToken = token
539 kc.Arvados = &arvclient
541 // Only GET method is supported
542 if req.Method != "GET" {
543 status, err = http.StatusNotImplemented, MethodNotSupported
547 // Get index from all LocalRoots and write to resp
549 for uuid := range kc.LocalRoots() {
550 reader, err = kc.GetIndex(uuid, prefix)
552 status = http.StatusBadGateway
556 _, err = io.Copy(resp, reader)
558 status = http.StatusBadGateway
563 // Got index from all the keep servers and wrote to resp
564 status = http.StatusOK
565 resp.Write([]byte("\n"))