20318: Bypass disk cache in certain tests, keepstore, and keepproxy.
[arvados.git] / services / keepproxy / keepproxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepproxy
6
7 import (
8         "context"
9         "errors"
10         "fmt"
11         "io"
12         "io/ioutil"
13         "net"
14         "net/http"
15         "regexp"
16         "strings"
17         "time"
18
19         "git.arvados.org/arvados.git/lib/service"
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "git.arvados.org/arvados.git/sdk/go/arvadosclient"
22         "git.arvados.org/arvados.git/sdk/go/ctxlog"
23         "git.arvados.org/arvados.git/sdk/go/health"
24         "git.arvados.org/arvados.git/sdk/go/httpserver"
25         "git.arvados.org/arvados.git/sdk/go/keepclient"
26         "github.com/gorilla/mux"
27         lru "github.com/hashicorp/golang-lru"
28         "github.com/prometheus/client_golang/prometheus"
29         "github.com/sirupsen/logrus"
30 )
31
32 const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00"
33
34 var Command = service.Command(arvados.ServiceNameKeepproxy, newHandlerOrErrorHandler)
35
36 func newHandlerOrErrorHandler(ctx context.Context, cluster *arvados.Cluster, token string, reg *prometheus.Registry) service.Handler {
37         client, err := arvados.NewClientFromConfig(cluster)
38         if err != nil {
39                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
40         }
41         arv, err := arvadosclient.New(client)
42         if err != nil {
43                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
44         }
45         kc, err := keepclient.MakeKeepClient(arv)
46         if err != nil {
47                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up keep client: %w", err))
48         }
49         keepclient.RefreshServiceDiscoveryOnSIGHUP()
50         router, err := newHandler(ctx, kc, time.Duration(keepclient.DefaultProxyRequestTimeout), cluster)
51         if err != nil {
52                 return service.ErrorHandler(ctx, cluster, err)
53         }
54         return router
55 }
56
57 type tokenCacheEntry struct {
58         expire int64
59         user   *arvados.User
60 }
61
62 type apiTokenCache struct {
63         tokens     *lru.TwoQueueCache
64         expireTime int64
65 }
66
67 // RememberToken caches the token and set an expire time.  If the
68 // token is already in the cache, it is not updated.
69 func (cache *apiTokenCache) RememberToken(token string, user *arvados.User) {
70         now := time.Now().Unix()
71         _, ok := cache.tokens.Get(token)
72         if !ok {
73                 cache.tokens.Add(token, tokenCacheEntry{
74                         expire: now + cache.expireTime,
75                         user:   user,
76                 })
77         }
78 }
79
80 // RecallToken checks if the cached token is known and still believed to be
81 // valid.
82 func (cache *apiTokenCache) RecallToken(token string) (bool, *arvados.User) {
83         val, ok := cache.tokens.Get(token)
84         if !ok {
85                 return false, nil
86         }
87
88         cacheEntry := val.(tokenCacheEntry)
89         now := time.Now().Unix()
90         if now < cacheEntry.expire {
91                 // Token is known and still valid
92                 return true, cacheEntry.user
93         } else {
94                 // Token is expired
95                 cache.tokens.Remove(token)
96                 return false, nil
97         }
98 }
99
100 func (h *proxyHandler) Done() <-chan struct{} {
101         return nil
102 }
103
104 func (h *proxyHandler) CheckHealth() error {
105         return nil
106 }
107
108 func (h *proxyHandler) checkAuthorizationHeader(req *http.Request) (pass bool, tok string, user *arvados.User) {
109         parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
110         if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 {
111                 return false, "", nil
112         }
113         tok = parts[1]
114
115         // Tokens are validated differently depending on what kind of
116         // operation is being performed. For example, tokens in
117         // collection-sharing links permit GET requests, but not
118         // PUT requests.
119         var op string
120         if req.Method == "GET" || req.Method == "HEAD" {
121                 op = "read"
122         } else {
123                 op = "write"
124         }
125
126         if ok, user := h.apiTokenCache.RecallToken(op + ":" + tok); ok {
127                 // Valid in the cache, short circuit
128                 return true, tok, user
129         }
130
131         var err error
132         arv := *h.KeepClient.Arvados
133         arv.ApiToken = tok
134         arv.RequestID = req.Header.Get("X-Request-Id")
135         user = &arvados.User{}
136         userCurrentError := arv.Call("GET", "users", "", "current", nil, user)
137         err = userCurrentError
138         if err != nil && op == "read" {
139                 apiError, ok := err.(arvadosclient.APIServerError)
140                 if ok && apiError.HttpStatusCode == http.StatusForbidden {
141                         // If it was a scoped "sharing" token it will
142                         // return 403 instead of 401 for the current
143                         // user check.  If it is a download operation
144                         // and they have permission to read the
145                         // keep_services table, we can allow it.
146                         err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil)
147                 }
148         }
149         if err != nil {
150                 ctxlog.FromContext(req.Context()).WithError(err).Info("checkAuthorizationHeader error")
151                 return false, "", nil
152         }
153
154         if userCurrentError == nil && user.IsAdmin {
155                 // checking userCurrentError is probably redundant,
156                 // IsAdmin would be false anyway. But can't hurt.
157                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.Admin.Download {
158                         return false, "", nil
159                 }
160                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.Admin.Upload {
161                         return false, "", nil
162                 }
163         } else {
164                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.User.Download {
165                         return false, "", nil
166                 }
167                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.User.Upload {
168                         return false, "", nil
169                 }
170         }
171
172         // Success!  Update cache
173         h.apiTokenCache.RememberToken(op+":"+tok, user)
174
175         return true, tok, user
176 }
177
178 // We can't copy the default http transport because http.Transport has
179 // a mutex field, so we make our own using the values of the exported
180 // fields.
181 var defaultTransport = http.Transport{
182         Proxy:                 http.DefaultTransport.(*http.Transport).Proxy,
183         DialContext:           http.DefaultTransport.(*http.Transport).DialContext,
184         ForceAttemptHTTP2:     http.DefaultTransport.(*http.Transport).ForceAttemptHTTP2,
185         MaxIdleConns:          http.DefaultTransport.(*http.Transport).MaxIdleConns,
186         IdleConnTimeout:       http.DefaultTransport.(*http.Transport).IdleConnTimeout,
187         TLSHandshakeTimeout:   http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout,
188         ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout,
189 }
190
191 type proxyHandler struct {
192         http.Handler
193         *keepclient.KeepClient
194         *apiTokenCache
195         timeout   time.Duration
196         transport *http.Transport
197         cluster   *arvados.Cluster
198 }
199
200 func newHandler(ctx context.Context, kc *keepclient.KeepClient, timeout time.Duration, cluster *arvados.Cluster) (service.Handler, error) {
201         rest := mux.NewRouter()
202
203         // We can't copy the default http transport because
204         // http.Transport has a mutex field, so we copy the fields
205         // that we know have non-zero values in http.DefaultTransport.
206         transport := &http.Transport{
207                 Proxy:                 http.DefaultTransport.(*http.Transport).Proxy,
208                 ForceAttemptHTTP2:     http.DefaultTransport.(*http.Transport).ForceAttemptHTTP2,
209                 MaxIdleConns:          http.DefaultTransport.(*http.Transport).MaxIdleConns,
210                 IdleConnTimeout:       http.DefaultTransport.(*http.Transport).IdleConnTimeout,
211                 ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout,
212                 DialContext: (&net.Dialer{
213                         Timeout:   keepclient.DefaultConnectTimeout,
214                         KeepAlive: keepclient.DefaultKeepAlive,
215                         DualStack: true,
216                 }).DialContext,
217                 TLSClientConfig:     arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure),
218                 TLSHandshakeTimeout: keepclient.DefaultTLSHandshakeTimeout,
219         }
220
221         cacheQ, err := lru.New2Q(500)
222         if err != nil {
223                 return nil, fmt.Errorf("Error from lru.New2Q: %v", err)
224         }
225
226         h := &proxyHandler{
227                 Handler:    rest,
228                 KeepClient: kc,
229                 timeout:    timeout,
230                 transport:  transport,
231                 apiTokenCache: &apiTokenCache{
232                         tokens:     cacheQ,
233                         expireTime: 300,
234                 },
235                 cluster: cluster,
236         }
237
238         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD")
239         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD")
240
241         // List all blocks
242         rest.HandleFunc(`/index`, h.Index).Methods("GET")
243
244         // List blocks whose hash has the given prefix
245         rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET")
246
247         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT")
248         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT")
249         rest.HandleFunc(`/`, h.Put).Methods("POST")
250         rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS")
251         rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
252
253         rest.Handle("/_health/{check}", &health.Handler{
254                 Token:  cluster.ManagementToken,
255                 Prefix: "/_health/",
256         }).Methods("GET")
257
258         rest.NotFoundHandler = invalidPathHandler{}
259         return h, nil
260 }
261
262 var errLoopDetected = errors.New("loop detected")
263
264 func (h *proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error {
265         if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 {
266                 ctxlog.FromContext(req.Context()).Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via)
267                 http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError)
268                 return errLoopDetected
269         }
270         return nil
271 }
272
273 func setCORSHeaders(resp http.ResponseWriter) {
274         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
275         resp.Header().Set("Access-Control-Allow-Origin", "*")
276         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
277         resp.Header().Set("Access-Control-Max-Age", "86486400")
278 }
279
280 type invalidPathHandler struct{}
281
282 func (invalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
283         http.Error(resp, "Bad request", http.StatusBadRequest)
284 }
285
286 func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) {
287         setCORSHeaders(resp)
288 }
289
290 var errBadAuthorizationHeader = errors.New("Missing or invalid Authorization header, or method not allowed")
291 var errContentLengthMismatch = errors.New("Actual length != expected content length")
292 var errMethodNotSupported = errors.New("Method not supported")
293
294 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
295
296 func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
297         if err := h.checkLoop(resp, req); err != nil {
298                 return
299         }
300         setCORSHeaders(resp)
301         resp.Header().Set("Via", req.Proto+" "+viaAlias)
302
303         locator := mux.Vars(req)["locator"]
304         var err error
305         var status int
306         var expectLength, responseLength int64
307         var proxiedURI = "-"
308
309         logger := ctxlog.FromContext(req.Context())
310         defer func() {
311                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
312                         "locator":        locator,
313                         "expectLength":   expectLength,
314                         "responseLength": responseLength,
315                         "proxiedURI":     proxiedURI,
316                         "err":            err,
317                 })
318                 if status != http.StatusOK {
319                         http.Error(resp, err.Error(), status)
320                 }
321         }()
322
323         kc := h.makeKeepClient(req)
324         kc.DiskCacheSize = keepclient.DiskCacheDisabled
325
326         var pass bool
327         var tok string
328         var user *arvados.User
329         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
330                 status, err = http.StatusForbidden, errBadAuthorizationHeader
331                 return
332         }
333         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
334                 "userUUID":     user.UUID,
335                 "userFullName": user.FullName,
336         })
337
338         // Copy ArvadosClient struct and use the client's API token
339         arvclient := *kc.Arvados
340         arvclient.ApiToken = tok
341         kc.Arvados = &arvclient
342
343         var reader io.ReadCloser
344
345         locator = removeHint.ReplaceAllString(locator, "$1")
346
347         switch req.Method {
348         case "HEAD":
349                 expectLength, proxiedURI, err = kc.Ask(locator)
350         case "GET":
351                 reader, expectLength, proxiedURI, err = kc.Get(locator)
352                 if reader != nil {
353                         defer reader.Close()
354                 }
355         default:
356                 status, err = http.StatusNotImplemented, errMethodNotSupported
357                 return
358         }
359
360         if expectLength == -1 {
361                 logger.Warn("Content-Length not provided")
362         }
363
364         switch respErr := err.(type) {
365         case nil:
366                 status = http.StatusOK
367                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
368                 switch req.Method {
369                 case "HEAD":
370                         responseLength = 0
371                 case "GET":
372                         responseLength, err = io.Copy(resp, reader)
373                         if err == nil && expectLength > -1 && responseLength != expectLength {
374                                 err = errContentLengthMismatch
375                         }
376                 }
377         case keepclient.Error:
378                 if respErr == keepclient.BlockNotFound {
379                         status = http.StatusNotFound
380                 } else if respErr.Temporary() {
381                         status = http.StatusBadGateway
382                 } else {
383                         status = 422
384                 }
385         default:
386                 status = http.StatusInternalServerError
387         }
388 }
389
390 var errLengthRequired = errors.New(http.StatusText(http.StatusLengthRequired))
391 var errLengthMismatch = errors.New("Locator size hint does not match Content-Length header")
392
393 func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
394         if err := h.checkLoop(resp, req); err != nil {
395                 return
396         }
397         setCORSHeaders(resp)
398         resp.Header().Set("Via", "HTTP/1.1 "+viaAlias)
399
400         kc := h.makeKeepClient(req)
401
402         var err error
403         var expectLength int64
404         var status = http.StatusInternalServerError
405         var wroteReplicas int
406         var locatorOut string = "-"
407
408         defer func() {
409                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
410                         "expectLength":  expectLength,
411                         "wantReplicas":  kc.Want_replicas,
412                         "wroteReplicas": wroteReplicas,
413                         "locator":       strings.SplitN(locatorOut, "+A", 2)[0],
414                         "err":           err,
415                 })
416                 if status != http.StatusOK {
417                         http.Error(resp, err.Error(), status)
418                 }
419         }()
420
421         locatorIn := mux.Vars(req)["locator"]
422
423         // Check if the client specified storage classes
424         if req.Header.Get("X-Keep-Storage-Classes") != "" {
425                 var scl []string
426                 for _, sc := range strings.Split(req.Header.Get("X-Keep-Storage-Classes"), ",") {
427                         scl = append(scl, strings.Trim(sc, " "))
428                 }
429                 kc.SetStorageClasses(scl)
430         }
431
432         _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
433         if err != nil || expectLength < 0 {
434                 err = errLengthRequired
435                 status = http.StatusLengthRequired
436                 return
437         }
438
439         if locatorIn != "" {
440                 var loc *keepclient.Locator
441                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
442                         status = http.StatusBadRequest
443                         return
444                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
445                         err = errLengthMismatch
446                         status = http.StatusBadRequest
447                         return
448                 }
449         }
450
451         var pass bool
452         var tok string
453         var user *arvados.User
454         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
455                 err = errBadAuthorizationHeader
456                 status = http.StatusForbidden
457                 return
458         }
459         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
460                 "userUUID":     user.UUID,
461                 "userFullName": user.FullName,
462         })
463
464         // Copy ArvadosClient struct and use the client's API token
465         arvclient := *kc.Arvados
466         arvclient.ApiToken = tok
467         kc.Arvados = &arvclient
468
469         // Check if the client specified the number of replicas
470         if desiredReplicas := req.Header.Get(keepclient.XKeepDesiredReplicas); desiredReplicas != "" {
471                 var r int
472                 _, err := fmt.Sscanf(desiredReplicas, "%d", &r)
473                 if err == nil {
474                         kc.Want_replicas = r
475                 }
476         }
477
478         // Now try to put the block through
479         if locatorIn == "" {
480                 bytes, err2 := ioutil.ReadAll(req.Body)
481                 if err2 != nil {
482                         err = fmt.Errorf("Error reading request body: %s", err2)
483                         status = http.StatusInternalServerError
484                         return
485                 }
486                 locatorOut, wroteReplicas, err = kc.PutB(bytes)
487         } else {
488                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
489         }
490
491         // Tell the client how many successful PUTs we accomplished
492         resp.Header().Set(keepclient.XKeepReplicasStored, fmt.Sprintf("%d", wroteReplicas))
493
494         switch err.(type) {
495         case nil:
496                 status = http.StatusOK
497                 if len(kc.StorageClasses) > 0 {
498                         // A successful PUT request with storage classes means that all
499                         // storage classes were fulfilled, so the client will get a
500                         // confirmation via the X-Storage-Classes-Confirmed header.
501                         hdr := ""
502                         isFirst := true
503                         for _, sc := range kc.StorageClasses {
504                                 if isFirst {
505                                         hdr = fmt.Sprintf("%s=%d", sc, wroteReplicas)
506                                         isFirst = false
507                                 } else {
508                                         hdr += fmt.Sprintf(", %s=%d", sc, wroteReplicas)
509                                 }
510                         }
511                         resp.Header().Set(keepclient.XKeepStorageClassesConfirmed, hdr)
512                 }
513                 _, err = io.WriteString(resp, locatorOut)
514         case keepclient.OversizeBlockError:
515                 // Too much data
516                 status = http.StatusRequestEntityTooLarge
517         case keepclient.InsufficientReplicasError:
518                 status = http.StatusServiceUnavailable
519         default:
520                 status = http.StatusBadGateway
521         }
522 }
523
524 // ServeHTTP implementation for IndexHandler
525 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
526 // For each keep server found in LocalRoots:
527 // - Invokes GetIndex using keepclient
528 // - Expects "complete" response (terminating with blank new line)
529 // - Aborts on any errors
530 // Concatenates responses from all those keep servers and returns
531 func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
532         setCORSHeaders(resp)
533
534         prefix := mux.Vars(req)["prefix"]
535         var err error
536         var status int
537
538         defer func() {
539                 if status != http.StatusOK {
540                         http.Error(resp, err.Error(), status)
541                 }
542         }()
543
544         kc := h.makeKeepClient(req)
545         ok, token, _ := h.checkAuthorizationHeader(req)
546         if !ok {
547                 status, err = http.StatusForbidden, errBadAuthorizationHeader
548                 return
549         }
550
551         // Copy ArvadosClient struct and use the client's API token
552         arvclient := *kc.Arvados
553         arvclient.ApiToken = token
554         kc.Arvados = &arvclient
555
556         // Only GET method is supported
557         if req.Method != "GET" {
558                 status, err = http.StatusNotImplemented, errMethodNotSupported
559                 return
560         }
561
562         // Get index from all LocalRoots and write to resp
563         var reader io.Reader
564         for uuid := range kc.LocalRoots() {
565                 reader, err = kc.GetIndex(uuid, prefix)
566                 if err != nil {
567                         status = http.StatusBadGateway
568                         return
569                 }
570
571                 _, err = io.Copy(resp, reader)
572                 if err != nil {
573                         status = http.StatusBadGateway
574                         return
575                 }
576         }
577
578         // Got index from all the keep servers and wrote to resp
579         status = http.StatusOK
580         resp.Write([]byte("\n"))
581 }
582
583 func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
584         kc := h.KeepClient.Clone()
585         kc.RequestID = req.Header.Get("X-Request-Id")
586         kc.HTTPClient = &proxyClient{
587                 client: &http.Client{
588                         Timeout:   h.timeout,
589                         Transport: h.transport,
590                 },
591                 proto: req.Proto,
592         }
593         return kc
594 }