14715: Keepproxy uses cluster config
[arvados.git] / services / keepproxy / keepproxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package main
6
7 import (
8         "errors"
9         "flag"
10         "fmt"
11         "io"
12         "io/ioutil"
13         "net"
14         "net/http"
15         "os"
16         "os/signal"
17         "regexp"
18         "strings"
19         "sync"
20         "syscall"
21         "time"
22
23         "git.curoverse.com/arvados.git/lib/config"
24         "git.curoverse.com/arvados.git/sdk/go/arvados"
25         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
26         "git.curoverse.com/arvados.git/sdk/go/health"
27         "git.curoverse.com/arvados.git/sdk/go/httpserver"
28         "git.curoverse.com/arvados.git/sdk/go/keepclient"
29         "github.com/coreos/go-systemd/daemon"
30         "github.com/gorilla/mux"
31         log "github.com/sirupsen/logrus"
32         "gopkg.in/yaml.v2"
33 )
34
35 var version = "dev"
36
37 var (
38         listener net.Listener
39         router   http.Handler
40 )
41
42 const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00"
43
44 func configure(logger log.FieldLogger, args []string) *arvados.Cluster {
45         flags := flag.NewFlagSet(args[0], flag.ExitOnError)
46         flags.Usage = usage
47
48         dumpConfig := flags.Bool("dump-config", false, "write current configuration to stdout and exit")
49         getVersion := flags.Bool("version", false, "Print version information and exit.")
50
51         loader := config.NewLoader(os.Stdin, logger)
52         loader.SetupFlags(flags)
53
54         args = loader.MungeLegacyConfigArgs(logger, args[1:], "-legacy-keepproxy-config")
55         flags.Parse(args)
56
57         // Print version information if requested
58         if *getVersion {
59                 fmt.Printf("keepproxy %s\n", version)
60                 return nil
61         }
62
63         cfg, err := loader.Load()
64         if err != nil {
65                 log.Fatal(err)
66         }
67
68         cluster, err := cfg.GetCluster("")
69         if err != nil {
70                 log.Fatal(err)
71         }
72
73         if *dumpConfig {
74                 out, err := yaml.Marshal(cfg)
75                 if err != nil {
76                         log.Fatal(err)
77                 }
78                 _, err = os.Stdout.Write(out)
79                 if err != nil {
80                         log.Fatal(err)
81                 }
82                 return nil
83         }
84         return cluster
85 }
86
87 func main() {
88         logger := log.New()
89         logger.Formatter = &log.JSONFormatter{
90                 TimestampFormat: rfc3339NanoFixed,
91         }
92
93         cluster := configure(logger, os.Args)
94         if cluster == nil {
95                 return
96         }
97
98         log.Printf("keepproxy %s started", version)
99
100         client, err := arvados.NewClientFromConfig(cluster)
101         if err != nil {
102                 log.Fatal(err)
103         }
104         client.AuthToken = cluster.SystemRootToken
105
106         arv, err := arvadosclient.New(client)
107         if err != nil {
108                 log.Fatalf("Error setting up arvados client %s", err.Error())
109         }
110
111         if cluster.SystemLogs.LogLevel == "debug" {
112                 keepclient.DebugPrintf = log.Printf
113         }
114         kc, err := keepclient.MakeKeepClient(arv)
115         if err != nil {
116                 log.Fatalf("Error setting up keep client %s", err.Error())
117         }
118         keepclient.RefreshServiceDiscoveryOnSIGHUP()
119
120         pidFile := "keepproxy"
121         f, err := os.Create(pidFile)
122         if err != nil {
123                 log.Fatal(err)
124         }
125         defer f.Close()
126         err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
127         if err != nil {
128                 log.Fatalf("flock(%s): %s", pidFile, err)
129         }
130         defer os.Remove(pidFile)
131         err = f.Truncate(0)
132         if err != nil {
133                 log.Fatalf("truncate(%s): %s", pidFile, err)
134         }
135         _, err = fmt.Fprint(f, os.Getpid())
136         if err != nil {
137                 log.Fatalf("write(%s): %s", pidFile, err)
138         }
139         err = f.Sync()
140         if err != nil {
141                 log.Fatalf("sync(%s): %s", pidFile, err)
142         }
143
144         if cluster.Collections.DefaultReplication > 0 {
145                 kc.Want_replicas = cluster.Collections.DefaultReplication
146         }
147
148         var listen arvados.URL
149         for listen = range cluster.Services.Keepproxy.InternalURLs {
150                 break
151         }
152         listener, err := net.Listen("tcp", listen.Host)
153         if err != nil {
154                 log.Fatalf("listen(%s): %s", listen, err)
155         }
156
157         if _, err := daemon.SdNotify(false, "READY=1"); err != nil {
158                 log.Printf("Error notifying init daemon: %v", err)
159         }
160         log.Println("Listening at", listener.Addr())
161
162         // Shut down the server gracefully (by closing the listener)
163         // if SIGTERM is received.
164         term := make(chan os.Signal, 1)
165         go func(sig <-chan os.Signal) {
166                 s := <-sig
167                 log.Println("caught signal:", s)
168                 listener.Close()
169         }(term)
170         signal.Notify(term, syscall.SIGTERM)
171         signal.Notify(term, syscall.SIGINT)
172
173         // Start serving requests.
174         router = MakeRESTRouter(kc, time.Duration(cluster.API.KeepServiceRequestTimeout), cluster.SystemRootToken)
175         http.Serve(listener, httpserver.AddRequestIDs(httpserver.LogRequests(router)))
176
177         log.Println("shutting down")
178 }
179
180 type ApiTokenCache struct {
181         tokens     map[string]int64
182         lock       sync.Mutex
183         expireTime int64
184 }
185
186 // Cache the token and set an expire time.  If we already have an expire time
187 // on the token, it is not updated.
188 func (this *ApiTokenCache) RememberToken(token string) {
189         this.lock.Lock()
190         defer this.lock.Unlock()
191
192         now := time.Now().Unix()
193         if this.tokens[token] == 0 {
194                 this.tokens[token] = now + this.expireTime
195         }
196 }
197
198 // Check if the cached token is known and still believed to be valid.
199 func (this *ApiTokenCache) RecallToken(token string) bool {
200         this.lock.Lock()
201         defer this.lock.Unlock()
202
203         now := time.Now().Unix()
204         if this.tokens[token] == 0 {
205                 // Unknown token
206                 return false
207         } else if now < this.tokens[token] {
208                 // Token is known and still valid
209                 return true
210         } else {
211                 // Token is expired
212                 this.tokens[token] = 0
213                 return false
214         }
215 }
216
217 func GetRemoteAddress(req *http.Request) string {
218         if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
219                 return xff + "," + req.RemoteAddr
220         }
221         return req.RemoteAddr
222 }
223
224 func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
225         parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
226         if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 {
227                 return false, ""
228         }
229         tok = parts[1]
230
231         // Tokens are validated differently depending on what kind of
232         // operation is being performed. For example, tokens in
233         // collection-sharing links permit GET requests, but not
234         // PUT requests.
235         var op string
236         if req.Method == "GET" || req.Method == "HEAD" {
237                 op = "read"
238         } else {
239                 op = "write"
240         }
241
242         if cache.RecallToken(op + ":" + tok) {
243                 // Valid in the cache, short circuit
244                 return true, tok
245         }
246
247         var err error
248         arv := *kc.Arvados
249         arv.ApiToken = tok
250         arv.RequestID = req.Header.Get("X-Request-Id")
251         if op == "read" {
252                 err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil)
253         } else {
254                 err = arv.Call("HEAD", "users", "", "current", nil, nil)
255         }
256         if err != nil {
257                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
258                 return false, ""
259         }
260
261         // Success!  Update cache
262         cache.RememberToken(op + ":" + tok)
263
264         return true, tok
265 }
266
267 // We need to make a private copy of the default http transport early
268 // in initialization, then make copies of our private copy later. It
269 // won't be safe to copy http.DefaultTransport itself later, because
270 // its private mutexes might have already been used. (Without this,
271 // the test suite sometimes panics "concurrent map writes" in
272 // net/http.(*Transport).removeIdleConnLocked().)
273 var defaultTransport = *(http.DefaultTransport.(*http.Transport))
274
275 type proxyHandler struct {
276         http.Handler
277         *keepclient.KeepClient
278         *ApiTokenCache
279         timeout   time.Duration
280         transport *http.Transport
281 }
282
283 // MakeRESTRouter returns an http.Handler that passes GET and PUT
284 // requests to the appropriate handlers.
285 func MakeRESTRouter(kc *keepclient.KeepClient, timeout time.Duration, mgmtToken string) http.Handler {
286         rest := mux.NewRouter()
287
288         transport := defaultTransport
289         transport.DialContext = (&net.Dialer{
290                 Timeout:   keepclient.DefaultConnectTimeout,
291                 KeepAlive: keepclient.DefaultKeepAlive,
292                 DualStack: true,
293         }).DialContext
294         transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure)
295         transport.TLSHandshakeTimeout = keepclient.DefaultTLSHandshakeTimeout
296
297         h := &proxyHandler{
298                 Handler:    rest,
299                 KeepClient: kc,
300                 timeout:    timeout,
301                 transport:  &transport,
302                 ApiTokenCache: &ApiTokenCache{
303                         tokens:     make(map[string]int64),
304                         expireTime: 300,
305                 },
306         }
307
308         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD")
309         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD")
310
311         // List all blocks
312         rest.HandleFunc(`/index`, h.Index).Methods("GET")
313
314         // List blocks whose hash has the given prefix
315         rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET")
316
317         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT")
318         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT")
319         rest.HandleFunc(`/`, h.Put).Methods("POST")
320         rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS")
321         rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
322
323         rest.Handle("/_health/{check}", &health.Handler{
324                 Token:  mgmtToken,
325                 Prefix: "/_health/",
326         }).Methods("GET")
327
328         rest.NotFoundHandler = InvalidPathHandler{}
329         return h
330 }
331
332 var errLoopDetected = errors.New("loop detected")
333
334 func (*proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error {
335         if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 {
336                 log.Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via)
337                 http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError)
338                 return errLoopDetected
339         }
340         return nil
341 }
342
343 func SetCorsHeaders(resp http.ResponseWriter) {
344         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
345         resp.Header().Set("Access-Control-Allow-Origin", "*")
346         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
347         resp.Header().Set("Access-Control-Max-Age", "86486400")
348 }
349
350 type InvalidPathHandler struct{}
351
352 func (InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
353         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
354         http.Error(resp, "Bad request", http.StatusBadRequest)
355 }
356
357 func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) {
358         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
359         SetCorsHeaders(resp)
360 }
361
362 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
363 var ContentLengthMismatch = errors.New("Actual length != expected content length")
364 var MethodNotSupported = errors.New("Method not supported")
365
366 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
367
368 func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
369         if err := h.checkLoop(resp, req); err != nil {
370                 return
371         }
372         SetCorsHeaders(resp)
373         resp.Header().Set("Via", req.Proto+" "+viaAlias)
374
375         locator := mux.Vars(req)["locator"]
376         var err error
377         var status int
378         var expectLength, responseLength int64
379         var proxiedURI = "-"
380
381         defer func() {
382                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
383                 if status != http.StatusOK {
384                         http.Error(resp, err.Error(), status)
385                 }
386         }()
387
388         kc := h.makeKeepClient(req)
389
390         var pass bool
391         var tok string
392         if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass {
393                 status, err = http.StatusForbidden, BadAuthorizationHeader
394                 return
395         }
396
397         // Copy ArvadosClient struct and use the client's API token
398         arvclient := *kc.Arvados
399         arvclient.ApiToken = tok
400         kc.Arvados = &arvclient
401
402         var reader io.ReadCloser
403
404         locator = removeHint.ReplaceAllString(locator, "$1")
405
406         switch req.Method {
407         case "HEAD":
408                 expectLength, proxiedURI, err = kc.Ask(locator)
409         case "GET":
410                 reader, expectLength, proxiedURI, err = kc.Get(locator)
411                 if reader != nil {
412                         defer reader.Close()
413                 }
414         default:
415                 status, err = http.StatusNotImplemented, MethodNotSupported
416                 return
417         }
418
419         if expectLength == -1 {
420                 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
421         }
422
423         switch respErr := err.(type) {
424         case nil:
425                 status = http.StatusOK
426                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
427                 switch req.Method {
428                 case "HEAD":
429                         responseLength = 0
430                 case "GET":
431                         responseLength, err = io.Copy(resp, reader)
432                         if err == nil && expectLength > -1 && responseLength != expectLength {
433                                 err = ContentLengthMismatch
434                         }
435                 }
436         case keepclient.Error:
437                 if respErr == keepclient.BlockNotFound {
438                         status = http.StatusNotFound
439                 } else if respErr.Temporary() {
440                         status = http.StatusBadGateway
441                 } else {
442                         status = 422
443                 }
444         default:
445                 status = http.StatusInternalServerError
446         }
447 }
448
449 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
450 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
451
452 func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
453         if err := h.checkLoop(resp, req); err != nil {
454                 return
455         }
456         SetCorsHeaders(resp)
457         resp.Header().Set("Via", "HTTP/1.1 "+viaAlias)
458
459         kc := h.makeKeepClient(req)
460
461         var err error
462         var expectLength int64
463         var status = http.StatusInternalServerError
464         var wroteReplicas int
465         var locatorOut string = "-"
466
467         defer func() {
468                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
469                 if status != http.StatusOK {
470                         http.Error(resp, err.Error(), status)
471                 }
472         }()
473
474         locatorIn := mux.Vars(req)["locator"]
475
476         // Check if the client specified storage classes
477         if req.Header.Get("X-Keep-Storage-Classes") != "" {
478                 var scl []string
479                 for _, sc := range strings.Split(req.Header.Get("X-Keep-Storage-Classes"), ",") {
480                         scl = append(scl, strings.Trim(sc, " "))
481                 }
482                 kc.StorageClasses = scl
483         }
484
485         _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
486         if err != nil || expectLength < 0 {
487                 err = LengthRequiredError
488                 status = http.StatusLengthRequired
489                 return
490         }
491
492         if locatorIn != "" {
493                 var loc *keepclient.Locator
494                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
495                         status = http.StatusBadRequest
496                         return
497                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
498                         err = LengthMismatchError
499                         status = http.StatusBadRequest
500                         return
501                 }
502         }
503
504         var pass bool
505         var tok string
506         if pass, tok = CheckAuthorizationHeader(kc, h.ApiTokenCache, req); !pass {
507                 err = BadAuthorizationHeader
508                 status = http.StatusForbidden
509                 return
510         }
511
512         // Copy ArvadosClient struct and use the client's API token
513         arvclient := *kc.Arvados
514         arvclient.ApiToken = tok
515         kc.Arvados = &arvclient
516
517         // Check if the client specified the number of replicas
518         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
519                 var r int
520                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
521                 if err == nil {
522                         kc.Want_replicas = r
523                 }
524         }
525
526         // Now try to put the block through
527         if locatorIn == "" {
528                 bytes, err2 := ioutil.ReadAll(req.Body)
529                 if err2 != nil {
530                         err = fmt.Errorf("Error reading request body: %s", err2)
531                         status = http.StatusInternalServerError
532                         return
533                 }
534                 locatorOut, wroteReplicas, err = kc.PutB(bytes)
535         } else {
536                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
537         }
538
539         // Tell the client how many successful PUTs we accomplished
540         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
541
542         switch err.(type) {
543         case nil:
544                 status = http.StatusOK
545                 _, err = io.WriteString(resp, locatorOut)
546
547         case keepclient.OversizeBlockError:
548                 // Too much data
549                 status = http.StatusRequestEntityTooLarge
550
551         case keepclient.InsufficientReplicasError:
552                 if wroteReplicas > 0 {
553                         // At least one write is considered success.  The
554                         // client can decide if getting less than the number of
555                         // replications it asked for is a fatal error.
556                         status = http.StatusOK
557                         _, err = io.WriteString(resp, locatorOut)
558                 } else {
559                         status = http.StatusServiceUnavailable
560                 }
561
562         default:
563                 status = http.StatusBadGateway
564         }
565 }
566
567 // ServeHTTP implementation for IndexHandler
568 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
569 // For each keep server found in LocalRoots:
570 //   Invokes GetIndex using keepclient
571 //   Expects "complete" response (terminating with blank new line)
572 //   Aborts on any errors
573 // Concatenates responses from all those keep servers and returns
574 func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
575         SetCorsHeaders(resp)
576
577         prefix := mux.Vars(req)["prefix"]
578         var err error
579         var status int
580
581         defer func() {
582                 if status != http.StatusOK {
583                         http.Error(resp, err.Error(), status)
584                 }
585         }()
586
587         kc := h.makeKeepClient(req)
588         ok, token := CheckAuthorizationHeader(kc, h.ApiTokenCache, req)
589         if !ok {
590                 status, err = http.StatusForbidden, BadAuthorizationHeader
591                 return
592         }
593
594         // Copy ArvadosClient struct and use the client's API token
595         arvclient := *kc.Arvados
596         arvclient.ApiToken = token
597         kc.Arvados = &arvclient
598
599         // Only GET method is supported
600         if req.Method != "GET" {
601                 status, err = http.StatusNotImplemented, MethodNotSupported
602                 return
603         }
604
605         // Get index from all LocalRoots and write to resp
606         var reader io.Reader
607         for uuid := range kc.LocalRoots() {
608                 reader, err = kc.GetIndex(uuid, prefix)
609                 if err != nil {
610                         status = http.StatusBadGateway
611                         return
612                 }
613
614                 _, err = io.Copy(resp, reader)
615                 if err != nil {
616                         status = http.StatusBadGateway
617                         return
618                 }
619         }
620
621         // Got index from all the keep servers and wrote to resp
622         status = http.StatusOK
623         resp.Write([]byte("\n"))
624 }
625
626 func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
627         kc := *h.KeepClient
628         kc.RequestID = req.Header.Get("X-Request-Id")
629         kc.HTTPClient = &proxyClient{
630                 client: &http.Client{
631                         Timeout:   h.timeout,
632                         Transport: h.transport,
633                 },
634                 proto: req.Proto,
635         }
636         return &kc
637 }