// Copyright (C) The Arvados Authors. All rights reserved. // // SPDX-License-Identifier: AGPL-3.0 package controller import ( "bufio" "bytes" "context" "crypto/md5" "encoding/json" "fmt" "io" "io/ioutil" "net/http" "strings" "sync" "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/httpserver" "git.arvados.org/arvados.git/sdk/go/keepclient" ) func rewriteSignatures(clusterID string, expectHash string, resp *http.Response, requestError error) (newResponse *http.Response, err error) { if requestError != nil { return resp, requestError } if resp.StatusCode != http.StatusOK { return resp, nil } originalBody := resp.Body defer originalBody.Close() var col arvados.Collection err = json.NewDecoder(resp.Body).Decode(&col) if err != nil { return nil, err } // rewriting signatures will make manifest text 5-10% bigger so calculate // capacity accordingly updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1))) hasher := md5.New() mw := io.MultiWriter(hasher, updatedManifest) sz := 0 scanner := bufio.NewScanner(strings.NewReader(col.ManifestText)) scanner.Buffer(make([]byte, 1048576), len(col.ManifestText)) for scanner.Scan() { line := scanner.Text() tokens := strings.Split(line, " ") if len(tokens) < 3 { return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line) } n, err := mw.Write([]byte(tokens[0])) if err != nil { return nil, fmt.Errorf("Error updating manifest: %v", err) } sz += n for _, token := range tokens[1:] { n, err = mw.Write([]byte(" ")) if err != nil { return nil, fmt.Errorf("Error updating manifest: %v", err) } sz += n m := keepclient.SignedLocatorRe.FindStringSubmatch(token) if m != nil { // Rewrite the block signature to be a remote signature _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8]) if err != nil { return nil, fmt.Errorf("Error updating manifest: %v", err) } // for hash checking, ignore signatures n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2]) if err != nil { return nil, fmt.Errorf("Error updating manifest: %v", err) } sz += n } else { n, err = mw.Write([]byte(token)) if err != nil { return nil, fmt.Errorf("Error updating manifest: %v", err) } sz += n } } n, err = mw.Write([]byte("\n")) if err != nil { return nil, fmt.Errorf("Error updating manifest: %v", err) } sz += n } // Check that expected hash is consistent with // portable_data_hash field of the returned record if expectHash == "" { expectHash = col.PortableDataHash } else if expectHash != col.PortableDataHash { return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash) } // Certify that the computed hash of the manifest_text matches our expectation sum := hasher.Sum(nil) computedHash := fmt.Sprintf("%x+%v", sum, sz) if computedHash != expectHash { return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash) } col.ManifestText = updatedManifest.String() newbody, err := json.Marshal(col) if err != nil { return nil, err } buf := bytes.NewBuffer(newbody) resp.Body = ioutil.NopCloser(buf) resp.ContentLength = int64(buf.Len()) resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len())) return resp, nil } func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) { if requestError != nil { return resp, requestError } if resp.StatusCode == http.StatusNotFound { // Suppress returning this result, because we want to // search the federation. return nil, nil } return resp, nil } type searchRemoteClusterForPDH struct { pdh string remoteID string mtx *sync.Mutex sentResponse *bool sharedContext *context.Context cancelFunc func() errors *[]string statusCode *int } func fetchRemoteCollectionByUUID( h *genericFederatedRequestHandler, effectiveMethod string, clusterID *string, uuid string, remainder string, w http.ResponseWriter, req *http.Request) bool { if effectiveMethod != "GET" { // Only handle GET requests right now return false } if uuid != "" { // Collection UUID GET request *clusterID = uuid[0:5] if *clusterID != "" && *clusterID != h.handler.Cluster.ClusterID { // request for remote collection by uuid resp, err := h.handler.remoteClusterRequest(*clusterID, req) newResponse, err := rewriteSignatures(*clusterID, "", resp, err) h.handler.proxy.ForwardResponse(w, newResponse, err) return true } } return false } func fetchRemoteCollectionByPDH( h *genericFederatedRequestHandler, effectiveMethod string, clusterID *string, uuid string, remainder string, w http.ResponseWriter, req *http.Request) bool { if effectiveMethod != "GET" { // Only handle GET requests right now return false } m := collectionsByPDHRe.FindStringSubmatch(req.URL.Path) if len(m) != 2 { return false } // Request for collection by PDH. Search the federation. // First, query the local cluster. resp, err := h.handler.localClusterRequest(req) newResp, err := filterLocalClusterResponse(resp, err) if newResp != nil || err != nil { h.handler.proxy.ForwardResponse(w, newResp, err) return true } // Create a goroutine for each cluster in the // RemoteClusters map. The first valid result gets // returned to the client. When that happens, all // other outstanding requests are cancelled sharedContext, cancelFunc := context.WithCancel(req.Context()) defer cancelFunc() req = req.WithContext(sharedContext) wg := sync.WaitGroup{} pdh := m[1] success := make(chan *http.Response) errorChan := make(chan error, len(h.handler.Cluster.RemoteClusters)) acquire, release := semaphore(h.handler.Cluster.API.MaxRequestAmplification) for remoteID := range h.handler.Cluster.RemoteClusters { if remoteID == h.handler.Cluster.ClusterID { // No need to query local cluster again continue } if remoteID == "*" { // This isn't a real remote cluster: it just sets defaults for unlisted remotes. continue } wg.Add(1) go func(remote string) { defer wg.Done() acquire() defer release() select { case <-sharedContext.Done(): return default: } resp, err := h.handler.remoteClusterRequest(remote, req) wasSuccess := false defer func() { if resp != nil && !wasSuccess { resp.Body.Close() } }() if err != nil { errorChan <- err return } if resp.StatusCode != http.StatusOK { errorChan <- HTTPError{resp.Status, resp.StatusCode} return } select { case <-sharedContext.Done(): return default: } newResponse, err := rewriteSignatures(remote, pdh, resp, nil) if err != nil { errorChan <- err return } select { case <-sharedContext.Done(): case success <- newResponse: wasSuccess = true } }(remoteID) } go func() { wg.Wait() cancelFunc() }() errorCode := http.StatusNotFound for { select { case newResp = <-success: h.handler.proxy.ForwardResponse(w, newResp, nil) return true case <-sharedContext.Done(): var errors []string for len(errorChan) > 0 { err := <-errorChan if httperr, ok := err.(HTTPError); !ok || httperr.Code != http.StatusNotFound { errorCode = http.StatusBadGateway } errors = append(errors, err.Error()) } httpserver.Errors(w, errors, errorCode) return true } } // shouldn't ever get here return true }