X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/2b0b06579199967eca3d44d955ad64195d2db3c3..720e76bb1d82d5a5448ce395df634310ceee473e:/lib/controller/federation.go diff --git a/lib/controller/federation.go b/lib/controller/federation.go index da4c00da26..3715edae9a 100644 --- a/lib/controller/federation.go +++ b/lib/controller/federation.go @@ -7,14 +7,18 @@ package controller import ( "bufio" "bytes" + "context" + "crypto/md5" "database/sql" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" "net/url" "regexp" "strings" + "sync" "git.curoverse.com/arvados.git/sdk/go/arvados" "git.curoverse.com/arvados.git/sdk/go/auth" @@ -24,6 +28,7 @@ import ( var wfRe = regexp.MustCompile(`^/arvados/v1/workflows/([0-9a-z]{5})-[^/]+$`) var collectionRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-z]{5})-[^/]+$`) +var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`) type genericFederatedRequestHandler struct { next http.Handler @@ -73,9 +78,16 @@ func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *h h.handler.remoteClusterRequest(m[1], w, req, nil) } -type rewriteSignaturesClusterId string +type rewriteSignaturesClusterId struct { + clusterID string + expectHash string +} + +func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) { + if requestError != nil { + return resp, requestError + } -func (clusterId rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response) (newResponse *http.Response, err error) { if resp.StatusCode != 200 { return resp, nil } @@ -93,6 +105,10 @@ func (clusterId rewriteSignaturesClusterId) rewriteSignatures(resp *http.Respons // capacity accordingly updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1))) + hasher := md5.New() + mw := io.MultiWriter(hasher, updatedManifest) + sz := 0 + scanner := bufio.NewScanner(strings.NewReader(col.ManifestText)) scanner.Buffer(make([]byte, 1048576), len(col.ManifestText)) for scanner.Scan() { @@ -102,19 +118,56 @@ func (clusterId rewriteSignaturesClusterId) rewriteSignatures(resp *http.Respons return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line) } - updatedManifest.WriteString(tokens[0]) + n, err := mw.Write([]byte(tokens[0])) + if err != nil { + return nil, fmt.Errorf("Error updating manifest: %v", err) + } + sz += n for _, token := range tokens[1:] { - updatedManifest.WriteString(" ") + n, err = mw.Write([]byte(" ")) + if err != nil { + return nil, fmt.Errorf("Error updating manifest: %v", err) + } + sz += n + m := keepclient.SignedLocatorRe.FindStringSubmatch(token) if m != nil { // Rewrite the block signature to be a remote signature - fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterId, m[5][2:], m[8]) + _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], rw.clusterID, m[5][2:], m[8]) + if err != nil { + return nil, fmt.Errorf("Error updating manifest: %v", err) + } + + // for hash checking, ignore signatures + n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2]) + if err != nil { + return nil, fmt.Errorf("Error updating manifest: %v", err) + } + sz += n } else { - updatedManifest.WriteString(token) + n, err = mw.Write([]byte(token)) + if err != nil { + return nil, fmt.Errorf("Error updating manifest: %v", err) + } + sz += n } - } - updatedManifest.WriteString("\n") + n, err = mw.Write([]byte("\n")) + if err != nil { + return nil, fmt.Errorf("Error updating manifest: %v", err) + } + sz += n + } + + // Certify that the computed hash of the manifest matches our expectation + if rw.expectHash == "" { + rw.expectHash = col.PortableDataHash + } + + sum := hasher.Sum(nil) + computedHash := fmt.Sprintf("%x+%v", sum, sz) + if computedHash != rw.expectHash { + return nil, fmt.Errorf("Computed hash %q did not match expected hash %q", computedHash, rw.expectHash) } col.ManifestText = updatedManifest.String() @@ -132,14 +185,181 @@ func (clusterId rewriteSignaturesClusterId) rewriteSignatures(resp *http.Respons return resp, nil } +type searchLocalClusterForPDH struct { + sentResponse bool +} + +func (s *searchLocalClusterForPDH) filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) { + if requestError != nil { + return resp, requestError + } + + if resp.StatusCode == 404 { + // Suppress returning this result, because we want to + // search the federation. + s.sentResponse = false + return nil, nil + } + s.sentResponse = true + return resp, nil +} + +type searchRemoteClusterForPDH struct { + pdh string + remoteID string + mtx *sync.Mutex + sentResponse *bool + sharedContext *context.Context + cancelFunc func() + errors *[]string + statusCode *int +} + +func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) { + s.mtx.Lock() + defer s.mtx.Unlock() + + if *s.sentResponse { + // Another request already returned a response + return nil, nil + } + + if requestError != nil { + *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError)) + // Record the error and suppress response + return nil, nil + } + + if resp.StatusCode != 200 { + // Suppress returning unsuccessful result. Maybe + // another request will find it. + // TODO collect and return error responses. + *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status)) + if resp.StatusCode != 404 { + // Got a non-404 error response, convert into BadGateway + *s.statusCode = http.StatusBadGateway + } + return nil, nil + } + + s.mtx.Unlock() + + // This reads the response body. We don't want to hold the + // lock while doing this because other remote requests could + // also have made it to this point, and we don't want a + // slow response holding the lock to block a faster response + // that is waiting on the lock. + newResponse, err = rewriteSignaturesClusterId{s.remoteID, s.pdh}.rewriteSignatures(resp, nil) + + s.mtx.Lock() + + if *s.sentResponse { + // Another request already returned a response + return nil, nil + } + + if err != nil { + // Suppress returning unsuccessful result. Maybe + // another request will be successful. + *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err)) + return nil, nil + } + + // We have a successful response. Suppress/cancel all the + // other requests/responses. + *s.sentResponse = true + s.cancelFunc() + + return newResponse, nil +} + func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - m := collectionRe.FindStringSubmatch(req.URL.Path) - if len(m) < 2 || m[1] == h.handler.Cluster.ClusterID { + m := collectionByPDHRe.FindStringSubmatch(req.URL.Path) + if len(m) != 2 { + // Not a collection PDH request + m = collectionRe.FindStringSubmatch(req.URL.Path) + if len(m) == 2 && m[1] != h.handler.Cluster.ClusterID { + // request for remote collection by uuid + h.handler.remoteClusterRequest(m[1], w, req, + rewriteSignaturesClusterId{m[1], ""}.rewriteSignatures) + return + } + // not a collection UUID request, or it is a request + // for a local UUID, either way, continue down the + // handler stack. h.next.ServeHTTP(w, req) return } - h.handler.remoteClusterRequest(m[1], w, req, - rewriteSignaturesClusterId(m[1]).rewriteSignatures) + + // Request for collection by PDH. Search the federation. + + // First, query the local cluster. + urlOut, insecure, err := findRailsAPI(h.handler.Cluster, h.handler.NodeProfile) + if err != nil { + httpserver.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + urlOut = &url.URL{ + Scheme: urlOut.Scheme, + Host: urlOut.Host, + Path: req.URL.Path, + RawPath: req.URL.RawPath, + RawQuery: req.URL.RawQuery, + } + client := h.handler.secureClient + if insecure { + client = h.handler.insecureClient + } + sf := &searchLocalClusterForPDH{} + h.handler.proxy.Do(w, req, urlOut, client, sf.filterLocalClusterResponse) + if sf.sentResponse { + return + } + + sharedContext, cancelFunc := context.WithCancel(req.Context()) + defer cancelFunc() + req = req.WithContext(sharedContext) + + // Create a goroutine for each cluster in the + // RemoteClusters map. The first valid result gets + // returned to the client. When that happens, all + // other outstanding requests are cancelled or + // suppressed. + sentResponse := false + mtx := sync.Mutex{} + wg := sync.WaitGroup{} + var errors []string + var errorCode int = 404 + + // use channel as a semaphore to limit it to 4 + // parallel requests at a time + sem := make(chan bool, 4) + defer close(sem) + for remoteID := range h.handler.Cluster.RemoteClusters { + // blocks until it can put a value into the + // channel (which has a max queue capacity) + sem <- true + if sentResponse { + break + } + search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse, + &sharedContext, cancelFunc, &errors, &errorCode} + wg.Add(1) + go func() { + h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse) + wg.Done() + <-sem + }() + } + wg.Wait() + + if sentResponse { + return + } + + // No successful responses, so return the error + httpserver.Errors(w, errors, errorCode) } func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler { @@ -150,6 +370,24 @@ func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler { mux.Handle("/arvados/v1/collections/", &collectionFederatedRequestHandler{next, h}) mux.Handle("/", next) + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + parts := strings.Split(req.Header.Get("Authorization"), "/") + alreadySalted := (len(parts) == 3 && parts[0] == "Bearer v2" && len(parts[2]) == 40) + + if alreadySalted || + strings.Index(req.Header.Get("Via"), "arvados-controller") != -1 { + // The token is already salted, or this is a + // request from another instance of + // arvados-controller. In either case, we + // don't want to proxy this query, so just + // continue down the instance handler stack. + next.ServeHTTP(w, req) + return + } + + mux.ServeHTTP(w, req) + }) + return mux }