14325: Note assumption concurrent dispatchers share a VM size menu.
[arvados.git] / lib / controller / fed_generic.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bytes"
9         "encoding/json"
10         "fmt"
11         "io/ioutil"
12         "net/http"
13         "net/url"
14         "regexp"
15         "sync"
16
17         "git.curoverse.com/arvados.git/sdk/go/httpserver"
18 )
19
20 type federatedRequestDelegate func(
21         h *genericFederatedRequestHandler,
22         effectiveMethod string,
23         clusterId *string,
24         uuid string,
25         remainder string,
26         w http.ResponseWriter,
27         req *http.Request) bool
28
29 type genericFederatedRequestHandler struct {
30         next      http.Handler
31         handler   *Handler
32         matcher   *regexp.Regexp
33         delegates []federatedRequestDelegate
34 }
35
36 func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
37         req *http.Request,
38         clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
39
40         found := make(map[string]bool)
41         prev_len_uuids := len(uuids) + 1
42         // Loop while
43         // (1) there are more uuids to query
44         // (2) we're making progress - on each iteration the set of
45         // uuids we are expecting for must shrink.
46         for len(uuids) > 0 && len(uuids) < prev_len_uuids {
47                 var remoteReq http.Request
48                 remoteReq.Header = req.Header
49                 remoteReq.Method = "POST"
50                 remoteReq.URL = &url.URL{Path: req.URL.Path}
51                 remoteParams := make(url.Values)
52                 remoteParams.Set("_method", "GET")
53                 remoteParams.Set("count", "none")
54                 if req.Form.Get("select") != "" {
55                         remoteParams.Set("select", req.Form.Get("select"))
56                 }
57                 content, err := json.Marshal(uuids)
58                 if err != nil {
59                         return nil, "", err
60                 }
61                 remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
62                 enc := remoteParams.Encode()
63                 remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
64
65                 rc := multiClusterQueryResponseCollector{clusterID: clusterID}
66
67                 var resp *http.Response
68                 if clusterID == h.handler.Cluster.ClusterID {
69                         resp, err = h.handler.localClusterRequest(&remoteReq)
70                 } else {
71                         resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
72                 }
73                 rc.collectResponse(resp, err)
74
75                 if rc.error != nil {
76                         return nil, "", rc.error
77                 }
78
79                 kind = rc.kind
80
81                 if len(rc.responses) == 0 {
82                         // We got zero responses, no point in doing
83                         // another query.
84                         return rp, kind, nil
85                 }
86
87                 rp = append(rp, rc.responses...)
88
89                 // Go through the responses and determine what was
90                 // returned.  If there are remaining items, loop
91                 // around and do another request with just the
92                 // stragglers.
93                 for _, i := range rc.responses {
94                         uuid, ok := i["uuid"].(string)
95                         if ok {
96                                 found[uuid] = true
97                         }
98                 }
99
100                 l := []string{}
101                 for _, u := range uuids {
102                         if !found[u] {
103                                 l = append(l, u)
104                         }
105                 }
106                 prev_len_uuids = len(uuids)
107                 uuids = l
108         }
109
110         return rp, kind, nil
111 }
112
113 func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
114         req *http.Request, clusterId *string) bool {
115
116         var filters [][]interface{}
117         err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
118         if err != nil {
119                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
120                 return true
121         }
122
123         // Split the list of uuids by prefix
124         queryClusters := make(map[string][]string)
125         expectCount := 0
126         for _, filter := range filters {
127                 if len(filter) != 3 {
128                         return false
129                 }
130
131                 if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
132                         return false
133                 }
134
135                 op, ok := filter[1].(string)
136                 if !ok {
137                         return false
138                 }
139
140                 if op == "in" {
141                         if rhs, ok := filter[2].([]interface{}); ok {
142                                 for _, i := range rhs {
143                                         if u, ok := i.(string); ok && len(u) == 27 {
144                                                 *clusterId = u[0:5]
145                                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
146                                                 expectCount += 1
147                                         }
148                                 }
149                         }
150                 } else if op == "=" {
151                         if u, ok := filter[2].(string); ok && len(u) == 27 {
152                                 *clusterId = u[0:5]
153                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
154                                 expectCount += 1
155                         }
156                 } else {
157                         return false
158                 }
159
160         }
161
162         if len(queryClusters) <= 1 {
163                 // Query does not search for uuids across multiple
164                 // clusters.
165                 return false
166         }
167
168         // Validations
169         count := req.Form.Get("count")
170         if count != "" && count != `none` && count != `"none"` {
171                 httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
172                 return true
173         }
174         if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
175                 httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
176                 return true
177         }
178         if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
179                 httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
180                         expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
181                 return true
182         }
183         if req.Form.Get("select") != "" {
184                 foundUUID := false
185                 var selects []string
186                 err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
187                 if err != nil {
188                         httpserver.Error(w, err.Error(), http.StatusBadRequest)
189                         return true
190                 }
191
192                 for _, r := range selects {
193                         if r == "uuid" {
194                                 foundUUID = true
195                                 break
196                         }
197                 }
198                 if !foundUUID {
199                         httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
200                         return true
201                 }
202         }
203
204         // Perform concurrent requests to each cluster
205
206         // use channel as a semaphore to limit the number of concurrent
207         // requests at a time
208         sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
209         defer close(sem)
210         wg := sync.WaitGroup{}
211
212         req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
213         mtx := sync.Mutex{}
214         errors := []error{}
215         var completeResponses []map[string]interface{}
216         var kind string
217
218         for k, v := range queryClusters {
219                 if len(v) == 0 {
220                         // Nothing to query
221                         continue
222                 }
223
224                 // blocks until it can put a value into the
225                 // channel (which has a max queue capacity)
226                 sem <- true
227                 wg.Add(1)
228                 go func(k string, v []string) {
229                         rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
230                         mtx.Lock()
231                         if err == nil {
232                                 completeResponses = append(completeResponses, rp...)
233                                 kind = kn
234                         } else {
235                                 errors = append(errors, err)
236                         }
237                         mtx.Unlock()
238                         wg.Done()
239                         <-sem
240                 }(k, v)
241         }
242         wg.Wait()
243
244         if len(errors) > 0 {
245                 var strerr []string
246                 for _, e := range errors {
247                         strerr = append(strerr, e.Error())
248                 }
249                 httpserver.Errors(w, strerr, http.StatusBadGateway)
250                 return true
251         }
252
253         w.Header().Set("Content-Type", "application/json")
254         w.WriteHeader(http.StatusOK)
255         itemList := make(map[string]interface{})
256         itemList["items"] = completeResponses
257         itemList["kind"] = kind
258         json.NewEncoder(w).Encode(itemList)
259
260         return true
261 }
262
263 func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
264         m := h.matcher.FindStringSubmatch(req.URL.Path)
265         clusterId := ""
266
267         if len(m) > 0 && m[2] != "" {
268                 clusterId = m[2]
269         }
270
271         // Get form parameters from URL and form body (if POST).
272         if err := loadParamsFromForm(req); err != nil {
273                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
274                 return
275         }
276
277         // Check if the parameters have an explicit cluster_id
278         if req.Form.Get("cluster_id") != "" {
279                 clusterId = req.Form.Get("cluster_id")
280         }
281
282         // Handle the POST-as-GET special case (workaround for large
283         // GET requests that potentially exceed maximum URL length,
284         // like multi-object queries where the filter has 100s of
285         // items)
286         effectiveMethod := req.Method
287         if req.Method == "POST" && req.Form.Get("_method") != "" {
288                 effectiveMethod = req.Form.Get("_method")
289         }
290
291         if effectiveMethod == "GET" &&
292                 clusterId == "" &&
293                 req.Form.Get("filters") != "" &&
294                 h.handleMultiClusterQuery(w, req, &clusterId) {
295                 return
296         }
297
298         var uuid string
299         if len(m[1]) > 0 {
300                 // trim leading slash
301                 uuid = m[1][1:]
302         }
303         for _, d := range h.delegates {
304                 if d(h, effectiveMethod, &clusterId, uuid, m[3], w, req) {
305                         return
306                 }
307         }
308
309         if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
310                 h.next.ServeHTTP(w, req)
311         } else {
312                 resp, err := h.handler.remoteClusterRequest(clusterId, req)
313                 h.handler.proxy.ForwardResponse(w, resp, err)
314         }
315 }
316
317 type multiClusterQueryResponseCollector struct {
318         responses []map[string]interface{}
319         error     error
320         kind      string
321         clusterID string
322 }
323
324 func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
325         requestError error) (newResponse *http.Response, err error) {
326         if requestError != nil {
327                 c.error = requestError
328                 return nil, nil
329         }
330
331         defer resp.Body.Close()
332         var loadInto struct {
333                 Kind   string                   `json:"kind"`
334                 Items  []map[string]interface{} `json:"items"`
335                 Errors []string                 `json:"errors"`
336         }
337         err = json.NewDecoder(resp.Body).Decode(&loadInto)
338
339         if err != nil {
340                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
341                 return nil, nil
342         }
343         if resp.StatusCode != http.StatusOK {
344                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
345                 return nil, nil
346         }
347
348         c.responses = loadInto.Items
349         c.kind = loadInto.Kind
350
351         return nil, nil
352 }