14262: Only create runtime_token on home cluster for the authorization
[arvados.git] / lib / controller / fed_generic.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bytes"
9         "context"
10         "encoding/json"
11         "fmt"
12         "io/ioutil"
13         "net/http"
14         "net/url"
15         "regexp"
16         "sync"
17
18         "git.curoverse.com/arvados.git/sdk/go/httpserver"
19 )
20
21 type federatedRequestDelegate func(
22         h *genericFederatedRequestHandler,
23         effectiveMethod string,
24         clusterId *string,
25         uuid string,
26         remainder string,
27         w http.ResponseWriter,
28         req *http.Request) bool
29
30 type genericFederatedRequestHandler struct {
31         next      http.Handler
32         handler   *Handler
33         matcher   *regexp.Regexp
34         delegates []federatedRequestDelegate
35 }
36
37 func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
38         req *http.Request,
39         clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
40
41         found := make(map[string]bool)
42         prev_len_uuids := len(uuids) + 1
43         // Loop while
44         // (1) there are more uuids to query
45         // (2) we're making progress - on each iteration the set of
46         // uuids we are expecting for must shrink.
47         for len(uuids) > 0 && len(uuids) < prev_len_uuids {
48                 var remoteReq http.Request
49                 remoteReq.Header = req.Header
50                 remoteReq.Method = "POST"
51                 remoteReq.URL = &url.URL{Path: req.URL.Path}
52                 remoteParams := make(url.Values)
53                 remoteParams.Set("_method", "GET")
54                 remoteParams.Set("count", "none")
55                 if req.Form.Get("select") != "" {
56                         remoteParams.Set("select", req.Form.Get("select"))
57                 }
58                 content, err := json.Marshal(uuids)
59                 if err != nil {
60                         return nil, "", err
61                 }
62                 remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
63                 enc := remoteParams.Encode()
64                 remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
65
66                 rc := multiClusterQueryResponseCollector{clusterID: clusterID}
67
68                 var resp *http.Response
69                 var cancel context.CancelFunc
70                 if clusterID == h.handler.Cluster.ClusterID {
71                         resp, cancel, err = h.handler.localClusterRequest(&remoteReq)
72                 } else {
73                         resp, cancel, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
74                 }
75                 rc.collectResponse(resp, err)
76                 if cancel != nil {
77                         cancel()
78                 }
79
80                 if rc.error != nil {
81                         return nil, "", rc.error
82                 }
83
84                 kind = rc.kind
85
86                 if len(rc.responses) == 0 {
87                         // We got zero responses, no point in doing
88                         // another query.
89                         return rp, kind, nil
90                 }
91
92                 rp = append(rp, rc.responses...)
93
94                 // Go through the responses and determine what was
95                 // returned.  If there are remaining items, loop
96                 // around and do another request with just the
97                 // stragglers.
98                 for _, i := range rc.responses {
99                         uuid, ok := i["uuid"].(string)
100                         if ok {
101                                 found[uuid] = true
102                         }
103                 }
104
105                 l := []string{}
106                 for _, u := range uuids {
107                         if !found[u] {
108                                 l = append(l, u)
109                         }
110                 }
111                 prev_len_uuids = len(uuids)
112                 uuids = l
113         }
114
115         return rp, kind, nil
116 }
117
118 func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
119         req *http.Request, clusterId *string) bool {
120
121         var filters [][]interface{}
122         err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
123         if err != nil {
124                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
125                 return true
126         }
127
128         // Split the list of uuids by prefix
129         queryClusters := make(map[string][]string)
130         expectCount := 0
131         for _, filter := range filters {
132                 if len(filter) != 3 {
133                         return false
134                 }
135
136                 if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
137                         return false
138                 }
139
140                 op, ok := filter[1].(string)
141                 if !ok {
142                         return false
143                 }
144
145                 if op == "in" {
146                         if rhs, ok := filter[2].([]interface{}); ok {
147                                 for _, i := range rhs {
148                                         if u, ok := i.(string); ok {
149                                                 *clusterId = u[0:5]
150                                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
151                                                 expectCount += 1
152                                         }
153                                 }
154                         }
155                 } else if op == "=" {
156                         if u, ok := filter[2].(string); ok {
157                                 *clusterId = u[0:5]
158                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
159                                 expectCount += 1
160                         }
161                 } else {
162                         return false
163                 }
164
165         }
166
167         if len(queryClusters) <= 1 {
168                 // Query does not search for uuids across multiple
169                 // clusters.
170                 return false
171         }
172
173         // Validations
174         count := req.Form.Get("count")
175         if count != "" && count != `none` && count != `"none"` {
176                 httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
177                 return true
178         }
179         if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
180                 httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
181                 return true
182         }
183         if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
184                 httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
185                         expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
186                 return true
187         }
188         if req.Form.Get("select") != "" {
189                 foundUUID := false
190                 var selects []string
191                 err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
192                 if err != nil {
193                         httpserver.Error(w, err.Error(), http.StatusBadRequest)
194                         return true
195                 }
196
197                 for _, r := range selects {
198                         if r == "uuid" {
199                                 foundUUID = true
200                                 break
201                         }
202                 }
203                 if !foundUUID {
204                         httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
205                         return true
206                 }
207         }
208
209         // Perform concurrent requests to each cluster
210
211         // use channel as a semaphore to limit the number of concurrent
212         // requests at a time
213         sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
214         defer close(sem)
215         wg := sync.WaitGroup{}
216
217         req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
218         mtx := sync.Mutex{}
219         errors := []error{}
220         var completeResponses []map[string]interface{}
221         var kind string
222
223         for k, v := range queryClusters {
224                 if len(v) == 0 {
225                         // Nothing to query
226                         continue
227                 }
228
229                 // blocks until it can put a value into the
230                 // channel (which has a max queue capacity)
231                 sem <- true
232                 wg.Add(1)
233                 go func(k string, v []string) {
234                         rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
235                         mtx.Lock()
236                         if err == nil {
237                                 completeResponses = append(completeResponses, rp...)
238                                 kind = kn
239                         } else {
240                                 errors = append(errors, err)
241                         }
242                         mtx.Unlock()
243                         wg.Done()
244                         <-sem
245                 }(k, v)
246         }
247         wg.Wait()
248
249         if len(errors) > 0 {
250                 var strerr []string
251                 for _, e := range errors {
252                         strerr = append(strerr, e.Error())
253                 }
254                 httpserver.Errors(w, strerr, http.StatusBadGateway)
255                 return true
256         }
257
258         w.Header().Set("Content-Type", "application/json")
259         w.WriteHeader(http.StatusOK)
260         itemList := make(map[string]interface{})
261         itemList["items"] = completeResponses
262         itemList["kind"] = kind
263         json.NewEncoder(w).Encode(itemList)
264
265         return true
266 }
267
268 func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
269         m := h.matcher.FindStringSubmatch(req.URL.Path)
270         clusterId := ""
271
272         if len(m) > 0 && m[2] != "" {
273                 clusterId = m[2]
274         }
275
276         // Get form parameters from URL and form body (if POST).
277         if err := loadParamsFromForm(req); err != nil {
278                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
279                 return
280         }
281
282         // Check if the parameters have an explicit cluster_id
283         if req.Form.Get("cluster_id") != "" {
284                 clusterId = req.Form.Get("cluster_id")
285         }
286
287         // Handle the POST-as-GET special case (workaround for large
288         // GET requests that potentially exceed maximum URL length,
289         // like multi-object queries where the filter has 100s of
290         // items)
291         effectiveMethod := req.Method
292         if req.Method == "POST" && req.Form.Get("_method") != "" {
293                 effectiveMethod = req.Form.Get("_method")
294         }
295
296         if effectiveMethod == "GET" &&
297                 clusterId == "" &&
298                 req.Form.Get("filters") != "" &&
299                 h.handleMultiClusterQuery(w, req, &clusterId) {
300                 return
301         }
302
303         for _, d := range h.delegates {
304                 if d(h, effectiveMethod, &clusterId, m[1], m[3], w, req) {
305                         return
306                 }
307         }
308
309         if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
310                 h.next.ServeHTTP(w, req)
311         } else {
312                 resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
313                 if cancel != nil {
314                         defer cancel()
315                 }
316                 h.handler.proxy.ForwardResponse(w, resp, err)
317         }
318 }
319
320 type multiClusterQueryResponseCollector struct {
321         responses []map[string]interface{}
322         error     error
323         kind      string
324         clusterID string
325 }
326
327 func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
328         requestError error) (newResponse *http.Response, err error) {
329         if requestError != nil {
330                 c.error = requestError
331                 return nil, nil
332         }
333
334         defer resp.Body.Close()
335         var loadInto struct {
336                 Kind   string                   `json:"kind"`
337                 Items  []map[string]interface{} `json:"items"`
338                 Errors []string                 `json:"errors"`
339         }
340         err = json.NewDecoder(resp.Body).Decode(&loadInto)
341
342         if err != nil {
343                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
344                 return nil, nil
345         }
346         if resp.StatusCode != http.StatusOK {
347                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
348                 return nil, nil
349         }
350
351         c.responses = loadInto.Items
352         c.kind = loadInto.Kind
353
354         return nil, nil
355 }