14087: Federated fetch by PDH WIP
[arvados.git] / lib / controller / federation.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bufio"
9         "bytes"
10         "database/sql"
11         "encoding/json"
12         "fmt"
13         "io/ioutil"
14         "net/http"
15         "net/url"
16         "regexp"
17         "strings"
18
19         "git.curoverse.com/arvados.git/sdk/go/arvados"
20         "git.curoverse.com/arvados.git/sdk/go/auth"
21         "git.curoverse.com/arvados.git/sdk/go/httpserver"
22         "git.curoverse.com/arvados.git/sdk/go/keepclient"
23 )
24
25 var wfRe = regexp.MustCompile(`^/arvados/v1/workflows/([0-9a-z]{5})-[^/]+$`)
26 var collectionRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-z]{5})-[^/]+$`)
27 var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
28
29 type genericFederatedRequestHandler struct {
30         next    http.Handler
31         handler *Handler
32 }
33
34 type collectionFederatedRequestHandler struct {
35         next    http.Handler
36         handler *Handler
37 }
38
39 func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
40         remote, ok := h.Cluster.RemoteClusters[remoteID]
41         if !ok {
42                 httpserver.Error(w, "no proxy available for cluster "+remoteID, http.StatusNotFound)
43                 return
44         }
45         scheme := remote.Scheme
46         if scheme == "" {
47                 scheme = "https"
48         }
49         err := h.saltAuthToken(req, remoteID)
50         if err != nil {
51                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
52                 return
53         }
54         urlOut := &url.URL{
55                 Scheme:   scheme,
56                 Host:     remote.Host,
57                 Path:     req.URL.Path,
58                 RawPath:  req.URL.RawPath,
59                 RawQuery: req.URL.RawQuery,
60         }
61         client := h.secureClient
62         if remote.Insecure {
63                 client = h.insecureClient
64         }
65         h.proxy.Do(w, req, urlOut, client, filter)
66 }
67
68 func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
69         m := wfRe.FindStringSubmatch(req.URL.Path)
70         if len(m) < 2 || m[1] == h.handler.Cluster.ClusterID {
71                 h.next.ServeHTTP(w, req)
72                 return
73         }
74         h.handler.remoteClusterRequest(m[1], w, req, nil)
75 }
76
77 type rewriteSignaturesClusterId string
78
79 func (clusterId rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response) (newResponse *http.Response, err error) {
80         if resp.StatusCode != 200 {
81                 return resp, nil
82         }
83
84         originalBody := resp.Body
85         defer originalBody.Close()
86
87         var col arvados.Collection
88         err = json.NewDecoder(resp.Body).Decode(&col)
89         if err != nil {
90                 return nil, err
91         }
92
93         // rewriting signatures will make manifest text 5-10% bigger so calculate
94         // capacity accordingly
95         updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
96
97         scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
98         scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
99         for scanner.Scan() {
100                 line := scanner.Text()
101                 tokens := strings.Split(line, " ")
102                 if len(tokens) < 3 {
103                         return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
104                 }
105
106                 updatedManifest.WriteString(tokens[0])
107                 for _, token := range tokens[1:] {
108                         updatedManifest.WriteString(" ")
109                         m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
110                         if m != nil {
111                                 // Rewrite the block signature to be a remote signature
112                                 fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterId, m[5][2:], m[8])
113                         } else {
114                                 updatedManifest.WriteString(token)
115                         }
116
117                 }
118                 updatedManifest.WriteString("\n")
119         }
120
121         col.ManifestText = updatedManifest.String()
122
123         newbody, err := json.Marshal(col)
124         if err != nil {
125                 return nil, err
126         }
127
128         buf := bytes.NewBuffer(newbody)
129         resp.Body = ioutil.NopCloser(buf)
130         resp.ContentLength = int64(buf.Len())
131         resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
132
133         return resp, nil
134 }
135
136 type searchLocalClusterForPDH struct {
137         needSearchFederation bool
138 }
139
140 func (s *searchLocalClusterForPDH) filterLocalClusterResponse(resp *http.Response) (newResponse *http.Response, err error) {
141         if resp.StatusCode == 404 {
142                 // Suppress returning this result, because we want to
143                 // search the federation.
144                 s.needSearchFederation = true
145                 return nil, nil
146         }
147         return resp, nil
148 }
149
150 func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
151         m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
152         if len(m) == 2 {
153                 urlOut, insecure, err := findRailsAPI(h.handler.Cluster, h.handler.NodeProfile)
154                 if err != nil {
155                         httpserver.Error(w, err.Error(), http.StatusInternalServerError)
156                         return
157                 }
158
159                 urlOut = &url.URL{
160                         Scheme:   urlOut.Scheme,
161                         Host:     urlOut.Host,
162                         Path:     req.URL.Path,
163                         RawPath:  req.URL.RawPath,
164                         RawQuery: req.URL.RawQuery,
165                 }
166                 client := h.handler.secureClient
167                 if insecure {
168                         client = h.handler.insecureClient
169                 }
170                 sf := &searchLocalClusterForPDH{false}
171                 h.handler.proxy.Do(w, req, urlOut, client, sf.filterLocalClusterResponse)
172                 if !sf.needSearchFederation {
173                         // a response was sent
174                         return
175                 }
176         }
177
178         m = collectionRe.FindStringSubmatch(req.URL.Path)
179         if len(m) < 2 || m[1] == h.handler.Cluster.ClusterID {
180                 h.next.ServeHTTP(w, req)
181                 return
182         }
183         h.handler.remoteClusterRequest(m[1], w, req,
184                 rewriteSignaturesClusterId(m[1]).rewriteSignatures)
185 }
186
187 func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
188         mux := http.NewServeMux()
189         mux.Handle("/arvados/v1/workflows", next)
190         mux.Handle("/arvados/v1/workflows/", &genericFederatedRequestHandler{next, h})
191         mux.Handle("/arvados/v1/collections", next)
192         mux.Handle("/arvados/v1/collections/", &collectionFederatedRequestHandler{next, h})
193         mux.Handle("/", next)
194
195         return mux
196 }
197
198 type CurrentUser struct {
199         Authorization arvados.APIClientAuthorization
200         UUID          string
201 }
202
203 func (h *Handler) validateAPItoken(req *http.Request, user *CurrentUser) error {
204         db, err := h.db(req)
205         if err != nil {
206                 return err
207         }
208         return db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, user.Authorization.APIToken).Scan(&user.Authorization.UUID, &user.UUID)
209 }
210
211 // Extract the auth token supplied in req, and replace it with a
212 // salted token for the remote cluster.
213 func (h *Handler) saltAuthToken(req *http.Request, remote string) error {
214         creds := auth.NewCredentials()
215         creds.LoadTokensFromHTTPRequest(req)
216         if len(creds.Tokens) == 0 && req.Header.Get("Content-Type") == "application/x-www-form-encoded" {
217                 // Override ParseForm's 10MiB limit by ensuring
218                 // req.Body is a *http.maxBytesReader.
219                 req.Body = http.MaxBytesReader(nil, req.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
220                 if err := creds.LoadTokensFromHTTPRequestBody(req); err != nil {
221                         return err
222                 }
223                 // Replace req.Body with a buffer that re-encodes the
224                 // form without api_token, in case we end up
225                 // forwarding the request.
226                 if req.PostForm != nil {
227                         req.PostForm.Del("api_token")
228                 }
229                 req.Body = ioutil.NopCloser(bytes.NewBufferString(req.PostForm.Encode()))
230         }
231         if len(creds.Tokens) == 0 {
232                 return nil
233         }
234         token, err := auth.SaltToken(creds.Tokens[0], remote)
235         if err == auth.ErrObsoleteToken {
236                 // If the token exists in our own database, salt it
237                 // for the remote. Otherwise, assume it was issued by
238                 // the remote, and pass it through unmodified.
239                 currentUser := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: creds.Tokens[0]}}
240                 err = h.validateAPItoken(req, &currentUser)
241                 if err == sql.ErrNoRows {
242                         // Not ours; pass through unmodified.
243                         token = currentUser.Authorization.APIToken
244                 } else if err != nil {
245                         return err
246                 } else {
247                         // Found; make V2 version and salt it.
248                         token, err = auth.SaltToken(currentUser.Authorization.TokenV2(), remote)
249                         if err != nil {
250                                 return err
251                         }
252                 }
253         } else if err != nil {
254                 return err
255         }
256         req.Header.Set("Authorization", "Bearer "+token)
257
258         // Remove api_token=... from the the query string, in case we
259         // end up forwarding the request.
260         if values, err := url.ParseQuery(req.URL.RawQuery); err != nil {
261                 return err
262         } else if _, ok := values["api_token"]; ok {
263                 delete(values, "api_token")
264                 req.URL.RawQuery = values.Encode()
265         }
266         return nil
267 }