14807: Use context to pass a suitable logger to all service commands.
[arvados.git] / lib / controller / fed_collections.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package controller
6
7 import (
8         "bufio"
9         "bytes"
10         "context"
11         "crypto/md5"
12         "encoding/json"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "net/http"
17         "strings"
18         "sync"
19
20         "git.curoverse.com/arvados.git/sdk/go/arvados"
21         "git.curoverse.com/arvados.git/sdk/go/httpserver"
22         "git.curoverse.com/arvados.git/sdk/go/keepclient"
23 )
24
25 func rewriteSignatures(clusterID string, expectHash string,
26         resp *http.Response, requestError error) (newResponse *http.Response, err error) {
27
28         if requestError != nil {
29                 return resp, requestError
30         }
31
32         if resp.StatusCode != http.StatusOK {
33                 return resp, nil
34         }
35
36         originalBody := resp.Body
37         defer originalBody.Close()
38
39         var col arvados.Collection
40         err = json.NewDecoder(resp.Body).Decode(&col)
41         if err != nil {
42                 return nil, err
43         }
44
45         // rewriting signatures will make manifest text 5-10% bigger so calculate
46         // capacity accordingly
47         updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
48
49         hasher := md5.New()
50         mw := io.MultiWriter(hasher, updatedManifest)
51         sz := 0
52
53         scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
54         scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
55         for scanner.Scan() {
56                 line := scanner.Text()
57                 tokens := strings.Split(line, " ")
58                 if len(tokens) < 3 {
59                         return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
60                 }
61
62                 n, err := mw.Write([]byte(tokens[0]))
63                 if err != nil {
64                         return nil, fmt.Errorf("Error updating manifest: %v", err)
65                 }
66                 sz += n
67                 for _, token := range tokens[1:] {
68                         n, err = mw.Write([]byte(" "))
69                         if err != nil {
70                                 return nil, fmt.Errorf("Error updating manifest: %v", err)
71                         }
72                         sz += n
73
74                         m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
75                         if m != nil {
76                                 // Rewrite the block signature to be a remote signature
77                                 _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
78                                 if err != nil {
79                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
80                                 }
81
82                                 // for hash checking, ignore signatures
83                                 n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
84                                 if err != nil {
85                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
86                                 }
87                                 sz += n
88                         } else {
89                                 n, err = mw.Write([]byte(token))
90                                 if err != nil {
91                                         return nil, fmt.Errorf("Error updating manifest: %v", err)
92                                 }
93                                 sz += n
94                         }
95                 }
96                 n, err = mw.Write([]byte("\n"))
97                 if err != nil {
98                         return nil, fmt.Errorf("Error updating manifest: %v", err)
99                 }
100                 sz += n
101         }
102
103         // Check that expected hash is consistent with
104         // portable_data_hash field of the returned record
105         if expectHash == "" {
106                 expectHash = col.PortableDataHash
107         } else if expectHash != col.PortableDataHash {
108                 return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
109         }
110
111         // Certify that the computed hash of the manifest_text matches our expectation
112         sum := hasher.Sum(nil)
113         computedHash := fmt.Sprintf("%x+%v", sum, sz)
114         if computedHash != expectHash {
115                 return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
116         }
117
118         col.ManifestText = updatedManifest.String()
119
120         newbody, err := json.Marshal(col)
121         if err != nil {
122                 return nil, err
123         }
124
125         buf := bytes.NewBuffer(newbody)
126         resp.Body = ioutil.NopCloser(buf)
127         resp.ContentLength = int64(buf.Len())
128         resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
129
130         return resp, nil
131 }
132
133 func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
134         if requestError != nil {
135                 return resp, requestError
136         }
137
138         if resp.StatusCode == http.StatusNotFound {
139                 // Suppress returning this result, because we want to
140                 // search the federation.
141                 return nil, nil
142         }
143         return resp, nil
144 }
145
146 type searchRemoteClusterForPDH struct {
147         pdh           string
148         remoteID      string
149         mtx           *sync.Mutex
150         sentResponse  *bool
151         sharedContext *context.Context
152         cancelFunc    func()
153         errors        *[]string
154         statusCode    *int
155 }
156
157 func fetchRemoteCollectionByUUID(
158         h *genericFederatedRequestHandler,
159         effectiveMethod string,
160         clusterId *string,
161         uuid string,
162         remainder string,
163         w http.ResponseWriter,
164         req *http.Request) bool {
165
166         if effectiveMethod != "GET" {
167                 // Only handle GET requests right now
168                 return false
169         }
170
171         if uuid != "" {
172                 // Collection UUID GET request
173                 *clusterId = uuid[0:5]
174                 if *clusterId != "" && *clusterId != h.handler.Cluster.ClusterID {
175                         // request for remote collection by uuid
176                         resp, err := h.handler.remoteClusterRequest(*clusterId, req)
177                         newResponse, err := rewriteSignatures(*clusterId, "", resp, err)
178                         h.handler.proxy.ForwardResponse(w, newResponse, err)
179                         return true
180                 }
181         }
182
183         return false
184 }
185
186 func fetchRemoteCollectionByPDH(
187         h *genericFederatedRequestHandler,
188         effectiveMethod string,
189         clusterId *string,
190         uuid string,
191         remainder string,
192         w http.ResponseWriter,
193         req *http.Request) bool {
194
195         if effectiveMethod != "GET" {
196                 // Only handle GET requests right now
197                 return false
198         }
199
200         m := collectionsByPDHRe.FindStringSubmatch(req.URL.Path)
201         if len(m) != 2 {
202                 return false
203         }
204
205         // Request for collection by PDH.  Search the federation.
206
207         // First, query the local cluster.
208         resp, err := h.handler.localClusterRequest(req)
209         newResp, err := filterLocalClusterResponse(resp, err)
210         if newResp != nil || err != nil {
211                 h.handler.proxy.ForwardResponse(w, newResp, err)
212                 return true
213         }
214
215         // Create a goroutine for each cluster in the
216         // RemoteClusters map.  The first valid result gets
217         // returned to the client.  When that happens, all
218         // other outstanding requests are cancelled
219         sharedContext, cancelFunc := context.WithCancel(req.Context())
220         req = req.WithContext(sharedContext)
221         wg := sync.WaitGroup{}
222         pdh := m[1]
223         success := make(chan *http.Response)
224         errorChan := make(chan error, len(h.handler.Cluster.RemoteClusters))
225
226         // use channel as a semaphore to limit the number of concurrent
227         // requests at a time
228         sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
229
230         defer cancelFunc()
231
232         for remoteID := range h.handler.Cluster.RemoteClusters {
233                 if remoteID == h.handler.Cluster.ClusterID {
234                         // No need to query local cluster again
235                         continue
236                 }
237
238                 wg.Add(1)
239                 go func(remote string) {
240                         defer wg.Done()
241                         // blocks until it can put a value into the
242                         // channel (which has a max queue capacity)
243                         sem <- true
244                         select {
245                         case <-sharedContext.Done():
246                                 return
247                         default:
248                         }
249
250                         resp, err := h.handler.remoteClusterRequest(remote, req)
251                         wasSuccess := false
252                         defer func() {
253                                 if resp != nil && !wasSuccess {
254                                         resp.Body.Close()
255                                 }
256                         }()
257                         if err != nil {
258                                 errorChan <- err
259                                 return
260                         }
261                         if resp.StatusCode != http.StatusOK {
262                                 errorChan <- HTTPError{resp.Status, resp.StatusCode}
263                                 return
264                         }
265                         select {
266                         case <-sharedContext.Done():
267                                 return
268                         default:
269                         }
270
271                         newResponse, err := rewriteSignatures(remote, pdh, resp, nil)
272                         if err != nil {
273                                 errorChan <- err
274                                 return
275                         }
276                         select {
277                         case <-sharedContext.Done():
278                         case success <- newResponse:
279                                 wasSuccess = true
280                         }
281                         <-sem
282                 }(remoteID)
283         }
284         go func() {
285                 wg.Wait()
286                 cancelFunc()
287         }()
288
289         errorCode := http.StatusNotFound
290
291         for {
292                 select {
293                 case newResp = <-success:
294                         h.handler.proxy.ForwardResponse(w, newResp, nil)
295                         return true
296                 case <-sharedContext.Done():
297                         var errors []string
298                         for len(errorChan) > 0 {
299                                 err := <-errorChan
300                                 if httperr, ok := err.(HTTPError); ok {
301                                         if httperr.Code != http.StatusNotFound {
302                                                 errorCode = http.StatusBadGateway
303                                         }
304                                 }
305                                 errors = append(errors, err.Error())
306                         }
307                         httpserver.Errors(w, errors, errorCode)
308                         return true
309                 }
310         }
311
312         // shouldn't ever get here
313         return true
314 }