7 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "git.curoverse.com/arvados.git/sdk/go/keepclient"
9 "github.com/gorilla/mux"
22 // Default TCP address on which to listen for requests.
23 // Initialized by the -listen flag.
24 const DEFAULT_ADDR = ":25107"
26 var listener net.Listener
38 flagset := flag.NewFlagSet("default", flag.ExitOnError)
44 "Interface on which to listen for requests, in the format "+
45 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
46 "to listen on all network interfaces.")
52 "If set, disable GET operations")
58 "If set, disable PUT operations")
64 "Default number of replicas to write if not specified by the client.")
70 "Timeout on requests to internal Keep services (default 15 seconds)")
76 "Path to write pid file")
78 flagset.Parse(os.Args[1:])
80 arv, err := arvadosclient.MakeArvadosClient()
82 log.Fatalf("Error setting up arvados client %s", err.Error())
85 kc, err := keepclient.MakeKeepClient(&arv)
87 log.Fatalf("Error setting up keep client %s", err.Error())
91 f, err := os.Create(pidfile)
93 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
95 fmt.Fprint(f, os.Getpid())
97 defer os.Remove(pidfile)
100 kc.Want_replicas = default_replicas
102 kc.Client.Timeout = time.Duration(timeout) * time.Second
104 listener, err = net.Listen("tcp", listen)
106 log.Fatalf("Could not listen on %v", listen)
109 go RefreshServicesList(kc)
111 // Shut down the server gracefully (by closing the listener)
112 // if SIGTERM is received.
113 term := make(chan os.Signal, 1)
114 go func(sig <-chan os.Signal) {
116 log.Println("caught signal:", s)
119 signal.Notify(term, syscall.SIGTERM)
120 signal.Notify(term, syscall.SIGINT)
122 log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
124 // Start listening for requests.
125 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
127 log.Println("shutting down")
130 type ApiTokenCache struct {
131 tokens map[string]int64
136 // Refresh the keep service list every five minutes.
137 func RefreshServicesList(kc *keepclient.KeepClient) {
140 if err := kc.DiscoverKeepServers(); err != nil {
141 log.Println("Error retrieving services list:", err)
142 time.Sleep(3*time.Second)
144 } else if len(kc.LocalRoots()) == 0 {
145 log.Println("Received empty services list")
146 time.Sleep(3*time.Second)
149 newRoots := fmt.Sprint("Locals ", kc.LocalRoots(), ", gateways ", kc.GatewayRoots())
150 if newRoots != previousRoots {
151 log.Println("Updated services list:", newRoots)
152 previousRoots = newRoots
154 time.Sleep(300*time.Second)
159 // Cache the token and set an expire time. If we already have an expire time
160 // on the token, it is not updated.
161 func (this *ApiTokenCache) RememberToken(token string) {
163 defer this.lock.Unlock()
165 now := time.Now().Unix()
166 if this.tokens[token] == 0 {
167 this.tokens[token] = now + this.expireTime
171 // Check if the cached token is known and still believed to be valid.
172 func (this *ApiTokenCache) RecallToken(token string) bool {
174 defer this.lock.Unlock()
176 now := time.Now().Unix()
177 if this.tokens[token] == 0 {
180 } else if now < this.tokens[token] {
181 // Token is known and still valid
185 this.tokens[token] = 0
190 func GetRemoteAddress(req *http.Request) string {
191 if realip := req.Header.Get("X-Real-IP"); realip != "" {
192 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
193 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
198 return req.RemoteAddr
201 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
203 if auth = req.Header.Get("Authorization"); auth == "" {
207 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
213 if cache.RecallToken(tok) {
214 // Valid in the cache, short circut
220 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
221 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
225 // Success! Update cache
226 cache.RememberToken(tok)
231 type GetBlockHandler struct {
232 *keepclient.KeepClient
236 type PutBlockHandler struct {
237 *keepclient.KeepClient
241 type InvalidPathHandler struct{}
243 type OptionsHandler struct{}
246 // Returns a mux.Router that passes GET and PUT requests to the
247 // appropriate handlers.
252 kc *keepclient.KeepClient) *mux.Router {
254 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
256 rest := mux.NewRouter()
259 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
260 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
261 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
265 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
266 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
267 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
268 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
269 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
272 rest.NotFoundHandler = InvalidPathHandler{}
277 func SetCorsHeaders(resp http.ResponseWriter) {
278 resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
279 resp.Header().Set("Access-Control-Allow-Origin", "*")
280 resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
281 resp.Header().Set("Access-Control-Max-Age", "86486400")
284 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
285 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
286 http.Error(resp, "Bad request", http.StatusBadRequest)
289 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
290 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
294 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
295 var ContentLengthMismatch = errors.New("Actual length != expected content length")
296 var MethodNotSupported = errors.New("Method not supported")
298 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
301 locator := mux.Vars(req)["locator"]
304 var expectLength, responseLength int64
308 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
309 if status != http.StatusOK {
310 http.Error(resp, err.Error(), status)
314 kc := *this.KeepClient
318 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
319 status, err = http.StatusForbidden, BadAuthorizationHeader
323 // Copy ArvadosClient struct and use the client's API token
324 arvclient := *kc.Arvados
325 arvclient.ApiToken = tok
326 kc.Arvados = &arvclient
328 var reader io.ReadCloser
332 expectLength, proxiedURI, err = kc.Ask(locator)
334 reader, expectLength, proxiedURI, err = kc.Get(locator)
339 status, err = http.StatusNotImplemented, MethodNotSupported
343 if expectLength == -1 {
344 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
349 status = http.StatusOK
350 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
355 responseLength, err = io.Copy(resp, reader)
356 if err == nil && expectLength > -1 && responseLength != expectLength {
357 err = ContentLengthMismatch
360 case keepclient.BlockNotFound:
361 status = http.StatusNotFound
363 status = http.StatusBadGateway
367 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
368 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
370 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
373 kc := *this.KeepClient
375 var expectLength int64 = -1
376 var status = http.StatusInternalServerError
377 var wroteReplicas int
378 var locatorOut string = "-"
381 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
382 if status != http.StatusOK {
383 http.Error(resp, err.Error(), status)
387 locatorIn := mux.Vars(req)["locator"]
389 if req.Header.Get("Content-Length") != "" {
390 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
392 resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
397 if expectLength < 0 {
398 err = LengthRequiredError
399 status = http.StatusLengthRequired
404 var loc *keepclient.Locator
405 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
406 status = http.StatusBadRequest
408 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
409 err = LengthMismatchError
410 status = http.StatusBadRequest
417 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
418 err = BadAuthorizationHeader
419 status = http.StatusForbidden
423 // Copy ArvadosClient struct and use the client's API token
424 arvclient := *kc.Arvados
425 arvclient.ApiToken = tok
426 kc.Arvados = &arvclient
428 // Check if the client specified the number of replicas
429 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
431 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
437 // Now try to put the block through
439 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
440 err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
441 status = http.StatusInternalServerError
444 locatorOut, wroteReplicas, err = kc.PutB(bytes)
447 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
450 // Tell the client how many successful PUTs we accomplished
451 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
455 status = http.StatusOK
456 _, err = io.WriteString(resp, locatorOut)
458 case keepclient.OversizeBlockError:
460 status = http.StatusRequestEntityTooLarge
462 case keepclient.InsufficientReplicasError:
463 if wroteReplicas > 0 {
464 // At least one write is considered success. The
465 // client can decide if getting less than the number of
466 // replications it asked for is a fatal error.
467 status = http.StatusOK
468 _, err = io.WriteString(resp, locatorOut)
470 status = http.StatusServiceUnavailable
474 status = http.StatusBadGateway