4 "git.curoverse.com/arvados.git/sdk/go/keepclient"
5 "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8 "github.com/gorilla/mux"
20 // Default TCP address on which to listen for requests.
21 // Initialized by the -listen flag.
22 const DEFAULT_ADDR = ":25107"
24 var listener net.Listener
35 flagset := flag.NewFlagSet("default", flag.ExitOnError)
41 "Interface on which to listen for requests, in the format "+
42 "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
43 "to listen on all network interfaces.")
49 "If set, disable GET operations")
55 "If set, disable PUT operations")
61 "Default number of replicas to write if not specified by the client.")
67 "Path to write pid file")
69 flagset.Parse(os.Args[1:])
71 arv, err := arvadosclient.MakeArvadosClient()
73 log.Fatalf("Error setting up arvados client %s", err.Error())
76 kc, err := keepclient.MakeKeepClient(&arv)
78 log.Fatalf("Error setting up keep client %s", err.Error())
82 f, err := os.Create(pidfile)
84 log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
86 fmt.Fprint(f, os.Getpid())
88 defer os.Remove(pidfile)
91 kc.Want_replicas = default_replicas
93 listener, err = net.Listen("tcp", listen)
95 log.Fatalf("Could not listen on %v", listen)
98 go RefreshServicesList(&kc)
100 // Shut down the server gracefully (by closing the listener)
101 // if SIGTERM is received.
102 term := make(chan os.Signal, 1)
103 go func(sig <-chan os.Signal) {
105 log.Println("caught signal:", s)
108 signal.Notify(term, syscall.SIGTERM)
109 signal.Notify(term, syscall.SIGINT)
111 log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
113 // Start listening for requests.
114 http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
116 log.Println("shutting down")
119 type ApiTokenCache struct {
120 tokens map[string]int64
125 // Refresh the keep service list every five minutes.
126 func RefreshServicesList(kc *keepclient.KeepClient) {
128 time.Sleep(300 * time.Second)
129 oldservices := kc.ServiceRoots()
130 kc.DiscoverKeepServers()
131 newservices := kc.ServiceRoots()
132 s1 := fmt.Sprint(oldservices)
133 s2 := fmt.Sprint(newservices)
135 log.Printf("Updated server list to %v", s2)
140 // Cache the token and set an expire time. If we already have an expire time
141 // on the token, it is not updated.
142 func (this *ApiTokenCache) RememberToken(token string) {
144 defer this.lock.Unlock()
146 now := time.Now().Unix()
147 if this.tokens[token] == 0 {
148 this.tokens[token] = now + this.expireTime
152 // Check if the cached token is known and still believed to be valid.
153 func (this *ApiTokenCache) RecallToken(token string) bool {
155 defer this.lock.Unlock()
157 now := time.Now().Unix()
158 if this.tokens[token] == 0 {
161 } else if now < this.tokens[token] {
162 // Token is known and still valid
166 this.tokens[token] = 0
171 func GetRemoteAddress(req *http.Request) string {
172 if realip := req.Header.Get("X-Real-IP"); realip != "" {
173 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
174 return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
179 return req.RemoteAddr
182 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
184 if auth = req.Header.Get("Authorization"); auth == "" {
188 _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
194 if cache.RecallToken(tok) {
195 // Valid in the cache, short circut
201 if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
202 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
206 // Success! Update cache
207 cache.RememberToken(tok)
212 type GetBlockHandler struct {
213 *keepclient.KeepClient
217 type PutBlockHandler struct {
218 *keepclient.KeepClient
222 type InvalidPathHandler struct{}
225 // Returns a mux.Router that passes GET and PUT requests to the
226 // appropriate handlers.
231 kc *keepclient.KeepClient) *mux.Router {
233 t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
235 rest := mux.NewRouter()
238 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
239 GetBlockHandler{kc, t}).Methods("GET", "HEAD")
240 rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
244 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
245 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
248 rest.NotFoundHandler = InvalidPathHandler{}
253 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
254 log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
255 http.Error(resp, "Bad request", http.StatusBadRequest)
258 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
260 kc := *this.KeepClient
262 hash := mux.Vars(req)["hash"]
263 hints := mux.Vars(req)["hints"]
265 locator := keepclient.MakeLocator2(hash, hints)
267 log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
271 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
272 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
276 // Copy ArvadosClient struct and use the client's API token
277 arvclient := *kc.Arvados
278 arvclient.ApiToken = tok
279 kc.Arvados = &arvclient
281 var reader io.ReadCloser
285 if req.Method == "GET" {
286 reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
288 } else if req.Method == "HEAD" {
289 blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
293 resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
299 n, err2 := io.Copy(resp, reader)
301 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error %v", GetRemoteAddress(req), req.Method, hash, n, blocklen, err2)
302 } else if err2 == nil {
303 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
305 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
308 log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
310 case keepclient.BlockNotFound:
311 http.Error(resp, "Not found", http.StatusNotFound)
313 http.Error(resp, err.Error(), http.StatusBadGateway)
317 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
321 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
323 kc := *this.KeepClient
325 hash := mux.Vars(req)["hash"]
326 hints := mux.Vars(req)["hints"]
328 locator := keepclient.MakeLocator2(hash, hints)
330 var contentLength int64 = -1
331 if req.Header.Get("Content-Length") != "" {
332 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
334 resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
339 log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
341 if contentLength < 1 {
342 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
346 if locator.Size > 0 && int64(locator.Size) != contentLength {
347 http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
353 if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
354 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
358 // Copy ArvadosClient struct and use the client's API token
359 arvclient := *kc.Arvados
360 arvclient.ApiToken = tok
361 kc.Arvados = &arvclient
363 // Check if the client specified the number of replicas
364 if req.Header.Get("X-Keep-Desired-Replicas") != "" {
366 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
372 // Now try to put the block through
373 hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
375 // Tell the client how many successful PUTs we accomplished
376 resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
380 // Default will return http.StatusOK
381 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
382 n, err2 := io.WriteString(resp, hash)
384 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
387 case keepclient.OversizeBlockError:
389 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
391 case keepclient.InsufficientReplicasError:
393 // At least one write is considered success. The
394 // client can decide if getting less than the number of
395 // replications it asked for is a fatal error.
396 // Default will return http.StatusOK
397 n, err2 := io.WriteString(resp, hash)
399 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
402 http.Error(resp, "", http.StatusServiceUnavailable)
406 http.Error(resp, err.Error(), http.StatusBadGateway)
410 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())