19146: Remove unneeded special case checks, explain the needed one.
[arvados.git] / lib / controller / fed_generic.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bytes"
9         "encoding/json"
10         "fmt"
11         "io/ioutil"
12         "net/http"
13         "net/url"
14         "regexp"
15         "sync"
16
17         "git.arvados.org/arvados.git/sdk/go/httpserver"
18 )
19
20 type federatedRequestDelegate func(
21         h *genericFederatedRequestHandler,
22         effectiveMethod string,
23         clusterID *string,
24         uuid string,
25         remainder string,
26         w http.ResponseWriter,
27         req *http.Request) bool
28
29 type genericFederatedRequestHandler struct {
30         next      http.Handler
31         handler   *Handler
32         matcher   *regexp.Regexp
33         delegates []federatedRequestDelegate
34 }
35
36 func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
37         req *http.Request,
38         clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
39
40         found := make(map[string]bool)
41         prevLenUuids := len(uuids) + 1
42         // Loop while
43         // (1) there are more uuids to query
44         // (2) we're making progress - on each iteration the set of
45         // uuids we are expecting for must shrink.
46         for len(uuids) > 0 && len(uuids) < prevLenUuids {
47                 var remoteReq http.Request
48                 remoteReq.Header = req.Header
49                 remoteReq.Method = "POST"
50                 remoteReq.URL = &url.URL{Path: req.URL.Path}
51                 remoteParams := make(url.Values)
52                 remoteParams.Set("_method", "GET")
53                 remoteParams.Set("count", "none")
54                 if req.Form.Get("select") != "" {
55                         remoteParams.Set("select", req.Form.Get("select"))
56                 }
57                 content, err := json.Marshal(uuids)
58                 if err != nil {
59                         return nil, "", err
60                 }
61                 remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
62                 enc := remoteParams.Encode()
63                 remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
64
65                 rc := multiClusterQueryResponseCollector{clusterID: clusterID}
66
67                 var resp *http.Response
68                 if clusterID == h.handler.Cluster.ClusterID {
69                         resp, err = h.handler.localClusterRequest(&remoteReq)
70                 } else {
71                         resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
72                 }
73                 rc.collectResponse(resp, err)
74
75                 if rc.error != nil {
76                         return nil, "", rc.error
77                 }
78
79                 kind = rc.kind
80
81                 if len(rc.responses) == 0 {
82                         // We got zero responses, no point in doing
83                         // another query.
84                         return rp, kind, nil
85                 }
86
87                 rp = append(rp, rc.responses...)
88
89                 // Go through the responses and determine what was
90                 // returned.  If there are remaining items, loop
91                 // around and do another request with just the
92                 // stragglers.
93                 for _, i := range rc.responses {
94                         uuid, ok := i["uuid"].(string)
95                         if ok {
96                                 found[uuid] = true
97                         }
98                 }
99
100                 l := []string{}
101                 for _, u := range uuids {
102                         if !found[u] {
103                                 l = append(l, u)
104                         }
105                 }
106                 prevLenUuids = len(uuids)
107                 uuids = l
108         }
109
110         return rp, kind, nil
111 }
112
113 func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
114         req *http.Request, clusterID *string) bool {
115
116         var filters [][]interface{}
117         err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
118         if err != nil {
119                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
120                 return true
121         }
122
123         // Split the list of uuids by prefix
124         queryClusters := make(map[string][]string)
125         expectCount := 0
126         for _, filter := range filters {
127                 if len(filter) != 3 {
128                         return false
129                 }
130
131                 if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
132                         return false
133                 }
134
135                 op, ok := filter[1].(string)
136                 if !ok {
137                         return false
138                 }
139
140                 if op == "in" {
141                         if rhs, ok := filter[2].([]interface{}); ok {
142                                 for _, i := range rhs {
143                                         if u, ok := i.(string); ok && len(u) == 27 {
144                                                 *clusterID = u[0:5]
145                                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
146                                                 expectCount++
147                                         }
148                                 }
149                         }
150                 } else if op == "=" {
151                         if u, ok := filter[2].(string); ok && len(u) == 27 {
152                                 *clusterID = u[0:5]
153                                 queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
154                                 expectCount++
155                         }
156                 } else {
157                         return false
158                 }
159
160         }
161
162         if len(queryClusters) <= 1 {
163                 // Query does not search for uuids across multiple
164                 // clusters.
165                 return false
166         }
167
168         // Validations
169         count := req.Form.Get("count")
170         if count != "" && count != `none` && count != `"none"` {
171                 httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
172                 return true
173         }
174         if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
175                 httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
176                 return true
177         }
178         if max := h.handler.Cluster.API.MaxItemsPerResponse; expectCount > max {
179                 httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
180                         expectCount, max), http.StatusBadRequest)
181                 return true
182         }
183         if req.Form.Get("select") != "" {
184                 foundUUID := false
185                 var selects []string
186                 err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
187                 if err != nil {
188                         httpserver.Error(w, err.Error(), http.StatusBadRequest)
189                         return true
190                 }
191
192                 for _, r := range selects {
193                         if r == "uuid" {
194                                 foundUUID = true
195                                 break
196                         }
197                 }
198                 if !foundUUID {
199                         httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
200                         return true
201                 }
202         }
203
204         // Perform concurrent requests to each cluster
205
206         acquire, release := semaphore(h.handler.Cluster.API.MaxRequestAmplification)
207         wg := sync.WaitGroup{}
208
209         req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
210         mtx := sync.Mutex{}
211         errors := []error{}
212         var completeResponses []map[string]interface{}
213         var kind string
214
215         for k, v := range queryClusters {
216                 if len(v) == 0 {
217                         // Nothing to query
218                         continue
219                 }
220                 acquire()
221                 wg.Add(1)
222                 go func(k string, v []string) {
223                         defer release()
224                         defer wg.Done()
225                         rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
226                         mtx.Lock()
227                         defer mtx.Unlock()
228                         if err == nil {
229                                 completeResponses = append(completeResponses, rp...)
230                                 kind = kn
231                         } else {
232                                 errors = append(errors, err)
233                         }
234                 }(k, v)
235         }
236         wg.Wait()
237
238         if len(errors) > 0 {
239                 var strerr []string
240                 for _, e := range errors {
241                         strerr = append(strerr, e.Error())
242                 }
243                 httpserver.Errors(w, strerr, http.StatusBadGateway)
244                 return true
245         }
246
247         w.Header().Set("Content-Type", "application/json")
248         w.WriteHeader(http.StatusOK)
249         itemList := make(map[string]interface{})
250         itemList["items"] = completeResponses
251         itemList["kind"] = kind
252         json.NewEncoder(w).Encode(itemList)
253
254         return true
255 }
256
257 func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
258         m := h.matcher.FindStringSubmatch(req.URL.Path)
259         clusterID := ""
260
261         if len(m) > 0 && m[2] != "" {
262                 clusterID = m[2]
263         }
264
265         // Get form parameters from URL and form body (if POST).
266         if err := loadParamsFromForm(req); err != nil {
267                 httpserver.Error(w, err.Error(), http.StatusBadRequest)
268                 return
269         }
270
271         // Check if the parameters have an explicit cluster_id
272         if req.Form.Get("cluster_id") != "" {
273                 clusterID = req.Form.Get("cluster_id")
274         }
275
276         // Handle the POST-as-GET special case (workaround for large
277         // GET requests that potentially exceed maximum URL length,
278         // like multi-object queries where the filter has 100s of
279         // items)
280         effectiveMethod := req.Method
281         if req.Method == "POST" && req.Form.Get("_method") != "" {
282                 effectiveMethod = req.Form.Get("_method")
283         }
284
285         if effectiveMethod == "GET" &&
286                 clusterID == "" &&
287                 req.Form.Get("filters") != "" &&
288                 h.handleMultiClusterQuery(w, req, &clusterID) {
289                 return
290         }
291
292         var uuid string
293         if len(m[1]) > 0 {
294                 // trim leading slash
295                 uuid = m[1][1:]
296         }
297         for _, d := range h.delegates {
298                 if d(h, effectiveMethod, &clusterID, uuid, m[3], w, req) {
299                         return
300                 }
301         }
302
303         if clusterID == "" || clusterID == h.handler.Cluster.ClusterID {
304                 h.next.ServeHTTP(w, req)
305         } else {
306                 resp, err := h.handler.remoteClusterRequest(clusterID, req)
307                 h.handler.proxy.ForwardResponse(w, resp, err)
308         }
309 }
310
311 type multiClusterQueryResponseCollector struct {
312         responses []map[string]interface{}
313         error     error
314         kind      string
315         clusterID string
316 }
317
318 func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
319         requestError error) (newResponse *http.Response, err error) {
320         if requestError != nil {
321                 c.error = requestError
322                 return nil, nil
323         }
324
325         defer resp.Body.Close()
326         var loadInto struct {
327                 Kind   string                   `json:"kind"`
328                 Items  []map[string]interface{} `json:"items"`
329                 Errors []string                 `json:"errors"`
330         }
331         err = json.NewDecoder(resp.Body).Decode(&loadInto)
332
333         if err != nil {
334                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
335                 return nil, nil
336         }
337         if resp.StatusCode != http.StatusOK {
338                 c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
339                 return nil, nil
340         }
341
342         c.responses = loadInto.Items
343         c.kind = loadInto.Kind
344
345         return nil, nil
346 }