--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/md5"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "sync"
+
+ "git.curoverse.com/arvados.git/sdk/go/arvados"
+ "git.curoverse.com/arvados.git/sdk/go/httpserver"
+ "git.curoverse.com/arvados.git/sdk/go/keepclient"
+)
+
+type collectionFederatedRequestHandler struct {
+ next http.Handler
+ handler *Handler
+}
+
+func rewriteSignatures(clusterID string, expectHash string,
+ resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+
+ if requestError != nil {
+ return resp, requestError
+ }
+
+ if resp.StatusCode != 200 {
+ return resp, nil
+ }
+
+ originalBody := resp.Body
+ defer originalBody.Close()
+
+ var col arvados.Collection
+ err = json.NewDecoder(resp.Body).Decode(&col)
+ if err != nil {
+ return nil, err
+ }
+
+ // rewriting signatures will make manifest text 5-10% bigger so calculate
+ // capacity accordingly
+ updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
+
+ hasher := md5.New()
+ mw := io.MultiWriter(hasher, updatedManifest)
+ sz := 0
+
+ scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
+ scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
+ for scanner.Scan() {
+ line := scanner.Text()
+ tokens := strings.Split(line, " ")
+ if len(tokens) < 3 {
+ return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
+ }
+
+ n, err := mw.Write([]byte(tokens[0]))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ for _, token := range tokens[1:] {
+ n, err = mw.Write([]byte(" "))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+
+ m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
+ if m != nil {
+ // Rewrite the block signature to be a remote signature
+ _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+
+ // for hash checking, ignore signatures
+ n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ } else {
+ n, err = mw.Write([]byte(token))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ }
+ }
+ n, err = mw.Write([]byte("\n"))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ }
+
+ // Check that expected hash is consistent with
+ // portable_data_hash field of the returned record
+ if expectHash == "" {
+ expectHash = col.PortableDataHash
+ } else if expectHash != col.PortableDataHash {
+ return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
+ }
+
+ // Certify that the computed hash of the manifest_text matches our expectation
+ sum := hasher.Sum(nil)
+ computedHash := fmt.Sprintf("%x+%v", sum, sz)
+ if computedHash != expectHash {
+ return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
+ }
+
+ col.ManifestText = updatedManifest.String()
+
+ newbody, err := json.Marshal(col)
+ if err != nil {
+ return nil, err
+ }
+
+ buf := bytes.NewBuffer(newbody)
+ resp.Body = ioutil.NopCloser(buf)
+ resp.ContentLength = int64(buf.Len())
+ resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
+
+ return resp, nil
+}
+
+func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+ if requestError != nil {
+ return resp, requestError
+ }
+
+ if resp.StatusCode == 404 {
+ // Suppress returning this result, because we want to
+ // search the federation.
+ return nil, nil
+ }
+ return resp, nil
+}
+
+type searchRemoteClusterForPDH struct {
+ pdh string
+ remoteID string
+ mtx *sync.Mutex
+ sentResponse *bool
+ sharedContext *context.Context
+ cancelFunc func()
+ errors *[]string
+ statusCode *int
+}
+
+func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+ s.mtx.Lock()
+ defer s.mtx.Unlock()
+
+ if *s.sentResponse {
+ // Another request already returned a response
+ return nil, nil
+ }
+
+ if requestError != nil {
+ *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
+ // Record the error and suppress response
+ return nil, nil
+ }
+
+ if resp.StatusCode != 200 {
+ // Suppress returning unsuccessful result. Maybe
+ // another request will find it.
+ // TODO collect and return error responses.
+ *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
+ if resp.StatusCode != 404 {
+ // Got a non-404 error response, convert into BadGateway
+ *s.statusCode = http.StatusBadGateway
+ }
+ return nil, nil
+ }
+
+ s.mtx.Unlock()
+
+ // This reads the response body. We don't want to hold the
+ // lock while doing this because other remote requests could
+ // also have made it to this point, and we don't want a
+ // slow response holding the lock to block a faster response
+ // that is waiting on the lock.
+ newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
+
+ s.mtx.Lock()
+
+ if *s.sentResponse {
+ // Another request already returned a response
+ return nil, nil
+ }
+
+ if err != nil {
+ // Suppress returning unsuccessful result. Maybe
+ // another request will be successful.
+ *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
+ return nil, nil
+ }
+
+ // We have a successful response. Suppress/cancel all the
+ // other requests/responses.
+ *s.sentResponse = true
+ s.cancelFunc()
+
+ return newResponse, nil
+}
+
+func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if req.Method != "GET" {
+ // Only handle GET requests right now
+ h.next.ServeHTTP(w, req)
+ return
+ }
+
+ m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
+ if len(m) != 2 {
+ // Not a collection PDH GET request
+ m = collectionRe.FindStringSubmatch(req.URL.Path)
+ clusterId := ""
+
+ if len(m) > 0 {
+ clusterId = m[2]
+ }
+
+ if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
+ // request for remote collection by uuid
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ newResponse, err := rewriteSignatures(clusterId, "", resp, err)
+ h.handler.proxy.ForwardResponse(w, newResponse, err)
+ return
+ }
+ // not a collection UUID request, or it is a request
+ // for a local UUID, either way, continue down the
+ // handler stack.
+ h.next.ServeHTTP(w, req)
+ return
+ }
+
+ // Request for collection by PDH. Search the federation.
+
+ // First, query the local cluster.
+ resp, err := h.handler.localClusterRequest(req)
+ newResp, err := filterLocalClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
+ return
+ }
+
+ sharedContext, cancelFunc := context.WithCancel(req.Context())
+ defer cancelFunc()
+ req = req.WithContext(sharedContext)
+
+ // Create a goroutine for each cluster in the
+ // RemoteClusters map. The first valid result gets
+ // returned to the client. When that happens, all
+ // other outstanding requests are cancelled or
+ // suppressed.
+ sentResponse := false
+ mtx := sync.Mutex{}
+ wg := sync.WaitGroup{}
+ var errors []string
+ var errorCode int = 404
+
+ // use channel as a semaphore to limit the number of concurrent
+ // requests at a time
+ sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+ defer close(sem)
+ for remoteID := range h.handler.Cluster.RemoteClusters {
+ if remoteID == h.handler.Cluster.ClusterID {
+ // No need to query local cluster again
+ continue
+ }
+ // blocks until it can put a value into the
+ // channel (which has a max queue capacity)
+ sem <- true
+ if sentResponse {
+ break
+ }
+ search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
+ &sharedContext, cancelFunc, &errors, &errorCode}
+ wg.Add(1)
+ go func() {
+ resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
+ newResp, err := search.filterRemoteClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
+ }
+ wg.Done()
+ <-sem
+ }()
+ }
+ wg.Wait()
+
+ if sentResponse {
+ return
+ }
+
+ // No successful responses, so return the error
+ httpserver.Errors(w, errors, errorCode)
+}
--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "regexp"
+ "sync"
+
+ "git.curoverse.com/arvados.git/sdk/go/httpserver"
+)
+
+type genericFederatedRequestHandler struct {
+ next http.Handler
+ handler *Handler
+ matcher *regexp.Regexp
+}
+
+func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
+ req *http.Request,
+ clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
+
+ found := make(map[string]bool)
+ prev_len_uuids := len(uuids) + 1
+ // Loop while
+ // (1) there are more uuids to query
+ // (2) we're making progress - on each iteration the set of
+ // uuids we are expecting for must shrink.
+ for len(uuids) > 0 && len(uuids) < prev_len_uuids {
+ var remoteReq http.Request
+ remoteReq.Header = req.Header
+ remoteReq.Method = "POST"
+ remoteReq.URL = &url.URL{Path: req.URL.Path}
+ remoteParams := make(url.Values)
+ remoteParams.Set("_method", "GET")
+ remoteParams.Set("count", "none")
+ if req.Form.Get("select") != "" {
+ remoteParams.Set("select", req.Form.Get("select"))
+ }
+ content, err := json.Marshal(uuids)
+ if err != nil {
+ return nil, "", err
+ }
+ remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
+ enc := remoteParams.Encode()
+ remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
+
+ rc := multiClusterQueryResponseCollector{clusterID: clusterID}
+
+ var resp *http.Response
+ if clusterID == h.handler.Cluster.ClusterID {
+ resp, err = h.handler.localClusterRequest(&remoteReq)
+ } else {
+ resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
+ }
+ rc.collectResponse(resp, err)
+
+ if rc.error != nil {
+ return nil, "", rc.error
+ }
+
+ kind = rc.kind
+
+ if len(rc.responses) == 0 {
+ // We got zero responses, no point in doing
+ // another query.
+ return rp, kind, nil
+ }
+
+ rp = append(rp, rc.responses...)
+
+ // Go through the responses and determine what was
+ // returned. If there are remaining items, loop
+ // around and do another request with just the
+ // stragglers.
+ for _, i := range rc.responses {
+ uuid, ok := i["uuid"].(string)
+ if ok {
+ found[uuid] = true
+ }
+ }
+
+ l := []string{}
+ for _, u := range uuids {
+ if !found[u] {
+ l = append(l, u)
+ }
+ }
+ prev_len_uuids = len(uuids)
+ uuids = l
+ }
+
+ return rp, kind, nil
+}
+
+func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
+ req *http.Request, clusterId *string) bool {
+
+ var filters [][]interface{}
+ err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return true
+ }
+
+ // Split the list of uuids by prefix
+ queryClusters := make(map[string][]string)
+ expectCount := 0
+ for _, filter := range filters {
+ if len(filter) != 3 {
+ return false
+ }
+
+ if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
+ return false
+ }
+
+ op, ok := filter[1].(string)
+ if !ok {
+ return false
+ }
+
+ if op == "in" {
+ if rhs, ok := filter[2].([]interface{}); ok {
+ for _, i := range rhs {
+ if u, ok := i.(string); ok {
+ *clusterId = u[0:5]
+ queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
+ expectCount += 1
+ }
+ }
+ }
+ } else if op == "=" {
+ if u, ok := filter[2].(string); ok {
+ *clusterId = u[0:5]
+ queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
+ expectCount += 1
+ }
+ } else {
+ return false
+ }
+
+ }
+
+ if len(queryClusters) <= 1 {
+ // Query does not search for uuids across multiple
+ // clusters.
+ return false
+ }
+
+ // Validations
+ count := req.Form.Get("count")
+ if count != "" && count != `none` && count != `"none"` {
+ httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
+ return true
+ }
+ if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
+ httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
+ return true
+ }
+ if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
+ httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
+ expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
+ return true
+ }
+ if req.Form.Get("select") != "" {
+ foundUUID := false
+ var selects []string
+ err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return true
+ }
+
+ for _, r := range selects {
+ if r == "uuid" {
+ foundUUID = true
+ break
+ }
+ }
+ if !foundUUID {
+ httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
+ return true
+ }
+ }
+
+ // Perform concurrent requests to each cluster
+
+ // use channel as a semaphore to limit the number of concurrent
+ // requests at a time
+ sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+ defer close(sem)
+ wg := sync.WaitGroup{}
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ mtx := sync.Mutex{}
+ errors := []error{}
+ var completeResponses []map[string]interface{}
+ var kind string
+
+ for k, v := range queryClusters {
+ if len(v) == 0 {
+ // Nothing to query
+ continue
+ }
+
+ // blocks until it can put a value into the
+ // channel (which has a max queue capacity)
+ sem <- true
+ wg.Add(1)
+ go func(k string, v []string) {
+ rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
+ mtx.Lock()
+ if err == nil {
+ completeResponses = append(completeResponses, rp...)
+ kind = kn
+ } else {
+ errors = append(errors, err)
+ }
+ mtx.Unlock()
+ wg.Done()
+ <-sem
+ }(k, v)
+ }
+ wg.Wait()
+
+ if len(errors) > 0 {
+ var strerr []string
+ for _, e := range errors {
+ strerr = append(strerr, e.Error())
+ }
+ httpserver.Errors(w, strerr, http.StatusBadGateway)
+ return true
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ itemList := make(map[string]interface{})
+ itemList["items"] = completeResponses
+ itemList["kind"] = kind
+ json.NewEncoder(w).Encode(itemList)
+
+ return true
+}
+
+func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ m := h.matcher.FindStringSubmatch(req.URL.Path)
+ clusterId := ""
+
+ if len(m) > 0 && m[2] != "" {
+ clusterId = m[2]
+ }
+
+ // Get form parameters from URL and form body (if POST).
+ if err := loadParamsFromForm(req); err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ // Check if the parameters have an explicit cluster_id
+ if req.Form.Get("cluster_id") != "" {
+ clusterId = req.Form.Get("cluster_id")
+ }
+
+ // Handle the POST-as-GET special case (workaround for large
+ // GET requests that potentially exceed maximum URL length,
+ // like multi-object queries where the filter has 100s of
+ // items)
+ effectiveMethod := req.Method
+ if req.Method == "POST" && req.Form.Get("_method") != "" {
+ effectiveMethod = req.Form.Get("_method")
+ }
+
+ if effectiveMethod == "GET" &&
+ clusterId == "" &&
+ req.Form.Get("filters") != "" &&
+ h.handleMultiClusterQuery(w, req, &clusterId) {
+ return
+ }
+
+ if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
+ h.next.ServeHTTP(w, req)
+ } else {
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ h.handler.proxy.ForwardResponse(w, resp, err)
+ }
+}
+
+type multiClusterQueryResponseCollector struct {
+ responses []map[string]interface{}
+ error error
+ kind string
+ clusterID string
+}
+
+func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
+ requestError error) (newResponse *http.Response, err error) {
+ if requestError != nil {
+ c.error = requestError
+ return nil, nil
+ }
+
+ defer resp.Body.Close()
+ var loadInto struct {
+ Kind string `json:"kind"`
+ Items []map[string]interface{} `json:"items"`
+ Errors []string `json:"errors"`
+ }
+ err = json.NewDecoder(resp.Body).Decode(&loadInto)
+
+ if err != nil {
+ c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
+ return nil, nil
+ }
+ if resp.StatusCode != http.StatusOK {
+ c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
+ return nil, nil
+ }
+
+ c.responses = loadInto.Items
+ c.kind = loadInto.Kind
+
+ return nil, nil
+}
package controller
import (
- "bufio"
"bytes"
- "context"
- "crypto/md5"
"database/sql"
- "encoding/json"
"fmt"
"io"
"io/ioutil"
"net/url"
"regexp"
"strings"
- "sync"
"git.curoverse.com/arvados.git/sdk/go/arvados"
"git.curoverse.com/arvados.git/sdk/go/auth"
- "git.curoverse.com/arvados.git/sdk/go/httpserver"
- "git.curoverse.com/arvados.git/sdk/go/keepclient"
)
var pathPattern = `^/arvados/v1/%s(/([0-9a-z]{5})-%s-[0-9a-z]{15})?(.*)$`
var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
-type genericFederatedRequestHandler struct {
- next http.Handler
- handler *Handler
- matcher *regexp.Regexp
-}
-
-type collectionFederatedRequestHandler struct {
- next http.Handler
- handler *Handler
-}
-
func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
remote, ok := h.Cluster.RemoteClusters[remoteID]
if !ok {
return nil
}
-type multiClusterQueryResponseCollector struct {
- responses []map[string]interface{}
- error error
- kind string
- clusterID string
-}
-
-func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
- requestError error) (newResponse *http.Response, err error) {
- if requestError != nil {
- c.error = requestError
- return nil, nil
- }
-
- defer resp.Body.Close()
- var loadInto struct {
- Kind string `json:"kind"`
- Items []map[string]interface{} `json:"items"`
- Errors []string `json:"errors"`
- }
- err = json.NewDecoder(resp.Body).Decode(&loadInto)
-
- if err != nil {
- c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
- return nil, nil
- }
- if resp.StatusCode != http.StatusOK {
- c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
- return nil, nil
- }
-
- c.responses = loadInto.Items
- c.kind = loadInto.Kind
-
- return nil, nil
-}
-
-func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
- req *http.Request,
- clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
-
- found := make(map[string]bool)
- prev_len_uuids := len(uuids) + 1
- // Loop while
- // (1) there are more uuids to query
- // (2) we're making progress - on each iteration the set of
- // uuids we are expecting for must shrink.
- for len(uuids) > 0 && len(uuids) < prev_len_uuids {
- var remoteReq http.Request
- remoteReq.Header = req.Header
- remoteReq.Method = "POST"
- remoteReq.URL = &url.URL{Path: req.URL.Path}
- remoteParams := make(url.Values)
- remoteParams.Set("_method", "GET")
- remoteParams.Set("count", "none")
- if req.Form.Get("select") != "" {
- remoteParams.Set("select", req.Form.Get("select"))
- }
- content, err := json.Marshal(uuids)
- if err != nil {
- return nil, "", err
- }
- remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
- enc := remoteParams.Encode()
- remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
-
- rc := multiClusterQueryResponseCollector{clusterID: clusterID}
-
- var resp *http.Response
- if clusterID == h.handler.Cluster.ClusterID {
- resp, err = h.handler.localClusterRequest(&remoteReq)
- } else {
- resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
- }
- rc.collectResponse(resp, err)
-
- if rc.error != nil {
- return nil, "", rc.error
- }
-
- kind = rc.kind
-
- if len(rc.responses) == 0 {
- // We got zero responses, no point in doing
- // another query.
- return rp, kind, nil
- }
-
- rp = append(rp, rc.responses...)
-
- // Go through the responses and determine what was
- // returned. If there are remaining items, loop
- // around and do another request with just the
- // stragglers.
- for _, i := range rc.responses {
- uuid, ok := i["uuid"].(string)
- if ok {
- found[uuid] = true
- }
- }
-
- l := []string{}
- for _, u := range uuids {
- if !found[u] {
- l = append(l, u)
- }
- }
- prev_len_uuids = len(uuids)
- uuids = l
- }
-
- return rp, kind, nil
-}
-
-func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
- req *http.Request, clusterId *string) bool {
-
- var filters [][]interface{}
- err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return true
- }
-
- // Split the list of uuids by prefix
- queryClusters := make(map[string][]string)
- expectCount := 0
- for _, filter := range filters {
- if len(filter) != 3 {
- return false
- }
-
- if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
- return false
- }
-
- op, ok := filter[1].(string)
- if !ok {
- return false
- }
-
- if op == "in" {
- if rhs, ok := filter[2].([]interface{}); ok {
- for _, i := range rhs {
- if u, ok := i.(string); ok {
- *clusterId = u[0:5]
- queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
- expectCount += 1
- }
- }
- }
- } else if op == "=" {
- if u, ok := filter[2].(string); ok {
- *clusterId = u[0:5]
- queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
- expectCount += 1
- }
- } else {
- return false
- }
-
- }
-
- if len(queryClusters) <= 1 {
- // Query does not search for uuids across multiple
- // clusters.
- return false
- }
-
- // Validations
- count := req.Form.Get("count")
- if count != "" && count != `none` && count != `"none"` {
- httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
- return true
- }
- if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
- httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
- return true
- }
- if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
- httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
- expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
- return true
- }
- if req.Form.Get("select") != "" {
- foundUUID := false
- var selects []string
- err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return true
- }
-
- for _, r := range selects {
- if r == "uuid" {
- foundUUID = true
- break
- }
- }
- if !foundUUID {
- httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
- return true
- }
- }
-
- // Perform concurrent requests to each cluster
-
- // use channel as a semaphore to limit the number of concurrent
- // requests at a time
- sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
- defer close(sem)
- wg := sync.WaitGroup{}
-
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- mtx := sync.Mutex{}
- errors := []error{}
- var completeResponses []map[string]interface{}
- var kind string
-
- for k, v := range queryClusters {
- if len(v) == 0 {
- // Nothing to query
- continue
- }
-
- // blocks until it can put a value into the
- // channel (which has a max queue capacity)
- sem <- true
- wg.Add(1)
- go func(k string, v []string) {
- rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
- mtx.Lock()
- if err == nil {
- completeResponses = append(completeResponses, rp...)
- kind = kn
- } else {
- errors = append(errors, err)
- }
- mtx.Unlock()
- wg.Done()
- <-sem
- }(k, v)
- }
- wg.Wait()
-
- if len(errors) > 0 {
- var strerr []string
- for _, e := range errors {
- strerr = append(strerr, e.Error())
- }
- httpserver.Errors(w, strerr, http.StatusBadGateway)
- return true
- }
-
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusOK)
- itemList := make(map[string]interface{})
- itemList["items"] = completeResponses
- itemList["kind"] = kind
- json.NewEncoder(w).Encode(itemList)
-
- return true
-}
-
-func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- m := h.matcher.FindStringSubmatch(req.URL.Path)
- clusterId := ""
-
- if len(m) > 0 && m[2] != "" {
- clusterId = m[2]
- }
-
- // Get form parameters from URL and form body (if POST).
- if err := loadParamsFromForm(req); err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
-
- // Check if the parameters have an explicit cluster_id
- if req.Form.Get("cluster_id") != "" {
- clusterId = req.Form.Get("cluster_id")
- }
-
- // Handle the POST-as-GET special case (workaround for large
- // GET requests that potentially exceed maximum URL length,
- // like multi-object queries where the filter has 100s of
- // items)
- effectiveMethod := req.Method
- if req.Method == "POST" && req.Form.Get("_method") != "" {
- effectiveMethod = req.Form.Get("_method")
- }
-
- if effectiveMethod == "GET" &&
- clusterId == "" &&
- req.Form.Get("filters") != "" &&
- h.handleMultiClusterQuery(w, req, &clusterId) {
- return
- }
-
- if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
- h.next.ServeHTTP(w, req)
- } else {
- resp, err := h.handler.remoteClusterRequest(clusterId, req)
- h.handler.proxy.ForwardResponse(w, resp, err)
- }
-}
-
-func rewriteSignatures(clusterID string, expectHash string,
- resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-
- if requestError != nil {
- return resp, requestError
- }
-
- if resp.StatusCode != 200 {
- return resp, nil
- }
-
- originalBody := resp.Body
- defer originalBody.Close()
-
- var col arvados.Collection
- err = json.NewDecoder(resp.Body).Decode(&col)
- if err != nil {
- return nil, err
- }
-
- // rewriting signatures will make manifest text 5-10% bigger so calculate
- // capacity accordingly
- updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
-
- hasher := md5.New()
- mw := io.MultiWriter(hasher, updatedManifest)
- sz := 0
-
- scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
- scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
- for scanner.Scan() {
- line := scanner.Text()
- tokens := strings.Split(line, " ")
- if len(tokens) < 3 {
- return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
- }
-
- n, err := mw.Write([]byte(tokens[0]))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- for _, token := range tokens[1:] {
- n, err = mw.Write([]byte(" "))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
-
- m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
- if m != nil {
- // Rewrite the block signature to be a remote signature
- _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
-
- // for hash checking, ignore signatures
- n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- } else {
- n, err = mw.Write([]byte(token))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- }
- }
- n, err = mw.Write([]byte("\n"))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- }
-
- // Check that expected hash is consistent with
- // portable_data_hash field of the returned record
- if expectHash == "" {
- expectHash = col.PortableDataHash
- } else if expectHash != col.PortableDataHash {
- return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
- }
-
- // Certify that the computed hash of the manifest_text matches our expectation
- sum := hasher.Sum(nil)
- computedHash := fmt.Sprintf("%x+%v", sum, sz)
- if computedHash != expectHash {
- return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
- }
-
- col.ManifestText = updatedManifest.String()
-
- newbody, err := json.Marshal(col)
- if err != nil {
- return nil, err
- }
-
- buf := bytes.NewBuffer(newbody)
- resp.Body = ioutil.NopCloser(buf)
- resp.ContentLength = int64(buf.Len())
- resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
-
- return resp, nil
-}
-
-func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- if requestError != nil {
- return resp, requestError
- }
-
- if resp.StatusCode == 404 {
- // Suppress returning this result, because we want to
- // search the federation.
- return nil, nil
- }
- return resp, nil
-}
-
-type searchRemoteClusterForPDH struct {
- pdh string
- remoteID string
- mtx *sync.Mutex
- sentResponse *bool
- sharedContext *context.Context
- cancelFunc func()
- errors *[]string
- statusCode *int
-}
-
-func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- s.mtx.Lock()
- defer s.mtx.Unlock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if requestError != nil {
- *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
- // Record the error and suppress response
- return nil, nil
- }
-
- if resp.StatusCode != 200 {
- // Suppress returning unsuccessful result. Maybe
- // another request will find it.
- // TODO collect and return error responses.
- *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
- if resp.StatusCode != 404 {
- // Got a non-404 error response, convert into BadGateway
- *s.statusCode = http.StatusBadGateway
- }
- return nil, nil
- }
-
- s.mtx.Unlock()
-
- // This reads the response body. We don't want to hold the
- // lock while doing this because other remote requests could
- // also have made it to this point, and we don't want a
- // slow response holding the lock to block a faster response
- // that is waiting on the lock.
- newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
-
- s.mtx.Lock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if err != nil {
- // Suppress returning unsuccessful result. Maybe
- // another request will be successful.
- *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
- return nil, nil
- }
-
- // We have a successful response. Suppress/cancel all the
- // other requests/responses.
- *s.sentResponse = true
- s.cancelFunc()
-
- return newResponse, nil
-}
-
-func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- if req.Method != "GET" {
- // Only handle GET requests right now
- h.next.ServeHTTP(w, req)
- return
- }
-
- m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
- if len(m) != 2 {
- // Not a collection PDH GET request
- m = collectionRe.FindStringSubmatch(req.URL.Path)
- clusterId := ""
-
- if len(m) > 0 {
- clusterId = m[2]
- }
-
- if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
- // request for remote collection by uuid
- resp, err := h.handler.remoteClusterRequest(clusterId, req)
- newResponse, err := rewriteSignatures(clusterId, "", resp, err)
- h.handler.proxy.ForwardResponse(w, newResponse, err)
- return
- }
- // not a collection UUID request, or it is a request
- // for a local UUID, either way, continue down the
- // handler stack.
- h.next.ServeHTTP(w, req)
- return
- }
-
- // Request for collection by PDH. Search the federation.
-
- // First, query the local cluster.
- resp, err := h.handler.localClusterRequest(req)
- newResp, err := filterLocalClusterResponse(resp, err)
- if newResp != nil || err != nil {
- h.handler.proxy.ForwardResponse(w, newResp, err)
- return
- }
-
- sharedContext, cancelFunc := context.WithCancel(req.Context())
- defer cancelFunc()
- req = req.WithContext(sharedContext)
-
- // Create a goroutine for each cluster in the
- // RemoteClusters map. The first valid result gets
- // returned to the client. When that happens, all
- // other outstanding requests are cancelled or
- // suppressed.
- sentResponse := false
- mtx := sync.Mutex{}
- wg := sync.WaitGroup{}
- var errors []string
- var errorCode int = 404
-
- // use channel as a semaphore to limit the number of concurrent
- // requests at a time
- sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
- defer close(sem)
- for remoteID := range h.handler.Cluster.RemoteClusters {
- if remoteID == h.handler.Cluster.ClusterID {
- // No need to query local cluster again
- continue
- }
- // blocks until it can put a value into the
- // channel (which has a max queue capacity)
- sem <- true
- if sentResponse {
- break
- }
- search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
- &sharedContext, cancelFunc, &errors, &errorCode}
- wg.Add(1)
- go func() {
- resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
- newResp, err := search.filterRemoteClusterResponse(resp, err)
- if newResp != nil || err != nil {
- h.handler.proxy.ForwardResponse(w, newResp, err)
- }
- wg.Done()
- <-sem
- }()
- }
- wg.Wait()
-
- if sentResponse {
- return
- }
-
- // No successful responses, so return the error
- httpserver.Errors(w, errors, errorCode)
-}
-
func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
mux := http.NewServeMux()
mux.Handle("/arvados/v1/workflows", &genericFederatedRequestHandler{next, h, wfRe})