--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/md5"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "sync"
+
+ "git.curoverse.com/arvados.git/sdk/go/arvados"
+ "git.curoverse.com/arvados.git/sdk/go/httpserver"
+ "git.curoverse.com/arvados.git/sdk/go/keepclient"
+)
+
+type collectionFederatedRequestHandler struct {
+ next http.Handler
+ handler *Handler
+}
+
+func rewriteSignatures(clusterID string, expectHash string,
+ resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+
+ if requestError != nil {
+ return resp, requestError
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return resp, nil
+ }
+
+ originalBody := resp.Body
+ defer originalBody.Close()
+
+ var col arvados.Collection
+ err = json.NewDecoder(resp.Body).Decode(&col)
+ if err != nil {
+ return nil, err
+ }
+
+ // rewriting signatures will make manifest text 5-10% bigger so calculate
+ // capacity accordingly
+ updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
+
+ hasher := md5.New()
+ mw := io.MultiWriter(hasher, updatedManifest)
+ sz := 0
+
+ scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
+ scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
+ for scanner.Scan() {
+ line := scanner.Text()
+ tokens := strings.Split(line, " ")
+ if len(tokens) < 3 {
+ return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
+ }
+
+ n, err := mw.Write([]byte(tokens[0]))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ for _, token := range tokens[1:] {
+ n, err = mw.Write([]byte(" "))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+
+ m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
+ if m != nil {
+ // Rewrite the block signature to be a remote signature
+ _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+
+ // for hash checking, ignore signatures
+ n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ } else {
+ n, err = mw.Write([]byte(token))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ }
+ }
+ n, err = mw.Write([]byte("\n"))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ }
+
+ // Check that expected hash is consistent with
+ // portable_data_hash field of the returned record
+ if expectHash == "" {
+ expectHash = col.PortableDataHash
+ } else if expectHash != col.PortableDataHash {
+ return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
+ }
+
+ // Certify that the computed hash of the manifest_text matches our expectation
+ sum := hasher.Sum(nil)
+ computedHash := fmt.Sprintf("%x+%v", sum, sz)
+ if computedHash != expectHash {
+ return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
+ }
+
+ col.ManifestText = updatedManifest.String()
+
+ newbody, err := json.Marshal(col)
+ if err != nil {
+ return nil, err
+ }
+
+ buf := bytes.NewBuffer(newbody)
+ resp.Body = ioutil.NopCloser(buf)
+ resp.ContentLength = int64(buf.Len())
+ resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
+
+ return resp, nil
+}
+
+func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+ if requestError != nil {
+ return resp, requestError
+ }
+
+ if resp.StatusCode == http.StatusNotFound {
+ // Suppress returning this result, because we want to
+ // search the federation.
+ return nil, nil
+ }
+ return resp, nil
+}
+
+type searchRemoteClusterForPDH struct {
+ pdh string
+ remoteID string
+ mtx *sync.Mutex
+ sentResponse *bool
+ sharedContext *context.Context
+ cancelFunc func()
+ errors *[]string
+ statusCode *int
+}
+
+func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if req.Method != "GET" {
+ // Only handle GET requests right now
+ h.next.ServeHTTP(w, req)
+ return
+ }
+
+ m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
+ if len(m) != 2 {
+ // Not a collection PDH GET request
+ m = collectionRe.FindStringSubmatch(req.URL.Path)
+ clusterId := ""
+
+ if len(m) > 0 {
+ clusterId = m[2]
+ }
+
+ if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
+ // request for remote collection by uuid
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ newResponse, err := rewriteSignatures(clusterId, "", resp, err)
+ h.handler.proxy.ForwardResponse(w, newResponse, err)
+ return
+ }
+ // not a collection UUID request, or it is a request
+ // for a local UUID, either way, continue down the
+ // handler stack.
+ h.next.ServeHTTP(w, req)
+ return
+ }
+
+ // Request for collection by PDH. Search the federation.
+
+ // First, query the local cluster.
+ resp, err := h.handler.localClusterRequest(req)
+ newResp, err := filterLocalClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
+ return
+ }
+
+ // Create a goroutine for each cluster in the
+ // RemoteClusters map. The first valid result gets
+ // returned to the client. When that happens, all
+ // other outstanding requests are cancelled
+ sharedContext, cancelFunc := context.WithCancel(req.Context())
+ req = req.WithContext(sharedContext)
+ wg := sync.WaitGroup{}
+ pdh := m[1]
+ success := make(chan *http.Response)
+ errorChan := make(chan error)
+
+ // use channel as a semaphore to limit the number of concurrent
+ // requests at a time
+ sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+
+ defer close(errorChan)
+ defer close(success)
+ defer close(sem)
+ defer cancelFunc()
+
+ for remoteID := range h.handler.Cluster.RemoteClusters {
+ if remoteID == h.handler.Cluster.ClusterID {
+ // No need to query local cluster again
+ continue
+ }
+
+ wg.Add(1)
+ go func(remote string) {
+ defer wg.Done()
+ // blocks until it can put a value into the
+ // channel (which has a max queue capacity)
+ sem <- true
+ select {
+ case <-sharedContext.Done():
+ return
+ default:
+ }
+
+ resp, err := h.handler.remoteClusterRequest(remote, req)
+ wasSuccess := false
+ defer func() {
+ if resp != nil && !wasSuccess {
+ resp.Body.Close()
+ }
+ }()
+ if err != nil {
+ errorChan <- err
+ return
+ }
+ if resp.StatusCode != http.StatusOK {
+ errorChan <- HTTPError{resp.Status, resp.StatusCode}
+ return
+ }
+ select {
+ case <-sharedContext.Done():
+ return
+ default:
+ }
+
+ newResponse, err := rewriteSignatures(remote, pdh, resp, nil)
+ if err != nil {
+ errorChan <- err
+ return
+ }
+ select {
+ case <-sharedContext.Done():
+ case success <- newResponse:
+ wasSuccess = true
+ }
+ <-sem
+ }(remoteID)
+ }
+ go func() {
+ wg.Wait()
+ cancelFunc()
+ }()
+
+ var errors []string
+ errorCode := http.StatusNotFound
+
+ for {
+ select {
+ case newResp = <-success:
+ h.handler.proxy.ForwardResponse(w, newResp, nil)
+ return
+ case err := <-errorChan:
+ if httperr, ok := err.(HTTPError); ok {
+ if httperr.Code != http.StatusNotFound {
+ errorCode = http.StatusBadGateway
+ }
+ }
+ errors = append(errors, err.Error())
+ case <-sharedContext.Done():
+ httpserver.Errors(w, errors, errorCode)
+ return
+ }
+ }
+}
--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "strings"
+
+ "git.curoverse.com/arvados.git/sdk/go/auth"
+ "git.curoverse.com/arvados.git/sdk/go/httpserver"
+)
+
+func remoteContainerRequestCreate(
+ h *genericFederatedRequestHandler,
+ effectiveMethod string,
+ clusterId *string,
+ uuid string,
+ remainder string,
+ w http.ResponseWriter,
+ req *http.Request) bool {
+
+ if effectiveMethod != "POST" || uuid != "" || remainder != "" ||
+ *clusterId == "" || *clusterId == h.handler.Cluster.ClusterID {
+ return false
+ }
+
+ if req.Header.Get("Content-Type") != "application/json" {
+ httpserver.Error(w, "Expected Content-Type: application/json, got "+req.Header.Get("Content-Type"), http.StatusBadRequest)
+ return true
+ }
+
+ originalBody := req.Body
+ defer originalBody.Close()
+ var request map[string]interface{}
+ err := json.NewDecoder(req.Body).Decode(&request)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return true
+ }
+
+ crString, ok := request["container_request"].(string)
+ if ok {
+ var crJson map[string]interface{}
+ err := json.Unmarshal([]byte(crString), &crJson)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return true
+ }
+
+ request["container_request"] = crJson
+ }
+
+ containerRequest, ok := request["container_request"].(map[string]interface{})
+ if !ok {
+ // Use toplevel object as the container_request object
+ containerRequest = request
+ }
+
+ // If runtime_token is not set, create a new token
+ if _, ok := containerRequest["runtime_token"]; !ok {
+ // First make sure supplied token is valid.
+ creds := auth.NewCredentials()
+ creds.LoadTokensFromHTTPRequest(req)
+
+ currentUser, err := h.handler.validateAPItoken(req, creds.Tokens[0])
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusForbidden)
+ return true
+ }
+
+ if len(currentUser.Authorization.Scopes) != 1 || currentUser.Authorization.Scopes[0] != "all" {
+ httpserver.Error(w, "Token scope is not [all]", http.StatusForbidden)
+ return true
+ }
+
+ // Must be home cluster for this authorization
+ if strings.HasPrefix(currentUser.Authorization.UUID, h.handler.Cluster.ClusterID) {
+ newtok, err := h.handler.createAPItoken(req, currentUser.UUID, nil)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusForbidden)
+ return true
+ }
+ containerRequest["runtime_token"] = newtok.TokenV2()
+ }
+ }
+
+ newbody, err := json.Marshal(request)
+ buf := bytes.NewBuffer(newbody)
+ req.Body = ioutil.NopCloser(buf)
+ req.ContentLength = int64(buf.Len())
+ req.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
+
+ resp, err := h.handler.remoteClusterRequest(*clusterId, req)
+ h.handler.proxy.ForwardResponse(w, resp, err)
+ return true
+}
--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "regexp"
+ "sync"
+
+ "git.curoverse.com/arvados.git/sdk/go/httpserver"
+)
+
+type federatedRequestDelegate func(
+ h *genericFederatedRequestHandler,
+ effectiveMethod string,
+ clusterId *string,
+ uuid string,
+ remainder string,
+ w http.ResponseWriter,
+ req *http.Request) bool
+
+type genericFederatedRequestHandler struct {
+ next http.Handler
+ handler *Handler
+ matcher *regexp.Regexp
+ delegates []federatedRequestDelegate
+}
+
+func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
+ req *http.Request,
+ clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
+
+ found := make(map[string]bool)
+ prev_len_uuids := len(uuids) + 1
+ // Loop while
+ // (1) there are more uuids to query
+ // (2) we're making progress - on each iteration the set of
+ // uuids we are expecting for must shrink.
+ for len(uuids) > 0 && len(uuids) < prev_len_uuids {
+ var remoteReq http.Request
+ remoteReq.Header = req.Header
+ remoteReq.Method = "POST"
+ remoteReq.URL = &url.URL{Path: req.URL.Path}
+ remoteParams := make(url.Values)
+ remoteParams.Set("_method", "GET")
+ remoteParams.Set("count", "none")
+ if req.Form.Get("select") != "" {
+ remoteParams.Set("select", req.Form.Get("select"))
+ }
+ content, err := json.Marshal(uuids)
+ if err != nil {
+ return nil, "", err
+ }
+ remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
+ enc := remoteParams.Encode()
+ remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
+
+ rc := multiClusterQueryResponseCollector{clusterID: clusterID}
+
+ var resp *http.Response
+ if clusterID == h.handler.Cluster.ClusterID {
+ resp, err = h.handler.localClusterRequest(&remoteReq)
+ } else {
+ resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
+ }
+ rc.collectResponse(resp, err)
+
+ if rc.error != nil {
+ return nil, "", rc.error
+ }
+
+ kind = rc.kind
+
+ if len(rc.responses) == 0 {
+ // We got zero responses, no point in doing
+ // another query.
+ return rp, kind, nil
+ }
+
+ rp = append(rp, rc.responses...)
+
+ // Go through the responses and determine what was
+ // returned. If there are remaining items, loop
+ // around and do another request with just the
+ // stragglers.
+ for _, i := range rc.responses {
+ uuid, ok := i["uuid"].(string)
+ if ok {
+ found[uuid] = true
+ }
+ }
+
+ l := []string{}
+ for _, u := range uuids {
+ if !found[u] {
+ l = append(l, u)
+ }
+ }
+ prev_len_uuids = len(uuids)
+ uuids = l
+ }
+
+ return rp, kind, nil
+}
+
+func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
+ req *http.Request, clusterId *string) bool {
+
+ var filters [][]interface{}
+ err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return true
+ }
+
+ // Split the list of uuids by prefix
+ queryClusters := make(map[string][]string)
+ expectCount := 0
+ for _, filter := range filters {
+ if len(filter) != 3 {
+ return false
+ }
+
+ if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
+ return false
+ }
+
+ op, ok := filter[1].(string)
+ if !ok {
+ return false
+ }
+
+ if op == "in" {
+ if rhs, ok := filter[2].([]interface{}); ok {
+ for _, i := range rhs {
+ if u, ok := i.(string); ok && len(u) == 27 {
+ *clusterId = u[0:5]
+ queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
+ expectCount += 1
+ }
+ }
+ }
+ } else if op == "=" {
+ if u, ok := filter[2].(string); ok && len(u) == 27 {
+ *clusterId = u[0:5]
+ queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
+ expectCount += 1
+ }
+ } else {
+ return false
+ }
+
+ }
+
+ if len(queryClusters) <= 1 {
+ // Query does not search for uuids across multiple
+ // clusters.
+ return false
+ }
+
+ // Validations
+ count := req.Form.Get("count")
+ if count != "" && count != `none` && count != `"none"` {
+ httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
+ return true
+ }
+ if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
+ httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
+ return true
+ }
+ if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
+ httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
+ expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
+ return true
+ }
+ if req.Form.Get("select") != "" {
+ foundUUID := false
+ var selects []string
+ err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return true
+ }
+
+ for _, r := range selects {
+ if r == "uuid" {
+ foundUUID = true
+ break
+ }
+ }
+ if !foundUUID {
+ httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
+ return true
+ }
+ }
+
+ // Perform concurrent requests to each cluster
+
+ // use channel as a semaphore to limit the number of concurrent
+ // requests at a time
+ sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+ defer close(sem)
+ wg := sync.WaitGroup{}
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ mtx := sync.Mutex{}
+ errors := []error{}
+ var completeResponses []map[string]interface{}
+ var kind string
+
+ for k, v := range queryClusters {
+ if len(v) == 0 {
+ // Nothing to query
+ continue
+ }
+
+ // blocks until it can put a value into the
+ // channel (which has a max queue capacity)
+ sem <- true
+ wg.Add(1)
+ go func(k string, v []string) {
+ rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
+ mtx.Lock()
+ if err == nil {
+ completeResponses = append(completeResponses, rp...)
+ kind = kn
+ } else {
+ errors = append(errors, err)
+ }
+ mtx.Unlock()
+ wg.Done()
+ <-sem
+ }(k, v)
+ }
+ wg.Wait()
+
+ if len(errors) > 0 {
+ var strerr []string
+ for _, e := range errors {
+ strerr = append(strerr, e.Error())
+ }
+ httpserver.Errors(w, strerr, http.StatusBadGateway)
+ return true
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ itemList := make(map[string]interface{})
+ itemList["items"] = completeResponses
+ itemList["kind"] = kind
+ json.NewEncoder(w).Encode(itemList)
+
+ return true
+}
+
+func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ m := h.matcher.FindStringSubmatch(req.URL.Path)
+ clusterId := ""
+
+ if len(m) > 0 && m[2] != "" {
+ clusterId = m[2]
+ }
+
+ // Get form parameters from URL and form body (if POST).
+ if err := loadParamsFromForm(req); err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ // Check if the parameters have an explicit cluster_id
+ if req.Form.Get("cluster_id") != "" {
+ clusterId = req.Form.Get("cluster_id")
+ }
+
+ // Handle the POST-as-GET special case (workaround for large
+ // GET requests that potentially exceed maximum URL length,
+ // like multi-object queries where the filter has 100s of
+ // items)
+ effectiveMethod := req.Method
+ if req.Method == "POST" && req.Form.Get("_method") != "" {
+ effectiveMethod = req.Form.Get("_method")
+ }
+
+ if effectiveMethod == "GET" &&
+ clusterId == "" &&
+ req.Form.Get("filters") != "" &&
+ h.handleMultiClusterQuery(w, req, &clusterId) {
+ return
+ }
+
+ for _, d := range h.delegates {
+ if d(h, effectiveMethod, &clusterId, m[1], m[3], w, req) {
+ return
+ }
+ }
+
+ if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
+ h.next.ServeHTTP(w, req)
+ } else {
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ h.handler.proxy.ForwardResponse(w, resp, err)
+ }
+}
+
+type multiClusterQueryResponseCollector struct {
+ responses []map[string]interface{}
+ error error
+ kind string
+ clusterID string
+}
+
+func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
+ requestError error) (newResponse *http.Response, err error) {
+ if requestError != nil {
+ c.error = requestError
+ return nil, nil
+ }
+
+ defer resp.Body.Close()
+ var loadInto struct {
+ Kind string `json:"kind"`
+ Items []map[string]interface{} `json:"items"`
+ Errors []string `json:"errors"`
+ }
+ err = json.NewDecoder(resp.Body).Decode(&loadInto)
+
+ if err != nil {
+ c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
+ return nil, nil
+ }
+ if resp.StatusCode != http.StatusOK {
+ c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
+ return nil, nil
+ }
+
+ c.responses = loadInto.Items
+ c.kind = loadInto.Kind
+
+ return nil, nil
+}
package controller
import (
- "bufio"
"bytes"
- "context"
- "crypto/md5"
"database/sql"
"encoding/json"
"fmt"
"net/url"
"regexp"
"strings"
- "sync"
"git.curoverse.com/arvados.git/sdk/go/arvados"
"git.curoverse.com/arvados.git/sdk/go/auth"
- "git.curoverse.com/arvados.git/sdk/go/httpserver"
- "git.curoverse.com/arvados.git/sdk/go/keepclient"
+ "github.com/jmcvetta/randutil"
)
var pathPattern = `^/arvados/v1/%s(/([0-9a-z]{5})-%s-[0-9a-z]{15})?(.*)$`
var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
-type genericFederatedRequestHandler struct {
- next http.Handler
- handler *Handler
- matcher *regexp.Regexp
-}
-
-type collectionFederatedRequestHandler struct {
- next http.Handler
- handler *Handler
-}
-
-func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
+func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
remote, ok := h.Cluster.RemoteClusters[remoteID]
if !ok {
- httpserver.Error(w, "no proxy available for cluster "+remoteID, http.StatusNotFound)
- return
+ return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
}
scheme := remote.Scheme
if scheme == "" {
scheme = "https"
}
- err := h.saltAuthToken(req, remoteID)
+ saltedReq, err := h.saltAuthToken(req, remoteID)
if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return
+ return nil, err
}
urlOut := &url.URL{
Scheme: scheme,
Host: remote.Host,
- Path: req.URL.Path,
- RawPath: req.URL.RawPath,
- RawQuery: req.URL.RawQuery,
+ Path: saltedReq.URL.Path,
+ RawPath: saltedReq.URL.RawPath,
+ RawQuery: saltedReq.URL.RawQuery,
}
client := h.secureClient
if remote.Insecure {
client = h.insecureClient
}
- h.proxy.Do(w, req, urlOut, client, filter)
+ return h.proxy.Do(saltedReq, urlOut, client)
}
// Buffer request body, parse form parameters in request, and then
return nil
}
-type multiClusterQueryResponseCollector struct {
- responses []map[string]interface{}
- error error
- kind string
- clusterID string
-}
-
-func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
- requestError error) (newResponse *http.Response, err error) {
- if requestError != nil {
- c.error = requestError
- return nil, nil
- }
-
- defer resp.Body.Close()
- var loadInto struct {
- Kind string `json:"kind"`
- Items []map[string]interface{} `json:"items"`
- Errors []string `json:"errors"`
- }
- err = json.NewDecoder(resp.Body).Decode(&loadInto)
-
- if err != nil {
- c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
- return nil, nil
- }
- if resp.StatusCode != http.StatusOK {
- c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
- return nil, nil
- }
-
- c.responses = loadInto.Items
- c.kind = loadInto.Kind
-
- return nil, nil
-}
-
-func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
- req *http.Request,
- clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
-
- found := make(map[string]bool)
- prev_len_uuids := len(uuids) + 1
- // Loop while
- // (1) there are more uuids to query
- // (2) we're making progress - on each iteration the set of
- // uuids we are expecting for must shrink.
- for len(uuids) > 0 && len(uuids) < prev_len_uuids {
- var remoteReq http.Request
- remoteReq.Header = req.Header
- remoteReq.Method = "POST"
- remoteReq.URL = &url.URL{Path: req.URL.Path}
- remoteParams := make(url.Values)
- remoteParams.Set("_method", "GET")
- remoteParams.Set("count", "none")
- if req.Form.Get("select") != "" {
- remoteParams.Set("select", req.Form.Get("select"))
- }
- content, err := json.Marshal(uuids)
- if err != nil {
- return nil, "", err
- }
- remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
- enc := remoteParams.Encode()
- remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
-
- rc := multiClusterQueryResponseCollector{clusterID: clusterID}
-
- if clusterID == h.handler.Cluster.ClusterID {
- h.handler.localClusterRequest(w, &remoteReq,
- rc.collectResponse)
- } else {
- h.handler.remoteClusterRequest(clusterID, w, &remoteReq,
- rc.collectResponse)
- }
- if rc.error != nil {
- return nil, "", rc.error
- }
-
- kind = rc.kind
-
- if len(rc.responses) == 0 {
- // We got zero responses, no point in doing
- // another query.
- return rp, kind, nil
- }
-
- rp = append(rp, rc.responses...)
-
- // Go through the responses and determine what was
- // returned. If there are remaining items, loop
- // around and do another request with just the
- // stragglers.
- for _, i := range rc.responses {
- uuid, ok := i["uuid"].(string)
- if ok {
- found[uuid] = true
- }
- }
-
- l := []string{}
- for _, u := range uuids {
- if !found[u] {
- l = append(l, u)
- }
- }
- prev_len_uuids = len(uuids)
- uuids = l
- }
-
- return rp, kind, nil
-}
-
-func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
- req *http.Request, clusterId *string) bool {
-
- var filters [][]interface{}
- err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return true
- }
-
- // Split the list of uuids by prefix
- queryClusters := make(map[string][]string)
- expectCount := 0
- for _, filter := range filters {
- if len(filter) != 3 {
- return false
- }
-
- if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
- return false
- }
-
- op, ok := filter[1].(string)
- if !ok {
- return false
- }
-
- if op == "in" {
- if rhs, ok := filter[2].([]interface{}); ok {
- for _, i := range rhs {
- if u, ok := i.(string); ok {
- *clusterId = u[0:5]
- queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
- expectCount += 1
- }
- }
- }
- } else if op == "=" {
- if u, ok := filter[2].(string); ok {
- *clusterId = u[0:5]
- queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
- expectCount += 1
- }
- } else {
- return false
- }
-
- }
-
- if len(queryClusters) <= 1 {
- // Query does not search for uuids across multiple
- // clusters.
- return false
- }
-
- // Validations
- count := req.Form.Get("count")
- if count != "" && count != `none` && count != `"none"` {
- httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
- return true
- }
- if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
- httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
- return true
- }
- if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
- httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
- expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
- return true
- }
- if req.Form.Get("select") != "" {
- foundUUID := false
- var selects []string
- err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return true
- }
-
- for _, r := range selects {
- if r == "uuid" {
- foundUUID = true
- break
- }
- }
- if !foundUUID {
- httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
- return true
- }
- }
-
- // Perform concurrent requests to each cluster
-
- // use channel as a semaphore to limit the number of concurrent
- // requests at a time
- sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
- defer close(sem)
- wg := sync.WaitGroup{}
-
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- mtx := sync.Mutex{}
- errors := []error{}
- var completeResponses []map[string]interface{}
- var kind string
-
- for k, v := range queryClusters {
- if len(v) == 0 {
- // Nothing to query
- continue
- }
-
- // blocks until it can put a value into the
- // channel (which has a max queue capacity)
- sem <- true
- wg.Add(1)
- go func(k string, v []string) {
- rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
- mtx.Lock()
- if err == nil {
- completeResponses = append(completeResponses, rp...)
- kind = kn
- } else {
- errors = append(errors, err)
- }
- mtx.Unlock()
- wg.Done()
- <-sem
- }(k, v)
- }
- wg.Wait()
-
- if len(errors) > 0 {
- var strerr []string
- for _, e := range errors {
- strerr = append(strerr, e.Error())
- }
- httpserver.Errors(w, strerr, http.StatusBadGateway)
- return true
- }
-
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusOK)
- itemList := make(map[string]interface{})
- itemList["items"] = completeResponses
- itemList["kind"] = kind
- json.NewEncoder(w).Encode(itemList)
-
- return true
-}
-
-func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- m := h.matcher.FindStringSubmatch(req.URL.Path)
- clusterId := ""
-
- if len(m) > 0 && m[2] != "" {
- clusterId = m[2]
- }
-
- // Get form parameters from URL and form body (if POST).
- if err := loadParamsFromForm(req); err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
-
- // Check if the parameters have an explicit cluster_id
- if req.Form.Get("cluster_id") != "" {
- clusterId = req.Form.Get("cluster_id")
- }
-
- // Handle the POST-as-GET special case (workaround for large
- // GET requests that potentially exceed maximum URL length,
- // like multi-object queries where the filter has 100s of
- // items)
- effectiveMethod := req.Method
- if req.Method == "POST" && req.Form.Get("_method") != "" {
- effectiveMethod = req.Form.Get("_method")
- }
-
- if effectiveMethod == "GET" &&
- clusterId == "" &&
- req.Form.Get("filters") != "" &&
- h.handleMultiClusterQuery(w, req, &clusterId) {
- return
- }
-
- if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
- h.next.ServeHTTP(w, req)
- } else {
- h.handler.remoteClusterRequest(clusterId, w, req, nil)
- }
-}
-
-type rewriteSignaturesClusterId struct {
- clusterID string
- expectHash string
-}
-
-func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- if requestError != nil {
- return resp, requestError
- }
-
- if resp.StatusCode != 200 {
- return resp, nil
- }
-
- originalBody := resp.Body
- defer originalBody.Close()
-
- var col arvados.Collection
- err = json.NewDecoder(resp.Body).Decode(&col)
- if err != nil {
- return nil, err
- }
-
- // rewriting signatures will make manifest text 5-10% bigger so calculate
- // capacity accordingly
- updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
-
- hasher := md5.New()
- mw := io.MultiWriter(hasher, updatedManifest)
- sz := 0
-
- scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
- scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
- for scanner.Scan() {
- line := scanner.Text()
- tokens := strings.Split(line, " ")
- if len(tokens) < 3 {
- return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
- }
-
- n, err := mw.Write([]byte(tokens[0]))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- for _, token := range tokens[1:] {
- n, err = mw.Write([]byte(" "))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
-
- m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
- if m != nil {
- // Rewrite the block signature to be a remote signature
- _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], rw.clusterID, m[5][2:], m[8])
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
-
- // for hash checking, ignore signatures
- n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- } else {
- n, err = mw.Write([]byte(token))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- }
- }
- n, err = mw.Write([]byte("\n"))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- }
-
- // Check that expected hash is consistent with
- // portable_data_hash field of the returned record
- if rw.expectHash == "" {
- rw.expectHash = col.PortableDataHash
- } else if rw.expectHash != col.PortableDataHash {
- return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", rw.expectHash, col.PortableDataHash)
- }
-
- // Certify that the computed hash of the manifest_text matches our expectation
- sum := hasher.Sum(nil)
- computedHash := fmt.Sprintf("%x+%v", sum, sz)
- if computedHash != rw.expectHash {
- return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, rw.expectHash)
- }
-
- col.ManifestText = updatedManifest.String()
-
- newbody, err := json.Marshal(col)
- if err != nil {
- return nil, err
- }
-
- buf := bytes.NewBuffer(newbody)
- resp.Body = ioutil.NopCloser(buf)
- resp.ContentLength = int64(buf.Len())
- resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
-
- return resp, nil
-}
-
-func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- if requestError != nil {
- return resp, requestError
- }
-
- if resp.StatusCode == 404 {
- // Suppress returning this result, because we want to
- // search the federation.
- return nil, nil
- }
- return resp, nil
-}
-
-type searchRemoteClusterForPDH struct {
- pdh string
- remoteID string
- mtx *sync.Mutex
- sentResponse *bool
- sharedContext *context.Context
- cancelFunc func()
- errors *[]string
- statusCode *int
-}
-
-func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- s.mtx.Lock()
- defer s.mtx.Unlock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if requestError != nil {
- *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
- // Record the error and suppress response
- return nil, nil
- }
-
- if resp.StatusCode != 200 {
- // Suppress returning unsuccessful result. Maybe
- // another request will find it.
- // TODO collect and return error responses.
- *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
- if resp.StatusCode != 404 {
- // Got a non-404 error response, convert into BadGateway
- *s.statusCode = http.StatusBadGateway
- }
- return nil, nil
- }
-
- s.mtx.Unlock()
-
- // This reads the response body. We don't want to hold the
- // lock while doing this because other remote requests could
- // also have made it to this point, and we don't want a
- // slow response holding the lock to block a faster response
- // that is waiting on the lock.
- newResponse, err = rewriteSignaturesClusterId{s.remoteID, s.pdh}.rewriteSignatures(resp, nil)
-
- s.mtx.Lock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if err != nil {
- // Suppress returning unsuccessful result. Maybe
- // another request will be successful.
- *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
- return nil, nil
- }
-
- // We have a successful response. Suppress/cancel all the
- // other requests/responses.
- *s.sentResponse = true
- s.cancelFunc()
-
- return newResponse, nil
-}
-
-func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- if req.Method != "GET" {
- // Only handle GET requests right now
- h.next.ServeHTTP(w, req)
- return
- }
-
- m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
- if len(m) != 2 {
- // Not a collection PDH GET request
- m = collectionRe.FindStringSubmatch(req.URL.Path)
- clusterId := ""
-
- if len(m) > 0 {
- clusterId = m[2]
- }
-
- if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
- // request for remote collection by uuid
- h.handler.remoteClusterRequest(clusterId, w, req,
- rewriteSignaturesClusterId{clusterId, ""}.rewriteSignatures)
- return
- }
- // not a collection UUID request, or it is a request
- // for a local UUID, either way, continue down the
- // handler stack.
- h.next.ServeHTTP(w, req)
- return
- }
-
- // Request for collection by PDH. Search the federation.
-
- // First, query the local cluster.
- if h.handler.localClusterRequest(w, req, filterLocalClusterResponse) {
- return
- }
-
- sharedContext, cancelFunc := context.WithCancel(req.Context())
- defer cancelFunc()
- req = req.WithContext(sharedContext)
-
- // Create a goroutine for each cluster in the
- // RemoteClusters map. The first valid result gets
- // returned to the client. When that happens, all
- // other outstanding requests are cancelled or
- // suppressed.
- sentResponse := false
- mtx := sync.Mutex{}
- wg := sync.WaitGroup{}
- var errors []string
- var errorCode int = 404
-
- // use channel as a semaphore to limit the number of concurrent
- // requests at a time
- sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
- defer close(sem)
- for remoteID := range h.handler.Cluster.RemoteClusters {
- // blocks until it can put a value into the
- // channel (which has a max queue capacity)
- sem <- true
- if sentResponse {
- break
- }
- search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
- &sharedContext, cancelFunc, &errors, &errorCode}
- wg.Add(1)
- go func() {
- h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse)
- wg.Done()
- <-sem
- }()
- }
- wg.Wait()
-
- if sentResponse {
- return
- }
-
- // No successful responses, so return the error
- httpserver.Errors(w, errors, errorCode)
-}
-
func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
mux := http.NewServeMux()
- mux.Handle("/arvados/v1/workflows", &genericFederatedRequestHandler{next, h, wfRe})
- mux.Handle("/arvados/v1/workflows/", &genericFederatedRequestHandler{next, h, wfRe})
- mux.Handle("/arvados/v1/containers", &genericFederatedRequestHandler{next, h, containersRe})
- mux.Handle("/arvados/v1/containers/", &genericFederatedRequestHandler{next, h, containersRe})
- mux.Handle("/arvados/v1/container_requests", &genericFederatedRequestHandler{next, h, containerRequestsRe})
- mux.Handle("/arvados/v1/container_requests/", &genericFederatedRequestHandler{next, h, containerRequestsRe})
+
+ wfHandler := &genericFederatedRequestHandler{next, h, wfRe, nil}
+ containersHandler := &genericFederatedRequestHandler{next, h, containersRe, nil}
+ containerRequestsHandler := &genericFederatedRequestHandler{next, h, containerRequestsRe,
+ []federatedRequestDelegate{remoteContainerRequestCreate}}
+
+ mux.Handle("/arvados/v1/workflows", wfHandler)
+ mux.Handle("/arvados/v1/workflows/", wfHandler)
+ mux.Handle("/arvados/v1/containers", containersHandler)
+ mux.Handle("/arvados/v1/containers/", containersHandler)
+ mux.Handle("/arvados/v1/container_requests", containerRequestsHandler)
+ mux.Handle("/arvados/v1/container_requests/", containerRequestsHandler)
mux.Handle("/arvados/v1/collections", next)
mux.Handle("/arvados/v1/collections/", &collectionFederatedRequestHandler{next, h})
mux.Handle("/", next)
UUID string
}
-func (h *Handler) validateAPItoken(req *http.Request, user *CurrentUser) error {
+// validateAPItoken extracts the token from the provided http request,
+// checks it again api_client_authorizations table in the database,
+// and fills in the token scope and user UUID. Does not handle remote
+// tokens unless they are already in the database and not expired.
+func (h *Handler) validateAPItoken(req *http.Request, token string) (*CurrentUser, error) {
+ user := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: token}}
db, err := h.db(req)
if err != nil {
- return err
+ return nil, err
+ }
+
+ var uuid string
+ if strings.HasPrefix(token, "v2/") {
+ sp := strings.Split(token, "/")
+ uuid = sp[1]
+ token = sp[2]
+ }
+ user.Authorization.APIToken = token
+ var scopes string
+ err = db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, api_client_authorizations.scopes, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, token).Scan(&user.Authorization.UUID, &scopes, &user.UUID)
+ if err != nil {
+ return nil, err
+ }
+ if uuid != "" && user.Authorization.UUID != uuid {
+ return nil, fmt.Errorf("UUID embedded in v2 token did not match record")
+ }
+ err = json.Unmarshal([]byte(scopes), &user.Authorization.Scopes)
+ if err != nil {
+ return nil, err
+ }
+ return &user, nil
+}
+
+func (h *Handler) createAPItoken(req *http.Request, userUUID string, scopes []string) (*arvados.APIClientAuthorization, error) {
+ db, err := h.db(req)
+ if err != nil {
+ return nil, err
+ }
+ rd, err := randutil.String(15, "abcdefghijklmnopqrstuvwxyz0123456789")
+ if err != nil {
+ return nil, err
+ }
+ uuid := fmt.Sprintf("%v-gj3su-%v", h.Cluster.ClusterID, rd)
+ token, err := randutil.String(50, "abcdefghijklmnopqrstuvwxyz0123456789")
+ if err != nil {
+ return nil, err
+ }
+ if len(scopes) == 0 {
+ scopes = append(scopes, "all")
+ }
+ scopesjson, err := json.Marshal(scopes)
+ if err != nil {
+ return nil, err
+ }
+ _, err = db.ExecContext(req.Context(),
+ `INSERT INTO api_client_authorizations
+(uuid, api_token, expires_at, scopes,
+user_id,
+api_client_id, created_at, updated_at)
+VALUES ($1, $2, CURRENT_TIMESTAMP + INTERVAL '2 weeks', $3,
+(SELECT id FROM users WHERE users.uuid=$4 LIMIT 1),
+0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)`,
+ uuid, token, string(scopesjson), userUUID)
+
+ if err != nil {
+ return nil, err
}
- return db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, user.Authorization.APIToken).Scan(&user.Authorization.UUID, &user.UUID)
+
+ return &arvados.APIClientAuthorization{
+ UUID: uuid,
+ APIToken: token,
+ ExpiresAt: "",
+ Scopes: scopes}, nil
}
// Extract the auth token supplied in req, and replace it with a
// salted token for the remote cluster.
-func (h *Handler) saltAuthToken(req *http.Request, remote string) error {
+func (h *Handler) saltAuthToken(req *http.Request, remote string) (updatedReq *http.Request, err error) {
+ updatedReq = (&http.Request{
+ Method: req.Method,
+ URL: req.URL,
+ Header: req.Header,
+ Body: req.Body,
+ ContentLength: req.ContentLength,
+ Host: req.Host,
+ }).WithContext(req.Context())
+
creds := auth.NewCredentials()
- creds.LoadTokensFromHTTPRequest(req)
- if len(creds.Tokens) == 0 && req.Header.Get("Content-Type") == "application/x-www-form-encoded" {
+ creds.LoadTokensFromHTTPRequest(updatedReq)
+ if len(creds.Tokens) == 0 && updatedReq.Header.Get("Content-Type") == "application/x-www-form-encoded" {
// Override ParseForm's 10MiB limit by ensuring
// req.Body is a *http.maxBytesReader.
- req.Body = http.MaxBytesReader(nil, req.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
- if err := creds.LoadTokensFromHTTPRequestBody(req); err != nil {
- return err
+ updatedReq.Body = http.MaxBytesReader(nil, updatedReq.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
+ if err := creds.LoadTokensFromHTTPRequestBody(updatedReq); err != nil {
+ return nil, err
}
// Replace req.Body with a buffer that re-encodes the
// form without api_token, in case we end up
// forwarding the request.
- if req.PostForm != nil {
- req.PostForm.Del("api_token")
+ if updatedReq.PostForm != nil {
+ updatedReq.PostForm.Del("api_token")
}
- req.Body = ioutil.NopCloser(bytes.NewBufferString(req.PostForm.Encode()))
+ updatedReq.Body = ioutil.NopCloser(bytes.NewBufferString(updatedReq.PostForm.Encode()))
}
if len(creds.Tokens) == 0 {
- return nil
+ return updatedReq, nil
}
+
token, err := auth.SaltToken(creds.Tokens[0], remote)
+
if err == auth.ErrObsoleteToken {
// If the token exists in our own database, salt it
// for the remote. Otherwise, assume it was issued by
// the remote, and pass it through unmodified.
- currentUser := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: creds.Tokens[0]}}
- err = h.validateAPItoken(req, ¤tUser)
+ currentUser, err := h.validateAPItoken(req, creds.Tokens[0])
if err == sql.ErrNoRows {
// Not ours; pass through unmodified.
- token = currentUser.Authorization.APIToken
+ token = creds.Tokens[0]
} else if err != nil {
- return err
+ return nil, err
} else {
// Found; make V2 version and salt it.
token, err = auth.SaltToken(currentUser.Authorization.TokenV2(), remote)
if err != nil {
- return err
+ return nil, err
}
}
} else if err != nil {
- return err
+ return nil, err
+ }
+ updatedReq.Header = http.Header{}
+ for k, v := range req.Header {
+ if k != "Authorization" {
+ updatedReq.Header[k] = v
+ }
}
- req.Header.Set("Authorization", "Bearer "+token)
+ updatedReq.Header.Set("Authorization", "Bearer "+token)
// Remove api_token=... from the the query string, in case we
// end up forwarding the request.
- if values, err := url.ParseQuery(req.URL.RawQuery); err != nil {
- return err
+ if values, err := url.ParseQuery(updatedReq.URL.RawQuery); err != nil {
+ return nil, err
} else if _, ok := values["api_token"]; ok {
delete(values, "api_token")
- req.URL.RawQuery = values.Encode()
+ updatedReq.URL = &url.URL{
+ Scheme: req.URL.Scheme,
+ Host: req.URL.Host,
+ Path: req.URL.Path,
+ RawPath: req.URL.RawPath,
+ RawQuery: values.Encode(),
+ }
}
- return nil
+ return updatedReq, nil
}
package controller
import (
+ "bytes"
"encoding/json"
"fmt"
+ "io"
"io/ioutil"
"net/http"
"net/http/httptest"
}
func (s *FederationSuite) remoteMockHandler(w http.ResponseWriter, req *http.Request) {
+ b := &bytes.Buffer{}
+ io.Copy(b, req.Body)
+ req.Body.Close()
+ req.Body = ioutil.NopCloser(b)
s.remoteMockRequests = append(s.remoteMockRequests, *req)
}
c.Check(strings.HasPrefix(cr.UUID, "zzzzz-"), check.Equals, true)
}
+func (s *FederationSuite) TestCreateRemoteContainerRequestCheckRuntimeToken(c *check.C) {
+ // Send request to zmock and check that outgoing request has
+ // runtime_token sent (because runtime_token isn't returned in
+ // the response).
+
+ defer s.localServiceReturns404(c).Close()
+ // pass cluster_id via query parameter, this allows arvados-controller
+ // to avoid parsing the body
+ req := httptest.NewRequest("POST", "/arvados/v1/container_requests?cluster_id=zmock",
+ strings.NewReader(`{
+ "container_request": {
+ "name": "hello world",
+ "state": "Uncommitted",
+ "output_path": "/",
+ "container_image": "123",
+ "command": ["abc"]
+ }
+}
+`))
+ req.Header.Set("Authorization", "Bearer "+arvadostest.ActiveToken)
+ req.Header.Set("Content-type", "application/json")
+
+ np := arvados.NodeProfile{
+ Controller: arvados.SystemServiceInstance{Listen: ":"},
+ RailsAPI: arvados.SystemServiceInstance{Listen: os.Getenv("ARVADOS_TEST_API_HOST"),
+ TLS: true, Insecure: true}}
+ s.testHandler.Cluster.ClusterID = "zzzzz"
+ s.testHandler.Cluster.NodeProfiles["*"] = np
+ s.testHandler.NodeProfile = &np
+
+ resp := s.testRequest(req)
+ c.Check(resp.StatusCode, check.Equals, http.StatusOK)
+ var cr struct {
+ arvados.ContainerRequest `json:"container_request"`
+ }
+ c.Check(json.NewDecoder(s.remoteMockRequests[0].Body).Decode(&cr), check.IsNil)
+ c.Check(strings.HasPrefix(cr.ContainerRequest.RuntimeToken, "v2/"), check.Equals, true)
+}
+
+func (s *FederationSuite) TestCreateRemoteContainerRequestCheckSetRuntimeToken(c *check.C) {
+ // Send request to zmock and check that outgoing request has
+ // runtime_token sent (because runtime_token isn't returned in
+ // the response).
+
+ defer s.localServiceReturns404(c).Close()
+ // pass cluster_id via query parameter, this allows arvados-controller
+ // to avoid parsing the body
+ req := httptest.NewRequest("POST", "/arvados/v1/container_requests?cluster_id=zmock",
+ strings.NewReader(`{
+ "container_request": {
+ "name": "hello world",
+ "state": "Uncommitted",
+ "output_path": "/",
+ "container_image": "123",
+ "command": ["abc"],
+ "runtime_token": "xyz"
+ }
+}
+`))
+ req.Header.Set("Authorization", "Bearer "+arvadostest.ActiveToken)
+ req.Header.Set("Content-type", "application/json")
+ resp := s.testRequest(req)
+ c.Check(resp.StatusCode, check.Equals, http.StatusOK)
+ var cr struct {
+ arvados.ContainerRequest `json:"container_request"`
+ }
+ c.Check(json.NewDecoder(s.remoteMockRequests[0].Body).Decode(&cr), check.IsNil)
+ c.Check(cr.ContainerRequest.RuntimeToken, check.Equals, "xyz")
+}
+
func (s *FederationSuite) TestCreateRemoteContainerRequestError(c *check.C) {
defer s.localServiceReturns404(c).Close()
// pass cluster_id via query parameter, this allows arvados-controller
package controller
import (
+ "context"
"database/sql"
"errors"
"net"
req.URL.Path = strings.Replace(req.URL.Path, "//", "/", -1)
}
}
+ if h.Cluster.HTTPRequestTimeout > 0 {
+ ctx, cancel := context.WithDeadline(req.Context(), time.Now().Add(time.Duration(h.Cluster.HTTPRequestTimeout)))
+ req = req.WithContext(ctx)
+ defer cancel()
+ }
+
h.handlerStack.ServeHTTP(w, req)
}
h.insecureClient = &ic
h.proxy = &proxy{
- Name: "arvados-controller",
- RequestTimeout: time.Duration(h.Cluster.HTTPRequestTimeout),
+ Name: "arvados-controller",
}
}
})
}
-// localClusterRequest sets up a request so it can be proxied to the
-// local API server using proxy.Do(). Returns true if a response was
-// written, false if not.
-func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request, filter ResponseFilter) bool {
+func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error) {
urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
if err != nil {
- httpserver.Error(w, err.Error(), http.StatusInternalServerError)
- return true
+ return nil, err
}
urlOut = &url.URL{
Scheme: urlOut.Scheme,
if insecure {
client = h.insecureClient
}
- return h.proxy.Do(w, req, urlOut, client, filter)
+ return h.proxy.Do(req, urlOut, client)
}
func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
- if !h.localClusterRequest(w, req, nil) && next != nil {
- next.ServeHTTP(w, req)
+ resp, err := h.localClusterRequest(req)
+ n, err := h.proxy.ForwardResponse(w, resp, err)
+ if err != nil {
+ httpserver.Logger(req).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
}
}
c.Check(resp.Code, check.Equals, http.StatusFound)
c.Check(resp.Header().Get("Location"), check.Matches, `https://0.0.0.0:1/auth/joshid\?return_to=foo&?`)
}
+
+func (s *HandlerSuite) TestValidateV1APIToken(c *check.C) {
+ req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+ user, err := s.handler.(*Handler).validateAPItoken(req, arvadostest.ActiveToken)
+ c.Assert(err, check.IsNil)
+ c.Check(user.Authorization.UUID, check.Equals, arvadostest.ActiveTokenUUID)
+ c.Check(user.Authorization.APIToken, check.Equals, arvadostest.ActiveToken)
+ c.Check(user.Authorization.Scopes, check.DeepEquals, []string{"all"})
+ c.Check(user.UUID, check.Equals, arvadostest.ActiveUserUUID)
+}
+
+func (s *HandlerSuite) TestValidateV2APIToken(c *check.C) {
+ req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+ user, err := s.handler.(*Handler).validateAPItoken(req, arvadostest.ActiveTokenV2)
+ c.Assert(err, check.IsNil)
+ c.Check(user.Authorization.UUID, check.Equals, arvadostest.ActiveTokenUUID)
+ c.Check(user.Authorization.APIToken, check.Equals, arvadostest.ActiveToken)
+ c.Check(user.Authorization.Scopes, check.DeepEquals, []string{"all"})
+ c.Check(user.UUID, check.Equals, arvadostest.ActiveUserUUID)
+ c.Check(user.Authorization.TokenV2(), check.Equals, arvadostest.ActiveTokenV2)
+}
+
+func (s *HandlerSuite) TestCreateAPIToken(c *check.C) {
+ req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+ auth, err := s.handler.(*Handler).createAPItoken(req, arvadostest.ActiveUserUUID, nil)
+ c.Assert(err, check.IsNil)
+ c.Check(auth.Scopes, check.DeepEquals, []string{"all"})
+
+ user, err := s.handler.(*Handler).validateAPItoken(req, auth.TokenV2())
+ c.Assert(err, check.IsNil)
+ c.Check(user.Authorization.UUID, check.Equals, auth.UUID)
+ c.Check(user.Authorization.APIToken, check.Equals, auth.APIToken)
+ c.Check(user.Authorization.Scopes, check.DeepEquals, []string{"all"})
+ c.Check(user.UUID, check.Equals, arvadostest.ActiveUserUUID)
+ c.Check(user.Authorization.TokenV2(), check.Equals, auth.TokenV2())
+}
package controller
import (
- "context"
"io"
"net/http"
"net/url"
- "time"
"git.curoverse.com/arvados.git/sdk/go/httpserver"
)
type proxy struct {
- Name string // to use in Via header
- RequestTimeout time.Duration
+ Name string // to use in Via header
+}
+
+type HTTPError struct {
+ Message string
+ Code int
+}
+
+func (h HTTPError) Error() string {
+ return h.Message
}
// headers that shouldn't be forwarded when proxying. See
type ResponseFilter func(*http.Response, error) (*http.Response, error)
-// Do sends a request, passes the result to the filter (if provided)
-// and then if the result is not suppressed by the filter, sends the
-// request to the ResponseWriter. Returns true if a response was written,
-// false if not.
-func (p *proxy) Do(w http.ResponseWriter,
+// Forward a request to upstream service, and return response or error.
+func (p *proxy) Do(
reqIn *http.Request,
urlOut *url.URL,
- client *http.Client,
- filter ResponseFilter) bool {
+ client *http.Client) (*http.Response, error) {
// Copy headers from incoming request, then add/replace proxy
// headers like Via and X-Forwarded-For.
}
hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
- ctx := reqIn.Context()
- if p.RequestTimeout > 0 {
- var cancel context.CancelFunc
- ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
- defer cancel()
- }
-
reqOut := (&http.Request{
Method: reqIn.Method,
URL: urlOut,
Host: reqIn.Host,
Header: hdrOut,
Body: reqIn.Body,
- }).WithContext(ctx)
+ }).WithContext(reqIn.Context())
resp, err := client.Do(reqOut)
- if filter == nil && err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadGateway)
- return true
- }
-
- // make sure original response body gets closed
- var originalBody io.ReadCloser
- if resp != nil {
- originalBody = resp.Body
- if originalBody != nil {
- defer originalBody.Close()
- }
- }
-
- if filter != nil {
- resp, err = filter(resp, err)
+ return resp, err
+}
- if err != nil {
+// Copy a response (or error) to the downstream client
+func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
+ if err != nil {
+ if he, ok := err.(HTTPError); ok {
+ httpserver.Error(w, he.Message, he.Code)
+ } else {
httpserver.Error(w, err.Error(), http.StatusBadGateway)
- return true
- }
- if resp == nil {
- // filter() returned a nil response, this means suppress
- // writing a response, for the case where there might
- // be multiple response writers.
- return false
- }
-
- // the filter gave us a new response body, make sure that gets closed too.
- if resp.Body != originalBody {
- defer resp.Body.Close()
}
+ return 0, nil
}
+ defer resp.Body.Close()
for k, v := range resp.Header {
for _, v := range v {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
- n, err := io.Copy(w, resp.Body)
- if err != nil {
- httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
- }
- return true
+ return io.Copy(w, resp.Body)
}
// APIClientAuthorization is an arvados#apiClientAuthorization resource.
type APIClientAuthorization struct {
- UUID string `json:"uuid"`
- APIToken string `json:"api_token"`
+ UUID string `json:"uuid,omitempty"`
+ APIToken string `json:"api_token,omitempty"`
+ ExpiresAt string `json:"expires_at,omitempty"`
+ Scopes []string `json:"scopes,omitempty"`
}
// APIClientAuthorizationList is an arvados#apiClientAuthorizationList resource.
UseExisting bool `json:"use_existing"`
LogUUID string `json:"log_uuid"`
OutputUUID string `json:"output_uuid"`
+ RuntimeToken string `json:"runtime_token"`
}
// Mount is special behavior to attach to a filesystem path or device.
const (
SpectatorToken = "zw2f4gwx8hw8cjre7yp6v1zylhrhn3m5gvjq73rtpwhmknrybu"
ActiveToken = "3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi"
+ ActiveTokenUUID = "zzzzz-gj3su-077z32aux8dg2s1"
ActiveTokenV2 = "v2/zzzzz-gj3su-077z32aux8dg2s1/3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi"
AdminToken = "4axaw8zxe0qm22wa6urpp5nskcne8z88cvbupv653y1njyi05h"
AnonymousToken = "4kg6k6lzmp9kj4cpkcoxie964cmvjahbt4fod9zru44k4jqdmi"
"time"
)
+const (
+ HeaderRequestID = "X-Request-Id"
+)
+
// IDGenerator generates alphanumeric strings suitable for use as
// unique IDs (a given IDGenerator will never return the same ID
// twice).
func AddRequestIDs(h http.Handler) http.Handler {
gen := &IDGenerator{Prefix: "req-"}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- if req.Header.Get("X-Request-Id") == "" {
+ if req.Header.Get(HeaderRequestID) == "" {
if req.Header == nil {
req.Header = http.Header{}
}
- req.Header.Set("X-Request-Id", gen.Next())
+ req.Header.Set(HeaderRequestID, gen.Next())
}
h.ServeHTTP(w, req)
})
if loc = Keep::Locator.parse(search_term)
loc.strip_hints!
coll_match = readable_by(*readers).where(portable_data_hash: loc.to_s).limit(1)
- return get_compatible_images(readers, pattern, coll_match)
+ if coll_match.any? or Rails.configuration.remote_hosts.length == 0
+ return get_compatible_images(readers, pattern, coll_match)
+ else
+ # Allow bare pdh that doesn't exist in the local database so
+ # that federated container requests which refer to remotely
+ # stored containers will validate.
+ return [Collection.new(portable_data_hash: loc.to_s)]
+ end
end
if search_tag.nil? and (n = search_term.index(":"))
return false
end
- if current_api_client_authorization.andand.uuid.andand == self.auth_uuid
- # The contained process itself can update progress indicators,
- # but can't change priority etc.
- permitted = permitted & (progress_attrs + final_attrs + [:state] - [:log])
+ if self.state == Running &&
+ !current_api_client_authorization.nil? &&
+ (current_api_client_authorization.uuid == self.auth_uuid ||
+ current_api_client_authorization.token == self.runtime_token)
+ # The contained process itself can write final attrs but can't
+ # change priority or log.
+ permitted.push *final_attrs
+ permitted = permitted - [:log, :priority]
elsif self.locked_by_uuid && self.locked_by_uuid != current_api_client_authorization.andand.uuid
# When locked, progress fields cannot be updated by the wrong
# dispatcher, even though it has admin privileges.
ready.pop
@remote_server = srv
@remote_host = "127.0.0.1:#{srv.config[:Port]}"
- Rails.configuration.remote_hosts['zbbbb'] = @remote_host
- Rails.configuration.remote_hosts['zbork'] = @remote_host
+ Rails.configuration.remote_hosts = Rails.configuration.remote_hosts.merge({'zbbbb' => @remote_host,
+ 'zbork' => @remote_host})
Arvados::V1::SchemaController.any_instance.stubs(:root_url).returns "https://#{@remote_host}"
@stub_status = 200
@stub_content = {
end
end
+ test "allow unrecognized container when there are remote_hosts" do
+ set_user_from_auth :active
+ Rails.configuration.remote_hosts = {"foooo" => "bar.com"}
+ Container.resolve_container_image('acbd18db4cc2f85cedef654fccc4a4d8+3')
+ end
+
test "migrated docker image" do
Rails.configuration.docker_image_formats = ['v2']
add_docker19_migration_link
assert_equal [logpdh_time2], Collection.where(uuid: [cr1log_uuid, cr2log_uuid]).to_a.collect(&:portable_data_hash).uniq
end
- test "auth_uuid can set output, progress, runtime_status, state on running container -- but not log" do
- set_user_from_auth :active
- c, _ = minimal_new
- set_user_from_auth :dispatch1
- c.lock
- c.update_attributes! state: Container::Running
-
- auth = ApiClientAuthorization.find_by_uuid(c.auth_uuid)
- Thread.current[:api_client_authorization] = auth
- Thread.current[:api_client] = auth.api_client
- Thread.current[:token] = auth.token
- Thread.current[:user] = auth.user
+ ["auth_uuid", "runtime_token"].each do |tok|
+ test "#{tok} can set output, progress, runtime_status, state on running container -- but not log" do
+ if tok == "runtime_token"
+ set_user_from_auth :spectator
+ c, _ = minimal_new(container_image: "9ae44d5792468c58bcf85ce7353c7027+124",
+ runtime_token: api_client_authorizations(:active).token)
+ else
+ set_user_from_auth :active
+ c, _ = minimal_new
+ end
+ set_user_from_auth :dispatch1
+ c.lock
+ c.update_attributes! state: Container::Running
+
+ if tok == "runtime_token"
+ auth = ApiClientAuthorization.validate(token: c.runtime_token)
+ Thread.current[:api_client_authorization] = auth
+ Thread.current[:api_client] = auth.api_client
+ Thread.current[:token] = auth.token
+ Thread.current[:user] = auth.user
+ else
+ auth = ApiClientAuthorization.find_by_uuid(c.auth_uuid)
+ Thread.current[:api_client_authorization] = auth
+ Thread.current[:api_client] = auth.api_client
+ Thread.current[:token] = auth.token
+ Thread.current[:user] = auth.user
+ end
- assert c.update_attributes(output: collections(:collection_owned_by_active).portable_data_hash)
- assert c.update_attributes(runtime_status: {'warning' => 'something happened'})
- assert c.update_attributes(progress: 0.5)
- refute c.update_attributes(log: collections(:real_log_collection).portable_data_hash)
- c.reload
- assert c.update_attributes(state: Container::Complete, exit_code: 0)
+ assert c.update_attributes(output: collections(:collection_owned_by_active).portable_data_hash)
+ assert c.update_attributes(runtime_status: {'warning' => 'something happened'})
+ assert c.update_attributes(progress: 0.5)
+ refute c.update_attributes(log: collections(:real_log_collection).portable_data_hash)
+ c.reload
+ assert c.update_attributes(state: Container::Complete, exit_code: 0)
+ end
end
test "not allowed to set output that is not readable by current user" do
'locator' => BAD_COLLECTION,
}.each_pair do |spec_type, image_spec|
test "Job validation fails with nonexistent Docker image #{spec_type}" do
+ Rails.configuration.remote_hosts = {}
job = Job.new job_attrs(runtime_constraints:
{'docker_image' => image_spec})
assert(job.invalid?, "nonexistent Docker image #{spec_type} was valid")
SigChan chan os.Signal
ArvMountExit chan error
SecretMounts map[string]arvados.Mount
- MkArvClient func(token string) (IArvadosClient, error)
+ MkArvClient func(token string) (IArvadosClient, IKeepClient, error)
finalState string
parentTemp string
runner.CrunchLog.Printf("Fetching Docker image from collection '%s'", runner.Container.ContainerImage)
+ tok, err := runner.ContainerToken()
+ if err != nil {
+ return fmt.Errorf("While getting container token (LoadImage): %v", err)
+ }
+ arvClient, kc, err := runner.MkArvClient(tok)
+ if err != nil {
+ return fmt.Errorf("While creating arv client (LoadImage): %v", err)
+ }
+
var collection arvados.Collection
- err = runner.ArvClient.Get("collections", runner.Container.ContainerImage, nil, &collection)
+ err = arvClient.Get("collections", runner.Container.ContainerImage, nil, &collection)
if err != nil {
return fmt.Errorf("While getting container image collection: %v", err)
}
runner.CrunchLog.Print("Loading Docker image from keep")
var readCloser io.ReadCloser
- readCloser, err = runner.Kc.ManifestFileReader(manifest, img)
+ readCloser, err = kc.ManifestFileReader(manifest, img)
if err != nil {
return fmt.Errorf("While creating ManifestFileReader for container image: %v", err)
}
runner.ContainerConfig.Image = imageID
- runner.Kc.ClearBlockCache()
+ kc.ClearBlockCache()
return nil
}
return fmt.Errorf("error getting container token: %v", err)
}
- containerClient, err := runner.MkArvClient(containerToken)
+ containerClient, _, err := runner.MkArvClient(containerToken)
if err != nil {
return fmt.Errorf("error creating container API client: %v", err)
}
}
return ps, nil
}
- cr.MkArvClient = func(token string) (IArvadosClient, error) {
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
cl, err := arvadosclient.MakeArvadosClient()
if err != nil {
- return nil, err
+ return nil, nil, err
}
cl.ApiToken = token
- return cl, nil
+ kc, err := keepclient.MakeKeepClient(cl)
+ if err != nil {
+ return nil, nil, err
+ }
+ return cl, kc, nil
}
var err error
cr.LogCollection, err = (&arvados.Collection{}).FileSystem(cr.client, cr.Kc)
cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
c.Assert(err, IsNil)
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvTestClient{}, kc, nil
+ }
+
_, err = cr.Docker.ImageRemove(nil, hwImageId, dockertypes.ImageRemoveOptions{})
c.Check(err, IsNil)
}
func (ArvErrorTestClient) Call(method, resourceType, uuid, action string, parameters arvadosclient.Dict, output interface{}) error {
+ if method == "GET" && resourceType == "containers" && action == "auth" {
+ return nil
+ }
return errors.New("ArvError")
}
// (1) Arvados error
kc := &KeepTestClient{}
defer kc.Close()
- cr, err := NewContainerRunner(s.client, ArvErrorTestClient{}, kc, nil, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
+ cr, err := NewContainerRunner(s.client, &ArvErrorTestClient{}, kc, nil, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
c.Assert(err, IsNil)
+
cr.Container.ContainerImage = hwPDH
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvErrorTestClient{}, &KeepTestClient{}, nil
+ }
err = cr.LoadImage()
c.Check(err.Error(), Equals, "While getting container image collection: ArvError")
func (s *TestSuite) TestLoadImageKeepError(c *C) {
// (2) Keep error
- cr, err := NewContainerRunner(s.client, &ArvTestClient{}, &KeepErrorTestClient{}, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
+ kc := &KeepErrorTestClient{}
+ cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
c.Assert(err, IsNil)
cr.Container.ContainerImage = hwPDH
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvTestClient{}, kc, nil
+ }
err = cr.LoadImage()
c.Assert(err, NotNil)
func (s *TestSuite) TestLoadImageCollectionError(c *C) {
// (3) Collection doesn't contain image
- cr, err := NewContainerRunner(s.client, &ArvTestClient{}, &KeepReadErrorTestClient{}, nil, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
+ kc := &KeepReadErrorTestClient{}
+ cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, nil, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
c.Assert(err, IsNil)
cr.Container.ContainerImage = otherPDH
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvTestClient{}, kc, nil
+ }
err = cr.LoadImage()
c.Check(err.Error(), Equals, "First file in the container image collection does not end in .tar")
func (s *TestSuite) TestLoadImageKeepReadError(c *C) {
// (4) Collection doesn't contain image
- cr, err := NewContainerRunner(s.client, &ArvTestClient{}, &KeepReadErrorTestClient{}, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
+ kc := &KeepReadErrorTestClient{}
+ cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
c.Assert(err, IsNil)
cr.Container.ContainerImage = hwPDH
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvTestClient{}, kc, nil
+ }
err = cr.LoadImage()
c.Check(err, NotNil)
cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
c.Assert(err, IsNil)
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvTestClient{}, kc, nil
+ }
+
var logs TestLogs
cr.NewLogWriter = logs.NewTestLoggingWriter
cr.Container.ContainerImage = hwPDH
}
return d, err
}
- cr.MkArvClient = func(token string) (IArvadosClient, error) {
- return &ArvTestClient{secretMounts: secretMounts}, nil
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvTestClient{secretMounts: secretMounts}, &KeepTestClient{}, nil
}
if extraMounts != nil && len(extraMounts) > 0 {
cr, err := NewContainerRunner(s.client, api, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
c.Assert(err, IsNil)
cr.RunArvMount = func([]string, string) (*exec.Cmd, error) { return nil, nil }
- cr.MkArvClient = func(token string) (IArvadosClient, error) {
- return &ArvTestClient{}, nil
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvTestClient{}, &KeepTestClient{}, nil
}
setup(cr)
c.Assert(err, IsNil)
am := &ArvMountCmdLine{}
cr.RunArvMount = am.ArvMountTest
- cr.MkArvClient = func(token string) (IArvadosClient, error) {
- return &ArvTestClient{}, nil
+ cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+ return &ArvTestClient{}, &KeepTestClient{}, nil
}
err = cr.Run()
"revision": "d14ea06fba99483203c19d92cfcd13ebe73135f4",
"revisionTime": "2015-07-11T00:45:18Z"
},
+ {
+ "checksumSHA1": "khL6oKjx81rAZKW+36050b7f5As=",
+ "path": "github.com/jmcvetta/randutil",
+ "revision": "2bb1b664bcff821e02b2a0644cd29c7e824d54f8",
+ "revisionTime": "2015-08-17T12:26:01Z"
+ },
{
"checksumSHA1": "oX6jFQD74oOApvDIhOzW2dXpg5Q=",
"path": "github.com/kevinburke/ssh_config",