13619: More tests for paging, error conditions
[arvados.git] / lib / controller / federation.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bufio"
9         "bytes"
10         "context"
11         "crypto/md5"
12         "database/sql"
13         "encoding/json"
14         "fmt"
15         "io"
16         "io/ioutil"
17         "net/http"
18         "net/url"
19         "regexp"
20         "strings"
21         "sync"
22
23         "git.curoverse.com/arvados.git/sdk/go/arvados"
24         "git.curoverse.com/arvados.git/sdk/go/auth"
25         "git.curoverse.com/arvados.git/sdk/go/httpserver"
26         "git.curoverse.com/arvados.git/sdk/go/keepclient"
27 )
28
29 var pathPattern = `^/arvados/v1/%s(/([0-9a-z]{5})-%s-[0-9a-z]{15})?(.*)$`
30 var wfRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "workflows", "7fd4e"))
31 var containersRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "containers", "dz642"))
32 var containerRequestsRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "container_requests", "xvhdp"))
33 var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
34 var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
35
36 type genericFederatedRequestHandler struct {
37         next    http.Handler
38         handler *Handler
39         matcher *regexp.Regexp
40 }
41
42 type collectionFederatedRequestHandler struct {
43         next    http.Handler
44         handler *Handler
45 }
46
47 func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
48         remote, ok := h.Cluster.RemoteClusters[remoteID]
49         if !ok {
50                 httpserver.Error(w, "no proxy available for cluster "+remoteID, http.StatusNotFound)
51                 return
52         }
53         scheme := remote.Scheme
54         if scheme == "" {
55                 scheme = "https"
56         }
57         err := h.saltAuthToken(req, remoteID)
58         if err != nil {
59                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
60                 return
61         }
62         urlOut := &url.URL{
63                 Scheme:   scheme,
64                 Host:     remote.Host,
65                 Path:     req.URL.Path,
66                 RawPath:  req.URL.RawPath,
67                 RawQuery: req.URL.RawQuery,
68         }
69         client := h.secureClient
70         if remote.Insecure {
71                 client = h.insecureClient
72         }
73         h.proxy.Do(w, req, urlOut, client, filter)
74 }
75
76 // loadParamsFromForm expects a request with
77 // application/x-www-form-urlencoded body.  It parses the query, adds
78 // the query parameters to "params", and replaces the request body
79 // with a buffer holding the original body contents so it can be
80 // re-read by downstream proxy steps.
81 func loadParamsFromForm(req *http.Request, params url.Values) error {
82         body, err := ioutil.ReadAll(req.Body)
83         if err != nil {
84                 return err
85         }
86         req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
87         var v2 url.Values
88         if v2, err = url.ParseQuery(string(body)); err != nil {
89                 return err
90         }
91         for k, v := range v2 {
92                 params[k] = append(params[k], v...)
93         }
94         return nil
95 }
96
97 // loadParamsFromForm expects a request with application/json body.
98 // It parses the body, populates "loadInto", and replaces the request
99 // body with a buffer holding the original body contents so it can be
100 // re-read by downstream proxy steps.
101 func loadParamsFromJson(req *http.Request, loadInto interface{}) error {
102         var cl int64
103         if req.ContentLength > 0 {
104                 cl = req.ContentLength
105         }
106         postBody := bytes.NewBuffer(make([]byte, 0, cl))
107         defer req.Body.Close()
108
109         rdr := io.TeeReader(req.Body, postBody)
110
111         err := json.NewDecoder(rdr).Decode(loadInto)
112         if err != nil {
113                 return err
114         }
115         req.Body = ioutil.NopCloser(postBody)
116         return nil
117 }
118
119 type multiClusterQueryResponseCollector struct {
120         responses []interface{}
121         error     error
122         kind      string
123 }
124
125 func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
126         requestError error) (newResponse *http.Response, err error) {
127         if requestError != nil {
128                 c.error = requestError
129                 return nil, nil
130         }
131         defer resp.Body.Close()
132         loadInto := make(map[string]interface{})
133         err = json.NewDecoder(resp.Body).Decode(&loadInto)
134
135         if err == nil {
136                 if resp.StatusCode != http.StatusOK {
137                         c.error = fmt.Errorf("error %v", loadInto["errors"])
138                 } else {
139                         c.responses = loadInto["items"].([]interface{})
140                         c.kind, _ = loadInto["kind"].(string)
141                 }
142         } else {
143                 c.error = err
144         }
145
146         return nil, nil
147 }
148
149 func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
150         req *http.Request, params url.Values,
151         clusterID string, uuids []string) (rp []interface{}, kind string, err error) {
152
153         found := make(map[string]bool)
154         for len(uuids) > 0 {
155                 var remoteReq http.Request
156                 remoteReq.Header = req.Header
157                 remoteReq.Method = "POST"
158                 remoteReq.URL = &url.URL{Path: req.URL.Path}
159                 remoteParams := make(url.Values)
160                 remoteParams["_method"] = []string{"GET"}
161                 remoteParams["count"] = []string{"none"}
162                 if len(params["select"]) != 0 {
163                         remoteParams["select"] = params["select"]
164                 }
165                 content, err := json.Marshal(uuids)
166                 if err != nil {
167                         return nil, "", err
168                 }
169                 remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
170                 enc := remoteParams.Encode()
171                 remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
172
173                 rc := multiClusterQueryResponseCollector{}
174
175                 if clusterID == h.handler.Cluster.ClusterID {
176                         h.handler.localClusterRequest(w, &remoteReq,
177                                 rc.collectResponse)
178                 } else {
179                         h.handler.remoteClusterRequest(clusterID, w, &remoteReq,
180                                 rc.collectResponse)
181                 }
182                 if rc.error != nil {
183                         return nil, "", rc.error
184                 }
185
186                 kind = rc.kind
187
188                 if len(rc.responses) == 0 {
189                         // We got zero responses, no point in doing
190                         // another query.
191                         return rp, kind, nil
192                 }
193
194                 rp = append(rp, rc.responses...)
195
196                 // Go through the responses and determine what was
197                 // returned.  If there are remaining items, loop
198                 // around and do another request with just the
199                 // stragglers.
200                 for _, i := range rc.responses {
201                         m, ok := i.(map[string]interface{})
202                         if ok {
203                                 uuid, ok := m["uuid"].(string)
204                                 if ok {
205                                         found[uuid] = true
206                                 }
207                         }
208                 }
209
210                 l := []string{}
211                 for _, u := range uuids {
212                         if !found[u] {
213                                 l = append(l, u)
214                         }
215                 }
216                 uuids = l
217         }
218
219         return rp, kind, nil
220 }
221
222 func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter, req *http.Request,
223         params url.Values, clusterId *string) bool {
224
225         var filters [][]interface{}
226         err := json.Unmarshal([]byte(params["filters"][0]), &filters)
227         if err != nil {
228                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
229                 return true
230         }
231
232         // Split the list of uuids by prefix
233         queryClusters := make(map[string][]string)
234         expectCount := 0
235         for _, f1 := range filters {
236                 if len(f1) != 3 {
237                         return false
238                 }
239                 lhs, ok := f1[0].(string)
240                 if ok && lhs == "uuid" {
241                         op, ok := f1[1].(string)
242                         if !ok {
243                                 return false
244                         }
245
246                         if op == "in" {
247                                 rhs, ok := f1[2].([]interface{})
248                                 if ok {
249                                         for _, i := range rhs {
250                                                 u := i.(string)
251                                                 *clusterId = u[0:5]
252                                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
253                                         }
254                                         expectCount += len(rhs)
255                                 }
256                         } else if op == "=" {
257                                 u, ok := f1[2].(string)
258                                 if ok {
259                                         *clusterId = u[0:5]
260                                         queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
261                                         expectCount += 1
262                                 }
263                         } else {
264                                 return false
265                         }
266                 } else {
267                         return false
268                 }
269         }
270
271         if len(queryClusters) <= 1 {
272                 // Query does not search for uuids across multiple
273                 // clusters.
274                 return false
275         }
276
277         // Validations
278         if !(len(params["count"]) == 1 && (params["count"][0] == `none` ||
279                 params["count"][0] == `"none"`)) {
280                 httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
281                 return true
282         }
283         if len(params["limit"]) != 0 || len(params["offset"]) != 0 || len(params["order"]) != 0 {
284                 httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
285                 return true
286         }
287         if expectCount > h.handler.Cluster.MaxItemsPerResponse {
288                 httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
289                         expectCount, h.handler.Cluster.MaxItemsPerResponse), http.StatusBadRequest)
290                 return true
291         }
292         if len(params["select"]) == 1 {
293                 foundUUID := false
294                 var selects []interface{}
295                 err := json.Unmarshal([]byte(params["select"][0]), &selects)
296                 if err != nil {
297                         httpserver.Error(w, err.Error(), http.StatusBadRequest)
298                         return true
299                 }
300
301                 for _, r := range selects {
302                         if r.(string) == "uuid" {
303                                 foundUUID = true
304                                 break
305                         }
306                 }
307                 if !foundUUID {
308                         httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
309                         return true
310                 }
311         }
312
313         // Perform parallel requests to each cluster
314
315         // use channel as a semaphore to limit the number of parallel
316         // requests at a time
317         sem := make(chan bool, h.handler.Cluster.ParallelRemoteRequests)
318         defer close(sem)
319         wg := sync.WaitGroup{}
320
321         req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
322         mtx := sync.Mutex{}
323         errors := []error{}
324         var completeResponses []interface{}
325         var kind string
326
327         for k, v := range queryClusters {
328                 if len(v) == 0 {
329                         // Nothing to query
330                         continue
331                 }
332
333                 // blocks until it can put a value into the
334                 // channel (which has a max queue capacity)
335                 sem <- true
336                 wg.Add(1)
337                 go func(k string, v []string) {
338                         rp, kn, err := h.remoteQueryUUIDs(w, req, params, k, v)
339                         mtx.Lock()
340                         if err == nil {
341                                 completeResponses = append(completeResponses, rp...)
342                                 kind = kn
343                         } else {
344                                 errors = append(errors, err)
345                         }
346                         mtx.Unlock()
347                         wg.Done()
348                         <-sem
349                 }(k, v)
350         }
351         wg.Wait()
352
353         if len(errors) > 0 {
354                 var strerr []string
355                 for _, e := range errors {
356                         strerr = append(strerr, e.Error())
357                 }
358                 httpserver.Errors(w, strerr, http.StatusBadGateway)
359                 return true
360         }
361
362         w.Header().Set("Content-Type", "application/json")
363         w.WriteHeader(http.StatusOK)
364         itemList := make(map[string]interface{})
365         itemList["items"] = completeResponses
366         itemList["kind"] = kind
367         json.NewEncoder(w).Encode(itemList)
368
369         return true
370 }
371
372 func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
373         m := h.matcher.FindStringSubmatch(req.URL.Path)
374         clusterId := ""
375
376         if len(m) > 0 && m[2] != "" {
377                 clusterId = m[2]
378         }
379
380         // First, parse the query portion of the URL.
381         var params url.Values
382         var err error
383         if params, err = url.ParseQuery(req.URL.RawQuery); err != nil {
384                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
385                 return
386         }
387
388         // Next, if appropriate, merge in parameters from the form POST body.
389         if req.Method == "POST" && req.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
390                 if err = loadParamsFromForm(req, params); err != nil {
391                         httpserver.Error(w, err.Error(), http.StatusBadRequest)
392                         return
393                 }
394         }
395
396         // Check if the parameters have an explicit cluster_id
397         if len(params["cluster_id"]) == 1 {
398                 clusterId = params["cluster_id"][0]
399         }
400
401         // Handle the POST-as-GET special case (workaround for large
402         // GET requests that potentially exceed maximum URL length,
403         // like multi-object queries where the filter has 100s of
404         // items)
405         effectiveMethod := req.Method
406         if req.Method == "POST" && len(params["_method"]) == 1 {
407                 effectiveMethod = params["_method"][0]
408         }
409
410         if effectiveMethod == "GET" && clusterId == "" && len(params["filters"]) == 1 {
411                 if h.handleMultiClusterQuery(w, req, params, &clusterId) {
412                         return
413                 }
414         }
415
416         if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
417                 h.next.ServeHTTP(w, req)
418         } else {
419                 h.handler.remoteClusterRequest(clusterId, w, req, nil)
420         }
421 }
422
423 type rewriteSignaturesClusterId struct {
424         clusterID  string
425         expectHash string
426 }
427
428 func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
429         if requestError != nil {
430                 return resp, requestError
431         }
432
433         if resp.StatusCode != 200 {
434                 return resp, nil
435         }
436
437         originalBody := resp.Body
438         defer originalBody.Close()
439
440         var col arvados.Collection
441         err = json.NewDecoder(resp.Body).Decode(&col)
442         if err != nil {
443                 return nil, err
444         }
445
446         // rewriting signatures will make manifest text 5-10% bigger so calculate
447         // capacity accordingly
448         updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
449
450         hasher := md5.New()
451         mw := io.MultiWriter(hasher, updatedManifest)
452         sz := 0
453
454         scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
455         scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
456         for scanner.Scan() {
457                 line := scanner.Text()
458                 tokens := strings.Split(line, " ")
459                 if len(tokens) < 3 {
460                         return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
461                 }
462
463                 n, err := mw.Write([]byte(tokens[0]))
464                 if err != nil {
465                         return nil, fmt.Errorf("Error updating manifest: %v", err)
466                 }
467                 sz += n
468                 for _, token := range tokens[1:] {
469                         n, err = mw.Write([]byte(" "))
470                         if err != nil {
471                                 return nil, fmt.Errorf("Error updating manifest: %v", err)
472                         }
473                         sz += n
474
475                         m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
476                         if m != nil {
477                                 // Rewrite the block signature to be a remote signature
478                                 _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], rw.clusterID, m[5][2:], m[8])
479                                 if err != nil {
480                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
481                                 }
482
483                                 // for hash checking, ignore signatures
484                                 n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
485                                 if err != nil {
486                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
487                                 }
488                                 sz += n
489                         } else {
490                                 n, err = mw.Write([]byte(token))
491                                 if err != nil {
492                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
493                                 }
494                                 sz += n
495                         }
496                 }
497                 n, err = mw.Write([]byte("\n"))
498                 if err != nil {
499                         return nil, fmt.Errorf("Error updating manifest: %v", err)
500                 }
501                 sz += n
502         }
503
504         // Check that expected hash is consistent with
505         // portable_data_hash field of the returned record
506         if rw.expectHash == "" {
507                 rw.expectHash = col.PortableDataHash
508         } else if rw.expectHash != col.PortableDataHash {
509                 return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", rw.expectHash, col.PortableDataHash)
510         }
511
512         // Certify that the computed hash of the manifest_text matches our expectation
513         sum := hasher.Sum(nil)
514         computedHash := fmt.Sprintf("%x+%v", sum, sz)
515         if computedHash != rw.expectHash {
516                 return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, rw.expectHash)
517         }
518
519         col.ManifestText = updatedManifest.String()
520
521         newbody, err := json.Marshal(col)
522         if err != nil {
523                 return nil, err
524         }
525
526         buf := bytes.NewBuffer(newbody)
527         resp.Body = ioutil.NopCloser(buf)
528         resp.ContentLength = int64(buf.Len())
529         resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
530
531         return resp, nil
532 }
533
534 func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
535         if requestError != nil {
536                 return resp, requestError
537         }
538
539         if resp.StatusCode == 404 {
540                 // Suppress returning this result, because we want to
541                 // search the federation.
542                 return nil, nil
543         }
544         return resp, nil
545 }
546
547 type searchRemoteClusterForPDH struct {
548         pdh           string
549         remoteID      string
550         mtx           *sync.Mutex
551         sentResponse  *bool
552         sharedContext *context.Context
553         cancelFunc    func()
554         errors        *[]string
555         statusCode    *int
556 }
557
558 func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
559         s.mtx.Lock()
560         defer s.mtx.Unlock()
561
562         if *s.sentResponse {
563                 // Another request already returned a response
564                 return nil, nil
565         }
566
567         if requestError != nil {
568                 *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
569                 // Record the error and suppress response
570                 return nil, nil
571         }
572
573         if resp.StatusCode != 200 {
574                 // Suppress returning unsuccessful result.  Maybe
575                 // another request will find it.
576                 // TODO collect and return error responses.
577                 *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
578                 if resp.StatusCode != 404 {
579                         // Got a non-404 error response, convert into BadGateway
580                         *s.statusCode = http.StatusBadGateway
581                 }
582                 return nil, nil
583         }
584
585         s.mtx.Unlock()
586
587         // This reads the response body.  We don't want to hold the
588         // lock while doing this because other remote requests could
589         // also have made it to this point, and we don't want a
590         // slow response holding the lock to block a faster response
591         // that is waiting on the lock.
592         newResponse, err = rewriteSignaturesClusterId{s.remoteID, s.pdh}.rewriteSignatures(resp, nil)
593
594         s.mtx.Lock()
595
596         if *s.sentResponse {
597                 // Another request already returned a response
598                 return nil, nil
599         }
600
601         if err != nil {
602                 // Suppress returning unsuccessful result.  Maybe
603                 // another request will be successful.
604                 *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
605                 return nil, nil
606         }
607
608         // We have a successful response.  Suppress/cancel all the
609         // other requests/responses.
610         *s.sentResponse = true
611         s.cancelFunc()
612
613         return newResponse, nil
614 }
615
616 func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
617         if req.Method != "GET" {
618                 // Only handle GET requests right now
619                 h.next.ServeHTTP(w, req)
620                 return
621         }
622
623         m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
624         if len(m) != 2 {
625                 // Not a collection PDH GET request
626                 m = collectionRe.FindStringSubmatch(req.URL.Path)
627                 clusterId := ""
628
629                 if len(m) > 0 {
630                         clusterId = m[2]
631                 }
632
633                 if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
634                         // request for remote collection by uuid
635                         h.handler.remoteClusterRequest(clusterId, w, req,
636                                 rewriteSignaturesClusterId{clusterId, ""}.rewriteSignatures)
637                         return
638                 }
639                 // not a collection UUID request, or it is a request
640                 // for a local UUID, either way, continue down the
641                 // handler stack.
642                 h.next.ServeHTTP(w, req)
643                 return
644         }
645
646         // Request for collection by PDH.  Search the federation.
647
648         // First, query the local cluster.
649         if h.handler.localClusterRequest(w, req, filterLocalClusterResponse) {
650                 return
651         }
652
653         sharedContext, cancelFunc := context.WithCancel(req.Context())
654         defer cancelFunc()
655         req = req.WithContext(sharedContext)
656
657         // Create a goroutine for each cluster in the
658         // RemoteClusters map.  The first valid result gets
659         // returned to the client.  When that happens, all
660         // other outstanding requests are cancelled or
661         // suppressed.
662         sentResponse := false
663         mtx := sync.Mutex{}
664         wg := sync.WaitGroup{}
665         var errors []string
666         var errorCode int = 404
667
668         // use channel as a semaphore to limit the number of parallel
669         // requests at a time
670         sem := make(chan bool, h.handler.Cluster.ParallelRemoteRequests)
671         defer close(sem)
672         for remoteID := range h.handler.Cluster.RemoteClusters {
673                 // blocks until it can put a value into the
674                 // channel (which has a max queue capacity)
675                 sem <- true
676                 if sentResponse {
677                         break
678                 }
679                 search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
680                         &sharedContext, cancelFunc, &errors, &errorCode}
681                 wg.Add(1)
682                 go func() {
683                         h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse)
684                         wg.Done()
685                         <-sem
686                 }()
687         }
688         wg.Wait()
689
690         if sentResponse {
691                 return
692         }
693
694         // No successful responses, so return the error
695         httpserver.Errors(w, errors, errorCode)
696 }
697
698 func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
699         mux := http.NewServeMux()
700         mux.Handle("/arvados/v1/workflows", &genericFederatedRequestHandler{next, h, wfRe})
701         mux.Handle("/arvados/v1/workflows/", &genericFederatedRequestHandler{next, h, wfRe})
702         mux.Handle("/arvados/v1/containers", &genericFederatedRequestHandler{next, h, containersRe})
703         mux.Handle("/arvados/v1/containers/", &genericFederatedRequestHandler{next, h, containersRe})
704         mux.Handle("/arvados/v1/container_requests", &genericFederatedRequestHandler{next, h, containerRequestsRe})
705         mux.Handle("/arvados/v1/container_requests/", &genericFederatedRequestHandler{next, h, containerRequestsRe})
706         mux.Handle("/arvados/v1/collections", next)
707         mux.Handle("/arvados/v1/collections/", &collectionFederatedRequestHandler{next, h})
708         mux.Handle("/", next)
709
710         return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
711                 parts := strings.Split(req.Header.Get("Authorization"), "/")
712                 alreadySalted := (len(parts) == 3 && parts[0] == "Bearer v2" && len(parts[2]) == 40)
713
714                 if alreadySalted ||
715                         strings.Index(req.Header.Get("Via"), "arvados-controller") != -1 {
716                         // The token is already salted, or this is a
717                         // request from another instance of
718                         // arvados-controller.  In either case, we
719                         // don't want to proxy this query, so just
720                         // continue down the instance handler stack.
721                         next.ServeHTTP(w, req)
722                         return
723                 }
724
725                 mux.ServeHTTP(w, req)
726         })
727
728         return mux
729 }
730
731 type CurrentUser struct {
732         Authorization arvados.APIClientAuthorization
733         UUID          string
734 }
735
736 func (h *Handler) validateAPItoken(req *http.Request, user *CurrentUser) error {
737         db, err := h.db(req)
738         if err != nil {
739                 return err
740         }
741         return db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, user.Authorization.APIToken).Scan(&user.Authorization.UUID, &user.UUID)
742 }
743
744 // Extract the auth token supplied in req, and replace it with a
745 // salted token for the remote cluster.
746 func (h *Handler) saltAuthToken(req *http.Request, remote string) error {
747         creds := auth.NewCredentials()
748         creds.LoadTokensFromHTTPRequest(req)
749         if len(creds.Tokens) == 0 && req.Header.Get("Content-Type") == "application/x-www-form-encoded" {
750                 // Override ParseForm's 10MiB limit by ensuring
751                 // req.Body is a *http.maxBytesReader.
752                 req.Body = http.MaxBytesReader(nil, req.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
753                 if err := creds.LoadTokensFromHTTPRequestBody(req); err != nil {
754                         return err
755                 }
756                 // Replace req.Body with a buffer that re-encodes the
757                 // form without api_token, in case we end up
758                 // forwarding the request.
759                 if req.PostForm != nil {
760                         req.PostForm.Del("api_token")
761                 }
762                 req.Body = ioutil.NopCloser(bytes.NewBufferString(req.PostForm.Encode()))
763         }
764         if len(creds.Tokens) == 0 {
765                 return nil
766         }
767         token, err := auth.SaltToken(creds.Tokens[0], remote)
768         if err == auth.ErrObsoleteToken {
769                 // If the token exists in our own database, salt it
770                 // for the remote. Otherwise, assume it was issued by
771                 // the remote, and pass it through unmodified.
772                 currentUser := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: creds.Tokens[0]}}
773                 err = h.validateAPItoken(req, &currentUser)
774                 if err == sql.ErrNoRows {
775                         // Not ours; pass through unmodified.
776                         token = currentUser.Authorization.APIToken
777                 } else if err != nil {
778                         return err
779                 } else {
780                         // Found; make V2 version and salt it.
781                         token, err = auth.SaltToken(currentUser.Authorization.TokenV2(), remote)
782                         if err != nil {
783                                 return err
784                         }
785                 }
786         } else if err != nil {
787                 return err
788         }
789         req.Header.Set("Authorization", "Bearer "+token)
790
791         // Remove api_token=... from the the query string, in case we
792         // end up forwarding the request.
793         if values, err := url.ParseQuery(req.URL.RawQuery); err != nil {
794                 return err
795         } else if _, ok := values["api_token"]; ok {
796                 delete(values, "api_token")
797                 req.URL.RawQuery = values.Encode()
798         }
799         return nil
800 }