14262: Refactoring proxy
[arvados.git] / lib / controller / federation.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bufio"
9         "bytes"
10         "context"
11         "crypto/md5"
12         "database/sql"
13         "encoding/json"
14         "fmt"
15         "io"
16         "io/ioutil"
17         "net/http"
18         "net/url"
19         "regexp"
20         "strings"
21         "sync"
22
23         "git.curoverse.com/arvados.git/sdk/go/arvados"
24         "git.curoverse.com/arvados.git/sdk/go/auth"
25         "git.curoverse.com/arvados.git/sdk/go/httpserver"
26         "git.curoverse.com/arvados.git/sdk/go/keepclient"
27 )
28
29 var pathPattern = `^/arvados/v1/%s(/([0-9a-z]{5})-%s-[0-9a-z]{15})?(.*)$`
30 var wfRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "workflows", "7fd4e"))
31 var containersRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "containers", "dz642"))
32 var containerRequestsRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "container_requests", "xvhdp"))
33 var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
34 var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
35
36 type genericFederatedRequestHandler struct {
37         next    http.Handler
38         handler *Handler
39         matcher *regexp.Regexp
40 }
41
42 type collectionFederatedRequestHandler struct {
43         next    http.Handler
44         handler *Handler
45 }
46
47 func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
48         remote, ok := h.Cluster.RemoteClusters[remoteID]
49         if !ok {
50                 return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
51         }
52         scheme := remote.Scheme
53         if scheme == "" {
54                 scheme = "https"
55         }
56         saltedReq, err := h.saltAuthToken(req, remoteID)
57         if err != nil {
58                 return nil, err
59         }
60         urlOut := &url.URL{
61                 Scheme:   scheme,
62                 Host:     remote.Host,
63                 Path:     saltedReq.URL.Path,
64                 RawPath:  saltedReq.URL.RawPath,
65                 RawQuery: saltedReq.URL.RawQuery,
66         }
67         client := h.secureClient
68         if remote.Insecure {
69                 client = h.insecureClient
70         }
71         return h.proxy.ForwardRequest(saltedReq, urlOut, client)
72 }
73
74 // Buffer request body, parse form parameters in request, and then
75 // replace original body with the buffer so it can be re-read by
76 // downstream proxy steps.
77 func loadParamsFromForm(req *http.Request) error {
78         var postBody *bytes.Buffer
79         if req.Body != nil && req.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
80                 var cl int64
81                 if req.ContentLength > 0 {
82                         cl = req.ContentLength
83                 }
84                 postBody = bytes.NewBuffer(make([]byte, 0, cl))
85                 originalBody := req.Body
86                 defer originalBody.Close()
87                 req.Body = ioutil.NopCloser(io.TeeReader(req.Body, postBody))
88         }
89
90         err := req.ParseForm()
91         if err != nil {
92                 return err
93         }
94
95         if req.Body != nil && postBody != nil {
96                 req.Body = ioutil.NopCloser(postBody)
97         }
98         return nil
99 }
100
101 type multiClusterQueryResponseCollector struct {
102         responses []map[string]interface{}
103         error     error
104         kind      string
105         clusterID string
106 }
107
108 func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
109         requestError error) (newResponse *http.Response, err error) {
110         if requestError != nil {
111                 c.error = requestError
112                 return nil, nil
113         }
114
115         defer resp.Body.Close()
116         var loadInto struct {
117                 Kind   string                   `json:"kind"`
118                 Items  []map[string]interface{} `json:"items"`
119                 Errors []string                 `json:"errors"`
120         }
121         err = json.NewDecoder(resp.Body).Decode(&loadInto)
122
123         if err != nil {
124                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
125                 return nil, nil
126         }
127         if resp.StatusCode != http.StatusOK {
128                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
129                 return nil, nil
130         }
131
132         c.responses = loadInto.Items
133         c.kind = loadInto.Kind
134
135         return nil, nil
136 }
137
138 func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
139         req *http.Request,
140         clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
141
142         found := make(map[string]bool)
143         prev_len_uuids := len(uuids) + 1
144         // Loop while
145         // (1) there are more uuids to query
146         // (2) we're making progress - on each iteration the set of
147         // uuids we are expecting for must shrink.
148         for len(uuids) > 0 && len(uuids) < prev_len_uuids {
149                 var remoteReq http.Request
150                 remoteReq.Header = req.Header
151                 remoteReq.Method = "POST"
152                 remoteReq.URL = &url.URL{Path: req.URL.Path}
153                 remoteParams := make(url.Values)
154                 remoteParams.Set("_method", "GET")
155                 remoteParams.Set("count", "none")
156                 if req.Form.Get("select") != "" {
157                         remoteParams.Set("select", req.Form.Get("select"))
158                 }
159                 content, err := json.Marshal(uuids)
160                 if err != nil {
161                         return nil, "", err
162                 }
163                 remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
164                 enc := remoteParams.Encode()
165                 remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
166
167                 rc := multiClusterQueryResponseCollector{clusterID: clusterID}
168
169                 var resp *http.Response
170                 if clusterID == h.handler.Cluster.ClusterID {
171                         resp, err = h.handler.localClusterRequest(&remoteReq)
172                 } else {
173                         resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
174                 }
175                 rc.collectResponse(resp, err)
176
177                 if rc.error != nil {
178                         return nil, "", rc.error
179                 }
180
181                 kind = rc.kind
182
183                 if len(rc.responses) == 0 {
184                         // We got zero responses, no point in doing
185                         // another query.
186                         return rp, kind, nil
187                 }
188
189                 rp = append(rp, rc.responses...)
190
191                 // Go through the responses and determine what was
192                 // returned.  If there are remaining items, loop
193                 // around and do another request with just the
194                 // stragglers.
195                 for _, i := range rc.responses {
196                         uuid, ok := i["uuid"].(string)
197                         if ok {
198                                 found[uuid] = true
199                         }
200                 }
201
202                 l := []string{}
203                 for _, u := range uuids {
204                         if !found[u] {
205                                 l = append(l, u)
206                         }
207                 }
208                 prev_len_uuids = len(uuids)
209                 uuids = l
210         }
211
212         return rp, kind, nil
213 }
214
215 func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
216         req *http.Request, clusterId *string) bool {
217
218         var filters [][]interface{}
219         err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
220         if err != nil {
221                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
222                 return true
223         }
224
225         // Split the list of uuids by prefix
226         queryClusters := make(map[string][]string)
227         expectCount := 0
228         for _, filter := range filters {
229                 if len(filter) != 3 {
230                         return false
231                 }
232
233                 if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
234                         return false
235                 }
236
237                 op, ok := filter[1].(string)
238                 if !ok {
239                         return false
240                 }
241
242                 if op == "in" {
243                         if rhs, ok := filter[2].([]interface{}); ok {
244                                 for _, i := range rhs {
245                                         if u, ok := i.(string); ok {
246                                                 *clusterId = u[0:5]
247                                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
248                                                 expectCount += 1
249                                         }
250                                 }
251                         }
252                 } else if op == "=" {
253                         if u, ok := filter[2].(string); ok {
254                                 *clusterId = u[0:5]
255                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
256                                 expectCount += 1
257                         }
258                 } else {
259                         return false
260                 }
261
262         }
263
264         if len(queryClusters) <= 1 {
265                 // Query does not search for uuids across multiple
266                 // clusters.
267                 return false
268         }
269
270         // Validations
271         count := req.Form.Get("count")
272         if count != "" && count != `none` && count != `"none"` {
273                 httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
274                 return true
275         }
276         if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
277                 httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
278                 return true
279         }
280         if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
281                 httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
282                         expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
283                 return true
284         }
285         if req.Form.Get("select") != "" {
286                 foundUUID := false
287                 var selects []string
288                 err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
289                 if err != nil {
290                         httpserver.Error(w, err.Error(), http.StatusBadRequest)
291                         return true
292                 }
293
294                 for _, r := range selects {
295                         if r == "uuid" {
296                                 foundUUID = true
297                                 break
298                         }
299                 }
300                 if !foundUUID {
301                         httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
302                         return true
303                 }
304         }
305
306         // Perform concurrent requests to each cluster
307
308         // use channel as a semaphore to limit the number of concurrent
309         // requests at a time
310         sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
311         defer close(sem)
312         wg := sync.WaitGroup{}
313
314         req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
315         mtx := sync.Mutex{}
316         errors := []error{}
317         var completeResponses []map[string]interface{}
318         var kind string
319
320         for k, v := range queryClusters {
321                 if len(v) == 0 {
322                         // Nothing to query
323                         continue
324                 }
325
326                 // blocks until it can put a value into the
327                 // channel (which has a max queue capacity)
328                 sem <- true
329                 wg.Add(1)
330                 go func(k string, v []string) {
331                         rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
332                         mtx.Lock()
333                         if err == nil {
334                                 completeResponses = append(completeResponses, rp...)
335                                 kind = kn
336                         } else {
337                                 errors = append(errors, err)
338                         }
339                         mtx.Unlock()
340                         wg.Done()
341                         <-sem
342                 }(k, v)
343         }
344         wg.Wait()
345
346         if len(errors) > 0 {
347                 var strerr []string
348                 for _, e := range errors {
349                         strerr = append(strerr, e.Error())
350                 }
351                 httpserver.Errors(w, strerr, http.StatusBadGateway)
352                 return true
353         }
354
355         w.Header().Set("Content-Type", "application/json")
356         w.WriteHeader(http.StatusOK)
357         itemList := make(map[string]interface{})
358         itemList["items"] = completeResponses
359         itemList["kind"] = kind
360         json.NewEncoder(w).Encode(itemList)
361
362         return true
363 }
364
365 func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
366         m := h.matcher.FindStringSubmatch(req.URL.Path)
367         clusterId := ""
368
369         if len(m) > 0 && m[2] != "" {
370                 clusterId = m[2]
371         }
372
373         // Get form parameters from URL and form body (if POST).
374         if err := loadParamsFromForm(req); err != nil {
375                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
376                 return
377         }
378
379         // Check if the parameters have an explicit cluster_id
380         if req.Form.Get("cluster_id") != "" {
381                 clusterId = req.Form.Get("cluster_id")
382         }
383
384         // Handle the POST-as-GET special case (workaround for large
385         // GET requests that potentially exceed maximum URL length,
386         // like multi-object queries where the filter has 100s of
387         // items)
388         effectiveMethod := req.Method
389         if req.Method == "POST" && req.Form.Get("_method") != "" {
390                 effectiveMethod = req.Form.Get("_method")
391         }
392
393         if effectiveMethod == "GET" &&
394                 clusterId == "" &&
395                 req.Form.Get("filters") != "" &&
396                 h.handleMultiClusterQuery(w, req, &clusterId) {
397                 return
398         }
399
400         if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
401                 h.next.ServeHTTP(w, req)
402         } else {
403                 resp, err := h.handler.remoteClusterRequest(clusterId, req)
404                 h.handler.proxy.ForwardResponse(w, resp, err)
405         }
406 }
407
408 func rewriteSignatures(clusterID string, expectHash string,
409         resp *http.Response, requestError error) (newResponse *http.Response, err error) {
410
411         if requestError != nil {
412                 return resp, requestError
413         }
414
415         if resp.StatusCode != 200 {
416                 return resp, nil
417         }
418
419         originalBody := resp.Body
420         defer originalBody.Close()
421
422         var col arvados.Collection
423         err = json.NewDecoder(resp.Body).Decode(&col)
424         if err != nil {
425                 return nil, err
426         }
427
428         // rewriting signatures will make manifest text 5-10% bigger so calculate
429         // capacity accordingly
430         updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
431
432         hasher := md5.New()
433         mw := io.MultiWriter(hasher, updatedManifest)
434         sz := 0
435
436         scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
437         scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
438         for scanner.Scan() {
439                 line := scanner.Text()
440                 tokens := strings.Split(line, " ")
441                 if len(tokens) < 3 {
442                         return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
443                 }
444
445                 n, err := mw.Write([]byte(tokens[0]))
446                 if err != nil {
447                         return nil, fmt.Errorf("Error updating manifest: %v", err)
448                 }
449                 sz += n
450                 for _, token := range tokens[1:] {
451                         n, err = mw.Write([]byte(" "))
452                         if err != nil {
453                                 return nil, fmt.Errorf("Error updating manifest: %v", err)
454                         }
455                         sz += n
456
457                         m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
458                         if m != nil {
459                                 // Rewrite the block signature to be a remote signature
460                                 _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
461                                 if err != nil {
462                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
463                                 }
464
465                                 // for hash checking, ignore signatures
466                                 n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
467                                 if err != nil {
468                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
469                                 }
470                                 sz += n
471                         } else {
472                                 n, err = mw.Write([]byte(token))
473                                 if err != nil {
474                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
475                                 }
476                                 sz += n
477                         }
478                 }
479                 n, err = mw.Write([]byte("\n"))
480                 if err != nil {
481                         return nil, fmt.Errorf("Error updating manifest: %v", err)
482                 }
483                 sz += n
484         }
485
486         // Check that expected hash is consistent with
487         // portable_data_hash field of the returned record
488         if expectHash == "" {
489                 expectHash = col.PortableDataHash
490         } else if expectHash != col.PortableDataHash {
491                 return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
492         }
493
494         // Certify that the computed hash of the manifest_text matches our expectation
495         sum := hasher.Sum(nil)
496         computedHash := fmt.Sprintf("%x+%v", sum, sz)
497         if computedHash != expectHash {
498                 return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
499         }
500
501         col.ManifestText = updatedManifest.String()
502
503         newbody, err := json.Marshal(col)
504         if err != nil {
505                 return nil, err
506         }
507
508         buf := bytes.NewBuffer(newbody)
509         resp.Body = ioutil.NopCloser(buf)
510         resp.ContentLength = int64(buf.Len())
511         resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
512
513         return resp, nil
514 }
515
516 func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
517         if requestError != nil {
518                 return resp, requestError
519         }
520
521         if resp.StatusCode == 404 {
522                 // Suppress returning this result, because we want to
523                 // search the federation.
524                 return nil, nil
525         }
526         return resp, nil
527 }
528
529 type searchRemoteClusterForPDH struct {
530         pdh           string
531         remoteID      string
532         mtx           *sync.Mutex
533         sentResponse  *bool
534         sharedContext *context.Context
535         cancelFunc    func()
536         errors        *[]string
537         statusCode    *int
538 }
539
540 func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
541         s.mtx.Lock()
542         defer s.mtx.Unlock()
543
544         if *s.sentResponse {
545                 // Another request already returned a response
546                 return nil, nil
547         }
548
549         if requestError != nil {
550                 *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
551                 // Record the error and suppress response
552                 return nil, nil
553         }
554
555         if resp.StatusCode != 200 {
556                 // Suppress returning unsuccessful result.  Maybe
557                 // another request will find it.
558                 // TODO collect and return error responses.
559                 *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
560                 if resp.StatusCode != 404 {
561                         // Got a non-404 error response, convert into BadGateway
562                         *s.statusCode = http.StatusBadGateway
563                 }
564                 return nil, nil
565         }
566
567         s.mtx.Unlock()
568
569         // This reads the response body.  We don't want to hold the
570         // lock while doing this because other remote requests could
571         // also have made it to this point, and we don't want a
572         // slow response holding the lock to block a faster response
573         // that is waiting on the lock.
574         newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
575
576         s.mtx.Lock()
577
578         if *s.sentResponse {
579                 // Another request already returned a response
580                 return nil, nil
581         }
582
583         if err != nil {
584                 // Suppress returning unsuccessful result.  Maybe
585                 // another request will be successful.
586                 *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
587                 return nil, nil
588         }
589
590         // We have a successful response.  Suppress/cancel all the
591         // other requests/responses.
592         *s.sentResponse = true
593         s.cancelFunc()
594
595         return newResponse, nil
596 }
597
598 func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
599         if req.Method != "GET" {
600                 // Only handle GET requests right now
601                 h.next.ServeHTTP(w, req)
602                 return
603         }
604
605         m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
606         if len(m) != 2 {
607                 // Not a collection PDH GET request
608                 m = collectionRe.FindStringSubmatch(req.URL.Path)
609                 clusterId := ""
610
611                 if len(m) > 0 {
612                         clusterId = m[2]
613                 }
614
615                 if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
616                         // request for remote collection by uuid
617                         resp, err := h.handler.remoteClusterRequest(clusterId, req)
618                         newResponse, err := rewriteSignatures(clusterId, "", resp, err)
619                         h.handler.proxy.ForwardResponse(w, newResponse, err)
620                         return
621                 }
622                 // not a collection UUID request, or it is a request
623                 // for a local UUID, either way, continue down the
624                 // handler stack.
625                 h.next.ServeHTTP(w, req)
626                 return
627         }
628
629         // Request for collection by PDH.  Search the federation.
630
631         // First, query the local cluster.
632         resp, err := h.handler.localClusterRequest(req)
633         newResp, err := filterLocalClusterResponse(resp, err)
634         if newResp != nil || err != nil {
635                 h.handler.proxy.ForwardResponse(w, newResp, err)
636                 return
637         }
638
639         sharedContext, cancelFunc := context.WithCancel(req.Context())
640         defer cancelFunc()
641         req = req.WithContext(sharedContext)
642
643         // Create a goroutine for each cluster in the
644         // RemoteClusters map.  The first valid result gets
645         // returned to the client.  When that happens, all
646         // other outstanding requests are cancelled or
647         // suppressed.
648         sentResponse := false
649         mtx := sync.Mutex{}
650         wg := sync.WaitGroup{}
651         var errors []string
652         var errorCode int = 404
653
654         // use channel as a semaphore to limit the number of concurrent
655         // requests at a time
656         sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
657         defer close(sem)
658         for remoteID := range h.handler.Cluster.RemoteClusters {
659                 if remoteID == h.handler.Cluster.ClusterID {
660                         // No need to query local cluster again
661                         continue
662                 }
663                 // blocks until it can put a value into the
664                 // channel (which has a max queue capacity)
665                 sem <- true
666                 if sentResponse {
667                         break
668                 }
669                 search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
670                         &sharedContext, cancelFunc, &errors, &errorCode}
671                 wg.Add(1)
672                 go func() {
673                         resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
674                         newResp, err := search.filterRemoteClusterResponse(resp, err)
675                         if newResp != nil || err != nil {
676                                 h.handler.proxy.ForwardResponse(w, newResp, err)
677                         }
678                         wg.Done()
679                         <-sem
680                 }()
681         }
682         wg.Wait()
683
684         if sentResponse {
685                 return
686         }
687
688         // No successful responses, so return the error
689         httpserver.Errors(w, errors, errorCode)
690 }
691
692 func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
693         mux := http.NewServeMux()
694         mux.Handle("/arvados/v1/workflows", &genericFederatedRequestHandler{next, h, wfRe})
695         mux.Handle("/arvados/v1/workflows/", &genericFederatedRequestHandler{next, h, wfRe})
696         mux.Handle("/arvados/v1/containers", &genericFederatedRequestHandler{next, h, containersRe})
697         mux.Handle("/arvados/v1/containers/", &genericFederatedRequestHandler{next, h, containersRe})
698         mux.Handle("/arvados/v1/container_requests", &genericFederatedRequestHandler{next, h, containerRequestsRe})
699         mux.Handle("/arvados/v1/container_requests/", &genericFederatedRequestHandler{next, h, containerRequestsRe})
700         mux.Handle("/arvados/v1/collections", next)
701         mux.Handle("/arvados/v1/collections/", &collectionFederatedRequestHandler{next, h})
702         mux.Handle("/", next)
703
704         return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
705                 parts := strings.Split(req.Header.Get("Authorization"), "/")
706                 alreadySalted := (len(parts) == 3 && parts[0] == "Bearer v2" && len(parts[2]) == 40)
707
708                 if alreadySalted ||
709                         strings.Index(req.Header.Get("Via"), "arvados-controller") != -1 {
710                         // The token is already salted, or this is a
711                         // request from another instance of
712                         // arvados-controller.  In either case, we
713                         // don't want to proxy this query, so just
714                         // continue down the instance handler stack.
715                         next.ServeHTTP(w, req)
716                         return
717                 }
718
719                 mux.ServeHTTP(w, req)
720         })
721
722         return mux
723 }
724
725 type CurrentUser struct {
726         Authorization arvados.APIClientAuthorization
727         UUID          string
728 }
729
730 func (h *Handler) validateAPItoken(req *http.Request, user *CurrentUser) error {
731         db, err := h.db(req)
732         if err != nil {
733                 return err
734         }
735         return db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, user.Authorization.APIToken).Scan(&user.Authorization.UUID, &user.UUID)
736 }
737
738 // Extract the auth token supplied in req, and replace it with a
739 // salted token for the remote cluster.
740 func (h *Handler) saltAuthToken(req *http.Request, remote string) (updatedReq *http.Request, err error) {
741         updatedReq = (&http.Request{
742                 Method:        req.Method,
743                 URL:           req.URL,
744                 Header:        req.Header,
745                 Body:          req.Body,
746                 ContentLength: req.ContentLength,
747                 Host:          req.Host,
748         }).WithContext(req.Context())
749
750         creds := auth.NewCredentials()
751         creds.LoadTokensFromHTTPRequest(updatedReq)
752         if len(creds.Tokens) == 0 && updatedReq.Header.Get("Content-Type") == "application/x-www-form-encoded" {
753                 // Override ParseForm's 10MiB limit by ensuring
754                 // req.Body is a *http.maxBytesReader.
755                 updatedReq.Body = http.MaxBytesReader(nil, updatedReq.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
756                 if err := creds.LoadTokensFromHTTPRequestBody(updatedReq); err != nil {
757                         return nil, err
758                 }
759                 // Replace req.Body with a buffer that re-encodes the
760                 // form without api_token, in case we end up
761                 // forwarding the request.
762                 if updatedReq.PostForm != nil {
763                         updatedReq.PostForm.Del("api_token")
764                 }
765                 updatedReq.Body = ioutil.NopCloser(bytes.NewBufferString(updatedReq.PostForm.Encode()))
766         }
767         if len(creds.Tokens) == 0 {
768                 return updatedReq, nil
769         }
770
771         token, err := auth.SaltToken(creds.Tokens[0], remote)
772
773         if err == auth.ErrObsoleteToken {
774                 // If the token exists in our own database, salt it
775                 // for the remote. Otherwise, assume it was issued by
776                 // the remote, and pass it through unmodified.
777                 currentUser := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: creds.Tokens[0]}}
778                 err = h.validateAPItoken(req, &currentUser)
779                 if err == sql.ErrNoRows {
780                         // Not ours; pass through unmodified.
781                         token = currentUser.Authorization.APIToken
782                 } else if err != nil {
783                         return nil, err
784                 } else {
785                         // Found; make V2 version and salt it.
786                         token, err = auth.SaltToken(currentUser.Authorization.TokenV2(), remote)
787                         if err != nil {
788                                 return nil, err
789                         }
790                 }
791         } else if err != nil {
792                 return nil, err
793         }
794         updatedReq.Header = http.Header{}
795         for k, v := range req.Header {
796                 if k != "Authorization" {
797                         updatedReq.Header[k] = v
798                 }
799         }
800         updatedReq.Header.Set("Authorization", "Bearer "+token)
801
802         // Remove api_token=... from the the query string, in case we
803         // end up forwarding the request.
804         if values, err := url.ParseQuery(updatedReq.URL.RawQuery); err != nil {
805                 return nil, err
806         } else if _, ok := values["api_token"]; ok {
807                 delete(values, "api_token")
808                 updatedReq.URL = &url.URL{
809                         Scheme:   req.URL.Scheme,
810                         Host:     req.URL.Host,
811                         Path:     req.URL.Path,
812                         RawPath:  req.URL.RawPath,
813                         RawQuery: values.Encode(),
814                 }
815         }
816         return updatedReq, nil
817 }