20722: Cleans up dependencies with 'go mod tidy'
[arvados.git] / services / keepproxy / keepproxy.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepproxy
6
7 import (
8         "context"
9         "errors"
10         "fmt"
11         "io"
12         "io/ioutil"
13         "net"
14         "net/http"
15         "regexp"
16         "strings"
17         "time"
18
19         "git.arvados.org/arvados.git/lib/service"
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "git.arvados.org/arvados.git/sdk/go/arvadosclient"
22         "git.arvados.org/arvados.git/sdk/go/ctxlog"
23         "git.arvados.org/arvados.git/sdk/go/health"
24         "git.arvados.org/arvados.git/sdk/go/httpserver"
25         "git.arvados.org/arvados.git/sdk/go/keepclient"
26         "github.com/gorilla/mux"
27         lru "github.com/hashicorp/golang-lru"
28         "github.com/prometheus/client_golang/prometheus"
29         "github.com/sirupsen/logrus"
30 )
31
32 const rfc3339NanoFixed = "2006-01-02T15:04:05.000000000Z07:00"
33
34 var Command = service.Command(arvados.ServiceNameKeepproxy, newHandlerOrErrorHandler)
35
36 func newHandlerOrErrorHandler(ctx context.Context, cluster *arvados.Cluster, token string, reg *prometheus.Registry) service.Handler {
37         client, err := arvados.NewClientFromConfig(cluster)
38         if err != nil {
39                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
40         }
41         arv, err := arvadosclient.New(client)
42         if err != nil {
43                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up arvados client: %w", err))
44         }
45         kc, err := keepclient.MakeKeepClient(arv)
46         if err != nil {
47                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("Error setting up keep client: %w", err))
48         }
49         keepclient.RefreshServiceDiscoveryOnSIGHUP()
50         router, err := newHandler(ctx, kc, time.Duration(keepclient.DefaultProxyRequestTimeout), cluster)
51         if err != nil {
52                 return service.ErrorHandler(ctx, cluster, err)
53         }
54         return router
55 }
56
57 type tokenCacheEntry struct {
58         expire int64
59         user   *arvados.User
60 }
61
62 type apiTokenCache struct {
63         tokens     *lru.TwoQueueCache
64         expireTime int64
65 }
66
67 // RememberToken caches the token and set an expire time.  If the
68 // token is already in the cache, it is not updated.
69 func (cache *apiTokenCache) RememberToken(token string, user *arvados.User) {
70         now := time.Now().Unix()
71         _, ok := cache.tokens.Get(token)
72         if !ok {
73                 cache.tokens.Add(token, tokenCacheEntry{
74                         expire: now + cache.expireTime,
75                         user:   user,
76                 })
77         }
78 }
79
80 // RecallToken checks if the cached token is known and still believed to be
81 // valid.
82 func (cache *apiTokenCache) RecallToken(token string) (bool, *arvados.User) {
83         val, ok := cache.tokens.Get(token)
84         if !ok {
85                 return false, nil
86         }
87
88         cacheEntry := val.(tokenCacheEntry)
89         now := time.Now().Unix()
90         if now < cacheEntry.expire {
91                 // Token is known and still valid
92                 return true, cacheEntry.user
93         } else {
94                 // Token is expired
95                 cache.tokens.Remove(token)
96                 return false, nil
97         }
98 }
99
100 func (h *proxyHandler) Done() <-chan struct{} {
101         return nil
102 }
103
104 func (h *proxyHandler) CheckHealth() error {
105         return nil
106 }
107
108 func (h *proxyHandler) checkAuthorizationHeader(req *http.Request) (pass bool, tok string, user *arvados.User) {
109         parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
110         if len(parts) < 2 || !(parts[0] == "OAuth2" || parts[0] == "Bearer") || len(parts[1]) == 0 {
111                 return false, "", nil
112         }
113         tok = parts[1]
114
115         // Tokens are validated differently depending on what kind of
116         // operation is being performed. For example, tokens in
117         // collection-sharing links permit GET requests, but not
118         // PUT requests.
119         var op string
120         if req.Method == "GET" || req.Method == "HEAD" {
121                 op = "read"
122         } else {
123                 op = "write"
124         }
125
126         if ok, user := h.apiTokenCache.RecallToken(op + ":" + tok); ok {
127                 // Valid in the cache, short circuit
128                 return true, tok, user
129         }
130
131         var err error
132         arv := *h.KeepClient.Arvados
133         arv.ApiToken = tok
134         arv.RequestID = req.Header.Get("X-Request-Id")
135         user = &arvados.User{}
136         userCurrentError := arv.Call("GET", "users", "", "current", nil, user)
137         err = userCurrentError
138         if err != nil && op == "read" {
139                 apiError, ok := err.(arvadosclient.APIServerError)
140                 if ok && apiError.HttpStatusCode == http.StatusForbidden {
141                         // If it was a scoped "sharing" token it will
142                         // return 403 instead of 401 for the current
143                         // user check.  If it is a download operation
144                         // and they have permission to read the
145                         // keep_services table, we can allow it.
146                         err = arv.Call("HEAD", "keep_services", "", "accessible", nil, nil)
147                 }
148         }
149         if err != nil {
150                 ctxlog.FromContext(req.Context()).WithError(err).Info("checkAuthorizationHeader error")
151                 return false, "", nil
152         }
153
154         if userCurrentError == nil && user.IsAdmin {
155                 // checking userCurrentError is probably redundant,
156                 // IsAdmin would be false anyway. But can't hurt.
157                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.Admin.Download {
158                         return false, "", nil
159                 }
160                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.Admin.Upload {
161                         return false, "", nil
162                 }
163         } else {
164                 if op == "read" && !h.cluster.Collections.KeepproxyPermission.User.Download {
165                         return false, "", nil
166                 }
167                 if op == "write" && !h.cluster.Collections.KeepproxyPermission.User.Upload {
168                         return false, "", nil
169                 }
170         }
171
172         // Success!  Update cache
173         h.apiTokenCache.RememberToken(op+":"+tok, user)
174
175         return true, tok, user
176 }
177
178 // We need to make a private copy of the default http transport early
179 // in initialization, then make copies of our private copy later. It
180 // won't be safe to copy http.DefaultTransport itself later, because
181 // its private mutexes might have already been used. (Without this,
182 // the test suite sometimes panics "concurrent map writes" in
183 // net/http.(*Transport).removeIdleConnLocked().)
184 var defaultTransport = *(http.DefaultTransport.(*http.Transport))
185
186 type proxyHandler struct {
187         http.Handler
188         *keepclient.KeepClient
189         *apiTokenCache
190         timeout   time.Duration
191         transport *http.Transport
192         cluster   *arvados.Cluster
193 }
194
195 func newHandler(ctx context.Context, kc *keepclient.KeepClient, timeout time.Duration, cluster *arvados.Cluster) (service.Handler, error) {
196         rest := mux.NewRouter()
197
198         transport := defaultTransport
199         transport.DialContext = (&net.Dialer{
200                 Timeout:   keepclient.DefaultConnectTimeout,
201                 KeepAlive: keepclient.DefaultKeepAlive,
202                 DualStack: true,
203         }).DialContext
204         transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure)
205         transport.TLSHandshakeTimeout = keepclient.DefaultTLSHandshakeTimeout
206
207         cacheQ, err := lru.New2Q(500)
208         if err != nil {
209                 return nil, fmt.Errorf("Error from lru.New2Q: %v", err)
210         }
211
212         h := &proxyHandler{
213                 Handler:    rest,
214                 KeepClient: kc,
215                 timeout:    timeout,
216                 transport:  &transport,
217                 apiTokenCache: &apiTokenCache{
218                         tokens:     cacheQ,
219                         expireTime: 300,
220                 },
221                 cluster: cluster,
222         }
223
224         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Get).Methods("GET", "HEAD")
225         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Get).Methods("GET", "HEAD")
226
227         // List all blocks
228         rest.HandleFunc(`/index`, h.Index).Methods("GET")
229
230         // List blocks whose hash has the given prefix
231         rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, h.Index).Methods("GET")
232
233         rest.HandleFunc(`/{locator:[0-9a-f]{32}\+.*}`, h.Put).Methods("PUT")
234         rest.HandleFunc(`/{locator:[0-9a-f]{32}}`, h.Put).Methods("PUT")
235         rest.HandleFunc(`/`, h.Put).Methods("POST")
236         rest.HandleFunc(`/{any}`, h.Options).Methods("OPTIONS")
237         rest.HandleFunc(`/`, h.Options).Methods("OPTIONS")
238
239         rest.Handle("/_health/{check}", &health.Handler{
240                 Token:  cluster.ManagementToken,
241                 Prefix: "/_health/",
242         }).Methods("GET")
243
244         rest.NotFoundHandler = invalidPathHandler{}
245         return h, nil
246 }
247
248 var errLoopDetected = errors.New("loop detected")
249
250 func (h *proxyHandler) checkLoop(resp http.ResponseWriter, req *http.Request) error {
251         if via := req.Header.Get("Via"); strings.Index(via, " "+viaAlias) >= 0 {
252                 ctxlog.FromContext(req.Context()).Printf("proxy loop detected (request has Via: %q): perhaps keepproxy is misidentified by gateway config as an external client, or its keep_services record does not have service_type=proxy?", via)
253                 http.Error(resp, errLoopDetected.Error(), http.StatusInternalServerError)
254                 return errLoopDetected
255         }
256         return nil
257 }
258
259 func setCORSHeaders(resp http.ResponseWriter) {
260         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
261         resp.Header().Set("Access-Control-Allow-Origin", "*")
262         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
263         resp.Header().Set("Access-Control-Max-Age", "86486400")
264 }
265
266 type invalidPathHandler struct{}
267
268 func (invalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
269         http.Error(resp, "Bad request", http.StatusBadRequest)
270 }
271
272 func (h *proxyHandler) Options(resp http.ResponseWriter, req *http.Request) {
273         setCORSHeaders(resp)
274 }
275
276 var errBadAuthorizationHeader = errors.New("Missing or invalid Authorization header, or method not allowed")
277 var errContentLengthMismatch = errors.New("Actual length != expected content length")
278 var errMethodNotSupported = errors.New("Method not supported")
279
280 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
281
282 func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
283         if err := h.checkLoop(resp, req); err != nil {
284                 return
285         }
286         setCORSHeaders(resp)
287         resp.Header().Set("Via", req.Proto+" "+viaAlias)
288
289         locator := mux.Vars(req)["locator"]
290         var err error
291         var status int
292         var expectLength, responseLength int64
293         var proxiedURI = "-"
294
295         logger := ctxlog.FromContext(req.Context())
296         defer func() {
297                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
298                         "locator":        locator,
299                         "expectLength":   expectLength,
300                         "responseLength": responseLength,
301                         "proxiedURI":     proxiedURI,
302                         "err":            err,
303                 })
304                 if status != http.StatusOK {
305                         http.Error(resp, err.Error(), status)
306                 }
307         }()
308
309         kc := h.makeKeepClient(req)
310
311         var pass bool
312         var tok string
313         var user *arvados.User
314         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
315                 status, err = http.StatusForbidden, errBadAuthorizationHeader
316                 return
317         }
318         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
319                 "userUUID":     user.UUID,
320                 "userFullName": user.FullName,
321         })
322
323         // Copy ArvadosClient struct and use the client's API token
324         arvclient := *kc.Arvados
325         arvclient.ApiToken = tok
326         kc.Arvados = &arvclient
327
328         var reader io.ReadCloser
329
330         locator = removeHint.ReplaceAllString(locator, "$1")
331
332         switch req.Method {
333         case "HEAD":
334                 expectLength, proxiedURI, err = kc.Ask(locator)
335         case "GET":
336                 reader, expectLength, proxiedURI, err = kc.Get(locator)
337                 if reader != nil {
338                         defer reader.Close()
339                 }
340         default:
341                 status, err = http.StatusNotImplemented, errMethodNotSupported
342                 return
343         }
344
345         if expectLength == -1 {
346                 logger.Warn("Content-Length not provided")
347         }
348
349         switch respErr := err.(type) {
350         case nil:
351                 status = http.StatusOK
352                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
353                 switch req.Method {
354                 case "HEAD":
355                         responseLength = 0
356                 case "GET":
357                         responseLength, err = io.Copy(resp, reader)
358                         if err == nil && expectLength > -1 && responseLength != expectLength {
359                                 err = errContentLengthMismatch
360                         }
361                 }
362         case keepclient.Error:
363                 if respErr == keepclient.BlockNotFound {
364                         status = http.StatusNotFound
365                 } else if respErr.Temporary() {
366                         status = http.StatusBadGateway
367                 } else {
368                         status = 422
369                 }
370         default:
371                 status = http.StatusInternalServerError
372         }
373 }
374
375 var errLengthRequired = errors.New(http.StatusText(http.StatusLengthRequired))
376 var errLengthMismatch = errors.New("Locator size hint does not match Content-Length header")
377
378 func (h *proxyHandler) Put(resp http.ResponseWriter, req *http.Request) {
379         if err := h.checkLoop(resp, req); err != nil {
380                 return
381         }
382         setCORSHeaders(resp)
383         resp.Header().Set("Via", "HTTP/1.1 "+viaAlias)
384
385         kc := h.makeKeepClient(req)
386
387         var err error
388         var expectLength int64
389         var status = http.StatusInternalServerError
390         var wroteReplicas int
391         var locatorOut string = "-"
392
393         defer func() {
394                 httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
395                         "expectLength":  expectLength,
396                         "wantReplicas":  kc.Want_replicas,
397                         "wroteReplicas": wroteReplicas,
398                         "locator":       strings.SplitN(locatorOut, "+A", 2)[0],
399                         "err":           err,
400                 })
401                 if status != http.StatusOK {
402                         http.Error(resp, err.Error(), status)
403                 }
404         }()
405
406         locatorIn := mux.Vars(req)["locator"]
407
408         // Check if the client specified storage classes
409         if req.Header.Get("X-Keep-Storage-Classes") != "" {
410                 var scl []string
411                 for _, sc := range strings.Split(req.Header.Get("X-Keep-Storage-Classes"), ",") {
412                         scl = append(scl, strings.Trim(sc, " "))
413                 }
414                 kc.SetStorageClasses(scl)
415         }
416
417         _, err = fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
418         if err != nil || expectLength < 0 {
419                 err = errLengthRequired
420                 status = http.StatusLengthRequired
421                 return
422         }
423
424         if locatorIn != "" {
425                 var loc *keepclient.Locator
426                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
427                         status = http.StatusBadRequest
428                         return
429                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
430                         err = errLengthMismatch
431                         status = http.StatusBadRequest
432                         return
433                 }
434         }
435
436         var pass bool
437         var tok string
438         var user *arvados.User
439         if pass, tok, user = h.checkAuthorizationHeader(req); !pass {
440                 err = errBadAuthorizationHeader
441                 status = http.StatusForbidden
442                 return
443         }
444         httpserver.SetResponseLogFields(req.Context(), logrus.Fields{
445                 "userUUID":     user.UUID,
446                 "userFullName": user.FullName,
447         })
448
449         // Copy ArvadosClient struct and use the client's API token
450         arvclient := *kc.Arvados
451         arvclient.ApiToken = tok
452         kc.Arvados = &arvclient
453
454         // Check if the client specified the number of replicas
455         if desiredReplicas := req.Header.Get(keepclient.XKeepDesiredReplicas); desiredReplicas != "" {
456                 var r int
457                 _, err := fmt.Sscanf(desiredReplicas, "%d", &r)
458                 if err == nil {
459                         kc.Want_replicas = r
460                 }
461         }
462
463         // Now try to put the block through
464         if locatorIn == "" {
465                 bytes, err2 := ioutil.ReadAll(req.Body)
466                 if err2 != nil {
467                         err = fmt.Errorf("Error reading request body: %s", err2)
468                         status = http.StatusInternalServerError
469                         return
470                 }
471                 locatorOut, wroteReplicas, err = kc.PutB(bytes)
472         } else {
473                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
474         }
475
476         // Tell the client how many successful PUTs we accomplished
477         resp.Header().Set(keepclient.XKeepReplicasStored, fmt.Sprintf("%d", wroteReplicas))
478
479         switch err.(type) {
480         case nil:
481                 status = http.StatusOK
482                 if len(kc.StorageClasses) > 0 {
483                         // A successful PUT request with storage classes means that all
484                         // storage classes were fulfilled, so the client will get a
485                         // confirmation via the X-Storage-Classes-Confirmed header.
486                         hdr := ""
487                         isFirst := true
488                         for _, sc := range kc.StorageClasses {
489                                 if isFirst {
490                                         hdr = fmt.Sprintf("%s=%d", sc, wroteReplicas)
491                                         isFirst = false
492                                 } else {
493                                         hdr += fmt.Sprintf(", %s=%d", sc, wroteReplicas)
494                                 }
495                         }
496                         resp.Header().Set(keepclient.XKeepStorageClassesConfirmed, hdr)
497                 }
498                 _, err = io.WriteString(resp, locatorOut)
499         case keepclient.OversizeBlockError:
500                 // Too much data
501                 status = http.StatusRequestEntityTooLarge
502         case keepclient.InsufficientReplicasError:
503                 status = http.StatusServiceUnavailable
504         default:
505                 status = http.StatusBadGateway
506         }
507 }
508
509 // ServeHTTP implementation for IndexHandler
510 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
511 // For each keep server found in LocalRoots:
512 //   Invokes GetIndex using keepclient
513 //   Expects "complete" response (terminating with blank new line)
514 //   Aborts on any errors
515 // Concatenates responses from all those keep servers and returns
516 func (h *proxyHandler) Index(resp http.ResponseWriter, req *http.Request) {
517         setCORSHeaders(resp)
518
519         prefix := mux.Vars(req)["prefix"]
520         var err error
521         var status int
522
523         defer func() {
524                 if status != http.StatusOK {
525                         http.Error(resp, err.Error(), status)
526                 }
527         }()
528
529         kc := h.makeKeepClient(req)
530         ok, token, _ := h.checkAuthorizationHeader(req)
531         if !ok {
532                 status, err = http.StatusForbidden, errBadAuthorizationHeader
533                 return
534         }
535
536         // Copy ArvadosClient struct and use the client's API token
537         arvclient := *kc.Arvados
538         arvclient.ApiToken = token
539         kc.Arvados = &arvclient
540
541         // Only GET method is supported
542         if req.Method != "GET" {
543                 status, err = http.StatusNotImplemented, errMethodNotSupported
544                 return
545         }
546
547         // Get index from all LocalRoots and write to resp
548         var reader io.Reader
549         for uuid := range kc.LocalRoots() {
550                 reader, err = kc.GetIndex(uuid, prefix)
551                 if err != nil {
552                         status = http.StatusBadGateway
553                         return
554                 }
555
556                 _, err = io.Copy(resp, reader)
557                 if err != nil {
558                         status = http.StatusBadGateway
559                         return
560                 }
561         }
562
563         // Got index from all the keep servers and wrote to resp
564         status = http.StatusOK
565         resp.Write([]byte("\n"))
566 }
567
568 func (h *proxyHandler) makeKeepClient(req *http.Request) *keepclient.KeepClient {
569         kc := *h.KeepClient
570         kc.RequestID = req.Header.Get("X-Request-Id")
571         kc.HTTPClient = &proxyClient{
572                 client: &http.Client{
573                         Timeout:   h.timeout,
574                         Transport: h.transport,
575                 },
576                 proto: req.Proto,
577         }
578         return &kc
579 }