14262: Make sure cancel() from proxy.Do() gets called
[arvados.git] / lib / controller / fed_collections.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bufio"
9         "bytes"
10         "context"
11         "crypto/md5"
12         "encoding/json"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "net/http"
17         "strings"
18         "sync"
19
20         "git.curoverse.com/arvados.git/sdk/go/arvados"
21         "git.curoverse.com/arvados.git/sdk/go/httpserver"
22         "git.curoverse.com/arvados.git/sdk/go/keepclient"
23 )
24
25 type collectionFederatedRequestHandler struct {
26         next    http.Handler
27         handler *Handler
28 }
29
30 func rewriteSignatures(clusterID string, expectHash string,
31         resp *http.Response, requestError error) (newResponse *http.Response, err error) {
32
33         if requestError != nil {
34                 return resp, requestError
35         }
36
37         if resp.StatusCode != http.StatusOK {
38                 return resp, nil
39         }
40
41         originalBody := resp.Body
42         defer originalBody.Close()
43
44         var col arvados.Collection
45         err = json.NewDecoder(resp.Body).Decode(&col)
46         if err != nil {
47                 return nil, err
48         }
49
50         // rewriting signatures will make manifest text 5-10% bigger so calculate
51         // capacity accordingly
52         updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
53
54         hasher := md5.New()
55         mw := io.MultiWriter(hasher, updatedManifest)
56         sz := 0
57
58         scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
59         scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
60         for scanner.Scan() {
61                 line := scanner.Text()
62                 tokens := strings.Split(line, " ")
63                 if len(tokens) < 3 {
64                         return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
65                 }
66
67                 n, err := mw.Write([]byte(tokens[0]))
68                 if err != nil {
69                         return nil, fmt.Errorf("Error updating manifest: %v", err)
70                 }
71                 sz += n
72                 for _, token := range tokens[1:] {
73                         n, err = mw.Write([]byte(" "))
74                         if err != nil {
75                                 return nil, fmt.Errorf("Error updating manifest: %v", err)
76                         }
77                         sz += n
78
79                         m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
80                         if m != nil {
81                                 // Rewrite the block signature to be a remote signature
82                                 _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
83                                 if err != nil {
84                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
85                                 }
86
87                                 // for hash checking, ignore signatures
88                                 n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
89                                 if err != nil {
90                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
91                                 }
92                                 sz += n
93                         } else {
94                                 n, err = mw.Write([]byte(token))
95                                 if err != nil {
96                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
97                                 }
98                                 sz += n
99                         }
100                 }
101                 n, err = mw.Write([]byte("\n"))
102                 if err != nil {
103                         return nil, fmt.Errorf("Error updating manifest: %v", err)
104                 }
105                 sz += n
106         }
107
108         // Check that expected hash is consistent with
109         // portable_data_hash field of the returned record
110         if expectHash == "" {
111                 expectHash = col.PortableDataHash
112         } else if expectHash != col.PortableDataHash {
113                 return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
114         }
115
116         // Certify that the computed hash of the manifest_text matches our expectation
117         sum := hasher.Sum(nil)
118         computedHash := fmt.Sprintf("%x+%v", sum, sz)
119         if computedHash != expectHash {
120                 return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
121         }
122
123         col.ManifestText = updatedManifest.String()
124
125         newbody, err := json.Marshal(col)
126         if err != nil {
127                 return nil, err
128         }
129
130         buf := bytes.NewBuffer(newbody)
131         resp.Body = ioutil.NopCloser(buf)
132         resp.ContentLength = int64(buf.Len())
133         resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
134
135         return resp, nil
136 }
137
138 func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
139         if requestError != nil {
140                 return resp, requestError
141         }
142
143         if resp.StatusCode == http.StatusNotFound {
144                 // Suppress returning this result, because we want to
145                 // search the federation.
146                 return nil, nil
147         }
148         return resp, nil
149 }
150
151 type searchRemoteClusterForPDH struct {
152         pdh           string
153         remoteID      string
154         mtx           *sync.Mutex
155         sentResponse  *bool
156         sharedContext *context.Context
157         cancelFunc    func()
158         errors        *[]string
159         statusCode    *int
160 }
161
162 func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
163         s.mtx.Lock()
164         defer s.mtx.Unlock()
165
166         if *s.sentResponse {
167                 // Another request already returned a response
168                 return nil, nil
169         }
170
171         if requestError != nil {
172                 *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
173                 // Record the error and suppress response
174                 return nil, nil
175         }
176
177         if resp.StatusCode != http.StatusOK {
178                 // Suppress returning unsuccessful result.  Maybe
179                 // another request will find it.
180                 *s.errors = append(*s.errors, fmt.Sprintf("Response to %q from %q: %v", resp.Header.Get(httpserver.HeaderRequestID), s.remoteID, resp.Status))
181                 if resp.StatusCode != http.StatusNotFound {
182                         // Got a non-404 error response, convert into BadGateway
183                         *s.statusCode = http.StatusBadGateway
184                 }
185                 return nil, nil
186         }
187
188         s.mtx.Unlock()
189
190         // This reads the response body.  We don't want to hold the
191         // lock while doing this because other remote requests could
192         // also have made it to this point, and we don't want a
193         // slow response holding the lock to block a faster response
194         // that is waiting on the lock.
195         newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
196
197         s.mtx.Lock()
198
199         if *s.sentResponse {
200                 // Another request already returned a response
201                 return nil, nil
202         }
203
204         if err != nil {
205                 // Suppress returning unsuccessful result.  Maybe
206                 // another request will be successful.
207                 *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
208                 return nil, nil
209         }
210
211         // We have a successful response.  Suppress/cancel all the
212         // other requests/responses.
213         *s.sentResponse = true
214         s.cancelFunc()
215
216         return newResponse, nil
217 }
218
219 func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
220         if req.Method != "GET" {
221                 // Only handle GET requests right now
222                 h.next.ServeHTTP(w, req)
223                 return
224         }
225
226         m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
227         if len(m) != 2 {
228                 // Not a collection PDH GET request
229                 m = collectionRe.FindStringSubmatch(req.URL.Path)
230                 clusterId := ""
231
232                 if len(m) > 0 {
233                         clusterId = m[2]
234                 }
235
236                 if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
237                         // request for remote collection by uuid
238                         resp, cancel, err := h.handler.remoteClusterRequest(clusterId, req)
239                         if cancel != nil {
240                                 defer cancel()
241                         }
242                         newResponse, err := rewriteSignatures(clusterId, "", resp, err)
243                         h.handler.proxy.ForwardResponse(w, newResponse, err)
244                         return
245                 }
246                 // not a collection UUID request, or it is a request
247                 // for a local UUID, either way, continue down the
248                 // handler stack.
249                 h.next.ServeHTTP(w, req)
250                 return
251         }
252
253         // Request for collection by PDH.  Search the federation.
254
255         // First, query the local cluster.
256         resp, localClusterRequestCancel, err := h.handler.localClusterRequest(req)
257         if localClusterRequestCancel != nil {
258                 defer localClusterRequestCancel()
259         }
260         newResp, err := filterLocalClusterResponse(resp, err)
261         if newResp != nil || err != nil {
262                 h.handler.proxy.ForwardResponse(w, newResp, err)
263                 return
264         }
265
266         sharedContext, cancelFunc := context.WithCancel(req.Context())
267         defer cancelFunc()
268         req = req.WithContext(sharedContext)
269
270         // Create a goroutine for each cluster in the
271         // RemoteClusters map.  The first valid result gets
272         // returned to the client.  When that happens, all
273         // other outstanding requests are cancelled or
274         // suppressed.
275         sentResponse := false
276         mtx := sync.Mutex{}
277         wg := sync.WaitGroup{}
278         var errors []string
279         var errorCode int = http.StatusNotFound
280
281         // use channel as a semaphore to limit the number of concurrent
282         // requests at a time
283         sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
284         defer close(sem)
285         for remoteID := range h.handler.Cluster.RemoteClusters {
286                 if remoteID == h.handler.Cluster.ClusterID {
287                         // No need to query local cluster again
288                         continue
289                 }
290                 // blocks until it can put a value into the
291                 // channel (which has a max queue capacity)
292                 sem <- true
293                 if sentResponse {
294                         break
295                 }
296                 search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
297                         &sharedContext, cancelFunc, &errors, &errorCode}
298                 wg.Add(1)
299                 go func() {
300                         resp, cancel, err := h.handler.remoteClusterRequest(search.remoteID, req)
301                         if cancel != nil {
302                                 defer cancel()
303                         }
304                         newResp, err := search.filterRemoteClusterResponse(resp, err)
305                         if newResp != nil || err != nil {
306                                 h.handler.proxy.ForwardResponse(w, newResp, err)
307                         }
308                         wg.Done()
309                         <-sem
310                 }()
311         }
312         wg.Wait()
313
314         if sentResponse {
315                 return
316         }
317
318         // No successful responses, so return the error
319         httpserver.Errors(w, errors, errorCode)
320 }