13994: Merge branch 'master' into 13994-proxy-remote
[arvados.git] / lib / controller / federation.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bufio"
9         "bytes"
10         "context"
11         "database/sql"
12         "encoding/json"
13         "fmt"
14         "io/ioutil"
15         "net/http"
16         "net/url"
17         "regexp"
18         "strings"
19         "sync"
20
21         "git.curoverse.com/arvados.git/sdk/go/arvados"
22         "git.curoverse.com/arvados.git/sdk/go/auth"
23         "git.curoverse.com/arvados.git/sdk/go/httpserver"
24         "git.curoverse.com/arvados.git/sdk/go/keepclient"
25 )
26
27 var wfRe = regexp.MustCompile(`^/arvados/v1/workflows/([0-9a-z]{5})-[^/]+$`)
28 var collectionRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-z]{5})-[^/]+$`)
29 var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
30
31 type genericFederatedRequestHandler struct {
32         next    http.Handler
33         handler *Handler
34 }
35
36 type collectionFederatedRequestHandler struct {
37         next    http.Handler
38         handler *Handler
39 }
40
41 func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
42         remote, ok := h.Cluster.RemoteClusters[remoteID]
43         if !ok {
44                 httpserver.Error(w, "no proxy available for cluster "+remoteID, http.StatusNotFound)
45                 return
46         }
47         scheme := remote.Scheme
48         if scheme == "" {
49                 scheme = "https"
50         }
51         err := h.saltAuthToken(req, remoteID)
52         if err != nil {
53                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
54                 return
55         }
56         urlOut := &url.URL{
57                 Scheme:   scheme,
58                 Host:     remote.Host,
59                 Path:     req.URL.Path,
60                 RawPath:  req.URL.RawPath,
61                 RawQuery: req.URL.RawQuery,
62         }
63         client := h.secureClient
64         if remote.Insecure {
65                 client = h.insecureClient
66         }
67         h.proxy.Do(w, req, urlOut, client, filter)
68 }
69
70 func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
71         m := wfRe.FindStringSubmatch(req.URL.Path)
72         if len(m) < 2 || m[1] == h.handler.Cluster.ClusterID {
73                 h.next.ServeHTTP(w, req)
74                 return
75         }
76         h.handler.remoteClusterRequest(m[1], w, req, nil)
77 }
78
79 type rewriteSignaturesClusterId string
80
81 func (clusterId rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
82         if requestError != nil {
83                 return resp, requestError
84         }
85
86         if resp.StatusCode != 200 {
87                 return resp, nil
88         }
89
90         originalBody := resp.Body
91         defer originalBody.Close()
92
93         var col arvados.Collection
94         err = json.NewDecoder(resp.Body).Decode(&col)
95         if err != nil {
96                 return nil, err
97         }
98
99         // rewriting signatures will make manifest text 5-10% bigger so calculate
100         // capacity accordingly
101         updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
102
103         scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
104         scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
105         for scanner.Scan() {
106                 line := scanner.Text()
107                 tokens := strings.Split(line, " ")
108                 if len(tokens) < 3 {
109                         return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
110                 }
111
112                 updatedManifest.WriteString(tokens[0])
113                 for _, token := range tokens[1:] {
114                         updatedManifest.WriteString(" ")
115                         m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
116                         if m != nil {
117                                 // Rewrite the block signature to be a remote signature
118                                 fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterId, m[5][2:], m[8])
119                         } else {
120                                 updatedManifest.WriteString(token)
121                         }
122
123                 }
124                 updatedManifest.WriteString("\n")
125         }
126
127         col.ManifestText = updatedManifest.String()
128
129         newbody, err := json.Marshal(col)
130         if err != nil {
131                 return nil, err
132         }
133
134         buf := bytes.NewBuffer(newbody)
135         resp.Body = ioutil.NopCloser(buf)
136         resp.ContentLength = int64(buf.Len())
137         resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
138
139         return resp, nil
140 }
141
142 type searchLocalClusterForPDH struct {
143         sentResponse bool
144 }
145
146 func (s *searchLocalClusterForPDH) filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
147         if requestError != nil {
148                 return resp, requestError
149         }
150
151         if resp.StatusCode == 404 {
152                 // Suppress returning this result, because we want to
153                 // search the federation.
154                 s.sentResponse = false
155                 return nil, nil
156         }
157         s.sentResponse = true
158         return resp, nil
159 }
160
161 type searchRemoteClusterForPDH struct {
162         remoteID      string
163         mtx           *sync.Mutex
164         sentResponse  *bool
165         sharedContext *context.Context
166         cancelFunc    func()
167         errors        *[]string
168         statusCode    *int
169 }
170
171 func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
172         s.mtx.Lock()
173         defer s.mtx.Unlock()
174
175         if *s.sentResponse {
176                 // Another request already returned a response
177                 return nil, nil
178         }
179
180         if requestError != nil {
181                 *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
182                 // Record the error and suppress response
183                 return nil, nil
184         }
185
186         if resp.StatusCode != 200 {
187                 // Suppress returning unsuccessful result.  Maybe
188                 // another request will find it.
189                 // TODO collect and return error responses.
190                 *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
191                 if resp.StatusCode != 404 {
192                         // Got a non-404 error response, convert into BadGateway
193                         *s.statusCode = http.StatusBadGateway
194                 }
195                 return nil, nil
196         }
197
198         s.mtx.Unlock()
199
200         // This reads the response body.  We don't want to hold the
201         // lock while doing this because other remote requests could
202         // also have made it to this point, and we don't want a
203         // slow response holding the lock to block a faster response
204         // that is waiting on the lock.
205         newResponse, err = rewriteSignaturesClusterId(s.remoteID).rewriteSignatures(resp, nil)
206
207         s.mtx.Lock()
208
209         if *s.sentResponse {
210                 // Another request already returned a response
211                 return nil, nil
212         }
213
214         if err != nil {
215                 // Suppress returning unsuccessful result.  Maybe
216                 // another request will be successful.
217                 *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
218                 return nil, nil
219         }
220
221         // We have a successful response.  Suppress/cancel all the
222         // other requests/responses.
223         *s.sentResponse = true
224         s.cancelFunc()
225
226         return newResponse, nil
227 }
228
229 func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
230         m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
231         if len(m) != 2 {
232                 // Not a collection PDH request
233                 m = collectionRe.FindStringSubmatch(req.URL.Path)
234                 if len(m) == 2 && m[1] != h.handler.Cluster.ClusterID {
235                         // request for remote collection by uuid
236                         h.handler.remoteClusterRequest(m[1], w, req,
237                                 rewriteSignaturesClusterId(m[1]).rewriteSignatures)
238                         return
239                 }
240                 // not a collection UUID request, or it is a request
241                 // for a local UUID, either way, continue down the
242                 // handler stack.
243                 h.next.ServeHTTP(w, req)
244                 return
245         }
246
247         // Request for collection by PDH.  Search the federation.
248
249         // First, query the local cluster.
250         urlOut, insecure, err := findRailsAPI(h.handler.Cluster, h.handler.NodeProfile)
251         if err != nil {
252                 httpserver.Error(w, err.Error(), http.StatusInternalServerError)
253                 return
254         }
255
256         urlOut = &url.URL{
257                 Scheme:   urlOut.Scheme,
258                 Host:     urlOut.Host,
259                 Path:     req.URL.Path,
260                 RawPath:  req.URL.RawPath,
261                 RawQuery: req.URL.RawQuery,
262         }
263         client := h.handler.secureClient
264         if insecure {
265                 client = h.handler.insecureClient
266         }
267         sf := &searchLocalClusterForPDH{}
268         h.handler.proxy.Do(w, req, urlOut, client, sf.filterLocalClusterResponse)
269         if sf.sentResponse {
270                 return
271         }
272
273         sharedContext, cancelFunc := context.WithCancel(req.Context())
274         defer cancelFunc()
275         req = req.WithContext(sharedContext)
276
277         // Create a goroutine for each cluster in the
278         // RemoteClusters map.  The first valid result gets
279         // returned to the client.  When that happens, all
280         // other outstanding requests are cancelled or
281         // suppressed.
282         sentResponse := false
283         mtx := sync.Mutex{}
284         wg := sync.WaitGroup{}
285         var errors []string
286         var errorCode int = 404
287
288         // use channel as a semaphore to limit it to 4
289         // parallel requests at a time
290         sem := make(chan bool, 4)
291         defer close(sem)
292         for remoteID := range h.handler.Cluster.RemoteClusters {
293                 // blocks until it can put a value into the
294                 // channel (which has a max queue capacity)
295                 sem <- true
296                 if sentResponse {
297                         break
298                 }
299                 search := &searchRemoteClusterForPDH{remoteID, &mtx, &sentResponse,
300                         &sharedContext, cancelFunc, &errors, &errorCode}
301                 wg.Add(1)
302                 go func() {
303                         h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse)
304                         wg.Done()
305                         <-sem
306                 }()
307         }
308         wg.Wait()
309
310         if sentResponse {
311                 return
312         }
313
314         // No successful responses, so return the error
315         httpserver.Errors(w, errors, errorCode)
316 }
317
318 func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
319         mux := http.NewServeMux()
320         mux.Handle("/arvados/v1/workflows", next)
321         mux.Handle("/arvados/v1/workflows/", &genericFederatedRequestHandler{next, h})
322         mux.Handle("/arvados/v1/collections", next)
323         mux.Handle("/arvados/v1/collections/", &collectionFederatedRequestHandler{next, h})
324         mux.Handle("/", next)
325
326         return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
327                 parts := strings.Split(req.Header.Get("Authorization"), "/")
328                 alreadySalted := (len(parts) == 3 && parts[0] == "Bearer v2" && len(parts[2]) == 40)
329
330                 if alreadySalted ||
331                         strings.Index(req.Header.Get("Via"), "arvados-controller") != -1 {
332                         // The token is already salted, or this is a
333                         // request from another instance of
334                         // arvados-controller.  In either case, we
335                         // don't want to proxy this query, so just
336                         // continue down the instance handler stack.
337                         next.ServeHTTP(w, req)
338                         return
339                 }
340
341                 mux.ServeHTTP(w, req)
342         })
343
344         return mux
345 }
346
347 type CurrentUser struct {
348         Authorization arvados.APIClientAuthorization
349         UUID          string
350 }
351
352 func (h *Handler) validateAPItoken(req *http.Request, user *CurrentUser) error {
353         db, err := h.db(req)
354         if err != nil {
355                 return err
356         }
357         return db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, user.Authorization.APIToken).Scan(&user.Authorization.UUID, &user.UUID)
358 }
359
360 // Extract the auth token supplied in req, and replace it with a
361 // salted token for the remote cluster.
362 func (h *Handler) saltAuthToken(req *http.Request, remote string) error {
363         creds := auth.NewCredentials()
364         creds.LoadTokensFromHTTPRequest(req)
365         if len(creds.Tokens) == 0 && req.Header.Get("Content-Type") == "application/x-www-form-encoded" {
366                 // Override ParseForm's 10MiB limit by ensuring
367                 // req.Body is a *http.maxBytesReader.
368                 req.Body = http.MaxBytesReader(nil, req.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
369                 if err := creds.LoadTokensFromHTTPRequestBody(req); err != nil {
370                         return err
371                 }
372                 // Replace req.Body with a buffer that re-encodes the
373                 // form without api_token, in case we end up
374                 // forwarding the request.
375                 if req.PostForm != nil {
376                         req.PostForm.Del("api_token")
377                 }
378                 req.Body = ioutil.NopCloser(bytes.NewBufferString(req.PostForm.Encode()))
379         }
380         if len(creds.Tokens) == 0 {
381                 return nil
382         }
383         token, err := auth.SaltToken(creds.Tokens[0], remote)
384         if err == auth.ErrObsoleteToken {
385                 // If the token exists in our own database, salt it
386                 // for the remote. Otherwise, assume it was issued by
387                 // the remote, and pass it through unmodified.
388                 currentUser := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: creds.Tokens[0]}}
389                 err = h.validateAPItoken(req, &currentUser)
390                 if err == sql.ErrNoRows {
391                         // Not ours; pass through unmodified.
392                         token = currentUser.Authorization.APIToken
393                 } else if err != nil {
394                         return err
395                 } else {
396                         // Found; make V2 version and salt it.
397                         token, err = auth.SaltToken(currentUser.Authorization.TokenV2(), remote)
398                         if err != nil {
399                                 return err
400                         }
401                 }
402         } else if err != nil {
403                 return err
404         }
405         req.Header.Set("Authorization", "Bearer "+token)
406
407         // Remove api_token=... from the the query string, in case we
408         // end up forwarding the request.
409         if values, err := url.ParseQuery(req.URL.RawQuery); err != nil {
410                 return err
411         } else if _, ok := values["api_token"]; ok {
412                 delete(values, "api_token")
413                 req.URL.RawQuery = values.Encode()
414         }
415         return nil
416 }