7 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "git.curoverse.com/arvados.git/sdk/go/keepclient"
9 "github.com/gorilla/mux"
23 // Default TCP address on which to listen for requests.
24 // Initialized by the -listen flag.
25 const DEFAULT_ADDR = ":25107"
27 var listener net.Listener
39 flagset := flag.NewFlagSet("default", flag.ExitOnError)
45 "Interface on which to listen for requests, in the format "+
46 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
47 "to listen on all network interfaces.")
53 "If set, disable GET operations")
59 "If set, disable PUT operations")
65 "Default number of replicas to write if not specified by the client.")
71 "Timeout on requests to internal Keep services (default 15 seconds)")
77 "Path to write pid file")
79 flagset.Parse(os.Args[1:])
81 arv, err := arvadosclient.MakeArvadosClient()
83 log.Fatalf("Error setting up arvados client %s", err.Error())
86 kc, err := keepclient.MakeKeepClient(&arv)
88 log.Fatalf("Error setting up keep client %s", err.Error())
92 f, err := os.Create(pidfile)
94 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
96 fmt.Fprint(f, os.Getpid())
98 defer os.Remove(pidfile)
101 kc.Want_replicas = default_replicas
103 kc.Client.Timeout = time.Duration(timeout) * time.Second
105 listener, err = net.Listen("tcp", listen)
107 log.Fatalf("Could not listen on %v", listen)
110 go RefreshServicesList(kc)
112 // Shut down the server gracefully (by closing the listener)
113 // if SIGTERM is received.
114 term := make(chan os.Signal, 1)
115 go func(sig <-chan os.Signal) {
117 log.Println("caught signal:", s)
120 signal.Notify(term, syscall.SIGTERM)
121 signal.Notify(term, syscall.SIGINT)
123 log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
125 // Start listening for requests.
126 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
128 log.Println("shutting down")
131 type ApiTokenCache struct {
132 tokens map[string]int64
137 // Refresh the keep service list every five minutes.
138 func RefreshServicesList(kc *keepclient.KeepClient) {
139 var previousRoots = []map[string]string{}
140 var delay time.Duration = 0
142 time.Sleep(delay * time.Second)
144 if err := kc.DiscoverKeepServers(); err != nil {
145 log.Println("Error retrieving services list:", err)
149 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
150 if !reflect.DeepEqual(previousRoots, newRoots) {
151 log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
153 if len(newRoots[0]) == 0 {
154 log.Print("WARNING: No local services. Retrying in 3 seconds.")
157 previousRoots = newRoots
161 // Cache the token and set an expire time. If we already have an expire time
162 // on the token, it is not updated.
163 func (this *ApiTokenCache) RememberToken(token string) {
165 defer this.lock.Unlock()
167 now := time.Now().Unix()
168 if this.tokens[token] == 0 {
169 this.tokens[token] = now + this.expireTime
173 // Check if the cached token is known and still believed to be valid.
174 func (this *ApiTokenCache) RecallToken(token string) bool {
176 defer this.lock.Unlock()
178 now := time.Now().Unix()
179 if this.tokens[token] == 0 {
182 } else if now < this.tokens[token] {
183 // Token is known and still valid
187 this.tokens[token] = 0
192 func GetRemoteAddress(req *http.Request) string {
193 if realip := req.Header.Get("X-Real-IP"); realip != "" {
194 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
195 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
200 return req.RemoteAddr
203 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
205 if auth = req.Header.Get("Authorization"); auth == "" {
209 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
215 if cache.RecallToken(tok) {
216 // Valid in the cache, short circut
222 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
223 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
227 // Success! Update cache
228 cache.RememberToken(tok)
233 type GetBlockHandler struct {
234 *keepclient.KeepClient
238 type PutBlockHandler struct {
239 *keepclient.KeepClient
243 type InvalidPathHandler struct{}
245 type OptionsHandler struct{}
248 // Returns a mux.Router that passes GET and PUT requests to the
249 // appropriate handlers.
254 kc *keepclient.KeepClient) *mux.Router {
256 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
258 rest := mux.NewRouter()
261 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
262 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
263 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
267 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
268 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
269 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
270 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
271 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
274 rest.NotFoundHandler = InvalidPathHandler{}
279 func SetCorsHeaders(resp http.ResponseWriter) {
280 resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
281 resp.Header().Set("Access-Control-Allow-Origin", "*")
282 resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
283 resp.Header().Set("Access-Control-Max-Age", "86486400")
286 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
287 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
288 http.Error(resp, "Bad request", http.StatusBadRequest)
291 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
292 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
296 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
297 var ContentLengthMismatch = errors.New("Actual length != expected content length")
298 var MethodNotSupported = errors.New("Method not supported")
300 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
303 locator := mux.Vars(req)["locator"]
306 var expectLength, responseLength int64
310 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
311 if status != http.StatusOK {
312 http.Error(resp, err.Error(), status)
316 kc := *this.KeepClient
320 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
321 status, err = http.StatusForbidden, BadAuthorizationHeader
325 // Copy ArvadosClient struct and use the client's API token
326 arvclient := *kc.Arvados
327 arvclient.ApiToken = tok
328 kc.Arvados = &arvclient
330 var reader io.ReadCloser
334 expectLength, proxiedURI, err = kc.Ask(locator)
336 reader, expectLength, proxiedURI, err = kc.Get(locator)
341 status, err = http.StatusNotImplemented, MethodNotSupported
345 if expectLength == -1 {
346 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
351 status = http.StatusOK
352 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
357 responseLength, err = io.Copy(resp, reader)
358 if err == nil && expectLength > -1 && responseLength != expectLength {
359 err = ContentLengthMismatch
362 case keepclient.BlockNotFound:
363 status = http.StatusNotFound
365 status = http.StatusBadGateway
369 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
370 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
372 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
375 kc := *this.KeepClient
377 var expectLength int64 = -1
378 var status = http.StatusInternalServerError
379 var wroteReplicas int
380 var locatorOut string = "-"
383 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
384 if status != http.StatusOK {
385 http.Error(resp, err.Error(), status)
389 locatorIn := mux.Vars(req)["locator"]
391 if req.Header.Get("Content-Length") != "" {
392 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
394 resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
399 if expectLength < 0 {
400 err = LengthRequiredError
401 status = http.StatusLengthRequired
406 var loc *keepclient.Locator
407 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
408 status = http.StatusBadRequest
410 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
411 err = LengthMismatchError
412 status = http.StatusBadRequest
419 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
420 err = BadAuthorizationHeader
421 status = http.StatusForbidden
425 // Copy ArvadosClient struct and use the client's API token
426 arvclient := *kc.Arvados
427 arvclient.ApiToken = tok
428 kc.Arvados = &arvclient
430 // Check if the client specified the number of replicas
431 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
433 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
439 // Now try to put the block through
441 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
442 err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
443 status = http.StatusInternalServerError
446 locatorOut, wroteReplicas, err = kc.PutB(bytes)
449 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
452 // Tell the client how many successful PUTs we accomplished
453 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
457 status = http.StatusOK
458 _, err = io.WriteString(resp, locatorOut)
460 case keepclient.OversizeBlockError:
462 status = http.StatusRequestEntityTooLarge
464 case keepclient.InsufficientReplicasError:
465 if wroteReplicas > 0 {
466 // At least one write is considered success. The
467 // client can decide if getting less than the number of
468 // replications it asked for is a fatal error.
469 status = http.StatusOK
470 _, err = io.WriteString(resp, locatorOut)
472 status = http.StatusServiceUnavailable
476 status = http.StatusBadGateway