7 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "git.curoverse.com/arvados.git/sdk/go/keepclient"
9 "github.com/gorilla/mux"
24 // Default TCP address on which to listen for requests.
25 // Initialized by the -listen flag.
26 const DEFAULT_ADDR = ":25107"
28 var listener net.Listener
40 flagset := flag.NewFlagSet("default", flag.ExitOnError)
46 "Interface on which to listen for requests, in the format "+
47 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48 "to listen on all network interfaces.")
54 "If set, disable GET operations")
60 "If set, disable PUT operations")
66 "Default number of replicas to write if not specified by the client.")
72 "Timeout on requests to internal Keep services (default 15 seconds)")
78 "Path to write pid file")
80 flagset.Parse(os.Args[1:])
82 arv, err := arvadosclient.MakeArvadosClient()
84 log.Fatalf("Error setting up arvados client %s", err.Error())
87 kc, err := keepclient.MakeKeepClient(&arv)
89 log.Fatalf("Error setting up keep client %s", err.Error())
93 f, err := os.Create(pidfile)
95 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
97 fmt.Fprint(f, os.Getpid())
99 defer os.Remove(pidfile)
102 kc.Want_replicas = default_replicas
104 kc.Client.Timeout = time.Duration(timeout) * time.Second
106 listener, err = net.Listen("tcp", listen)
108 log.Fatalf("Could not listen on %v", listen)
111 go RefreshServicesList(kc)
113 // Shut down the server gracefully (by closing the listener)
114 // if SIGTERM is received.
115 term := make(chan os.Signal, 1)
116 go func(sig <-chan os.Signal) {
118 log.Println("caught signal:", s)
121 signal.Notify(term, syscall.SIGTERM)
122 signal.Notify(term, syscall.SIGINT)
124 log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
126 // Start listening for requests.
127 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
129 log.Println("shutting down")
132 type ApiTokenCache struct {
133 tokens map[string]int64
138 // Refresh the keep service list every five minutes.
139 func RefreshServicesList(kc *keepclient.KeepClient) {
140 var previousRoots = []map[string]string{}
141 var delay time.Duration = 0
143 time.Sleep(delay * time.Second)
145 if err := kc.DiscoverKeepServers(); err != nil {
146 log.Println("Error retrieving services list:", err)
150 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
151 if !reflect.DeepEqual(previousRoots, newRoots) {
152 log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
154 if len(newRoots[0]) == 0 {
155 log.Print("WARNING: No local services. Retrying in 3 seconds.")
158 previousRoots = newRoots
162 // Cache the token and set an expire time. If we already have an expire time
163 // on the token, it is not updated.
164 func (this *ApiTokenCache) RememberToken(token string) {
166 defer this.lock.Unlock()
168 now := time.Now().Unix()
169 if this.tokens[token] == 0 {
170 this.tokens[token] = now + this.expireTime
174 // Check if the cached token is known and still believed to be valid.
175 func (this *ApiTokenCache) RecallToken(token string) bool {
177 defer this.lock.Unlock()
179 now := time.Now().Unix()
180 if this.tokens[token] == 0 {
183 } else if now < this.tokens[token] {
184 // Token is known and still valid
188 this.tokens[token] = 0
193 func GetRemoteAddress(req *http.Request) string {
194 if realip := req.Header.Get("X-Real-IP"); realip != "" {
195 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
196 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
201 return req.RemoteAddr
204 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
206 if auth = req.Header.Get("Authorization"); auth == "" {
210 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
216 if cache.RecallToken(tok) {
217 // Valid in the cache, short circut
223 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
224 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
228 // Success! Update cache
229 cache.RememberToken(tok)
234 type GetBlockHandler struct {
235 *keepclient.KeepClient
239 type PutBlockHandler struct {
240 *keepclient.KeepClient
244 type InvalidPathHandler struct{}
246 type OptionsHandler struct{}
249 // Returns a mux.Router that passes GET and PUT requests to the
250 // appropriate handlers.
255 kc *keepclient.KeepClient) *mux.Router {
257 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
259 rest := mux.NewRouter()
262 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
263 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
264 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
268 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
269 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
270 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
271 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
272 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
275 rest.NotFoundHandler = InvalidPathHandler{}
280 func SetCorsHeaders(resp http.ResponseWriter) {
281 resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
282 resp.Header().Set("Access-Control-Allow-Origin", "*")
283 resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
284 resp.Header().Set("Access-Control-Max-Age", "86486400")
287 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
288 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
289 http.Error(resp, "Bad request", http.StatusBadRequest)
292 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
293 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
297 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
298 var ContentLengthMismatch = errors.New("Actual length != expected content length")
299 var MethodNotSupported = errors.New("Method not supported")
301 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
303 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
306 locator := mux.Vars(req)["locator"]
309 var expectLength, responseLength int64
313 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
314 if status != http.StatusOK {
315 http.Error(resp, err.Error(), status)
319 kc := *this.KeepClient
323 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
324 status, err = http.StatusForbidden, BadAuthorizationHeader
328 // Copy ArvadosClient struct and use the client's API token
329 arvclient := *kc.Arvados
330 arvclient.ApiToken = tok
331 kc.Arvados = &arvclient
333 var reader io.ReadCloser
335 locator = removeHint.ReplaceAllString(locator, "$1")
339 expectLength, proxiedURI, err = kc.Ask(locator)
341 reader, expectLength, proxiedURI, err = kc.Get(locator)
346 status, err = http.StatusNotImplemented, MethodNotSupported
350 if expectLength == -1 {
351 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
356 status = http.StatusOK
357 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
362 responseLength, err = io.Copy(resp, reader)
363 if err == nil && expectLength > -1 && responseLength != expectLength {
364 err = ContentLengthMismatch
367 case keepclient.BlockNotFound:
368 status = http.StatusNotFound
370 status = http.StatusBadGateway
374 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
375 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
377 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
380 kc := *this.KeepClient
382 var expectLength int64 = -1
383 var status = http.StatusInternalServerError
384 var wroteReplicas int
385 var locatorOut string = "-"
388 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
389 if status != http.StatusOK {
390 http.Error(resp, err.Error(), status)
394 locatorIn := mux.Vars(req)["locator"]
396 if req.Header.Get("Content-Length") != "" {
397 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
399 resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
404 if expectLength < 0 {
405 err = LengthRequiredError
406 status = http.StatusLengthRequired
411 var loc *keepclient.Locator
412 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
413 status = http.StatusBadRequest
415 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
416 err = LengthMismatchError
417 status = http.StatusBadRequest
424 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
425 err = BadAuthorizationHeader
426 status = http.StatusForbidden
430 // Copy ArvadosClient struct and use the client's API token
431 arvclient := *kc.Arvados
432 arvclient.ApiToken = tok
433 kc.Arvados = &arvclient
435 // Check if the client specified the number of replicas
436 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
438 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
444 // Now try to put the block through
446 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
447 err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
448 status = http.StatusInternalServerError
451 locatorOut, wroteReplicas, err = kc.PutB(bytes)
454 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
457 // Tell the client how many successful PUTs we accomplished
458 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
462 status = http.StatusOK
463 _, err = io.WriteString(resp, locatorOut)
465 case keepclient.OversizeBlockError:
467 status = http.StatusRequestEntityTooLarge
469 case keepclient.InsufficientReplicasError:
470 if wroteReplicas > 0 {
471 // At least one write is considered success. The
472 // client can decide if getting less than the number of
473 // replications it asked for is a fatal error.
474 status = http.StatusOK
475 _, err = io.WriteString(resp, locatorOut)
477 status = http.StatusServiceUnavailable
481 status = http.StatusBadGateway