eaa64b0ed2bf4dcf38a10ac2490de7de051709ab
[arvados.git] / services / keepproxy / keepproxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepproxy
6
7 import (
8         "context"
9         "errors"
10         "fmt"
11         "io"
12         "io/ioutil"
13         "net"
14         "net/http"
15         "regexp"
16         "strings"
17         "time"
18
19         "git.arvados.org/arvados.git/lib/service"
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "git.arvados.org/arvados.git/sdk/go/arvadosclient"
22         "git.arvados.org/arvados.git/sdk/go/ctxlog"
23         "git.arvados.org/arvados.git/sdk/go/health"
24         "git.arvados.org/arvados.git/sdk/go/httpserver"
25         "git.arvados.org/arvados.git/sdk/go/keepclient"
26         "github.com/gorilla/mux"
27         lru "github.com/hashicorp/golang-lru"
28         "github.com/prometheus/client_golang/prometheus"
29         "github.com/sirupsen/logrus"
30 )
31
32 const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00"
33
34 var Command = service.Command(arvados.ServiceNameKeepproxy, newHandlerOrErrorHandler)
35
36 func newHandlerOrErrorHandler(ctx context.Context, cluster *arvados.Cluster, token string, reg *prometheus.Registry) service.Handler {
37         client, err := arvados.NewClientFromConfig(cluster)
38         if err != nil {
39                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
40         }
41         arv, err := arvadosclient.New(client)
42         if err != nil {
43                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
44         }
45         kc, err := keepclient.MakeKeepClient(arv)
46         if err != nil {
47                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up keep client: %w", err))
48         }
49         keepclient.RefreshServiceDiscoveryOnSIGHUP()
50         router, err := newHandler(ctx, kc, time.Duration(keepclient.DefaultProxyRequestTimeout), cluster)
51         if err != nil {
52                 return service.ErrorHandler(ctx, cluster, err)
53         }
54         return router
55 }
56
57 type tokenCacheEntry struct {
58         expire int64
59         user   *arvados.User
60 }
61
62 type apiTokenCache struct {
63         tokens     *lru.TwoQueueCache
64         expireTime int64
65 }
66
67 // RememberToken caches the token and set an expire time.  If the
68 // token is already in the cache, it is not updated.
69 func (cache *apiTokenCache) RememberToken(token string, user *arvados.User) {
70         now := time.Now().Unix()
71         _, ok := cache.tokens.Get(token)
72         if !ok {
73                 cache.tokens.Add(token, tokenCacheEntry{
74                         expire: now + cache.expireTime,
75                         user:   user,
76                 })
77         }
78 }
79
80 // RecallToken checks if the cached token is known and still believed to be
81 // valid.
82 func (cache *apiTokenCache) RecallToken(token string) (bool, *arvados.User) {
83         val, ok := cache.tokens.Get(token)
84         if !ok {
85                 return false, nil
86         }
87
88         cacheEntry := val.(tokenCacheEntry)
89         now := time.Now().Unix()
90         if now < cacheEntry.expire {
91                 // Token is known and still valid
92                 return true, cacheEntry.user
93         } else {
94                 // Token is expired
95                 cache.tokens.Remove(token)
96                 return false, nil
97         }
98 }
99
100 func (h *proxyHandler) Done() <-chan struct{} {
101         return nil
102 }
103
104 func (h *proxyHandler) CheckHealth() error {
105         return nil
106 }
107
108 func (h *proxyHandler) checkAuthorizationHeader(req *http.Request) (pass bool, tok string, user *arvados.User) {
109         parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
110         if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 {
111                 return false, "", nil
112         }
113         tok = parts[1]
114
115         // Tokens are validated differently depending on what kind of
116         // operation is being performed. For example, tokens in
117         // collection-sharing links permit GET requests, but not
118         // PUT requests.
119         var op string
120         if req.Method == "GET" || req.Method == "HEAD" {
121                 op = "read"
122         } else {
123                 op = "write"
124         }
125
126         if ok, user := h.apiTokenCache.RecallToken(op + ":" + tok); ok {
127                 // Valid in the cache, short circuit
128                 return true, tok, user
129         }
130
131         var err error
132         arv := *h.KeepClient.Arvados
133         arv.ApiToken = tok
134         arv.RequestID = req.Header.Get("X-Request-Id")
135         user = &arvados.User{}
136         userCurrentError := arv.Call("GET", "users", "", "current", nil, user)
137         err = userCurrentError
138         if err != nil && op == "read" {
139                 apiError, ok := err.(arvadosclient.APIServerError)
140                 if ok && apiError.HttpStatusCode == http.StatusForbidden {
141                         // If it was a scoped "sharing" token it will
142                         // return 403 instead of 401 for the current
143                         // user check.  If it is a download operation
144                         // and they have permission to read the
145                         // keep_services table, we can allow it.
146                         err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil)
147                 }
148         }
149         if err != nil {
150                 ctxlog.FromContext(req.Context()).WithError(err).Info("checkAuthorizationHeader error")
151                 return false, "", nil
152         }
153
154         if userCurrentError == nil && user.IsAdmin {
155                 // checking userCurrentError is probably redundant,
156                 // IsAdmin would be false anyway. But can't hurt.
157                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.Admin.Download {
158                         return false, "", nil
159                 }
160                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.Admin.Upload {
161                         return false, "", nil
162                 }
163         } else {
164                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.User.Download {
165                         return false, "", nil
166                 }
167                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.User.Upload {
168                         return false, "", nil
169                 }
170         }
171
172         // Success!  Update cache
173         h.apiTokenCache.RememberToken(op+":"+tok, user)
174
175         return true, tok, user
176 }
177
178 // We need to make a private copy of the default http transport early
179 // in initialization, then make copies of our private copy later. It
180 // won't be safe to copy http.DefaultTransport itself later, because
181 // its private mutexes might have already been used. (Without this,
182 // the test suite sometimes panics "concurrent map writes" in
183 // net/http.(*Transport).removeIdleConnLocked().)
184 var defaultTransport = *(http.DefaultTransport.(*http.Transport))
185
186 type proxyHandler struct {
187         http.Handler
188         *keepclient.KeepClient
189         *apiTokenCache
190         timeout   time.Duration
191         transport *http.Transport
192         cluster   *arvados.Cluster
193 }
194
195 func newHandler(ctx context.Context, kc *keepclient.KeepClient, timeout time.Duration, cluster *arvados.Cluster) (service.Handler, error) {
196         rest := mux.NewRouter()
197
198         transport := defaultTransport
199         transport.DialContext = (&net.Dialer{
200                 Timeout:   keepclient.DefaultConnectTimeout,
201                 KeepAlive: keepclient.DefaultKeepAlive,
202                 DualStack: true,
203         }).DialContext
204         transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure)
205         transport.TLSHandshakeTimeout = keepclient.DefaultTLSHandshakeTimeout
206
207         cacheQ, err := lru.New2Q(500)
208         if err != nil {
209                 return nil, fmt.Errorf("Error from lru.New2Q: %v", err)
210         }
211
212         h := &proxyHandler{
213                 Handler:    rest,
214                 KeepClient: kc,
215                 timeout:    timeout,
216                 transport:  &transport,
217                 apiTokenCache: &apiTokenCache{
218                         tokens:     cacheQ,
219                         expireTime: 300,
220                 },
221                 cluster: cluster,
222         }
223
224         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD")
225         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD")
226
227         // List all blocks
228         rest.HandleFunc(`/index`, h.Index).Methods("GET")
229
230         // List blocks whose hash has the given prefix
231         rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET")
232
233         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT")
234         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT")
235         rest.HandleFunc(`/`, h.Put).Methods("POST")
236         rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS")
237         rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
238
239         rest.Handle("/_health/{check}", &health.Handler{
240                 Token:  cluster.ManagementToken,
241                 Prefix: "/_health/",
242         }).Methods("GET")
243
244         rest.NotFoundHandler = invalidPathHandler{}
245         return h, nil
246 }
247
248 var errLoopDetected = errors.New("loop detected")
249
250 func (h *proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error {
251         if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 {
252                 ctxlog.FromContext(req.Context()).Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via)
253                 http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError)
254                 return errLoopDetected
255         }
256         return nil
257 }
258
259 func setCORSHeaders(resp http.ResponseWriter) {
260         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
261         resp.Header().Set("Access-Control-Allow-Origin", "*")
262         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
263         resp.Header().Set("Access-Control-Max-Age", "86486400")
264 }
265
266 type invalidPathHandler struct{}
267
268 func (invalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
269         http.Error(resp, "Bad request", http.StatusBadRequest)
270 }
271
272 func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) {
273         setCORSHeaders(resp)
274 }
275
276 var errBadAuthorizationHeader = errors.New("Missing or invalid Authorization header, or method not allowed")
277 var errContentLengthMismatch = errors.New("Actual length != expected content length")
278 var errMethodNotSupported = errors.New("Method not supported")
279
280 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
281
282 func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
283         if err := h.checkLoop(resp, req); err != nil {
284                 return
285         }
286         setCORSHeaders(resp)
287         resp.Header().Set("Via", req.Proto+" "+viaAlias)
288
289         locator := mux.Vars(req)["locator"]
290         var err error
291         var status int
292         var expectLength, responseLength int64
293         var proxiedURI = "-"
294
295         logger := ctxlog.FromContext(req.Context())
296         defer func() {
297                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
298                         "expectLength":   expectLength,
299                         "responseLength": responseLength,
300                         "proxiedURI":     proxiedURI,
301                         "err":            err,
302                 })
303                 if status != http.StatusOK {
304                         http.Error(resp, err.Error(), status)
305                 }
306         }()
307
308         kc := h.makeKeepClient(req)
309
310         var pass bool
311         var tok string
312         var user *arvados.User
313         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
314                 status, err = http.StatusForbidden, errBadAuthorizationHeader
315                 return
316         }
317         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
318                 "userUUID":     user.UUID,
319                 "userFullName": user.FullName,
320         })
321
322         // Copy ArvadosClient struct and use the client's API token
323         arvclient := *kc.Arvados
324         arvclient.ApiToken = tok
325         kc.Arvados = &arvclient
326
327         var reader io.ReadCloser
328
329         locator = removeHint.ReplaceAllString(locator, "$1")
330
331         switch req.Method {
332         case "HEAD":
333                 expectLength, proxiedURI, err = kc.Ask(locator)
334         case "GET":
335                 reader, expectLength, proxiedURI, err = kc.Get(locator)
336                 if reader != nil {
337                         defer reader.Close()
338                 }
339         default:
340                 status, err = http.StatusNotImplemented, errMethodNotSupported
341                 return
342         }
343
344         if expectLength == -1 {
345                 logger.Warn("Content-Length not provided")
346         }
347
348         switch respErr := err.(type) {
349         case nil:
350                 status = http.StatusOK
351                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
352                 switch req.Method {
353                 case "HEAD":
354                         responseLength = 0
355                 case "GET":
356                         responseLength, err = io.Copy(resp, reader)
357                         if err == nil && expectLength > -1 && responseLength != expectLength {
358                                 err = errContentLengthMismatch
359                         }
360                 }
361         case keepclient.Error:
362                 if respErr == keepclient.BlockNotFound {
363                         status = http.StatusNotFound
364                 } else if respErr.Temporary() {
365                         status = http.StatusBadGateway
366                 } else {
367                         status = 422
368                 }
369         default:
370                 status = http.StatusInternalServerError
371         }
372 }
373
374 var errLengthRequired = errors.New(http.StatusText(http.StatusLengthRequired))
375 var errLengthMismatch = errors.New("Locator size hint does not match Content-Length header")
376
377 func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
378         if err := h.checkLoop(resp, req); err != nil {
379                 return
380         }
381         setCORSHeaders(resp)
382         resp.Header().Set("Via", "HTTP/1.1 "+viaAlias)
383
384         kc := h.makeKeepClient(req)
385
386         var err error
387         var expectLength int64
388         var status = http.StatusInternalServerError
389         var wroteReplicas int
390         var locatorOut string = "-"
391
392         defer func() {
393                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
394                         "expectLength":  expectLength,
395                         "wantReplicas":  kc.Want_replicas,
396                         "wroteReplicas": wroteReplicas,
397                         "locator":       strings.SplitN(locatorOut, "+A", 2)[0],
398                         "err":           err,
399                 })
400                 if status != http.StatusOK {
401                         http.Error(resp, err.Error(), status)
402                 }
403         }()
404
405         locatorIn := mux.Vars(req)["locator"]
406
407         // Check if the client specified storage classes
408         if req.Header.Get("X-Keep-Storage-Classes") != "" {
409                 var scl []string
410                 for _, sc := range strings.Split(req.Header.Get("X-Keep-Storage-Classes"), ",") {
411                         scl = append(scl, strings.Trim(sc, " "))
412                 }
413                 kc.SetStorageClasses(scl)
414         }
415
416         _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
417         if err != nil || expectLength < 0 {
418                 err = errLengthRequired
419                 status = http.StatusLengthRequired
420                 return
421         }
422
423         if locatorIn != "" {
424                 var loc *keepclient.Locator
425                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
426                         status = http.StatusBadRequest
427                         return
428                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
429                         err = errLengthMismatch
430                         status = http.StatusBadRequest
431                         return
432                 }
433         }
434
435         var pass bool
436         var tok string
437         var user *arvados.User
438         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
439                 err = errBadAuthorizationHeader
440                 status = http.StatusForbidden
441                 return
442         }
443         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
444                 "userUUID":     user.UUID,
445                 "userFullName": user.FullName,
446         })
447
448         // Copy ArvadosClient struct and use the client's API token
449         arvclient := *kc.Arvados
450         arvclient.ApiToken = tok
451         kc.Arvados = &arvclient
452
453         // Check if the client specified the number of replicas
454         if desiredReplicas := req.Header.Get(keepclient.XKeepDesiredReplicas); desiredReplicas != "" {
455                 var r int
456                 _, err := fmt.Sscanf(desiredReplicas, "%d", &r)
457                 if err == nil {
458                         kc.Want_replicas = r
459                 }
460         }
461
462         // Now try to put the block through
463         if locatorIn == "" {
464                 bytes, err2 := ioutil.ReadAll(req.Body)
465                 if err2 != nil {
466                         err = fmt.Errorf("Error reading request body: %s", err2)
467                         status = http.StatusInternalServerError
468                         return
469                 }
470                 locatorOut, wroteReplicas, err = kc.PutB(bytes)
471         } else {
472                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
473         }
474
475         // Tell the client how many successful PUTs we accomplished
476         resp.Header().Set(keepclient.XKeepReplicasStored, fmt.Sprintf("%d", wroteReplicas))
477
478         switch err.(type) {
479         case nil:
480                 status = http.StatusOK
481                 if len(kc.StorageClasses) > 0 {
482                         // A successful PUT request with storage classes means that all
483                         // storage classes were fulfilled, so the client will get a
484                         // confirmation via the X-Storage-Classes-Confirmed header.
485                         hdr := ""
486                         isFirst := true
487                         for _, sc := range kc.StorageClasses {
488                                 if isFirst {
489                                         hdr = fmt.Sprintf("%s=%d", sc, wroteReplicas)
490                                         isFirst = false
491                                 } else {
492                                         hdr += fmt.Sprintf(", %s=%d", sc, wroteReplicas)
493                                 }
494                         }
495                         resp.Header().Set(keepclient.XKeepStorageClassesConfirmed, hdr)
496                 }
497                 _, err = io.WriteString(resp, locatorOut)
498         case keepclient.OversizeBlockError:
499                 // Too much data
500                 status = http.StatusRequestEntityTooLarge
501         case keepclient.InsufficientReplicasError:
502                 status = http.StatusServiceUnavailable
503         default:
504                 status = http.StatusBadGateway
505         }
506 }
507
508 // ServeHTTP implementation for IndexHandler
509 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
510 // For each keep server found in LocalRoots:
511 //   Invokes GetIndex using keepclient
512 //   Expects "complete" response (terminating with blank new line)
513 //   Aborts on any errors
514 // Concatenates responses from all those keep servers and returns
515 func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
516         setCORSHeaders(resp)
517
518         prefix := mux.Vars(req)["prefix"]
519         var err error
520         var status int
521
522         defer func() {
523                 if status != http.StatusOK {
524                         http.Error(resp, err.Error(), status)
525                 }
526         }()
527
528         kc := h.makeKeepClient(req)
529         ok, token, _ := h.checkAuthorizationHeader(req)
530         if !ok {
531                 status, err = http.StatusForbidden, errBadAuthorizationHeader
532                 return
533         }
534
535         // Copy ArvadosClient struct and use the client's API token
536         arvclient := *kc.Arvados
537         arvclient.ApiToken = token
538         kc.Arvados = &arvclient
539
540         // Only GET method is supported
541         if req.Method != "GET" {
542                 status, err = http.StatusNotImplemented, errMethodNotSupported
543                 return
544         }
545
546         // Get index from all LocalRoots and write to resp
547         var reader io.Reader
548         for uuid := range kc.LocalRoots() {
549                 reader, err = kc.GetIndex(uuid, prefix)
550                 if err != nil {
551                         status = http.StatusBadGateway
552                         return
553                 }
554
555                 _, err = io.Copy(resp, reader)
556                 if err != nil {
557                         status = http.StatusBadGateway
558                         return
559                 }
560         }
561
562         // Got index from all the keep servers and wrote to resp
563         status = http.StatusOK
564         resp.Write([]byte("\n"))
565 }
566
567 func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
568         kc := *h.KeepClient
569         kc.RequestID = req.Header.Get("X-Request-Id")
570         kc.HTTPClient = &proxyClient{
571                 client: &http.Client{
572                         Timeout:   h.timeout,
573                         Transport: h.transport,
574                 },
575                 proto: req.Proto,
576         }
577         return &kc
578 }