21935: Rename safeapi.ThreadSafeApiCache to api.ThreadSafeAPIClient
[arvados.git] / services / keepproxy / keepproxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepproxy
6
7 import (
8         "context"
9         "errors"
10         "fmt"
11         "io"
12         "io/ioutil"
13         "net"
14         "net/http"
15         "regexp"
16         "strings"
17         "time"
18
19         "git.arvados.org/arvados.git/lib/service"
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "git.arvados.org/arvados.git/sdk/go/arvadosclient"
22         "git.arvados.org/arvados.git/sdk/go/ctxlog"
23         "git.arvados.org/arvados.git/sdk/go/health"
24         "git.arvados.org/arvados.git/sdk/go/httpserver"
25         "git.arvados.org/arvados.git/sdk/go/keepclient"
26         "git.arvados.org/arvados.git/services/keepstore"
27         "github.com/gorilla/mux"
28         lru "github.com/hashicorp/golang-lru"
29         "github.com/prometheus/client_golang/prometheus"
30         "github.com/sirupsen/logrus"
31 )
32
33 const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00"
34
35 var Command = service.Command(arvados.ServiceNameKeepproxy, newHandlerOrErrorHandler)
36
37 func newHandlerOrErrorHandler(ctx context.Context, cluster *arvados.Cluster, token string, reg *prometheus.Registry) service.Handler {
38         client, err := arvados.NewClientFromConfig(cluster)
39         if err != nil {
40                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
41         }
42         arv, err := arvadosclient.New(client)
43         if err != nil {
44                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
45         }
46         kc, err := keepclient.MakeKeepClient(arv)
47         if err != nil {
48                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up keep client: %w", err))
49         }
50         keepclient.RefreshServiceDiscoveryOnSIGHUP()
51         router, err := newHandler(ctx, kc, time.Duration(keepclient.DefaultProxyRequestTimeout), cluster)
52         if err != nil {
53                 return service.ErrorHandler(ctx, cluster, err)
54         }
55         return router
56 }
57
58 type tokenCacheEntry struct {
59         expire int64
60         user   *arvados.User
61 }
62
63 type apiTokenCache struct {
64         tokens     *lru.TwoQueueCache
65         expireTime int64
66 }
67
68 // RememberToken caches the token and set an expire time.  If the
69 // token is already in the cache, it is not updated.
70 func (cache *apiTokenCache) RememberToken(token string, user *arvados.User) {
71         now := time.Now().Unix()
72         _, ok := cache.tokens.Get(token)
73         if !ok {
74                 cache.tokens.Add(token, tokenCacheEntry{
75                         expire: now + cache.expireTime,
76                         user:   user,
77                 })
78         }
79 }
80
81 // RecallToken checks if the cached token is known and still believed to be
82 // valid.
83 func (cache *apiTokenCache) RecallToken(token string) (bool, *arvados.User) {
84         val, ok := cache.tokens.Get(token)
85         if !ok {
86                 return false, nil
87         }
88
89         cacheEntry := val.(tokenCacheEntry)
90         now := time.Now().Unix()
91         if now < cacheEntry.expire {
92                 // Token is known and still valid
93                 return true, cacheEntry.user
94         } else {
95                 // Token is expired
96                 cache.tokens.Remove(token)
97                 return false, nil
98         }
99 }
100
101 func (h *proxyHandler) Done() <-chan struct{} {
102         return nil
103 }
104
105 func (h *proxyHandler) CheckHealth() error {
106         return nil
107 }
108
109 func (h *proxyHandler) checkAuthorizationHeader(req *http.Request) (pass bool, tok string, user *arvados.User) {
110         parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
111         if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 {
112                 return false, "", nil
113         }
114         tok = parts[1]
115
116         // Tokens are validated differently depending on what kind of
117         // operation is being performed. For example, tokens in
118         // collection-sharing links permit GET requests, but not
119         // PUT requests.
120         var op string
121         if req.Method == "GET" || req.Method == "HEAD" {
122                 op = "read"
123         } else {
124                 op = "write"
125         }
126
127         if ok, user := h.apiTokenCache.RecallToken(op + ":" + tok); ok {
128                 // Valid in the cache, short circuit
129                 return true, tok, user
130         }
131
132         var err error
133         arv := *h.KeepClient.Arvados
134         arv.ApiToken = tok
135         arv.RequestID = req.Header.Get("X-Request-Id")
136         user = &arvados.User{}
137         userCurrentError := arv.Call("GET", "users", "", "current", nil, user)
138         err = userCurrentError
139         if err != nil && op == "read" {
140                 apiError, ok := err.(arvadosclient.APIServerError)
141                 if ok && apiError.HttpStatusCode == http.StatusForbidden {
142                         // If it was a scoped "sharing" token it will
143                         // return 403 instead of 401 for the current
144                         // user check.  If it is a download operation
145                         // and they have permission to read the
146                         // keep_services table, we can allow it.
147                         err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil)
148                 }
149         }
150         if err != nil {
151                 ctxlog.FromContext(req.Context()).WithError(err).Info("checkAuthorizationHeader error")
152                 return false, "", nil
153         }
154
155         if userCurrentError == nil && user.IsAdmin {
156                 // checking userCurrentError is probably redundant,
157                 // IsAdmin would be false anyway. But can't hurt.
158                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.Admin.Download {
159                         return false, "", nil
160                 }
161                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.Admin.Upload {
162                         return false, "", nil
163                 }
164         } else {
165                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.User.Download {
166                         return false, "", nil
167                 }
168                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.User.Upload {
169                         return false, "", nil
170                 }
171         }
172
173         // Success!  Update cache
174         h.apiTokenCache.RememberToken(op+":"+tok, user)
175
176         return true, tok, user
177 }
178
179 // We can't copy the default http transport because http.Transport has
180 // a mutex field, so we make our own using the values of the exported
181 // fields.
182 var defaultTransport = http.Transport{
183         Proxy:                 http.DefaultTransport.(*http.Transport).Proxy,
184         DialContext:           http.DefaultTransport.(*http.Transport).DialContext,
185         ForceAttemptHTTP2:     http.DefaultTransport.(*http.Transport).ForceAttemptHTTP2,
186         MaxIdleConns:          http.DefaultTransport.(*http.Transport).MaxIdleConns,
187         IdleConnTimeout:       http.DefaultTransport.(*http.Transport).IdleConnTimeout,
188         TLSHandshakeTimeout:   http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout,
189         ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout,
190 }
191
192 type proxyHandler struct {
193         http.Handler
194         *keepclient.KeepClient
195         *apiTokenCache
196         timeout   time.Duration
197         transport *http.Transport
198         cluster   *arvados.Cluster
199 }
200
201 func newHandler(ctx context.Context, kc *keepclient.KeepClient, timeout time.Duration, cluster *arvados.Cluster) (service.Handler, error) {
202         rest := mux.NewRouter()
203
204         // We can't copy the default http transport because
205         // http.Transport has a mutex field, so we copy the fields
206         // that we know have non-zero values in http.DefaultTransport.
207         transport := &http.Transport{
208                 Proxy:                 http.DefaultTransport.(*http.Transport).Proxy,
209                 ForceAttemptHTTP2:     http.DefaultTransport.(*http.Transport).ForceAttemptHTTP2,
210                 MaxIdleConns:          http.DefaultTransport.(*http.Transport).MaxIdleConns,
211                 IdleConnTimeout:       http.DefaultTransport.(*http.Transport).IdleConnTimeout,
212                 ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout,
213                 DialContext: (&net.Dialer{
214                         Timeout:   keepclient.DefaultConnectTimeout,
215                         KeepAlive: keepclient.DefaultKeepAlive,
216                         DualStack: true,
217                 }).DialContext,
218                 TLSClientConfig:     arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure),
219                 TLSHandshakeTimeout: keepclient.DefaultTLSHandshakeTimeout,
220         }
221
222         cacheQ, err := lru.New2Q(500)
223         if err != nil {
224                 return nil, fmt.Errorf("Error from lru.New2Q: %v", err)
225         }
226
227         h := &proxyHandler{
228                 Handler:    rest,
229                 KeepClient: kc,
230                 timeout:    timeout,
231                 transport:  transport,
232                 apiTokenCache: &apiTokenCache{
233                         tokens:     cacheQ,
234                         expireTime: 300,
235                 },
236                 cluster: cluster,
237         }
238
239         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD")
240         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD")
241
242         // List all blocks
243         rest.HandleFunc(`/index`, h.Index).Methods("GET")
244
245         // List blocks whose hash has the given prefix
246         rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET")
247
248         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT")
249         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT")
250         rest.HandleFunc(`/`, h.Put).Methods("POST")
251         rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS")
252         rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
253
254         rest.Handle("/_health/{check}", &health.Handler{
255                 Token:  cluster.ManagementToken,
256                 Prefix: "/_health/",
257         }).Methods("GET")
258
259         rest.NotFoundHandler = invalidPathHandler{}
260         return h, nil
261 }
262
263 var errLoopDetected = errors.New("loop detected")
264
265 func (h *proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error {
266         if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 {
267                 ctxlog.FromContext(req.Context()).Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via)
268                 http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError)
269                 return errLoopDetected
270         }
271         return nil
272 }
273
274 func setCORSHeaders(resp http.ResponseWriter) {
275         keepstore.SetCORSHeaders(resp)
276         acam := "Access-Control-Allow-Methods"
277         resp.Header().Set(acam, resp.Header().Get(acam)+", POST")
278 }
279
280 type invalidPathHandler struct{}
281
282 func (invalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
283         http.Error(resp, "Bad request", http.StatusBadRequest)
284 }
285
286 func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) {
287         setCORSHeaders(resp)
288 }
289
290 var errBadAuthorizationHeader = errors.New("Missing or invalid Authorization header, or method not allowed")
291 var errContentLengthMismatch = errors.New("Actual length != expected content length")
292 var errMethodNotSupported = errors.New("Method not supported")
293
294 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
295
296 func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
297         if err := h.checkLoop(resp, req); err != nil {
298                 return
299         }
300         setCORSHeaders(resp)
301         resp.Header().Set("Via", req.Proto+" "+viaAlias)
302
303         locator := mux.Vars(req)["locator"]
304         var err error
305         var status int
306         var expectLength, responseLength int64
307
308         logger := ctxlog.FromContext(req.Context())
309         defer func() {
310                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
311                         "locator":        locator,
312                         "expectLength":   expectLength,
313                         "responseLength": responseLength,
314                         "err":            err,
315                 })
316                 if status != http.StatusOK {
317                         http.Error(resp, err.Error(), status)
318                 }
319         }()
320
321         kc := h.makeKeepClient(req)
322         kc.DiskCacheSize = keepclient.DiskCacheDisabled
323
324         var pass bool
325         var tok string
326         var user *arvados.User
327         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
328                 status, err = http.StatusForbidden, errBadAuthorizationHeader
329                 return
330         }
331         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
332                 "userUUID":     user.UUID,
333                 "userFullName": user.FullName,
334         })
335
336         // Copy ArvadosClient struct and use the client's API token
337         arvclient := *kc.Arvados
338         arvclient.ApiToken = tok
339         kc.Arvados = &arvclient
340
341         var reader io.ReadCloser
342
343         locator = removeHint.ReplaceAllString(locator, "$1")
344
345         switch req.Method {
346         case "HEAD":
347                 expectLength, _, err = kc.Ask(locator)
348         case "GET":
349                 reader, expectLength, _, err = kc.Get(locator)
350                 if reader != nil {
351                         defer reader.Close()
352                 }
353         default:
354                 status, err = http.StatusNotImplemented, errMethodNotSupported
355                 return
356         }
357
358         if expectLength == -1 {
359                 logger.Warn("Content-Length not provided")
360         }
361
362         switch respErr := err.(type) {
363         case nil:
364                 status = http.StatusOK
365                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
366                 switch req.Method {
367                 case "HEAD":
368                         responseLength = 0
369                 case "GET":
370                         responseLength, err = io.Copy(resp, reader)
371                         if err == nil && expectLength > -1 && responseLength != expectLength {
372                                 err = errContentLengthMismatch
373                         }
374                 }
375         case keepclient.Error:
376                 if respErr == keepclient.BlockNotFound {
377                         status = http.StatusNotFound
378                 } else if respErr.Temporary() {
379                         status = http.StatusBadGateway
380                 } else {
381                         status = 422
382                 }
383         default:
384                 status = http.StatusInternalServerError
385         }
386 }
387
388 var errLengthRequired = errors.New(http.StatusText(http.StatusLengthRequired))
389 var errLengthMismatch = errors.New("Locator size hint does not match Content-Length header")
390
391 func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
392         if err := h.checkLoop(resp, req); err != nil {
393                 return
394         }
395         setCORSHeaders(resp)
396         resp.Header().Set("Via", "HTTP/1.1 "+viaAlias)
397
398         kc := h.makeKeepClient(req)
399
400         var err error
401         var expectLength int64
402         var status = http.StatusInternalServerError
403         var wroteReplicas int
404         var locatorOut string = "-"
405
406         defer func() {
407                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
408                         "expectLength":  expectLength,
409                         "wantReplicas":  kc.Want_replicas,
410                         "wroteReplicas": wroteReplicas,
411                         "locator":       strings.SplitN(locatorOut, "+A", 2)[0],
412                         "err":           err,
413                 })
414                 if status != http.StatusOK {
415                         http.Error(resp, err.Error(), status)
416                 }
417         }()
418
419         locatorIn := mux.Vars(req)["locator"]
420
421         // Check if the client specified storage classes
422         if req.Header.Get(keepclient.XKeepStorageClasses) != "" {
423                 var scl []string
424                 for _, sc := range strings.Split(req.Header.Get(keepclient.XKeepStorageClasses), ",") {
425                         scl = append(scl, strings.Trim(sc, " "))
426                 }
427                 kc.SetStorageClasses(scl)
428         }
429
430         _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
431         if err != nil || expectLength < 0 {
432                 err = errLengthRequired
433                 status = http.StatusLengthRequired
434                 return
435         }
436
437         if locatorIn != "" {
438                 var loc *keepclient.Locator
439                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
440                         status = http.StatusBadRequest
441                         return
442                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
443                         err = errLengthMismatch
444                         status = http.StatusBadRequest
445                         return
446                 }
447         }
448
449         var pass bool
450         var tok string
451         var user *arvados.User
452         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
453                 err = errBadAuthorizationHeader
454                 status = http.StatusForbidden
455                 return
456         }
457         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
458                 "userUUID":     user.UUID,
459                 "userFullName": user.FullName,
460         })
461
462         // Copy ArvadosClient struct and use the client's API token
463         arvclient := *kc.Arvados
464         arvclient.ApiToken = tok
465         kc.Arvados = &arvclient
466
467         // Check if the client specified the number of replicas
468         if desiredReplicas := req.Header.Get(keepclient.XKeepDesiredReplicas); desiredReplicas != "" {
469                 var r int
470                 _, err := fmt.Sscanf(desiredReplicas, "%d", &r)
471                 if err == nil {
472                         kc.Want_replicas = r
473                 }
474         }
475
476         // Now try to put the block through
477         if locatorIn == "" {
478                 bytes, err2 := ioutil.ReadAll(req.Body)
479                 if err2 != nil {
480                         err = fmt.Errorf("Error reading request body: %s", err2)
481                         status = http.StatusInternalServerError
482                         return
483                 }
484                 locatorOut, wroteReplicas, err = kc.PutB(bytes)
485         } else {
486                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
487         }
488
489         // Tell the client how many successful PUTs we accomplished
490         resp.Header().Set(keepclient.XKeepReplicasStored, fmt.Sprintf("%d", wroteReplicas))
491
492         switch err.(type) {
493         case nil:
494                 status = http.StatusOK
495                 if len(kc.StorageClasses) > 0 {
496                         // A successful PUT request with storage classes means that all
497                         // storage classes were fulfilled, so the client will get a
498                         // confirmation via the X-Storage-Classes-Confirmed header.
499                         hdr := ""
500                         isFirst := true
501                         for _, sc := range kc.StorageClasses {
502                                 if isFirst {
503                                         hdr = fmt.Sprintf("%s=%d", sc, wroteReplicas)
504                                         isFirst = false
505                                 } else {
506                                         hdr += fmt.Sprintf(", %s=%d", sc, wroteReplicas)
507                                 }
508                         }
509                         resp.Header().Set(keepclient.XKeepStorageClassesConfirmed, hdr)
510                 }
511                 _, err = io.WriteString(resp, locatorOut)
512         case keepclient.OversizeBlockError:
513                 // Too much data
514                 status = http.StatusRequestEntityTooLarge
515         case keepclient.InsufficientReplicasError:
516                 status = http.StatusServiceUnavailable
517         default:
518                 status = http.StatusBadGateway
519         }
520 }
521
522 // ServeHTTP implementation for IndexHandler
523 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
524 // For each keep server found in LocalRoots:
525 // - Invokes GetIndex using keepclient
526 // - Expects "complete" response (terminating with blank new line)
527 // - Aborts on any errors
528 // Concatenates responses from all those keep servers and returns
529 func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
530         setCORSHeaders(resp)
531
532         prefix := mux.Vars(req)["prefix"]
533         var err error
534         var status int
535
536         defer func() {
537                 if status != http.StatusOK {
538                         http.Error(resp, err.Error(), status)
539                 }
540         }()
541
542         kc := h.makeKeepClient(req)
543         ok, token, _ := h.checkAuthorizationHeader(req)
544         if !ok {
545                 status, err = http.StatusForbidden, errBadAuthorizationHeader
546                 return
547         }
548
549         // Copy ArvadosClient struct and use the client's API token
550         arvclient := *kc.Arvados
551         arvclient.ApiToken = token
552         kc.Arvados = &arvclient
553
554         // Only GET method is supported
555         if req.Method != "GET" {
556                 status, err = http.StatusNotImplemented, errMethodNotSupported
557                 return
558         }
559
560         // Get index from all LocalRoots and write to resp
561         var reader io.Reader
562         for uuid := range kc.LocalRoots() {
563                 reader, err = kc.GetIndex(uuid, prefix)
564                 if err != nil {
565                         status = http.StatusBadGateway
566                         return
567                 }
568
569                 _, err = io.Copy(resp, reader)
570                 if err != nil {
571                         status = http.StatusBadGateway
572                         return
573                 }
574         }
575
576         // Got index from all the keep servers and wrote to resp
577         status = http.StatusOK
578         resp.Write([]byte("\n"))
579 }
580
581 func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
582         kc := h.KeepClient.Clone()
583         kc.RequestID = req.Header.Get("X-Request-Id")
584         kc.HTTPClient = &proxyClient{
585                 client: &http.Client{
586                         Timeout:   h.timeout,
587                         Transport: h.transport,
588                 },
589                 proto: req.Proto,
590         }
591         return kc
592 }