7 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "git.curoverse.com/arvados.git/sdk/go/keepclient"
9 "github.com/gorilla/mux"
25 // Default TCP address on which to listen for requests.
26 // Initialized by the -listen flag.
27 const DEFAULT_ADDR = ":25107"
29 var listener net.Listener
41 flagset := flag.NewFlagSet("default", flag.ExitOnError)
47 "Interface on which to listen for requests, in the format "+
48 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
49 "to listen on all network interfaces.")
55 "If set, disable GET operations")
61 "If set, disable PUT operations")
67 "Default number of replicas to write if not specified by the client.")
73 "Timeout on requests to internal Keep services (default 15 seconds)")
79 "Path to write pid file")
81 flagset.Parse(os.Args[1:])
83 arv, err := arvadosclient.MakeArvadosClient()
85 log.Fatalf("Error setting up arvados client %s", err.Error())
88 kc, err := keepclient.MakeKeepClient(&arv)
90 log.Fatalf("Error setting up keep client %s", err.Error())
94 f, err := os.Create(pidfile)
96 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
98 fmt.Fprint(f, os.Getpid())
100 defer os.Remove(pidfile)
103 kc.Want_replicas = default_replicas
105 kc.Client.Timeout = time.Duration(timeout) * time.Second
107 listener, err = net.Listen("tcp", listen)
109 log.Fatalf("Could not listen on %v", listen)
112 go RefreshServicesList(kc)
114 // Shut down the server gracefully (by closing the listener)
115 // if SIGTERM is received.
116 term := make(chan os.Signal, 1)
117 go func(sig <-chan os.Signal) {
119 log.Println("caught signal:", s)
122 signal.Notify(term, syscall.SIGTERM)
123 signal.Notify(term, syscall.SIGINT)
125 log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
127 // Start listening for requests.
128 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
130 log.Println("shutting down")
133 type ApiTokenCache struct {
134 tokens map[string]int64
139 // Refresh the keep service list every five minutes.
140 func RefreshServicesList(kc *keepclient.KeepClient) {
141 var previousRoots = []map[string]string{}
142 var delay time.Duration = 0
144 time.Sleep(delay * time.Second)
146 if err := kc.DiscoverKeepServers(); err != nil {
147 log.Println("Error retrieving services list:", err)
151 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
152 if !reflect.DeepEqual(previousRoots, newRoots) {
153 log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
155 if len(newRoots[0]) == 0 {
156 log.Print("WARNING: No local services. Retrying in 3 seconds.")
159 previousRoots = newRoots
163 // Cache the token and set an expire time. If we already have an expire time
164 // on the token, it is not updated.
165 func (this *ApiTokenCache) RememberToken(token string) {
167 defer this.lock.Unlock()
169 now := time.Now().Unix()
170 if this.tokens[token] == 0 {
171 this.tokens[token] = now + this.expireTime
175 // Check if the cached token is known and still believed to be valid.
176 func (this *ApiTokenCache) RecallToken(token string) bool {
178 defer this.lock.Unlock()
180 now := time.Now().Unix()
181 if this.tokens[token] == 0 {
184 } else if now < this.tokens[token] {
185 // Token is known and still valid
189 this.tokens[token] = 0
194 func GetRemoteAddress(req *http.Request) string {
195 if realip := req.Header.Get("X-Real-IP"); realip != "" {
196 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
197 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
202 return req.RemoteAddr
205 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
207 if auth = req.Header.Get("Authorization"); auth == "" {
211 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
217 if cache.RecallToken(tok) {
218 // Valid in the cache, short circut
224 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
225 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
229 // Success! Update cache
230 cache.RememberToken(tok)
235 type GetBlockHandler struct {
236 *keepclient.KeepClient
240 type PutBlockHandler struct {
241 *keepclient.KeepClient
245 type IndexHandler struct {
246 *keepclient.KeepClient
250 type InvalidPathHandler struct{}
252 type OptionsHandler struct{}
255 // Returns a mux.Router that passes GET and PUT requests to the
256 // appropriate handlers.
261 kc *keepclient.KeepClient) *mux.Router {
263 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
265 rest := mux.NewRouter()
268 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
269 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
270 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
273 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
275 // List blocks whose hash has the given prefix
276 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
280 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
281 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
282 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
283 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
284 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
287 rest.NotFoundHandler = InvalidPathHandler{}
292 func SetCorsHeaders(resp http.ResponseWriter) {
293 resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
294 resp.Header().Set("Access-Control-Allow-Origin", "*")
295 resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
296 resp.Header().Set("Access-Control-Max-Age", "86486400")
299 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
300 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
301 http.Error(resp, "Bad request", http.StatusBadRequest)
304 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
305 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
309 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
310 var ContentLengthMismatch = errors.New("Actual length != expected content length")
311 var MethodNotSupported = errors.New("Method not supported")
313 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
315 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
318 locator := mux.Vars(req)["locator"]
321 var expectLength, responseLength int64
325 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
326 if status != http.StatusOK {
327 http.Error(resp, err.Error(), status)
331 kc := *this.KeepClient
335 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
336 status, err = http.StatusForbidden, BadAuthorizationHeader
340 // Copy ArvadosClient struct and use the client's API token
341 arvclient := *kc.Arvados
342 arvclient.ApiToken = tok
343 kc.Arvados = &arvclient
345 var reader io.ReadCloser
347 locator = removeHint.ReplaceAllString(locator, "$1")
351 expectLength, proxiedURI, err = kc.Ask(locator)
353 reader, expectLength, proxiedURI, err = kc.Get(locator)
358 status, err = http.StatusNotImplemented, MethodNotSupported
362 if expectLength == -1 {
363 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
368 status = http.StatusOK
369 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
374 responseLength, err = io.Copy(resp, reader)
375 if err == nil && expectLength > -1 && responseLength != expectLength {
376 err = ContentLengthMismatch
379 case keepclient.BlockNotFound:
380 status = http.StatusNotFound
382 status = http.StatusBadGateway
386 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
387 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
389 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
392 kc := *this.KeepClient
394 var expectLength int64 = -1
395 var status = http.StatusInternalServerError
396 var wroteReplicas int
397 var locatorOut string = "-"
400 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
401 if status != http.StatusOK {
402 http.Error(resp, err.Error(), status)
406 locatorIn := mux.Vars(req)["locator"]
408 if req.Header.Get("Content-Length") != "" {
409 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
411 resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
416 if expectLength < 0 {
417 err = LengthRequiredError
418 status = http.StatusLengthRequired
423 var loc *keepclient.Locator
424 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
425 status = http.StatusBadRequest
427 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
428 err = LengthMismatchError
429 status = http.StatusBadRequest
436 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
437 err = BadAuthorizationHeader
438 status = http.StatusForbidden
442 // Copy ArvadosClient struct and use the client's API token
443 arvclient := *kc.Arvados
444 arvclient.ApiToken = tok
445 kc.Arvados = &arvclient
447 // Check if the client specified the number of replicas
448 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
450 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
456 // Now try to put the block through
458 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
459 err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
460 status = http.StatusInternalServerError
463 locatorOut, wroteReplicas, err = kc.PutB(bytes)
466 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
469 // Tell the client how many successful PUTs we accomplished
470 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
474 status = http.StatusOK
475 _, err = io.WriteString(resp, locatorOut)
477 case keepclient.OversizeBlockError:
479 status = http.StatusRequestEntityTooLarge
481 case keepclient.InsufficientReplicasError:
482 if wroteReplicas > 0 {
483 // At least one write is considered success. The
484 // client can decide if getting less than the number of
485 // replications it asked for is a fatal error.
486 status = http.StatusOK
487 _, err = io.WriteString(resp, locatorOut)
489 status = http.StatusServiceUnavailable
493 status = http.StatusBadGateway
497 // ServeHTTP implemenation for IndexHandler
498 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
499 // For each keep server found in LocalRoots:
500 // Invokes GetIndex using keepclient
501 // Expects "complete" response (terminating with blank new line)
502 // Aborts on any errors
503 // Concatenates responses from all those keep servers and returns
504 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
507 prefix := mux.Vars(req)["prefix"]
512 if status != http.StatusOK {
513 http.Error(resp, err.Error(), status)
517 kc := *handler.KeepClient
521 if pass, tok = CheckAuthorizationHeader(kc, handler.ApiTokenCache, req); !pass {
522 status, err = http.StatusForbidden, BadAuthorizationHeader
526 // Copy ArvadosClient struct and use the client's API token
527 arvclient := *kc.Arvados
528 arvclient.ApiToken = tok
529 kc.Arvados = &arvclient
536 for uuid := range kc.LocalRoots() {
537 reader, err = kc.GetIndex(uuid, prefix)
543 readBytes, err = ioutil.ReadAll(reader)
548 // Got index; verify that it is complete
549 // The response should be "\n" if no locators matched the prefix
550 // Else, it should be a list of locators followed by a blank line
551 if (!strings.HasSuffix(string(readBytes), "\n\n")) && (string(readBytes) != "\n") {
552 err = errors.New("Got incomplete index")
555 // Trim the extra empty new line found in response from each server
556 indexResp = append(indexResp, (readBytes[0 : len(readBytes)-1])...)
559 // Append empty line at the end of concatenation of all server responses
560 indexResp = append(indexResp, ([]byte("\n"))...)
562 status, err = http.StatusNotImplemented, MethodNotSupported
568 status = http.StatusOK
569 resp.Header().Set("Content-Length", fmt.Sprint(len(indexResp)))
570 _, err = resp.Write(indexResp)
572 status = http.StatusBadGateway