15317: Replace comment with better variable name.
[arvados.git] / services / keepproxy / keepproxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepproxy
6
7 import (
8         "context"
9         "errors"
10         "fmt"
11         "io"
12         "io/ioutil"
13         "net"
14         "net/http"
15         "regexp"
16         "strings"
17         "time"
18
19         "git.arvados.org/arvados.git/lib/service"
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "git.arvados.org/arvados.git/sdk/go/arvadosclient"
22         "git.arvados.org/arvados.git/sdk/go/ctxlog"
23         "git.arvados.org/arvados.git/sdk/go/health"
24         "git.arvados.org/arvados.git/sdk/go/httpserver"
25         "git.arvados.org/arvados.git/sdk/go/keepclient"
26         "github.com/gorilla/mux"
27         lru "github.com/hashicorp/golang-lru"
28         "github.com/prometheus/client_golang/prometheus"
29         "github.com/sirupsen/logrus"
30 )
31
32 const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00"
33
34 var Command = service.Command(arvados.ServiceNameKeepproxy, newHandlerOrErrorHandler)
35
36 func newHandlerOrErrorHandler(ctx context.Context, cluster *arvados.Cluster, token string, reg *prometheus.Registry) service.Handler {
37         client, err := arvados.NewClientFromConfig(cluster)
38         if err != nil {
39                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
40         }
41         arv, err := arvadosclient.New(client)
42         if err != nil {
43                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
44         }
45         kc, err := keepclient.MakeKeepClient(arv)
46         if err != nil {
47                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up keep client: %w", err))
48         }
49         keepclient.RefreshServiceDiscoveryOnSIGHUP()
50         router, err := newHandler(ctx, kc, time.Duration(keepclient.DefaultProxyRequestTimeout), cluster)
51         if err != nil {
52                 return service.ErrorHandler(ctx, cluster, err)
53         }
54         return router
55 }
56
57 type tokenCacheEntry struct {
58         expire int64
59         user   *arvados.User
60 }
61
62 type apiTokenCache struct {
63         tokens     *lru.TwoQueueCache
64         expireTime int64
65 }
66
67 // RememberToken caches the token and set an expire time.  If the
68 // token is already in the cache, it is not updated.
69 func (cache *apiTokenCache) RememberToken(token string, user *arvados.User) {
70         now := time.Now().Unix()
71         _, ok := cache.tokens.Get(token)
72         if !ok {
73                 cache.tokens.Add(token, tokenCacheEntry{
74                         expire: now + cache.expireTime,
75                         user:   user,
76                 })
77         }
78 }
79
80 // RecallToken checks if the cached token is known and still believed to be
81 // valid.
82 func (cache *apiTokenCache) RecallToken(token string) (bool, *arvados.User) {
83         val, ok := cache.tokens.Get(token)
84         if !ok {
85                 return false, nil
86         }
87
88         cacheEntry := val.(tokenCacheEntry)
89         now := time.Now().Unix()
90         if now < cacheEntry.expire {
91                 // Token is known and still valid
92                 return true, cacheEntry.user
93         } else {
94                 // Token is expired
95                 cache.tokens.Remove(token)
96                 return false, nil
97         }
98 }
99
100 func (h *proxyHandler) Done() <-chan struct{} {
101         return nil
102 }
103
104 func (h *proxyHandler) CheckHealth() error {
105         return nil
106 }
107
108 func (h *proxyHandler) checkAuthorizationHeader(req *http.Request) (pass bool, tok string, user *arvados.User) {
109         parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
110         if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 {
111                 return false, "", nil
112         }
113         tok = parts[1]
114
115         // Tokens are validated differently depending on what kind of
116         // operation is being performed. For example, tokens in
117         // collection-sharing links permit GET requests, but not
118         // PUT requests.
119         var op string
120         if req.Method == "GET" || req.Method == "HEAD" {
121                 op = "read"
122         } else {
123                 op = "write"
124         }
125
126         if ok, user := h.apiTokenCache.RecallToken(op + ":" + tok); ok {
127                 // Valid in the cache, short circuit
128                 return true, tok, user
129         }
130
131         var err error
132         arv := *h.KeepClient.Arvados
133         arv.ApiToken = tok
134         arv.RequestID = req.Header.Get("X-Request-Id")
135         user = &arvados.User{}
136         userCurrentError := arv.Call("GET", "users", "", "current", nil, user)
137         err = userCurrentError
138         if err != nil && op == "read" {
139                 apiError, ok := err.(arvadosclient.APIServerError)
140                 if ok && apiError.HttpStatusCode == http.StatusForbidden {
141                         // If it was a scoped "sharing" token it will
142                         // return 403 instead of 401 for the current
143                         // user check.  If it is a download operation
144                         // and they have permission to read the
145                         // keep_services table, we can allow it.
146                         err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil)
147                 }
148         }
149         if err != nil {
150                 ctxlog.FromContext(req.Context()).WithError(err).Info("checkAuthorizationHeader error")
151                 return false, "", nil
152         }
153
154         if userCurrentError == nil && user.IsAdmin {
155                 // checking userCurrentError is probably redundant,
156                 // IsAdmin would be false anyway. But can't hurt.
157                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.Admin.Download {
158                         return false, "", nil
159                 }
160                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.Admin.Upload {
161                         return false, "", nil
162                 }
163         } else {
164                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.User.Download {
165                         return false, "", nil
166                 }
167                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.User.Upload {
168                         return false, "", nil
169                 }
170         }
171
172         // Success!  Update cache
173         h.apiTokenCache.RememberToken(op+":"+tok, user)
174
175         return true, tok, user
176 }
177
178 // We can't copy the default http transport because http.Transport has
179 // a mutex field, so we make our own using the values of the exported
180 // fields.
181 var defaultTransport = http.Transport{
182         Proxy:                 http.DefaultTransport.(*http.Transport).Proxy,
183         DialContext:           http.DefaultTransport.(*http.Transport).DialContext,
184         ForceAttemptHTTP2:     http.DefaultTransport.(*http.Transport).ForceAttemptHTTP2,
185         MaxIdleConns:          http.DefaultTransport.(*http.Transport).MaxIdleConns,
186         IdleConnTimeout:       http.DefaultTransport.(*http.Transport).IdleConnTimeout,
187         TLSHandshakeTimeout:   http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout,
188         ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout,
189 }
190
191 type proxyHandler struct {
192         http.Handler
193         *keepclient.KeepClient
194         *apiTokenCache
195         timeout   time.Duration
196         transport *http.Transport
197         cluster   *arvados.Cluster
198 }
199
200 func newHandler(ctx context.Context, kc *keepclient.KeepClient, timeout time.Duration, cluster *arvados.Cluster) (service.Handler, error) {
201         rest := mux.NewRouter()
202
203         // We can't copy the default http transport because
204         // http.Transport has a mutex field, so we copy the fields
205         // that we know have non-zero values in http.DefaultTransport.
206         transport := &http.Transport{
207                 Proxy:                 http.DefaultTransport.(*http.Transport).Proxy,
208                 ForceAttemptHTTP2:     http.DefaultTransport.(*http.Transport).ForceAttemptHTTP2,
209                 MaxIdleConns:          http.DefaultTransport.(*http.Transport).MaxIdleConns,
210                 IdleConnTimeout:       http.DefaultTransport.(*http.Transport).IdleConnTimeout,
211                 ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout,
212                 DialContext: (&net.Dialer{
213                         Timeout:   keepclient.DefaultConnectTimeout,
214                         KeepAlive: keepclient.DefaultKeepAlive,
215                         DualStack: true,
216                 }).DialContext,
217                 TLSClientConfig:     arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure),
218                 TLSHandshakeTimeout: keepclient.DefaultTLSHandshakeTimeout,
219         }
220
221         cacheQ, err := lru.New2Q(500)
222         if err != nil {
223                 return nil, fmt.Errorf("Error from lru.New2Q: %v", err)
224         }
225
226         h := &proxyHandler{
227                 Handler:    rest,
228                 KeepClient: kc,
229                 timeout:    timeout,
230                 transport:  transport,
231                 apiTokenCache: &apiTokenCache{
232                         tokens:     cacheQ,
233                         expireTime: 300,
234                 },
235                 cluster: cluster,
236         }
237
238         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD")
239         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD")
240
241         // List all blocks
242         rest.HandleFunc(`/index`, h.Index).Methods("GET")
243
244         // List blocks whose hash has the given prefix
245         rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET")
246
247         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT")
248         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT")
249         rest.HandleFunc(`/`, h.Put).Methods("POST")
250         rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS")
251         rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
252
253         rest.Handle("/_health/{check}", &health.Handler{
254                 Token:  cluster.ManagementToken,
255                 Prefix: "/_health/",
256         }).Methods("GET")
257
258         rest.NotFoundHandler = invalidPathHandler{}
259         return h, nil
260 }
261
262 var errLoopDetected = errors.New("loop detected")
263
264 func (h *proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error {
265         if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 {
266                 ctxlog.FromContext(req.Context()).Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via)
267                 http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError)
268                 return errLoopDetected
269         }
270         return nil
271 }
272
273 func setCORSHeaders(resp http.ResponseWriter) {
274         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
275         resp.Header().Set("Access-Control-Allow-Origin", "*")
276         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
277         resp.Header().Set("Access-Control-Max-Age", "86486400")
278 }
279
280 type invalidPathHandler struct{}
281
282 func (invalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
283         http.Error(resp, "Bad request", http.StatusBadRequest)
284 }
285
286 func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) {
287         setCORSHeaders(resp)
288 }
289
290 var errBadAuthorizationHeader = errors.New("Missing or invalid Authorization header, or method not allowed")
291 var errContentLengthMismatch = errors.New("Actual length != expected content length")
292 var errMethodNotSupported = errors.New("Method not supported")
293
294 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
295
296 func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
297         if err := h.checkLoop(resp, req); err != nil {
298                 return
299         }
300         setCORSHeaders(resp)
301         resp.Header().Set("Via", req.Proto+" "+viaAlias)
302
303         locator := mux.Vars(req)["locator"]
304         var err error
305         var status int
306         var expectLength, responseLength int64
307
308         logger := ctxlog.FromContext(req.Context())
309         defer func() {
310                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
311                         "locator":        locator,
312                         "expectLength":   expectLength,
313                         "responseLength": responseLength,
314                         "err":            err,
315                 })
316                 if status != http.StatusOK {
317                         http.Error(resp, err.Error(), status)
318                 }
319         }()
320
321         kc := h.makeKeepClient(req)
322         kc.DiskCacheSize = keepclient.DiskCacheDisabled
323
324         var pass bool
325         var tok string
326         var user *arvados.User
327         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
328                 status, err = http.StatusForbidden, errBadAuthorizationHeader
329                 return
330         }
331         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
332                 "userUUID":     user.UUID,
333                 "userFullName": user.FullName,
334         })
335
336         // Copy ArvadosClient struct and use the client's API token
337         arvclient := *kc.Arvados
338         arvclient.ApiToken = tok
339         kc.Arvados = &arvclient
340
341         var reader io.ReadCloser
342
343         locator = removeHint.ReplaceAllString(locator, "$1")
344
345         switch req.Method {
346         case "HEAD":
347                 expectLength, _, err = kc.Ask(locator)
348         case "GET":
349                 reader, expectLength, _, err = kc.Get(locator)
350                 if reader != nil {
351                         defer reader.Close()
352                 }
353         default:
354                 status, err = http.StatusNotImplemented, errMethodNotSupported
355                 return
356         }
357
358         if expectLength == -1 {
359                 logger.Warn("Content-Length not provided")
360         }
361
362         switch respErr := err.(type) {
363         case nil:
364                 status = http.StatusOK
365                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
366                 switch req.Method {
367                 case "HEAD":
368                         responseLength = 0
369                 case "GET":
370                         responseLength, err = io.Copy(resp, reader)
371                         if err == nil && expectLength > -1 && responseLength != expectLength {
372                                 err = errContentLengthMismatch
373                         }
374                 }
375         case keepclient.Error:
376                 if respErr == keepclient.BlockNotFound {
377                         status = http.StatusNotFound
378                 } else if respErr.Temporary() {
379                         status = http.StatusBadGateway
380                 } else {
381                         status = 422
382                 }
383         default:
384                 status = http.StatusInternalServerError
385         }
386 }
387
388 var errLengthRequired = errors.New(http.StatusText(http.StatusLengthRequired))
389 var errLengthMismatch = errors.New("Locator size hint does not match Content-Length header")
390
391 func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
392         if err := h.checkLoop(resp, req); err != nil {
393                 return
394         }
395         setCORSHeaders(resp)
396         resp.Header().Set("Via", "HTTP/1.1 "+viaAlias)
397
398         kc := h.makeKeepClient(req)
399
400         var err error
401         var expectLength int64
402         var status = http.StatusInternalServerError
403         var wroteReplicas int
404         var locatorOut string = "-"
405
406         defer func() {
407                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
408                         "expectLength":  expectLength,
409                         "wantReplicas":  kc.Want_replicas,
410                         "wroteReplicas": wroteReplicas,
411                         "locator":       strings.SplitN(locatorOut, "+A", 2)[0],
412                         "err":           err,
413                 })
414                 if status != http.StatusOK {
415                         http.Error(resp, err.Error(), status)
416                 }
417         }()
418
419         locatorIn := mux.Vars(req)["locator"]
420
421         // Check if the client specified storage classes
422         if req.Header.Get("X-Keep-Storage-Classes") != "" {
423                 var scl []string
424                 for _, sc := range strings.Split(req.Header.Get("X-Keep-Storage-Classes"), ",") {
425                         scl = append(scl, strings.Trim(sc, " "))
426                 }
427                 kc.SetStorageClasses(scl)
428         }
429
430         _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
431         if err != nil || expectLength < 0 {
432                 err = errLengthRequired
433                 status = http.StatusLengthRequired
434                 return
435         }
436
437         if locatorIn != "" {
438                 var loc *keepclient.Locator
439                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
440                         status = http.StatusBadRequest
441                         return
442                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
443                         err = errLengthMismatch
444                         status = http.StatusBadRequest
445                         return
446                 }
447         }
448
449         var pass bool
450         var tok string
451         var user *arvados.User
452         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
453                 err = errBadAuthorizationHeader
454                 status = http.StatusForbidden
455                 return
456         }
457         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
458                 "userUUID":     user.UUID,
459                 "userFullName": user.FullName,
460         })
461
462         // Copy ArvadosClient struct and use the client's API token
463         arvclient := *kc.Arvados
464         arvclient.ApiToken = tok
465         kc.Arvados = &arvclient
466
467         // Check if the client specified the number of replicas
468         if desiredReplicas := req.Header.Get(keepclient.XKeepDesiredReplicas); desiredReplicas != "" {
469                 var r int
470                 _, err := fmt.Sscanf(desiredReplicas, "%d", &r)
471                 if err == nil {
472                         kc.Want_replicas = r
473                 }
474         }
475
476         // Now try to put the block through
477         if locatorIn == "" {
478                 bytes, err2 := ioutil.ReadAll(req.Body)
479                 if err2 != nil {
480                         err = fmt.Errorf("Error reading request body: %s", err2)
481                         status = http.StatusInternalServerError
482                         return
483                 }
484                 locatorOut, wroteReplicas, err = kc.PutB(bytes)
485         } else {
486                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
487         }
488
489         // Tell the client how many successful PUTs we accomplished
490         resp.Header().Set(keepclient.XKeepReplicasStored, fmt.Sprintf("%d", wroteReplicas))
491
492         switch err.(type) {
493         case nil:
494                 status = http.StatusOK
495                 if len(kc.StorageClasses) > 0 {
496                         // A successful PUT request with storage classes means that all
497                         // storage classes were fulfilled, so the client will get a
498                         // confirmation via the X-Storage-Classes-Confirmed header.
499                         hdr := ""
500                         isFirst := true
501                         for _, sc := range kc.StorageClasses {
502                                 if isFirst {
503                                         hdr = fmt.Sprintf("%s=%d", sc, wroteReplicas)
504                                         isFirst = false
505                                 } else {
506                                         hdr += fmt.Sprintf(", %s=%d", sc, wroteReplicas)
507                                 }
508                         }
509                         resp.Header().Set(keepclient.XKeepStorageClassesConfirmed, hdr)
510                 }
511                 _, err = io.WriteString(resp, locatorOut)
512         case keepclient.OversizeBlockError:
513                 // Too much data
514                 status = http.StatusRequestEntityTooLarge
515         case keepclient.InsufficientReplicasError:
516                 status = http.StatusServiceUnavailable
517         default:
518                 status = http.StatusBadGateway
519         }
520 }
521
522 // ServeHTTP implementation for IndexHandler
523 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
524 // For each keep server found in LocalRoots:
525 // - Invokes GetIndex using keepclient
526 // - Expects "complete" response (terminating with blank new line)
527 // - Aborts on any errors
528 // Concatenates responses from all those keep servers and returns
529 func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
530         setCORSHeaders(resp)
531
532         prefix := mux.Vars(req)["prefix"]
533         var err error
534         var status int
535
536         defer func() {
537                 if status != http.StatusOK {
538                         http.Error(resp, err.Error(), status)
539                 }
540         }()
541
542         kc := h.makeKeepClient(req)
543         ok, token, _ := h.checkAuthorizationHeader(req)
544         if !ok {
545                 status, err = http.StatusForbidden, errBadAuthorizationHeader
546                 return
547         }
548
549         // Copy ArvadosClient struct and use the client's API token
550         arvclient := *kc.Arvados
551         arvclient.ApiToken = token
552         kc.Arvados = &arvclient
553
554         // Only GET method is supported
555         if req.Method != "GET" {
556                 status, err = http.StatusNotImplemented, errMethodNotSupported
557                 return
558         }
559
560         // Get index from all LocalRoots and write to resp
561         var reader io.Reader
562         for uuid := range kc.LocalRoots() {
563                 reader, err = kc.GetIndex(uuid, prefix)
564                 if err != nil {
565                         status = http.StatusBadGateway
566                         return
567                 }
568
569                 _, err = io.Copy(resp, reader)
570                 if err != nil {
571                         status = http.StatusBadGateway
572                         return
573                 }
574         }
575
576         // Got index from all the keep servers and wrote to resp
577         status = http.StatusOK
578         resp.Write([]byte("\n"))
579 }
580
581 func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
582         kc := h.KeepClient.Clone()
583         kc.RequestID = req.Header.Get("X-Request-Id")
584         kc.HTTPClient = &proxyClient{
585                 client: &http.Client{
586                         Timeout:   h.timeout,
587                         Transport: h.transport,
588                 },
589                 proto: req.Proto,
590         }
591         return kc
592 }