14262: saltAuthToken returns copy of request object
[arvados.git] / lib / controller / federation.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bufio"
9         "bytes"
10         "context"
11         "crypto/md5"
12         "database/sql"
13         "encoding/json"
14         "fmt"
15         "io"
16         "io/ioutil"
17         "log"
18         "net/http"
19         "net/url"
20         "regexp"
21         "strings"
22         "sync"
23
24         "git.curoverse.com/arvados.git/sdk/go/arvados"
25         "git.curoverse.com/arvados.git/sdk/go/auth"
26         "git.curoverse.com/arvados.git/sdk/go/httpserver"
27         "git.curoverse.com/arvados.git/sdk/go/keepclient"
28 )
29
30 var pathPattern = `^/arvados/v1/%s(/([0-9a-z]{5})-%s-[0-9a-z]{15})?(.*)$`
31 var wfRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "workflows", "7fd4e"))
32 var containersRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "containers", "dz642"))
33 var containerRequestsRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "container_requests", "xvhdp"))
34 var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
35 var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
36
37 type genericFederatedRequestHandler struct {
38         next    http.Handler
39         handler *Handler
40         matcher *regexp.Regexp
41 }
42
43 type collectionFederatedRequestHandler struct {
44         next    http.Handler
45         handler *Handler
46 }
47
48 func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
49         remote, ok := h.Cluster.RemoteClusters[remoteID]
50         if !ok {
51                 err := fmt.Errorf("no proxy available for cluster %v", remoteID)
52                 if filter != nil {
53                         _, err = filter(nil, err)
54                 }
55                 if err != nil {
56                         httpserver.Error(w, err.Error(), http.StatusNotFound)
57                 }
58                 return
59         }
60         scheme := remote.Scheme
61         if scheme == "" {
62                 scheme = "https"
63         }
64         req, err := h.saltAuthToken(req, remoteID)
65         if err != nil {
66                 if filter != nil {
67                         _, err = filter(nil, err)
68                 }
69                 if err != nil {
70                         httpserver.Error(w, err.Error(), http.StatusBadRequest)
71                 }
72                 return
73         }
74         urlOut := &url.URL{
75                 Scheme:   scheme,
76                 Host:     remote.Host,
77                 Path:     req.URL.Path,
78                 RawPath:  req.URL.RawPath,
79                 RawQuery: req.URL.RawQuery,
80         }
81         client := h.secureClient
82         if remote.Insecure {
83                 client = h.insecureClient
84         }
85         h.proxy.Do(w, req, urlOut, client, filter)
86 }
87
88 // Buffer request body, parse form parameters in request, and then
89 // replace original body with the buffer so it can be re-read by
90 // downstream proxy steps.
91 func loadParamsFromForm(req *http.Request) error {
92         var postBody *bytes.Buffer
93         if req.Body != nil && req.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
94                 var cl int64
95                 if req.ContentLength > 0 {
96                         cl = req.ContentLength
97                 }
98                 postBody = bytes.NewBuffer(make([]byte, 0, cl))
99                 originalBody := req.Body
100                 defer originalBody.Close()
101                 req.Body = ioutil.NopCloser(io.TeeReader(req.Body, postBody))
102         }
103
104         err := req.ParseForm()
105         if err != nil {
106                 return err
107         }
108
109         if req.Body != nil && postBody != nil {
110                 req.Body = ioutil.NopCloser(postBody)
111         }
112         return nil
113 }
114
115 type multiClusterQueryResponseCollector struct {
116         responses []map[string]interface{}
117         error     error
118         kind      string
119         clusterID string
120 }
121
122 func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
123         requestError error) (newResponse *http.Response, err error) {
124         if requestError != nil {
125                 c.error = requestError
126                 return nil, nil
127         }
128
129         defer resp.Body.Close()
130         var loadInto struct {
131                 Kind   string                   `json:"kind"`
132                 Items  []map[string]interface{} `json:"items"`
133                 Errors []string                 `json:"errors"`
134         }
135         err = json.NewDecoder(resp.Body).Decode(&loadInto)
136
137         if err != nil {
138                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
139                 return nil, nil
140         }
141         if resp.StatusCode != http.StatusOK {
142                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
143                 return nil, nil
144         }
145
146         c.responses = loadInto.Items
147         c.kind = loadInto.Kind
148
149         return nil, nil
150 }
151
152 func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
153         req *http.Request,
154         clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
155
156         found := make(map[string]bool)
157         prev_len_uuids := len(uuids) + 1
158         // Loop while
159         // (1) there are more uuids to query
160         // (2) we're making progress - on each iteration the set of
161         // uuids we are expecting for must shrink.
162         for len(uuids) > 0 && len(uuids) < prev_len_uuids {
163                 var remoteReq http.Request
164                 remoteReq.Header = req.Header
165                 remoteReq.Method = "POST"
166                 remoteReq.URL = &url.URL{Path: req.URL.Path}
167                 remoteParams := make(url.Values)
168                 remoteParams.Set("_method", "GET")
169                 remoteParams.Set("count", "none")
170                 if req.Form.Get("select") != "" {
171                         remoteParams.Set("select", req.Form.Get("select"))
172                 }
173                 content, err := json.Marshal(uuids)
174                 if err != nil {
175                         return nil, "", err
176                 }
177                 remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
178                 enc := remoteParams.Encode()
179                 remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
180
181                 rc := multiClusterQueryResponseCollector{clusterID: clusterID}
182
183                 if clusterID == h.handler.Cluster.ClusterID {
184                         h.handler.localClusterRequest(w, &remoteReq,
185                                 rc.collectResponse)
186                 } else {
187                         h.handler.remoteClusterRequest(clusterID, w, &remoteReq,
188                                 rc.collectResponse)
189                 }
190                 if rc.error != nil {
191                         return nil, "", rc.error
192                 }
193
194                 kind = rc.kind
195
196                 if len(rc.responses) == 0 {
197                         // We got zero responses, no point in doing
198                         // another query.
199                         return rp, kind, nil
200                 }
201
202                 rp = append(rp, rc.responses...)
203
204                 // Go through the responses and determine what was
205                 // returned.  If there are remaining items, loop
206                 // around and do another request with just the
207                 // stragglers.
208                 for _, i := range rc.responses {
209                         uuid, ok := i["uuid"].(string)
210                         if ok {
211                                 found[uuid] = true
212                         }
213                 }
214
215                 l := []string{}
216                 for _, u := range uuids {
217                         if !found[u] {
218                                 l = append(l, u)
219                         }
220                 }
221                 prev_len_uuids = len(uuids)
222                 uuids = l
223         }
224
225         return rp, kind, nil
226 }
227
228 func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
229         req *http.Request, clusterId *string) bool {
230
231         var filters [][]interface{}
232         err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
233         if err != nil {
234                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
235                 return true
236         }
237
238         // Split the list of uuids by prefix
239         queryClusters := make(map[string][]string)
240         expectCount := 0
241         for _, filter := range filters {
242                 if len(filter) != 3 {
243                         return false
244                 }
245
246                 if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
247                         return false
248                 }
249
250                 op, ok := filter[1].(string)
251                 if !ok {
252                         return false
253                 }
254
255                 if op == "in" {
256                         if rhs, ok := filter[2].([]interface{}); ok {
257                                 for _, i := range rhs {
258                                         if u, ok := i.(string); ok {
259                                                 *clusterId = u[0:5]
260                                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
261                                                 expectCount += 1
262                                         }
263                                 }
264                         }
265                 } else if op == "=" {
266                         if u, ok := filter[2].(string); ok {
267                                 *clusterId = u[0:5]
268                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
269                                 expectCount += 1
270                         }
271                 } else {
272                         return false
273                 }
274
275         }
276
277         if len(queryClusters) <= 1 {
278                 // Query does not search for uuids across multiple
279                 // clusters.
280                 return false
281         }
282
283         // Validations
284         count := req.Form.Get("count")
285         if count != "" && count != `none` && count != `"none"` {
286                 httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
287                 return true
288         }
289         if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
290                 httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
291                 return true
292         }
293         if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
294                 httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
295                         expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
296                 return true
297         }
298         if req.Form.Get("select") != "" {
299                 foundUUID := false
300                 var selects []string
301                 err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
302                 if err != nil {
303                         httpserver.Error(w, err.Error(), http.StatusBadRequest)
304                         return true
305                 }
306
307                 for _, r := range selects {
308                         if r == "uuid" {
309                                 foundUUID = true
310                                 break
311                         }
312                 }
313                 if !foundUUID {
314                         httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
315                         return true
316                 }
317         }
318
319         // Perform concurrent requests to each cluster
320
321         // use channel as a semaphore to limit the number of concurrent
322         // requests at a time
323         sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
324         defer close(sem)
325         wg := sync.WaitGroup{}
326
327         req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
328         mtx := sync.Mutex{}
329         errors := []error{}
330         var completeResponses []map[string]interface{}
331         var kind string
332
333         for k, v := range queryClusters {
334                 if len(v) == 0 {
335                         // Nothing to query
336                         continue
337                 }
338
339                 // blocks until it can put a value into the
340                 // channel (which has a max queue capacity)
341                 sem <- true
342                 wg.Add(1)
343                 go func(k string, v []string) {
344                         rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
345                         mtx.Lock()
346                         if err == nil {
347                                 completeResponses = append(completeResponses, rp...)
348                                 kind = kn
349                         } else {
350                                 errors = append(errors, err)
351                         }
352                         mtx.Unlock()
353                         wg.Done()
354                         <-sem
355                 }(k, v)
356         }
357         wg.Wait()
358
359         if len(errors) > 0 {
360                 var strerr []string
361                 for _, e := range errors {
362                         strerr = append(strerr, e.Error())
363                 }
364                 httpserver.Errors(w, strerr, http.StatusBadGateway)
365                 return true
366         }
367
368         w.Header().Set("Content-Type", "application/json")
369         w.WriteHeader(http.StatusOK)
370         itemList := make(map[string]interface{})
371         itemList["items"] = completeResponses
372         itemList["kind"] = kind
373         json.NewEncoder(w).Encode(itemList)
374
375         return true
376 }
377
378 func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
379         m := h.matcher.FindStringSubmatch(req.URL.Path)
380         clusterId := ""
381
382         if len(m) > 0 && m[2] != "" {
383                 clusterId = m[2]
384         }
385
386         // Get form parameters from URL and form body (if POST).
387         if err := loadParamsFromForm(req); err != nil {
388                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
389                 return
390         }
391
392         // Check if the parameters have an explicit cluster_id
393         if req.Form.Get("cluster_id") != "" {
394                 clusterId = req.Form.Get("cluster_id")
395         }
396
397         // Handle the POST-as-GET special case (workaround for large
398         // GET requests that potentially exceed maximum URL length,
399         // like multi-object queries where the filter has 100s of
400         // items)
401         effectiveMethod := req.Method
402         if req.Method == "POST" && req.Form.Get("_method") != "" {
403                 effectiveMethod = req.Form.Get("_method")
404         }
405
406         if effectiveMethod == "GET" &&
407                 clusterId == "" &&
408                 req.Form.Get("filters") != "" &&
409                 h.handleMultiClusterQuery(w, req, &clusterId) {
410                 return
411         }
412
413         if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
414                 h.next.ServeHTTP(w, req)
415         } else {
416                 h.handler.remoteClusterRequest(clusterId, w, req, nil)
417         }
418 }
419
420 type rewriteSignaturesClusterId struct {
421         clusterID  string
422         expectHash string
423 }
424
425 func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
426         if requestError != nil {
427                 return resp, requestError
428         }
429
430         if resp.StatusCode != 200 {
431                 return resp, nil
432         }
433
434         originalBody := resp.Body
435         defer originalBody.Close()
436
437         var col arvados.Collection
438         err = json.NewDecoder(resp.Body).Decode(&col)
439         if err != nil {
440                 return nil, err
441         }
442
443         // rewriting signatures will make manifest text 5-10% bigger so calculate
444         // capacity accordingly
445         updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
446
447         hasher := md5.New()
448         mw := io.MultiWriter(hasher, updatedManifest)
449         sz := 0
450
451         scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
452         scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
453         for scanner.Scan() {
454                 line := scanner.Text()
455                 tokens := strings.Split(line, " ")
456                 if len(tokens) < 3 {
457                         return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
458                 }
459
460                 n, err := mw.Write([]byte(tokens[0]))
461                 if err != nil {
462                         return nil, fmt.Errorf("Error updating manifest: %v", err)
463                 }
464                 sz += n
465                 for _, token := range tokens[1:] {
466                         n, err = mw.Write([]byte(" "))
467                         if err != nil {
468                                 return nil, fmt.Errorf("Error updating manifest: %v", err)
469                         }
470                         sz += n
471
472                         m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
473                         if m != nil {
474                                 // Rewrite the block signature to be a remote signature
475                                 _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], rw.clusterID, m[5][2:], m[8])
476                                 if err != nil {
477                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
478                                 }
479
480                                 // for hash checking, ignore signatures
481                                 n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
482                                 if err != nil {
483                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
484                                 }
485                                 sz += n
486                         } else {
487                                 n, err = mw.Write([]byte(token))
488                                 if err != nil {
489                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
490                                 }
491                                 sz += n
492                         }
493                 }
494                 n, err = mw.Write([]byte("\n"))
495                 if err != nil {
496                         return nil, fmt.Errorf("Error updating manifest: %v", err)
497                 }
498                 sz += n
499         }
500
501         // Check that expected hash is consistent with
502         // portable_data_hash field of the returned record
503         if rw.expectHash == "" {
504                 rw.expectHash = col.PortableDataHash
505         } else if rw.expectHash != col.PortableDataHash {
506                 return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", rw.expectHash, col.PortableDataHash)
507         }
508
509         // Certify that the computed hash of the manifest_text matches our expectation
510         sum := hasher.Sum(nil)
511         computedHash := fmt.Sprintf("%x+%v", sum, sz)
512         if computedHash != rw.expectHash {
513                 return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, rw.expectHash)
514         }
515
516         col.ManifestText = updatedManifest.String()
517
518         newbody, err := json.Marshal(col)
519         if err != nil {
520                 return nil, err
521         }
522
523         buf := bytes.NewBuffer(newbody)
524         resp.Body = ioutil.NopCloser(buf)
525         resp.ContentLength = int64(buf.Len())
526         resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
527
528         return resp, nil
529 }
530
531 func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
532         if requestError != nil {
533                 return resp, requestError
534         }
535
536         if resp.StatusCode == 404 {
537                 // Suppress returning this result, because we want to
538                 // search the federation.
539                 return nil, nil
540         }
541         return resp, nil
542 }
543
544 type searchRemoteClusterForPDH struct {
545         pdh           string
546         remoteID      string
547         mtx           *sync.Mutex
548         sentResponse  *bool
549         sharedContext *context.Context
550         cancelFunc    func()
551         errors        *[]string
552         statusCode    *int
553 }
554
555 func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
556         s.mtx.Lock()
557         defer s.mtx.Unlock()
558
559         if *s.sentResponse {
560                 // Another request already returned a response
561                 return nil, nil
562         }
563
564         if requestError != nil {
565                 *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
566                 // Record the error and suppress response
567                 return nil, nil
568         }
569
570         if resp.StatusCode != 200 {
571                 // Suppress returning unsuccessful result.  Maybe
572                 // another request will find it.
573                 // TODO collect and return error responses.
574                 *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
575                 if resp.StatusCode != 404 {
576                         // Got a non-404 error response, convert into BadGateway
577                         *s.statusCode = http.StatusBadGateway
578                 }
579                 return nil, nil
580         }
581
582         s.mtx.Unlock()
583
584         // This reads the response body.  We don't want to hold the
585         // lock while doing this because other remote requests could
586         // also have made it to this point, and we don't want a
587         // slow response holding the lock to block a faster response
588         // that is waiting on the lock.
589         newResponse, err = rewriteSignaturesClusterId{s.remoteID, s.pdh}.rewriteSignatures(resp, nil)
590
591         s.mtx.Lock()
592
593         if *s.sentResponse {
594                 // Another request already returned a response
595                 return nil, nil
596         }
597
598         if err != nil {
599                 // Suppress returning unsuccessful result.  Maybe
600                 // another request will be successful.
601                 *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
602                 return nil, nil
603         }
604
605         // We have a successful response.  Suppress/cancel all the
606         // other requests/responses.
607         *s.sentResponse = true
608         s.cancelFunc()
609
610         return newResponse, nil
611 }
612
613 func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
614         if req.Method != "GET" {
615                 // Only handle GET requests right now
616                 h.next.ServeHTTP(w, req)
617                 return
618         }
619
620         m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
621         if len(m) != 2 {
622                 // Not a collection PDH GET request
623                 m = collectionRe.FindStringSubmatch(req.URL.Path)
624                 clusterId := ""
625
626                 if len(m) > 0 {
627                         clusterId = m[2]
628                 }
629
630                 if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
631                         // request for remote collection by uuid
632                         h.handler.remoteClusterRequest(clusterId, w, req,
633                                 rewriteSignaturesClusterId{clusterId, ""}.rewriteSignatures)
634                         return
635                 }
636                 // not a collection UUID request, or it is a request
637                 // for a local UUID, either way, continue down the
638                 // handler stack.
639                 h.next.ServeHTTP(w, req)
640                 return
641         }
642
643         // Request for collection by PDH.  Search the federation.
644
645         // First, query the local cluster.
646         if h.handler.localClusterRequest(w, req, filterLocalClusterResponse) {
647                 return
648         }
649
650         sharedContext, cancelFunc := context.WithCancel(req.Context())
651         defer cancelFunc()
652         req = req.WithContext(sharedContext)
653
654         // Create a goroutine for each cluster in the
655         // RemoteClusters map.  The first valid result gets
656         // returned to the client.  When that happens, all
657         // other outstanding requests are cancelled or
658         // suppressed.
659         sentResponse := false
660         mtx := sync.Mutex{}
661         wg := sync.WaitGroup{}
662         var errors []string
663         var errorCode int = 404
664
665         // use channel as a semaphore to limit the number of concurrent
666         // requests at a time
667         sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
668         defer close(sem)
669         for remoteID := range h.handler.Cluster.RemoteClusters {
670                 if remoteID == h.handler.Cluster.ClusterID {
671                         // No need to query local cluster again
672                         continue
673                 }
674                 // blocks until it can put a value into the
675                 // channel (which has a max queue capacity)
676                 sem <- true
677                 if sentResponse {
678                         break
679                 }
680                 search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
681                         &sharedContext, cancelFunc, &errors, &errorCode}
682                 wg.Add(1)
683                 go func() {
684                         h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse)
685                         wg.Done()
686                         <-sem
687                 }()
688         }
689         wg.Wait()
690
691         if sentResponse {
692                 return
693         }
694
695         // No successful responses, so return the error
696         httpserver.Errors(w, errors, errorCode)
697 }
698
699 func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
700         mux := http.NewServeMux()
701         mux.Handle("/arvados/v1/workflows", &genericFederatedRequestHandler{next, h, wfRe})
702         mux.Handle("/arvados/v1/workflows/", &genericFederatedRequestHandler{next, h, wfRe})
703         mux.Handle("/arvados/v1/containers", &genericFederatedRequestHandler{next, h, containersRe})
704         mux.Handle("/arvados/v1/containers/", &genericFederatedRequestHandler{next, h, containersRe})
705         mux.Handle("/arvados/v1/container_requests", &genericFederatedRequestHandler{next, h, containerRequestsRe})
706         mux.Handle("/arvados/v1/container_requests/", &genericFederatedRequestHandler{next, h, containerRequestsRe})
707         mux.Handle("/arvados/v1/collections", next)
708         mux.Handle("/arvados/v1/collections/", &collectionFederatedRequestHandler{next, h})
709         mux.Handle("/", next)
710
711         return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
712                 parts := strings.Split(req.Header.Get("Authorization"), "/")
713                 alreadySalted := (len(parts) == 3 && parts[0] == "Bearer v2" && len(parts[2]) == 40)
714
715                 if alreadySalted ||
716                         strings.Index(req.Header.Get("Via"), "arvados-controller") != -1 {
717                         // The token is already salted, or this is a
718                         // request from another instance of
719                         // arvados-controller.  In either case, we
720                         // don't want to proxy this query, so just
721                         // continue down the instance handler stack.
722                         next.ServeHTTP(w, req)
723                         return
724                 }
725
726                 mux.ServeHTTP(w, req)
727         })
728
729         return mux
730 }
731
732 type CurrentUser struct {
733         Authorization arvados.APIClientAuthorization
734         UUID          string
735 }
736
737 func (h *Handler) validateAPItoken(req *http.Request, user *CurrentUser) error {
738         db, err := h.db(req)
739         if err != nil {
740                 return err
741         }
742         return db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, user.Authorization.APIToken).Scan(&user.Authorization.UUID, &user.UUID)
743 }
744
745 // Extract the auth token supplied in req, and replace it with a
746 // salted token for the remote cluster.
747 func (h *Handler) saltAuthToken(req *http.Request, remote string) (updatedReq *http.Request, err error) {
748         updatedReq = (&http.Request{
749                 Method:        req.Method,
750                 URL:           req.URL,
751                 Header:        req.Header,
752                 Body:          req.Body,
753                 ContentLength: req.ContentLength,
754                 Host:          req.Host,
755         }).WithContext(req.Context())
756
757         creds := auth.NewCredentials()
758         creds.LoadTokensFromHTTPRequest(updatedReq)
759         if len(creds.Tokens) == 0 && updatedReq.Header.Get("Content-Type") == "application/x-www-form-encoded" {
760                 // Override ParseForm's 10MiB limit by ensuring
761                 // req.Body is a *http.maxBytesReader.
762                 updatedReq.Body = http.MaxBytesReader(nil, updatedReq.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
763                 if err := creds.LoadTokensFromHTTPRequestBody(updatedReq); err != nil {
764                         return nil, err
765                 }
766                 // Replace req.Body with a buffer that re-encodes the
767                 // form without api_token, in case we end up
768                 // forwarding the request.
769                 if updatedReq.PostForm != nil {
770                         updatedReq.PostForm.Del("api_token")
771                 }
772                 updatedReq.Body = ioutil.NopCloser(bytes.NewBufferString(updatedReq.PostForm.Encode()))
773         }
774         if len(creds.Tokens) == 0 {
775                 return updatedReq, nil
776         }
777
778         token, err := auth.SaltToken(creds.Tokens[0], remote)
779
780         log.Printf("Salting %q %q to get %q %q", creds.Tokens[0], remote, token, err)
781         if err == auth.ErrObsoleteToken {
782                 // If the token exists in our own database, salt it
783                 // for the remote. Otherwise, assume it was issued by
784                 // the remote, and pass it through unmodified.
785                 currentUser := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: creds.Tokens[0]}}
786                 err = h.validateAPItoken(req, &currentUser)
787                 if err == sql.ErrNoRows {
788                         // Not ours; pass through unmodified.
789                         token = currentUser.Authorization.APIToken
790                 } else if err != nil {
791                         return nil, err
792                 } else {
793                         // Found; make V2 version and salt it.
794                         token, err = auth.SaltToken(currentUser.Authorization.TokenV2(), remote)
795                         if err != nil {
796                                 return nil, err
797                         }
798                 }
799         } else if err != nil {
800                 return nil, err
801         }
802         updatedReq.Header = http.Header{}
803         for k, v := range req.Header {
804                 if k == "Authorization" {
805                         updatedReq.Header[k] = []string{"Bearer " + token}
806                 } else {
807                         updatedReq.Header[k] = v
808                 }
809         }
810
811         log.Printf("Salted %q %q to get %q", creds.Tokens[0], remote, token)
812
813         // Remove api_token=... from the the query string, in case we
814         // end up forwarding the request.
815         if values, err := url.ParseQuery(updatedReq.URL.RawQuery); err != nil {
816                 return nil, err
817         } else if _, ok := values["api_token"]; ok {
818                 delete(values, "api_token")
819                 updatedReq.URL = &url.URL{
820                         Scheme:   req.URL.Scheme,
821                         Host:     req.URL.Host,
822                         Path:     req.URL.Path,
823                         RawPath:  req.URL.RawPath,
824                         RawQuery: values.Encode(),
825                 }
826         }
827         return updatedReq, nil
828 }