refs #10028
[arvados.git] / services / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "encoding/json"
5         "errors"
6         "flag"
7         "fmt"
8         "io"
9         "io/ioutil"
10         "log"
11         "net"
12         "net/http"
13         "os"
14         "os/signal"
15         "regexp"
16         "sync"
17         "syscall"
18         "time"
19
20         "git.curoverse.com/arvados.git/sdk/go/arvados"
21         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
22         "git.curoverse.com/arvados.git/sdk/go/config"
23         "git.curoverse.com/arvados.git/sdk/go/keepclient"
24         "github.com/coreos/go-systemd/daemon"
25         "github.com/gorilla/mux"
26 )
27
28 type Config struct {
29         Client          arvados.Client
30         Listen          string
31         DisableGet      bool
32         DisablePut      bool
33         DefaultReplicas int
34         Timeout         arvados.Duration
35         PIDFile         string
36         Debug           bool
37 }
38
39 func DefaultConfig() *Config {
40         return &Config{
41                 Listen:  ":25107",
42                 Timeout: arvados.Duration(15 * time.Second),
43         }
44 }
45
46 var listener net.Listener
47
48 func main() {
49         cfg := DefaultConfig()
50
51         flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
52         flagset.Usage = usage
53
54         const deprecated = " (DEPRECATED -- use config file instead)"
55         flagset.StringVar(&cfg.Listen, "listen", cfg.Listen, "Local port to listen on."+deprecated)
56         flagset.BoolVar(&cfg.DisableGet, "no-get", cfg.DisableGet, "Disable GET operations."+deprecated)
57         flagset.BoolVar(&cfg.DisablePut, "no-put", cfg.DisablePut, "Disable PUT operations."+deprecated)
58         flagset.IntVar(&cfg.DefaultReplicas, "default-replicas", cfg.DefaultReplicas, "Default number of replicas to write if not specified by the client. If 0, use site default."+deprecated)
59         flagset.StringVar(&cfg.PIDFile, "pid", cfg.PIDFile, "Path to write pid file."+deprecated)
60         timeoutSeconds := flagset.Int("timeout", int(time.Duration(cfg.Timeout)/time.Second), "Timeout (in seconds) on requests to internal Keep services."+deprecated)
61
62         var cfgPath string
63         const defaultCfgPath = "/etc/arvados/keepproxy/config.json"
64         flagset.StringVar(&cfgPath, "config", defaultCfgPath, "Configuration file `path`")
65         flagset.Parse(os.Args[1:])
66
67         err := config.LoadFile(cfg, cfgPath)
68         if err != nil {
69                 h := os.Getenv("ARVADOS_API_HOST")
70                 t := os.Getenv("ARVADOS_API_TOKEN")
71                 if h == "" || t == "" || !os.IsNotExist(err) || cfgPath != defaultCfgPath {
72                         log.Fatal(err)
73                 }
74                 log.Print("DEPRECATED: No config file found, but ARVADOS_API_HOST and ARVADOS_API_TOKEN environment variables are set. Please use a config file instead.")
75                 cfg.Client.APIHost = h
76                 cfg.Client.AuthToken = t
77                 if regexp.MustCompile("^(?i:1|yes|true)$").MatchString(os.Getenv("ARVADOS_API_HOST_INSECURE")) {
78                         cfg.Client.Insecure = true
79                 }
80                 if j, err := json.MarshalIndent(cfg, "", "    "); err == nil {
81                         log.Print("Current configuration:\n", string(j))
82                 }
83                 cfg.Timeout = arvados.Duration(time.Duration(*timeoutSeconds) * time.Second)
84         }
85
86         arv, err := arvadosclient.New(&cfg.Client)
87         if err != nil {
88                 log.Fatalf("Error setting up arvados client %s", err.Error())
89         }
90
91         if cfg.Debug {
92                 keepclient.DebugPrintf = log.Printf
93         }
94         kc, err := keepclient.MakeKeepClient(arv)
95         if err != nil {
96                 log.Fatalf("Error setting up keep client %s", err.Error())
97         }
98
99         if cfg.PIDFile != "" {
100                 f, err := os.Create(cfg.PIDFile)
101                 if err != nil {
102                         log.Fatal(err)
103                 }
104                 defer f.Close()
105                 err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
106                 if err != nil {
107                         log.Fatalf("flock(%s): %s", cfg.PIDFile, err)
108                 }
109                 defer os.Remove(cfg.PIDFile)
110                 err = f.Truncate(0)
111                 if err != nil {
112                         log.Fatalf("truncate(%s): %s", cfg.PIDFile, err)
113                 }
114                 _, err = fmt.Fprint(f, os.Getpid())
115                 if err != nil {
116                         log.Fatalf("write(%s): %s", cfg.PIDFile, err)
117                 }
118                 err = f.Sync()
119                 if err != nil {
120                         log.Fatal("sync(%s): %s", cfg.PIDFile, err)
121                 }
122         }
123
124         if cfg.DefaultReplicas > 0 {
125                 kc.Want_replicas = cfg.DefaultReplicas
126         }
127         kc.Client.Timeout = time.Duration(cfg.Timeout)
128         go kc.RefreshServices(5*time.Minute, 3*time.Second)
129
130         listener, err = net.Listen("tcp", cfg.Listen)
131         if err != nil {
132                 log.Fatalf("listen(%s): %s", cfg.Listen, err)
133         }
134         if _, err := daemon.SdNotify("READY=1"); err != nil {
135                 log.Printf("Error notifying init daemon: %v", err)
136         }
137         log.Println("Listening at", listener.Addr())
138
139         // Shut down the server gracefully (by closing the listener)
140         // if SIGTERM is received.
141         term := make(chan os.Signal, 1)
142         go func(sig <-chan os.Signal) {
143                 s := <-sig
144                 log.Println("caught signal:", s)
145                 listener.Close()
146         }(term)
147         signal.Notify(term, syscall.SIGTERM)
148         signal.Notify(term, syscall.SIGINT)
149
150         // Start serving requests.
151         http.Serve(listener, MakeRESTRouter(!cfg.DisableGet, !cfg.DisablePut, kc))
152
153         log.Println("shutting down")
154 }
155
156 type ApiTokenCache struct {
157         tokens     map[string]int64
158         lock       sync.Mutex
159         expireTime int64
160 }
161
162 // Cache the token and set an expire time.  If we already have an expire time
163 // on the token, it is not updated.
164 func (this *ApiTokenCache) RememberToken(token string) {
165         this.lock.Lock()
166         defer this.lock.Unlock()
167
168         now := time.Now().Unix()
169         if this.tokens[token] == 0 {
170                 this.tokens[token] = now + this.expireTime
171         }
172 }
173
174 // Check if the cached token is known and still believed to be valid.
175 func (this *ApiTokenCache) RecallToken(token string) bool {
176         this.lock.Lock()
177         defer this.lock.Unlock()
178
179         now := time.Now().Unix()
180         if this.tokens[token] == 0 {
181                 // Unknown token
182                 return false
183         } else if now < this.tokens[token] {
184                 // Token is known and still valid
185                 return true
186         } else {
187                 // Token is expired
188                 this.tokens[token] = 0
189                 return false
190         }
191 }
192
193 func GetRemoteAddress(req *http.Request) string {
194         if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
195                 return xff + "," + req.RemoteAddr
196         }
197         return req.RemoteAddr
198 }
199
200 func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
201         var auth string
202         if auth = req.Header.Get("Authorization"); auth == "" {
203                 return false, ""
204         }
205
206         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
207         if err != nil {
208                 // Scanning error
209                 return false, ""
210         }
211
212         if cache.RecallToken(tok) {
213                 // Valid in the cache, short circuit
214                 return true, tok
215         }
216
217         arv := *kc.Arvados
218         arv.ApiToken = tok
219         if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
220                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
221                 return false, ""
222         }
223
224         // Success!  Update cache
225         cache.RememberToken(tok)
226
227         return true, tok
228 }
229
230 type GetBlockHandler struct {
231         *keepclient.KeepClient
232         *ApiTokenCache
233 }
234
235 type PutBlockHandler struct {
236         *keepclient.KeepClient
237         *ApiTokenCache
238 }
239
240 type IndexHandler struct {
241         *keepclient.KeepClient
242         *ApiTokenCache
243 }
244
245 type InvalidPathHandler struct{}
246
247 type OptionsHandler struct{}
248
249 // MakeRESTRouter
250 //     Returns a mux.Router that passes GET and PUT requests to the
251 //     appropriate handlers.
252 //
253 func MakeRESTRouter(
254         enable_get bool,
255         enable_put bool,
256         kc *keepclient.KeepClient) *mux.Router {
257
258         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
259
260         rest := mux.NewRouter()
261
262         if enable_get {
263                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
264                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
265                 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
266
267                 // List all blocks
268                 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
269
270                 // List blocks whose hash has the given prefix
271                 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
272         }
273
274         if enable_put {
275                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
276                 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
277                 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
278                 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
279                 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
280         }
281
282         rest.NotFoundHandler = InvalidPathHandler{}
283
284         return rest
285 }
286
287 func SetCorsHeaders(resp http.ResponseWriter) {
288         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
289         resp.Header().Set("Access-Control-Allow-Origin", "*")
290         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
291         resp.Header().Set("Access-Control-Max-Age", "86486400")
292 }
293
294 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
295         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
296         http.Error(resp, "Bad request", http.StatusBadRequest)
297 }
298
299 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
300         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
301         SetCorsHeaders(resp)
302 }
303
304 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
305 var ContentLengthMismatch = errors.New("Actual length != expected content length")
306 var MethodNotSupported = errors.New("Method not supported")
307
308 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
309
310 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
311         SetCorsHeaders(resp)
312
313         locator := mux.Vars(req)["locator"]
314         var err error
315         var status int
316         var expectLength, responseLength int64
317         var proxiedURI = "-"
318
319         defer func() {
320                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
321                 if status != http.StatusOK {
322                         http.Error(resp, err.Error(), status)
323                 }
324         }()
325
326         kc := *this.KeepClient
327
328         var pass bool
329         var tok string
330         if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
331                 status, err = http.StatusForbidden, BadAuthorizationHeader
332                 return
333         }
334
335         // Copy ArvadosClient struct and use the client's API token
336         arvclient := *kc.Arvados
337         arvclient.ApiToken = tok
338         kc.Arvados = &arvclient
339
340         var reader io.ReadCloser
341
342         locator = removeHint.ReplaceAllString(locator, "$1")
343
344         switch req.Method {
345         case "HEAD":
346                 expectLength, proxiedURI, err = kc.Ask(locator)
347         case "GET":
348                 reader, expectLength, proxiedURI, err = kc.Get(locator)
349                 if reader != nil {
350                         defer reader.Close()
351                 }
352         default:
353                 status, err = http.StatusNotImplemented, MethodNotSupported
354                 return
355         }
356
357         if expectLength == -1 {
358                 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
359         }
360
361         switch respErr := err.(type) {
362         case nil:
363                 status = http.StatusOK
364                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
365                 switch req.Method {
366                 case "HEAD":
367                         responseLength = 0
368                 case "GET":
369                         responseLength, err = io.Copy(resp, reader)
370                         if err == nil && expectLength > -1 && responseLength != expectLength {
371                                 err = ContentLengthMismatch
372                         }
373                 }
374         case keepclient.Error:
375                 if respErr == keepclient.BlockNotFound {
376                         status = http.StatusNotFound
377                 } else if respErr.Temporary() {
378                         status = http.StatusBadGateway
379                 } else {
380                         status = 422
381                 }
382         default:
383                 status = http.StatusInternalServerError
384         }
385 }
386
387 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
388 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
389
390 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
391         SetCorsHeaders(resp)
392
393         kc := *this.KeepClient
394         var err error
395         var expectLength int64
396         var status = http.StatusInternalServerError
397         var wroteReplicas int
398         var locatorOut string = "-"
399
400         defer func() {
401                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
402                 if status != http.StatusOK {
403                         http.Error(resp, err.Error(), status)
404                 }
405         }()
406
407         locatorIn := mux.Vars(req)["locator"]
408
409         _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
410         if err != nil || expectLength < 0 {
411                 err = LengthRequiredError
412                 status = http.StatusLengthRequired
413                 return
414         }
415
416         if locatorIn != "" {
417                 var loc *keepclient.Locator
418                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
419                         status = http.StatusBadRequest
420                         return
421                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
422                         err = LengthMismatchError
423                         status = http.StatusBadRequest
424                         return
425                 }
426         }
427
428         var pass bool
429         var tok string
430         if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
431                 err = BadAuthorizationHeader
432                 status = http.StatusForbidden
433                 return
434         }
435
436         // Copy ArvadosClient struct and use the client's API token
437         arvclient := *kc.Arvados
438         arvclient.ApiToken = tok
439         kc.Arvados = &arvclient
440
441         // Check if the client specified the number of replicas
442         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
443                 var r int
444                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
445                 if err == nil {
446                         kc.Want_replicas = r
447                 }
448         }
449
450         // Now try to put the block through
451         if locatorIn == "" {
452                 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
453                         err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
454                         status = http.StatusInternalServerError
455                         return
456                 } else {
457                         locatorOut, wroteReplicas, err = kc.PutB(bytes)
458                 }
459         } else {
460                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
461         }
462
463         // Tell the client how many successful PUTs we accomplished
464         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
465
466         switch err {
467         case nil:
468                 status = http.StatusOK
469                 _, err = io.WriteString(resp, locatorOut)
470
471         case keepclient.OversizeBlockError:
472                 // Too much data
473                 status = http.StatusRequestEntityTooLarge
474
475         case keepclient.InsufficientReplicasError:
476                 if wroteReplicas > 0 {
477                         // At least one write is considered success.  The
478                         // client can decide if getting less than the number of
479                         // replications it asked for is a fatal error.
480                         status = http.StatusOK
481                         _, err = io.WriteString(resp, locatorOut)
482                 } else {
483                         status = http.StatusServiceUnavailable
484                 }
485
486         default:
487                 status = http.StatusBadGateway
488         }
489 }
490
491 // ServeHTTP implementation for IndexHandler
492 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
493 // For each keep server found in LocalRoots:
494 //   Invokes GetIndex using keepclient
495 //   Expects "complete" response (terminating with blank new line)
496 //   Aborts on any errors
497 // Concatenates responses from all those keep servers and returns
498 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
499         SetCorsHeaders(resp)
500
501         prefix := mux.Vars(req)["prefix"]
502         var err error
503         var status int
504
505         defer func() {
506                 if status != http.StatusOK {
507                         http.Error(resp, err.Error(), status)
508                 }
509         }()
510
511         kc := *handler.KeepClient
512
513         ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
514         if !ok {
515                 status, err = http.StatusForbidden, BadAuthorizationHeader
516                 return
517         }
518
519         // Copy ArvadosClient struct and use the client's API token
520         arvclient := *kc.Arvados
521         arvclient.ApiToken = token
522         kc.Arvados = &arvclient
523
524         // Only GET method is supported
525         if req.Method != "GET" {
526                 status, err = http.StatusNotImplemented, MethodNotSupported
527                 return
528         }
529
530         // Get index from all LocalRoots and write to resp
531         var reader io.Reader
532         for uuid := range kc.LocalRoots() {
533                 reader, err = kc.GetIndex(uuid, prefix)
534                 if err != nil {
535                         status = http.StatusBadGateway
536                         return
537                 }
538
539                 _, err = io.Copy(resp, reader)
540                 if err != nil {
541                         status = http.StatusBadGateway
542                         return
543                 }
544         }
545
546         // Got index from all the keep servers and wrote to resp
547         status = http.StatusOK
548         resp.Write([]byte("\n"))
549 }