14406: Merge branch 'master'
authorTom Clegg <tclegg@veritasgenetics.com>
Tue, 6 Nov 2018 18:58:53 +0000 (13:58 -0500)
committerTom Clegg <tclegg@veritasgenetics.com>
Tue, 6 Nov 2018 18:58:53 +0000 (13:58 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tclegg@veritasgenetics.com>

21 files changed:
lib/controller/fed_collections.go [new file with mode: 0644]
lib/controller/fed_containers.go [new file with mode: 0644]
lib/controller/fed_generic.go [new file with mode: 0644]
lib/controller/federation.go
lib/controller/federation_test.go
lib/controller/handler.go
lib/controller/handler_test.go
lib/controller/proxy.go
sdk/go/arvados/api_client_authorization.go
sdk/go/arvados/container.go
sdk/go/arvadostest/fixtures.go
sdk/go/httpserver/id_generator.go
services/api/app/models/collection.rb
services/api/app/models/container.rb
services/api/test/integration/remote_user_test.rb
services/api/test/unit/container_request_test.rb
services/api/test/unit/container_test.rb
services/api/test/unit/job_test.rb
services/crunch-run/crunchrun.go
services/crunch-run/crunchrun_test.go
vendor/vendor.json

diff --git a/lib/controller/fed_collections.go b/lib/controller/fed_collections.go
new file mode 100644 (file)
index 0000000..b9cd205
--- /dev/null
@@ -0,0 +1,300 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+       "bufio"
+       "bytes"
+       "context"
+       "crypto/md5"
+       "encoding/json"
+       "fmt"
+       "io"
+       "io/ioutil"
+       "net/http"
+       "strings"
+       "sync"
+
+       "git.curoverse.com/arvados.git/sdk/go/arvados"
+       "git.curoverse.com/arvados.git/sdk/go/httpserver"
+       "git.curoverse.com/arvados.git/sdk/go/keepclient"
+)
+
+type collectionFederatedRequestHandler struct {
+       next    http.Handler
+       handler *Handler
+}
+
+func rewriteSignatures(clusterID string, expectHash string,
+       resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+
+       if requestError != nil {
+               return resp, requestError
+       }
+
+       if resp.StatusCode != http.StatusOK {
+               return resp, nil
+       }
+
+       originalBody := resp.Body
+       defer originalBody.Close()
+
+       var col arvados.Collection
+       err = json.NewDecoder(resp.Body).Decode(&col)
+       if err != nil {
+               return nil, err
+       }
+
+       // rewriting signatures will make manifest text 5-10% bigger so calculate
+       // capacity accordingly
+       updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
+
+       hasher := md5.New()
+       mw := io.MultiWriter(hasher, updatedManifest)
+       sz := 0
+
+       scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
+       scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
+       for scanner.Scan() {
+               line := scanner.Text()
+               tokens := strings.Split(line, " ")
+               if len(tokens) < 3 {
+                       return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
+               }
+
+               n, err := mw.Write([]byte(tokens[0]))
+               if err != nil {
+                       return nil, fmt.Errorf("Error updating manifest: %v", err)
+               }
+               sz += n
+               for _, token := range tokens[1:] {
+                       n, err = mw.Write([]byte(" "))
+                       if err != nil {
+                               return nil, fmt.Errorf("Error updating manifest: %v", err)
+                       }
+                       sz += n
+
+                       m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
+                       if m != nil {
+                               // Rewrite the block signature to be a remote signature
+                               _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
+                               if err != nil {
+                                       return nil, fmt.Errorf("Error updating manifest: %v", err)
+                               }
+
+                               // for hash checking, ignore signatures
+                               n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
+                               if err != nil {
+                                       return nil, fmt.Errorf("Error updating manifest: %v", err)
+                               }
+                               sz += n
+                       } else {
+                               n, err = mw.Write([]byte(token))
+                               if err != nil {
+                                       return nil, fmt.Errorf("Error updating manifest: %v", err)
+                               }
+                               sz += n
+                       }
+               }
+               n, err = mw.Write([]byte("\n"))
+               if err != nil {
+                       return nil, fmt.Errorf("Error updating manifest: %v", err)
+               }
+               sz += n
+       }
+
+       // Check that expected hash is consistent with
+       // portable_data_hash field of the returned record
+       if expectHash == "" {
+               expectHash = col.PortableDataHash
+       } else if expectHash != col.PortableDataHash {
+               return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
+       }
+
+       // Certify that the computed hash of the manifest_text matches our expectation
+       sum := hasher.Sum(nil)
+       computedHash := fmt.Sprintf("%x+%v", sum, sz)
+       if computedHash != expectHash {
+               return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
+       }
+
+       col.ManifestText = updatedManifest.String()
+
+       newbody, err := json.Marshal(col)
+       if err != nil {
+               return nil, err
+       }
+
+       buf := bytes.NewBuffer(newbody)
+       resp.Body = ioutil.NopCloser(buf)
+       resp.ContentLength = int64(buf.Len())
+       resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
+
+       return resp, nil
+}
+
+func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+       if requestError != nil {
+               return resp, requestError
+       }
+
+       if resp.StatusCode == http.StatusNotFound {
+               // Suppress returning this result, because we want to
+               // search the federation.
+               return nil, nil
+       }
+       return resp, nil
+}
+
+type searchRemoteClusterForPDH struct {
+       pdh           string
+       remoteID      string
+       mtx           *sync.Mutex
+       sentResponse  *bool
+       sharedContext *context.Context
+       cancelFunc    func()
+       errors        *[]string
+       statusCode    *int
+}
+
+func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+       if req.Method != "GET" {
+               // Only handle GET requests right now
+               h.next.ServeHTTP(w, req)
+               return
+       }
+
+       m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
+       if len(m) != 2 {
+               // Not a collection PDH GET request
+               m = collectionRe.FindStringSubmatch(req.URL.Path)
+               clusterId := ""
+
+               if len(m) > 0 {
+                       clusterId = m[2]
+               }
+
+               if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
+                       // request for remote collection by uuid
+                       resp, err := h.handler.remoteClusterRequest(clusterId, req)
+                       newResponse, err := rewriteSignatures(clusterId, "", resp, err)
+                       h.handler.proxy.ForwardResponse(w, newResponse, err)
+                       return
+               }
+               // not a collection UUID request, or it is a request
+               // for a local UUID, either way, continue down the
+               // handler stack.
+               h.next.ServeHTTP(w, req)
+               return
+       }
+
+       // Request for collection by PDH.  Search the federation.
+
+       // First, query the local cluster.
+       resp, err := h.handler.localClusterRequest(req)
+       newResp, err := filterLocalClusterResponse(resp, err)
+       if newResp != nil || err != nil {
+               h.handler.proxy.ForwardResponse(w, newResp, err)
+               return
+       }
+
+       // Create a goroutine for each cluster in the
+       // RemoteClusters map.  The first valid result gets
+       // returned to the client.  When that happens, all
+       // other outstanding requests are cancelled
+       sharedContext, cancelFunc := context.WithCancel(req.Context())
+       req = req.WithContext(sharedContext)
+       wg := sync.WaitGroup{}
+       pdh := m[1]
+       success := make(chan *http.Response)
+       errorChan := make(chan error)
+
+       // use channel as a semaphore to limit the number of concurrent
+       // requests at a time
+       sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+
+       defer close(errorChan)
+       defer close(success)
+       defer close(sem)
+       defer cancelFunc()
+
+       for remoteID := range h.handler.Cluster.RemoteClusters {
+               if remoteID == h.handler.Cluster.ClusterID {
+                       // No need to query local cluster again
+                       continue
+               }
+
+               wg.Add(1)
+               go func(remote string) {
+                       defer wg.Done()
+                       // blocks until it can put a value into the
+                       // channel (which has a max queue capacity)
+                       sem <- true
+                       select {
+                       case <-sharedContext.Done():
+                               return
+                       default:
+                       }
+
+                       resp, err := h.handler.remoteClusterRequest(remote, req)
+                       wasSuccess := false
+                       defer func() {
+                               if resp != nil && !wasSuccess {
+                                       resp.Body.Close()
+                               }
+                       }()
+                       if err != nil {
+                               errorChan <- err
+                               return
+                       }
+                       if resp.StatusCode != http.StatusOK {
+                               errorChan <- HTTPError{resp.Status, resp.StatusCode}
+                               return
+                       }
+                       select {
+                       case <-sharedContext.Done():
+                               return
+                       default:
+                       }
+
+                       newResponse, err := rewriteSignatures(remote, pdh, resp, nil)
+                       if err != nil {
+                               errorChan <- err
+                               return
+                       }
+                       select {
+                       case <-sharedContext.Done():
+                       case success <- newResponse:
+                               wasSuccess = true
+                       }
+                       <-sem
+               }(remoteID)
+       }
+       go func() {
+               wg.Wait()
+               cancelFunc()
+       }()
+
+       var errors []string
+       errorCode := http.StatusNotFound
+
+       for {
+               select {
+               case newResp = <-success:
+                       h.handler.proxy.ForwardResponse(w, newResp, nil)
+                       return
+               case err := <-errorChan:
+                       if httperr, ok := err.(HTTPError); ok {
+                               if httperr.Code != http.StatusNotFound {
+                                       errorCode = http.StatusBadGateway
+                               }
+                       }
+                       errors = append(errors, err.Error())
+               case <-sharedContext.Done():
+                       httpserver.Errors(w, errors, errorCode)
+                       return
+               }
+       }
+}
diff --git a/lib/controller/fed_containers.go b/lib/controller/fed_containers.go
new file mode 100644 (file)
index 0000000..5c5501d
--- /dev/null
@@ -0,0 +1,102 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+       "bytes"
+       "encoding/json"
+       "fmt"
+       "io/ioutil"
+       "net/http"
+       "strings"
+
+       "git.curoverse.com/arvados.git/sdk/go/auth"
+       "git.curoverse.com/arvados.git/sdk/go/httpserver"
+)
+
+func remoteContainerRequestCreate(
+       h *genericFederatedRequestHandler,
+       effectiveMethod string,
+       clusterId *string,
+       uuid string,
+       remainder string,
+       w http.ResponseWriter,
+       req *http.Request) bool {
+
+       if effectiveMethod != "POST" || uuid != "" || remainder != "" ||
+               *clusterId == "" || *clusterId == h.handler.Cluster.ClusterID {
+               return false
+       }
+
+       if req.Header.Get("Content-Type") != "application/json" {
+               httpserver.Error(w, "Expected Content-Type: application/json, got "+req.Header.Get("Content-Type"), http.StatusBadRequest)
+               return true
+       }
+
+       originalBody := req.Body
+       defer originalBody.Close()
+       var request map[string]interface{}
+       err := json.NewDecoder(req.Body).Decode(&request)
+       if err != nil {
+               httpserver.Error(w, err.Error(), http.StatusBadRequest)
+               return true
+       }
+
+       crString, ok := request["container_request"].(string)
+       if ok {
+               var crJson map[string]interface{}
+               err := json.Unmarshal([]byte(crString), &crJson)
+               if err != nil {
+                       httpserver.Error(w, err.Error(), http.StatusBadRequest)
+                       return true
+               }
+
+               request["container_request"] = crJson
+       }
+
+       containerRequest, ok := request["container_request"].(map[string]interface{})
+       if !ok {
+               // Use toplevel object as the container_request object
+               containerRequest = request
+       }
+
+       // If runtime_token is not set, create a new token
+       if _, ok := containerRequest["runtime_token"]; !ok {
+               // First make sure supplied token is valid.
+               creds := auth.NewCredentials()
+               creds.LoadTokensFromHTTPRequest(req)
+
+               currentUser, err := h.handler.validateAPItoken(req, creds.Tokens[0])
+               if err != nil {
+                       httpserver.Error(w, err.Error(), http.StatusForbidden)
+                       return true
+               }
+
+               if len(currentUser.Authorization.Scopes) != 1 || currentUser.Authorization.Scopes[0] != "all" {
+                       httpserver.Error(w, "Token scope is not [all]", http.StatusForbidden)
+                       return true
+               }
+
+               // Must be home cluster for this authorization
+               if strings.HasPrefix(currentUser.Authorization.UUID, h.handler.Cluster.ClusterID) {
+                       newtok, err := h.handler.createAPItoken(req, currentUser.UUID, nil)
+                       if err != nil {
+                               httpserver.Error(w, err.Error(), http.StatusForbidden)
+                               return true
+                       }
+                       containerRequest["runtime_token"] = newtok.TokenV2()
+               }
+       }
+
+       newbody, err := json.Marshal(request)
+       buf := bytes.NewBuffer(newbody)
+       req.Body = ioutil.NopCloser(buf)
+       req.ContentLength = int64(buf.Len())
+       req.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
+
+       resp, err := h.handler.remoteClusterRequest(*clusterId, req)
+       h.handler.proxy.ForwardResponse(w, resp, err)
+       return true
+}
diff --git a/lib/controller/fed_generic.go b/lib/controller/fed_generic.go
new file mode 100644 (file)
index 0000000..6c8135c
--- /dev/null
@@ -0,0 +1,347 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+       "bytes"
+       "encoding/json"
+       "fmt"
+       "io/ioutil"
+       "net/http"
+       "net/url"
+       "regexp"
+       "sync"
+
+       "git.curoverse.com/arvados.git/sdk/go/httpserver"
+)
+
+type federatedRequestDelegate func(
+       h *genericFederatedRequestHandler,
+       effectiveMethod string,
+       clusterId *string,
+       uuid string,
+       remainder string,
+       w http.ResponseWriter,
+       req *http.Request) bool
+
+type genericFederatedRequestHandler struct {
+       next      http.Handler
+       handler   *Handler
+       matcher   *regexp.Regexp
+       delegates []federatedRequestDelegate
+}
+
+func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
+       req *http.Request,
+       clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
+
+       found := make(map[string]bool)
+       prev_len_uuids := len(uuids) + 1
+       // Loop while
+       // (1) there are more uuids to query
+       // (2) we're making progress - on each iteration the set of
+       // uuids we are expecting for must shrink.
+       for len(uuids) > 0 && len(uuids) < prev_len_uuids {
+               var remoteReq http.Request
+               remoteReq.Header = req.Header
+               remoteReq.Method = "POST"
+               remoteReq.URL = &url.URL{Path: req.URL.Path}
+               remoteParams := make(url.Values)
+               remoteParams.Set("_method", "GET")
+               remoteParams.Set("count", "none")
+               if req.Form.Get("select") != "" {
+                       remoteParams.Set("select", req.Form.Get("select"))
+               }
+               content, err := json.Marshal(uuids)
+               if err != nil {
+                       return nil, "", err
+               }
+               remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
+               enc := remoteParams.Encode()
+               remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
+
+               rc := multiClusterQueryResponseCollector{clusterID: clusterID}
+
+               var resp *http.Response
+               if clusterID == h.handler.Cluster.ClusterID {
+                       resp, err = h.handler.localClusterRequest(&remoteReq)
+               } else {
+                       resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
+               }
+               rc.collectResponse(resp, err)
+
+               if rc.error != nil {
+                       return nil, "", rc.error
+               }
+
+               kind = rc.kind
+
+               if len(rc.responses) == 0 {
+                       // We got zero responses, no point in doing
+                       // another query.
+                       return rp, kind, nil
+               }
+
+               rp = append(rp, rc.responses...)
+
+               // Go through the responses and determine what was
+               // returned.  If there are remaining items, loop
+               // around and do another request with just the
+               // stragglers.
+               for _, i := range rc.responses {
+                       uuid, ok := i["uuid"].(string)
+                       if ok {
+                               found[uuid] = true
+                       }
+               }
+
+               l := []string{}
+               for _, u := range uuids {
+                       if !found[u] {
+                               l = append(l, u)
+                       }
+               }
+               prev_len_uuids = len(uuids)
+               uuids = l
+       }
+
+       return rp, kind, nil
+}
+
+func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
+       req *http.Request, clusterId *string) bool {
+
+       var filters [][]interface{}
+       err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
+       if err != nil {
+               httpserver.Error(w, err.Error(), http.StatusBadRequest)
+               return true
+       }
+
+       // Split the list of uuids by prefix
+       queryClusters := make(map[string][]string)
+       expectCount := 0
+       for _, filter := range filters {
+               if len(filter) != 3 {
+                       return false
+               }
+
+               if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
+                       return false
+               }
+
+               op, ok := filter[1].(string)
+               if !ok {
+                       return false
+               }
+
+               if op == "in" {
+                       if rhs, ok := filter[2].([]interface{}); ok {
+                               for _, i := range rhs {
+                                       if u, ok := i.(string); ok && len(u) == 27 {
+                                               *clusterId = u[0:5]
+                                               queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
+                                               expectCount += 1
+                                       }
+                               }
+                       }
+               } else if op == "=" {
+                       if u, ok := filter[2].(string); ok && len(u) == 27 {
+                               *clusterId = u[0:5]
+                               queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
+                               expectCount += 1
+                       }
+               } else {
+                       return false
+               }
+
+       }
+
+       if len(queryClusters) <= 1 {
+               // Query does not search for uuids across multiple
+               // clusters.
+               return false
+       }
+
+       // Validations
+       count := req.Form.Get("count")
+       if count != "" && count != `none` && count != `"none"` {
+               httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
+               return true
+       }
+       if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
+               httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
+               return true
+       }
+       if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
+               httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
+                       expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
+               return true
+       }
+       if req.Form.Get("select") != "" {
+               foundUUID := false
+               var selects []string
+               err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
+               if err != nil {
+                       httpserver.Error(w, err.Error(), http.StatusBadRequest)
+                       return true
+               }
+
+               for _, r := range selects {
+                       if r == "uuid" {
+                               foundUUID = true
+                               break
+                       }
+               }
+               if !foundUUID {
+                       httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
+                       return true
+               }
+       }
+
+       // Perform concurrent requests to each cluster
+
+       // use channel as a semaphore to limit the number of concurrent
+       // requests at a time
+       sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+       defer close(sem)
+       wg := sync.WaitGroup{}
+
+       req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+       mtx := sync.Mutex{}
+       errors := []error{}
+       var completeResponses []map[string]interface{}
+       var kind string
+
+       for k, v := range queryClusters {
+               if len(v) == 0 {
+                       // Nothing to query
+                       continue
+               }
+
+               // blocks until it can put a value into the
+               // channel (which has a max queue capacity)
+               sem <- true
+               wg.Add(1)
+               go func(k string, v []string) {
+                       rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
+                       mtx.Lock()
+                       if err == nil {
+                               completeResponses = append(completeResponses, rp...)
+                               kind = kn
+                       } else {
+                               errors = append(errors, err)
+                       }
+                       mtx.Unlock()
+                       wg.Done()
+                       <-sem
+               }(k, v)
+       }
+       wg.Wait()
+
+       if len(errors) > 0 {
+               var strerr []string
+               for _, e := range errors {
+                       strerr = append(strerr, e.Error())
+               }
+               httpserver.Errors(w, strerr, http.StatusBadGateway)
+               return true
+       }
+
+       w.Header().Set("Content-Type", "application/json")
+       w.WriteHeader(http.StatusOK)
+       itemList := make(map[string]interface{})
+       itemList["items"] = completeResponses
+       itemList["kind"] = kind
+       json.NewEncoder(w).Encode(itemList)
+
+       return true
+}
+
+func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+       m := h.matcher.FindStringSubmatch(req.URL.Path)
+       clusterId := ""
+
+       if len(m) > 0 && m[2] != "" {
+               clusterId = m[2]
+       }
+
+       // Get form parameters from URL and form body (if POST).
+       if err := loadParamsFromForm(req); err != nil {
+               httpserver.Error(w, err.Error(), http.StatusBadRequest)
+               return
+       }
+
+       // Check if the parameters have an explicit cluster_id
+       if req.Form.Get("cluster_id") != "" {
+               clusterId = req.Form.Get("cluster_id")
+       }
+
+       // Handle the POST-as-GET special case (workaround for large
+       // GET requests that potentially exceed maximum URL length,
+       // like multi-object queries where the filter has 100s of
+       // items)
+       effectiveMethod := req.Method
+       if req.Method == "POST" && req.Form.Get("_method") != "" {
+               effectiveMethod = req.Form.Get("_method")
+       }
+
+       if effectiveMethod == "GET" &&
+               clusterId == "" &&
+               req.Form.Get("filters") != "" &&
+               h.handleMultiClusterQuery(w, req, &clusterId) {
+               return
+       }
+
+       for _, d := range h.delegates {
+               if d(h, effectiveMethod, &clusterId, m[1], m[3], w, req) {
+                       return
+               }
+       }
+
+       if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
+               h.next.ServeHTTP(w, req)
+       } else {
+               resp, err := h.handler.remoteClusterRequest(clusterId, req)
+               h.handler.proxy.ForwardResponse(w, resp, err)
+       }
+}
+
+type multiClusterQueryResponseCollector struct {
+       responses []map[string]interface{}
+       error     error
+       kind      string
+       clusterID string
+}
+
+func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
+       requestError error) (newResponse *http.Response, err error) {
+       if requestError != nil {
+               c.error = requestError
+               return nil, nil
+       }
+
+       defer resp.Body.Close()
+       var loadInto struct {
+               Kind   string                   `json:"kind"`
+               Items  []map[string]interface{} `json:"items"`
+               Errors []string                 `json:"errors"`
+       }
+       err = json.NewDecoder(resp.Body).Decode(&loadInto)
+
+       if err != nil {
+               c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
+               return nil, nil
+       }
+       if resp.StatusCode != http.StatusOK {
+               c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
+               return nil, nil
+       }
+
+       c.responses = loadInto.Items
+       c.kind = loadInto.Kind
+
+       return nil, nil
+}
index 5c6f6bf7ab9d503c395701688555359a9e925e6b..e08a1c16742a6d5ea9b251d2906b24f6d5b00e61 100644 (file)
@@ -5,10 +5,7 @@
 package controller
 
 import (
-       "bufio"
        "bytes"
-       "context"
-       "crypto/md5"
        "database/sql"
        "encoding/json"
        "fmt"
@@ -18,12 +15,10 @@ import (
        "net/url"
        "regexp"
        "strings"
-       "sync"
 
        "git.curoverse.com/arvados.git/sdk/go/arvados"
        "git.curoverse.com/arvados.git/sdk/go/auth"
-       "git.curoverse.com/arvados.git/sdk/go/httpserver"
-       "git.curoverse.com/arvados.git/sdk/go/keepclient"
+       "github.com/jmcvetta/randutil"
 )
 
 var pathPattern = `^/arvados/v1/%s(/([0-9a-z]{5})-%s-[0-9a-z]{15})?(.*)$`
@@ -33,44 +28,31 @@ var containerRequestsRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "container
 var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
 var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
 
-type genericFederatedRequestHandler struct {
-       next    http.Handler
-       handler *Handler
-       matcher *regexp.Regexp
-}
-
-type collectionFederatedRequestHandler struct {
-       next    http.Handler
-       handler *Handler
-}
-
-func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
+func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
        remote, ok := h.Cluster.RemoteClusters[remoteID]
        if !ok {
-               httpserver.Error(w, "no proxy available for cluster "+remoteID, http.StatusNotFound)
-               return
+               return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
        }
        scheme := remote.Scheme
        if scheme == "" {
                scheme = "https"
        }
-       err := h.saltAuthToken(req, remoteID)
+       saltedReq, err := h.saltAuthToken(req, remoteID)
        if err != nil {
-               httpserver.Error(w, err.Error(), http.StatusBadRequest)
-               return
+               return nil, err
        }
        urlOut := &url.URL{
                Scheme:   scheme,
                Host:     remote.Host,
-               Path:     req.URL.Path,
-               RawPath:  req.URL.RawPath,
-               RawQuery: req.URL.RawQuery,
+               Path:     saltedReq.URL.Path,
+               RawPath:  saltedReq.URL.RawPath,
+               RawQuery: saltedReq.URL.RawQuery,
        }
        client := h.secureClient
        if remote.Insecure {
                client = h.insecureClient
        }
-       h.proxy.Do(w, req, urlOut, client, filter)
+       return h.proxy.Do(saltedReq, urlOut, client)
 }
 
 // Buffer request body, parse form parameters in request, and then
@@ -100,594 +82,20 @@ func loadParamsFromForm(req *http.Request) error {
        return nil
 }
 
-type multiClusterQueryResponseCollector struct {
-       responses []map[string]interface{}
-       error     error
-       kind      string
-       clusterID string
-}
-
-func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
-       requestError error) (newResponse *http.Response, err error) {
-       if requestError != nil {
-               c.error = requestError
-               return nil, nil
-       }
-
-       defer resp.Body.Close()
-       var loadInto struct {
-               Kind   string                   `json:"kind"`
-               Items  []map[string]interface{} `json:"items"`
-               Errors []string                 `json:"errors"`
-       }
-       err = json.NewDecoder(resp.Body).Decode(&loadInto)
-
-       if err != nil {
-               c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
-               return nil, nil
-       }
-       if resp.StatusCode != http.StatusOK {
-               c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
-               return nil, nil
-       }
-
-       c.responses = loadInto.Items
-       c.kind = loadInto.Kind
-
-       return nil, nil
-}
-
-func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
-       req *http.Request,
-       clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
-
-       found := make(map[string]bool)
-       prev_len_uuids := len(uuids) + 1
-       // Loop while
-       // (1) there are more uuids to query
-       // (2) we're making progress - on each iteration the set of
-       // uuids we are expecting for must shrink.
-       for len(uuids) > 0 && len(uuids) < prev_len_uuids {
-               var remoteReq http.Request
-               remoteReq.Header = req.Header
-               remoteReq.Method = "POST"
-               remoteReq.URL = &url.URL{Path: req.URL.Path}
-               remoteParams := make(url.Values)
-               remoteParams.Set("_method", "GET")
-               remoteParams.Set("count", "none")
-               if req.Form.Get("select") != "" {
-                       remoteParams.Set("select", req.Form.Get("select"))
-               }
-               content, err := json.Marshal(uuids)
-               if err != nil {
-                       return nil, "", err
-               }
-               remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
-               enc := remoteParams.Encode()
-               remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
-
-               rc := multiClusterQueryResponseCollector{clusterID: clusterID}
-
-               if clusterID == h.handler.Cluster.ClusterID {
-                       h.handler.localClusterRequest(w, &remoteReq,
-                               rc.collectResponse)
-               } else {
-                       h.handler.remoteClusterRequest(clusterID, w, &remoteReq,
-                               rc.collectResponse)
-               }
-               if rc.error != nil {
-                       return nil, "", rc.error
-               }
-
-               kind = rc.kind
-
-               if len(rc.responses) == 0 {
-                       // We got zero responses, no point in doing
-                       // another query.
-                       return rp, kind, nil
-               }
-
-               rp = append(rp, rc.responses...)
-
-               // Go through the responses and determine what was
-               // returned.  If there are remaining items, loop
-               // around and do another request with just the
-               // stragglers.
-               for _, i := range rc.responses {
-                       uuid, ok := i["uuid"].(string)
-                       if ok {
-                               found[uuid] = true
-                       }
-               }
-
-               l := []string{}
-               for _, u := range uuids {
-                       if !found[u] {
-                               l = append(l, u)
-                       }
-               }
-               prev_len_uuids = len(uuids)
-               uuids = l
-       }
-
-       return rp, kind, nil
-}
-
-func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
-       req *http.Request, clusterId *string) bool {
-
-       var filters [][]interface{}
-       err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
-       if err != nil {
-               httpserver.Error(w, err.Error(), http.StatusBadRequest)
-               return true
-       }
-
-       // Split the list of uuids by prefix
-       queryClusters := make(map[string][]string)
-       expectCount := 0
-       for _, filter := range filters {
-               if len(filter) != 3 {
-                       return false
-               }
-
-               if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
-                       return false
-               }
-
-               op, ok := filter[1].(string)
-               if !ok {
-                       return false
-               }
-
-               if op == "in" {
-                       if rhs, ok := filter[2].([]interface{}); ok {
-                               for _, i := range rhs {
-                                       if u, ok := i.(string); ok {
-                                               *clusterId = u[0:5]
-                                               queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
-                                               expectCount += 1
-                                       }
-                               }
-                       }
-               } else if op == "=" {
-                       if u, ok := filter[2].(string); ok {
-                               *clusterId = u[0:5]
-                               queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
-                               expectCount += 1
-                       }
-               } else {
-                       return false
-               }
-
-       }
-
-       if len(queryClusters) <= 1 {
-               // Query does not search for uuids across multiple
-               // clusters.
-               return false
-       }
-
-       // Validations
-       count := req.Form.Get("count")
-       if count != "" && count != `none` && count != `"none"` {
-               httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
-               return true
-       }
-       if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
-               httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
-               return true
-       }
-       if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
-               httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
-                       expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
-               return true
-       }
-       if req.Form.Get("select") != "" {
-               foundUUID := false
-               var selects []string
-               err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
-               if err != nil {
-                       httpserver.Error(w, err.Error(), http.StatusBadRequest)
-                       return true
-               }
-
-               for _, r := range selects {
-                       if r == "uuid" {
-                               foundUUID = true
-                               break
-                       }
-               }
-               if !foundUUID {
-                       httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
-                       return true
-               }
-       }
-
-       // Perform concurrent requests to each cluster
-
-       // use channel as a semaphore to limit the number of concurrent
-       // requests at a time
-       sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
-       defer close(sem)
-       wg := sync.WaitGroup{}
-
-       req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-       mtx := sync.Mutex{}
-       errors := []error{}
-       var completeResponses []map[string]interface{}
-       var kind string
-
-       for k, v := range queryClusters {
-               if len(v) == 0 {
-                       // Nothing to query
-                       continue
-               }
-
-               // blocks until it can put a value into the
-               // channel (which has a max queue capacity)
-               sem <- true
-               wg.Add(1)
-               go func(k string, v []string) {
-                       rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
-                       mtx.Lock()
-                       if err == nil {
-                               completeResponses = append(completeResponses, rp...)
-                               kind = kn
-                       } else {
-                               errors = append(errors, err)
-                       }
-                       mtx.Unlock()
-                       wg.Done()
-                       <-sem
-               }(k, v)
-       }
-       wg.Wait()
-
-       if len(errors) > 0 {
-               var strerr []string
-               for _, e := range errors {
-                       strerr = append(strerr, e.Error())
-               }
-               httpserver.Errors(w, strerr, http.StatusBadGateway)
-               return true
-       }
-
-       w.Header().Set("Content-Type", "application/json")
-       w.WriteHeader(http.StatusOK)
-       itemList := make(map[string]interface{})
-       itemList["items"] = completeResponses
-       itemList["kind"] = kind
-       json.NewEncoder(w).Encode(itemList)
-
-       return true
-}
-
-func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-       m := h.matcher.FindStringSubmatch(req.URL.Path)
-       clusterId := ""
-
-       if len(m) > 0 && m[2] != "" {
-               clusterId = m[2]
-       }
-
-       // Get form parameters from URL and form body (if POST).
-       if err := loadParamsFromForm(req); err != nil {
-               httpserver.Error(w, err.Error(), http.StatusBadRequest)
-               return
-       }
-
-       // Check if the parameters have an explicit cluster_id
-       if req.Form.Get("cluster_id") != "" {
-               clusterId = req.Form.Get("cluster_id")
-       }
-
-       // Handle the POST-as-GET special case (workaround for large
-       // GET requests that potentially exceed maximum URL length,
-       // like multi-object queries where the filter has 100s of
-       // items)
-       effectiveMethod := req.Method
-       if req.Method == "POST" && req.Form.Get("_method") != "" {
-               effectiveMethod = req.Form.Get("_method")
-       }
-
-       if effectiveMethod == "GET" &&
-               clusterId == "" &&
-               req.Form.Get("filters") != "" &&
-               h.handleMultiClusterQuery(w, req, &clusterId) {
-               return
-       }
-
-       if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
-               h.next.ServeHTTP(w, req)
-       } else {
-               h.handler.remoteClusterRequest(clusterId, w, req, nil)
-       }
-}
-
-type rewriteSignaturesClusterId struct {
-       clusterID  string
-       expectHash string
-}
-
-func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-       if requestError != nil {
-               return resp, requestError
-       }
-
-       if resp.StatusCode != 200 {
-               return resp, nil
-       }
-
-       originalBody := resp.Body
-       defer originalBody.Close()
-
-       var col arvados.Collection
-       err = json.NewDecoder(resp.Body).Decode(&col)
-       if err != nil {
-               return nil, err
-       }
-
-       // rewriting signatures will make manifest text 5-10% bigger so calculate
-       // capacity accordingly
-       updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
-
-       hasher := md5.New()
-       mw := io.MultiWriter(hasher, updatedManifest)
-       sz := 0
-
-       scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
-       scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
-       for scanner.Scan() {
-               line := scanner.Text()
-               tokens := strings.Split(line, " ")
-               if len(tokens) < 3 {
-                       return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
-               }
-
-               n, err := mw.Write([]byte(tokens[0]))
-               if err != nil {
-                       return nil, fmt.Errorf("Error updating manifest: %v", err)
-               }
-               sz += n
-               for _, token := range tokens[1:] {
-                       n, err = mw.Write([]byte(" "))
-                       if err != nil {
-                               return nil, fmt.Errorf("Error updating manifest: %v", err)
-                       }
-                       sz += n
-
-                       m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
-                       if m != nil {
-                               // Rewrite the block signature to be a remote signature
-                               _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], rw.clusterID, m[5][2:], m[8])
-                               if err != nil {
-                                       return nil, fmt.Errorf("Error updating manifest: %v", err)
-                               }
-
-                               // for hash checking, ignore signatures
-                               n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
-                               if err != nil {
-                                       return nil, fmt.Errorf("Error updating manifest: %v", err)
-                               }
-                               sz += n
-                       } else {
-                               n, err = mw.Write([]byte(token))
-                               if err != nil {
-                                       return nil, fmt.Errorf("Error updating manifest: %v", err)
-                               }
-                               sz += n
-                       }
-               }
-               n, err = mw.Write([]byte("\n"))
-               if err != nil {
-                       return nil, fmt.Errorf("Error updating manifest: %v", err)
-               }
-               sz += n
-       }
-
-       // Check that expected hash is consistent with
-       // portable_data_hash field of the returned record
-       if rw.expectHash == "" {
-               rw.expectHash = col.PortableDataHash
-       } else if rw.expectHash != col.PortableDataHash {
-               return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", rw.expectHash, col.PortableDataHash)
-       }
-
-       // Certify that the computed hash of the manifest_text matches our expectation
-       sum := hasher.Sum(nil)
-       computedHash := fmt.Sprintf("%x+%v", sum, sz)
-       if computedHash != rw.expectHash {
-               return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, rw.expectHash)
-       }
-
-       col.ManifestText = updatedManifest.String()
-
-       newbody, err := json.Marshal(col)
-       if err != nil {
-               return nil, err
-       }
-
-       buf := bytes.NewBuffer(newbody)
-       resp.Body = ioutil.NopCloser(buf)
-       resp.ContentLength = int64(buf.Len())
-       resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
-
-       return resp, nil
-}
-
-func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-       if requestError != nil {
-               return resp, requestError
-       }
-
-       if resp.StatusCode == 404 {
-               // Suppress returning this result, because we want to
-               // search the federation.
-               return nil, nil
-       }
-       return resp, nil
-}
-
-type searchRemoteClusterForPDH struct {
-       pdh           string
-       remoteID      string
-       mtx           *sync.Mutex
-       sentResponse  *bool
-       sharedContext *context.Context
-       cancelFunc    func()
-       errors        *[]string
-       statusCode    *int
-}
-
-func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-       s.mtx.Lock()
-       defer s.mtx.Unlock()
-
-       if *s.sentResponse {
-               // Another request already returned a response
-               return nil, nil
-       }
-
-       if requestError != nil {
-               *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
-               // Record the error and suppress response
-               return nil, nil
-       }
-
-       if resp.StatusCode != 200 {
-               // Suppress returning unsuccessful result.  Maybe
-               // another request will find it.
-               // TODO collect and return error responses.
-               *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
-               if resp.StatusCode != 404 {
-                       // Got a non-404 error response, convert into BadGateway
-                       *s.statusCode = http.StatusBadGateway
-               }
-               return nil, nil
-       }
-
-       s.mtx.Unlock()
-
-       // This reads the response body.  We don't want to hold the
-       // lock while doing this because other remote requests could
-       // also have made it to this point, and we don't want a
-       // slow response holding the lock to block a faster response
-       // that is waiting on the lock.
-       newResponse, err = rewriteSignaturesClusterId{s.remoteID, s.pdh}.rewriteSignatures(resp, nil)
-
-       s.mtx.Lock()
-
-       if *s.sentResponse {
-               // Another request already returned a response
-               return nil, nil
-       }
-
-       if err != nil {
-               // Suppress returning unsuccessful result.  Maybe
-               // another request will be successful.
-               *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
-               return nil, nil
-       }
-
-       // We have a successful response.  Suppress/cancel all the
-       // other requests/responses.
-       *s.sentResponse = true
-       s.cancelFunc()
-
-       return newResponse, nil
-}
-
-func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-       if req.Method != "GET" {
-               // Only handle GET requests right now
-               h.next.ServeHTTP(w, req)
-               return
-       }
-
-       m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
-       if len(m) != 2 {
-               // Not a collection PDH GET request
-               m = collectionRe.FindStringSubmatch(req.URL.Path)
-               clusterId := ""
-
-               if len(m) > 0 {
-                       clusterId = m[2]
-               }
-
-               if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
-                       // request for remote collection by uuid
-                       h.handler.remoteClusterRequest(clusterId, w, req,
-                               rewriteSignaturesClusterId{clusterId, ""}.rewriteSignatures)
-                       return
-               }
-               // not a collection UUID request, or it is a request
-               // for a local UUID, either way, continue down the
-               // handler stack.
-               h.next.ServeHTTP(w, req)
-               return
-       }
-
-       // Request for collection by PDH.  Search the federation.
-
-       // First, query the local cluster.
-       if h.handler.localClusterRequest(w, req, filterLocalClusterResponse) {
-               return
-       }
-
-       sharedContext, cancelFunc := context.WithCancel(req.Context())
-       defer cancelFunc()
-       req = req.WithContext(sharedContext)
-
-       // Create a goroutine for each cluster in the
-       // RemoteClusters map.  The first valid result gets
-       // returned to the client.  When that happens, all
-       // other outstanding requests are cancelled or
-       // suppressed.
-       sentResponse := false
-       mtx := sync.Mutex{}
-       wg := sync.WaitGroup{}
-       var errors []string
-       var errorCode int = 404
-
-       // use channel as a semaphore to limit the number of concurrent
-       // requests at a time
-       sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
-       defer close(sem)
-       for remoteID := range h.handler.Cluster.RemoteClusters {
-               // blocks until it can put a value into the
-               // channel (which has a max queue capacity)
-               sem <- true
-               if sentResponse {
-                       break
-               }
-               search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
-                       &sharedContext, cancelFunc, &errors, &errorCode}
-               wg.Add(1)
-               go func() {
-                       h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse)
-                       wg.Done()
-                       <-sem
-               }()
-       }
-       wg.Wait()
-
-       if sentResponse {
-               return
-       }
-
-       // No successful responses, so return the error
-       httpserver.Errors(w, errors, errorCode)
-}
-
 func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
        mux := http.NewServeMux()
-       mux.Handle("/arvados/v1/workflows", &genericFederatedRequestHandler{next, h, wfRe})
-       mux.Handle("/arvados/v1/workflows/", &genericFederatedRequestHandler{next, h, wfRe})
-       mux.Handle("/arvados/v1/containers", &genericFederatedRequestHandler{next, h, containersRe})
-       mux.Handle("/arvados/v1/containers/", &genericFederatedRequestHandler{next, h, containersRe})
-       mux.Handle("/arvados/v1/container_requests", &genericFederatedRequestHandler{next, h, containerRequestsRe})
-       mux.Handle("/arvados/v1/container_requests/", &genericFederatedRequestHandler{next, h, containerRequestsRe})
+
+       wfHandler := &genericFederatedRequestHandler{next, h, wfRe, nil}
+       containersHandler := &genericFederatedRequestHandler{next, h, containersRe, nil}
+       containerRequestsHandler := &genericFederatedRequestHandler{next, h, containerRequestsRe,
+               []federatedRequestDelegate{remoteContainerRequestCreate}}
+
+       mux.Handle("/arvados/v1/workflows", wfHandler)
+       mux.Handle("/arvados/v1/workflows/", wfHandler)
+       mux.Handle("/arvados/v1/containers", containersHandler)
+       mux.Handle("/arvados/v1/containers/", containersHandler)
+       mux.Handle("/arvados/v1/container_requests", containerRequestsHandler)
+       mux.Handle("/arvados/v1/container_requests/", containerRequestsHandler)
        mux.Handle("/arvados/v1/collections", next)
        mux.Handle("/arvados/v1/collections/", &collectionFederatedRequestHandler{next, h})
        mux.Handle("/", next)
@@ -718,68 +126,157 @@ type CurrentUser struct {
        UUID          string
 }
 
-func (h *Handler) validateAPItoken(req *http.Request, user *CurrentUser) error {
+// validateAPItoken extracts the token from the provided http request,
+// checks it again api_client_authorizations table in the database,
+// and fills in the token scope and user UUID.  Does not handle remote
+// tokens unless they are already in the database and not expired.
+func (h *Handler) validateAPItoken(req *http.Request, token string) (*CurrentUser, error) {
+       user := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: token}}
        db, err := h.db(req)
        if err != nil {
-               return err
+               return nil, err
+       }
+
+       var uuid string
+       if strings.HasPrefix(token, "v2/") {
+               sp := strings.Split(token, "/")
+               uuid = sp[1]
+               token = sp[2]
+       }
+       user.Authorization.APIToken = token
+       var scopes string
+       err = db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, api_client_authorizations.scopes, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, token).Scan(&user.Authorization.UUID, &scopes, &user.UUID)
+       if err != nil {
+               return nil, err
+       }
+       if uuid != "" && user.Authorization.UUID != uuid {
+               return nil, fmt.Errorf("UUID embedded in v2 token did not match record")
+       }
+       err = json.Unmarshal([]byte(scopes), &user.Authorization.Scopes)
+       if err != nil {
+               return nil, err
+       }
+       return &user, nil
+}
+
+func (h *Handler) createAPItoken(req *http.Request, userUUID string, scopes []string) (*arvados.APIClientAuthorization, error) {
+       db, err := h.db(req)
+       if err != nil {
+               return nil, err
+       }
+       rd, err := randutil.String(15, "abcdefghijklmnopqrstuvwxyz0123456789")
+       if err != nil {
+               return nil, err
+       }
+       uuid := fmt.Sprintf("%v-gj3su-%v", h.Cluster.ClusterID, rd)
+       token, err := randutil.String(50, "abcdefghijklmnopqrstuvwxyz0123456789")
+       if err != nil {
+               return nil, err
+       }
+       if len(scopes) == 0 {
+               scopes = append(scopes, "all")
+       }
+       scopesjson, err := json.Marshal(scopes)
+       if err != nil {
+               return nil, err
+       }
+       _, err = db.ExecContext(req.Context(),
+               `INSERT INTO api_client_authorizations
+(uuid, api_token, expires_at, scopes,
+user_id,
+api_client_id, created_at, updated_at)
+VALUES ($1, $2, CURRENT_TIMESTAMP + INTERVAL '2 weeks', $3,
+(SELECT id FROM users WHERE users.uuid=$4 LIMIT 1),
+0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)`,
+               uuid, token, string(scopesjson), userUUID)
+
+       if err != nil {
+               return nil, err
        }
-       return db.QueryRowContext(req.Context(), `SELECT api_client_authorizations.uuid, users.uuid FROM api_client_authorizations JOIN users on api_client_authorizations.user_id=users.id WHERE api_token=$1 AND (expires_at IS NULL OR expires_at > current_timestamp) LIMIT 1`, user.Authorization.APIToken).Scan(&user.Authorization.UUID, &user.UUID)
+
+       return &arvados.APIClientAuthorization{
+               UUID:      uuid,
+               APIToken:  token,
+               ExpiresAt: "",
+               Scopes:    scopes}, nil
 }
 
 // Extract the auth token supplied in req, and replace it with a
 // salted token for the remote cluster.
-func (h *Handler) saltAuthToken(req *http.Request, remote string) error {
+func (h *Handler) saltAuthToken(req *http.Request, remote string) (updatedReq *http.Request, err error) {
+       updatedReq = (&http.Request{
+               Method:        req.Method,
+               URL:           req.URL,
+               Header:        req.Header,
+               Body:          req.Body,
+               ContentLength: req.ContentLength,
+               Host:          req.Host,
+       }).WithContext(req.Context())
+
        creds := auth.NewCredentials()
-       creds.LoadTokensFromHTTPRequest(req)
-       if len(creds.Tokens) == 0 && req.Header.Get("Content-Type") == "application/x-www-form-encoded" {
+       creds.LoadTokensFromHTTPRequest(updatedReq)
+       if len(creds.Tokens) == 0 && updatedReq.Header.Get("Content-Type") == "application/x-www-form-encoded" {
                // Override ParseForm's 10MiB limit by ensuring
                // req.Body is a *http.maxBytesReader.
-               req.Body = http.MaxBytesReader(nil, req.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
-               if err := creds.LoadTokensFromHTTPRequestBody(req); err != nil {
-                       return err
+               updatedReq.Body = http.MaxBytesReader(nil, updatedReq.Body, 1<<28) // 256MiB. TODO: use MaxRequestSize from discovery doc or config.
+               if err := creds.LoadTokensFromHTTPRequestBody(updatedReq); err != nil {
+                       return nil, err
                }
                // Replace req.Body with a buffer that re-encodes the
                // form without api_token, in case we end up
                // forwarding the request.
-               if req.PostForm != nil {
-                       req.PostForm.Del("api_token")
+               if updatedReq.PostForm != nil {
+                       updatedReq.PostForm.Del("api_token")
                }
-               req.Body = ioutil.NopCloser(bytes.NewBufferString(req.PostForm.Encode()))
+               updatedReq.Body = ioutil.NopCloser(bytes.NewBufferString(updatedReq.PostForm.Encode()))
        }
        if len(creds.Tokens) == 0 {
-               return nil
+               return updatedReq, nil
        }
+
        token, err := auth.SaltToken(creds.Tokens[0], remote)
+
        if err == auth.ErrObsoleteToken {
                // If the token exists in our own database, salt it
                // for the remote. Otherwise, assume it was issued by
                // the remote, and pass it through unmodified.
-               currentUser := CurrentUser{Authorization: arvados.APIClientAuthorization{APIToken: creds.Tokens[0]}}
-               err = h.validateAPItoken(req, &currentUser)
+               currentUser, err := h.validateAPItoken(req, creds.Tokens[0])
                if err == sql.ErrNoRows {
                        // Not ours; pass through unmodified.
-                       token = currentUser.Authorization.APIToken
+                       token = creds.Tokens[0]
                } else if err != nil {
-                       return err
+                       return nil, err
                } else {
                        // Found; make V2 version and salt it.
                        token, err = auth.SaltToken(currentUser.Authorization.TokenV2(), remote)
                        if err != nil {
-                               return err
+                               return nil, err
                        }
                }
        } else if err != nil {
-               return err
+               return nil, err
+       }
+       updatedReq.Header = http.Header{}
+       for k, v := range req.Header {
+               if k != "Authorization" {
+                       updatedReq.Header[k] = v
+               }
        }
-       req.Header.Set("Authorization", "Bearer "+token)
+       updatedReq.Header.Set("Authorization", "Bearer "+token)
 
        // Remove api_token=... from the the query string, in case we
        // end up forwarding the request.
-       if values, err := url.ParseQuery(req.URL.RawQuery); err != nil {
-               return err
+       if values, err := url.ParseQuery(updatedReq.URL.RawQuery); err != nil {
+               return nil, err
        } else if _, ok := values["api_token"]; ok {
                delete(values, "api_token")
-               req.URL.RawQuery = values.Encode()
+               updatedReq.URL = &url.URL{
+                       Scheme:   req.URL.Scheme,
+                       Host:     req.URL.Host,
+                       Path:     req.URL.Path,
+                       RawPath:  req.URL.RawPath,
+                       RawQuery: values.Encode(),
+               }
        }
-       return nil
+       return updatedReq, nil
 }
index 23d5d7ca768111efd050861757e07879f91d7b05..da640071c523bc388af98fa3214d7328cf715359 100644 (file)
@@ -5,8 +5,10 @@
 package controller
 
 import (
+       "bytes"
        "encoding/json"
        "fmt"
+       "io"
        "io/ioutil"
        "net/http"
        "net/http/httptest"
@@ -90,6 +92,10 @@ func (s *FederationSuite) SetUpTest(c *check.C) {
 }
 
 func (s *FederationSuite) remoteMockHandler(w http.ResponseWriter, req *http.Request) {
+       b := &bytes.Buffer{}
+       io.Copy(b, req.Body)
+       req.Body.Close()
+       req.Body = ioutil.NopCloser(b)
        s.remoteMockRequests = append(s.remoteMockRequests, *req)
 }
 
@@ -567,6 +573,76 @@ func (s *FederationSuite) TestCreateRemoteContainerRequest(c *check.C) {
        c.Check(strings.HasPrefix(cr.UUID, "zzzzz-"), check.Equals, true)
 }
 
+func (s *FederationSuite) TestCreateRemoteContainerRequestCheckRuntimeToken(c *check.C) {
+       // Send request to zmock and check that outgoing request has
+       // runtime_token sent (because runtime_token isn't returned in
+       // the response).
+
+       defer s.localServiceReturns404(c).Close()
+       // pass cluster_id via query parameter, this allows arvados-controller
+       // to avoid parsing the body
+       req := httptest.NewRequest("POST", "/arvados/v1/container_requests?cluster_id=zmock",
+               strings.NewReader(`{
+  "container_request": {
+    "name": "hello world",
+    "state": "Uncommitted",
+    "output_path": "/",
+    "container_image": "123",
+    "command": ["abc"]
+  }
+}
+`))
+       req.Header.Set("Authorization", "Bearer "+arvadostest.ActiveToken)
+       req.Header.Set("Content-type", "application/json")
+
+       np := arvados.NodeProfile{
+               Controller: arvados.SystemServiceInstance{Listen: ":"},
+               RailsAPI: arvados.SystemServiceInstance{Listen: os.Getenv("ARVADOS_TEST_API_HOST"),
+                       TLS: true, Insecure: true}}
+       s.testHandler.Cluster.ClusterID = "zzzzz"
+       s.testHandler.Cluster.NodeProfiles["*"] = np
+       s.testHandler.NodeProfile = &np
+
+       resp := s.testRequest(req)
+       c.Check(resp.StatusCode, check.Equals, http.StatusOK)
+       var cr struct {
+               arvados.ContainerRequest `json:"container_request"`
+       }
+       c.Check(json.NewDecoder(s.remoteMockRequests[0].Body).Decode(&cr), check.IsNil)
+       c.Check(strings.HasPrefix(cr.ContainerRequest.RuntimeToken, "v2/"), check.Equals, true)
+}
+
+func (s *FederationSuite) TestCreateRemoteContainerRequestCheckSetRuntimeToken(c *check.C) {
+       // Send request to zmock and check that outgoing request has
+       // runtime_token sent (because runtime_token isn't returned in
+       // the response).
+
+       defer s.localServiceReturns404(c).Close()
+       // pass cluster_id via query parameter, this allows arvados-controller
+       // to avoid parsing the body
+       req := httptest.NewRequest("POST", "/arvados/v1/container_requests?cluster_id=zmock",
+               strings.NewReader(`{
+  "container_request": {
+    "name": "hello world",
+    "state": "Uncommitted",
+    "output_path": "/",
+    "container_image": "123",
+    "command": ["abc"],
+    "runtime_token": "xyz"
+  }
+}
+`))
+       req.Header.Set("Authorization", "Bearer "+arvadostest.ActiveToken)
+       req.Header.Set("Content-type", "application/json")
+       resp := s.testRequest(req)
+       c.Check(resp.StatusCode, check.Equals, http.StatusOK)
+       var cr struct {
+               arvados.ContainerRequest `json:"container_request"`
+       }
+       c.Check(json.NewDecoder(s.remoteMockRequests[0].Body).Decode(&cr), check.IsNil)
+       c.Check(cr.ContainerRequest.RuntimeToken, check.Equals, "xyz")
+}
+
 func (s *FederationSuite) TestCreateRemoteContainerRequestError(c *check.C) {
        defer s.localServiceReturns404(c).Close()
        // pass cluster_id via query parameter, this allows arvados-controller
index 0c31815cba21f2869e7ae4ddf73c880bf4d0a5c8..295dde7ca42821b1c8f904eec42ac7e7764812fa 100644 (file)
@@ -5,6 +5,7 @@
 package controller
 
 import (
+       "context"
        "database/sql"
        "errors"
        "net"
@@ -49,6 +50,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
                        req.URL.Path = strings.Replace(req.URL.Path, "//", "/", -1)
                }
        }
+       if h.Cluster.HTTPRequestTimeout > 0 {
+               ctx, cancel := context.WithDeadline(req.Context(), time.Now().Add(time.Duration(h.Cluster.HTTPRequestTimeout)))
+               req = req.WithContext(ctx)
+               defer cancel()
+       }
+
        h.handlerStack.ServeHTTP(w, req)
 }
 
@@ -83,8 +90,7 @@ func (h *Handler) setup() {
        h.insecureClient = &ic
 
        h.proxy = &proxy{
-               Name:           "arvados-controller",
-               RequestTimeout: time.Duration(h.Cluster.HTTPRequestTimeout),
+               Name: "arvados-controller",
        }
 }
 
@@ -121,14 +127,10 @@ func prepend(next http.Handler, middleware middlewareFunc) http.Handler {
        })
 }
 
-// localClusterRequest sets up a request so it can be proxied to the
-// local API server using proxy.Do().  Returns true if a response was
-// written, false if not.
-func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request, filter ResponseFilter) bool {
+func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error) {
        urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
        if err != nil {
-               httpserver.Error(w, err.Error(), http.StatusInternalServerError)
-               return true
+               return nil, err
        }
        urlOut = &url.URL{
                Scheme:   urlOut.Scheme,
@@ -141,12 +143,14 @@ func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request,
        if insecure {
                client = h.insecureClient
        }
-       return h.proxy.Do(w, req, urlOut, client, filter)
+       return h.proxy.Do(req, urlOut, client)
 }
 
 func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
-       if !h.localClusterRequest(w, req, nil) && next != nil {
-               next.ServeHTTP(w, req)
+       resp, err := h.localClusterRequest(req)
+       n, err := h.proxy.ForwardResponse(w, resp, err)
+       if err != nil {
+               httpserver.Logger(req).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
        }
 }
 
index 963fd1159415e16d93e676401021e974c287ad69..746b9242f2198ee3c3000c808771047d4aa1c77c 100644 (file)
@@ -130,3 +130,39 @@ func (s *HandlerSuite) TestProxyRedirect(c *check.C) {
        c.Check(resp.Code, check.Equals, http.StatusFound)
        c.Check(resp.Header().Get("Location"), check.Matches, `https://0.0.0.0:1/auth/joshid\?return_to=foo&?`)
 }
+
+func (s *HandlerSuite) TestValidateV1APIToken(c *check.C) {
+       req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+       user, err := s.handler.(*Handler).validateAPItoken(req, arvadostest.ActiveToken)
+       c.Assert(err, check.IsNil)
+       c.Check(user.Authorization.UUID, check.Equals, arvadostest.ActiveTokenUUID)
+       c.Check(user.Authorization.APIToken, check.Equals, arvadostest.ActiveToken)
+       c.Check(user.Authorization.Scopes, check.DeepEquals, []string{"all"})
+       c.Check(user.UUID, check.Equals, arvadostest.ActiveUserUUID)
+}
+
+func (s *HandlerSuite) TestValidateV2APIToken(c *check.C) {
+       req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+       user, err := s.handler.(*Handler).validateAPItoken(req, arvadostest.ActiveTokenV2)
+       c.Assert(err, check.IsNil)
+       c.Check(user.Authorization.UUID, check.Equals, arvadostest.ActiveTokenUUID)
+       c.Check(user.Authorization.APIToken, check.Equals, arvadostest.ActiveToken)
+       c.Check(user.Authorization.Scopes, check.DeepEquals, []string{"all"})
+       c.Check(user.UUID, check.Equals, arvadostest.ActiveUserUUID)
+       c.Check(user.Authorization.TokenV2(), check.Equals, arvadostest.ActiveTokenV2)
+}
+
+func (s *HandlerSuite) TestCreateAPIToken(c *check.C) {
+       req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+       auth, err := s.handler.(*Handler).createAPItoken(req, arvadostest.ActiveUserUUID, nil)
+       c.Assert(err, check.IsNil)
+       c.Check(auth.Scopes, check.DeepEquals, []string{"all"})
+
+       user, err := s.handler.(*Handler).validateAPItoken(req, auth.TokenV2())
+       c.Assert(err, check.IsNil)
+       c.Check(user.Authorization.UUID, check.Equals, auth.UUID)
+       c.Check(user.Authorization.APIToken, check.Equals, auth.APIToken)
+       c.Check(user.Authorization.Scopes, check.DeepEquals, []string{"all"})
+       c.Check(user.UUID, check.Equals, arvadostest.ActiveUserUUID)
+       c.Check(user.Authorization.TokenV2(), check.Equals, auth.TokenV2())
+}
index 951cb9d25fe24ba74a5697d54187847cfc84ae1a..c01c152352e6b8f101179bf38add3b0574a00c5d 100644 (file)
@@ -5,18 +5,24 @@
 package controller
 
 import (
-       "context"
        "io"
        "net/http"
        "net/url"
-       "time"
 
        "git.curoverse.com/arvados.git/sdk/go/httpserver"
 )
 
 type proxy struct {
-       Name           string // to use in Via header
-       RequestTimeout time.Duration
+       Name string // to use in Via header
+}
+
+type HTTPError struct {
+       Message string
+       Code    int
+}
+
+func (h HTTPError) Error() string {
+       return h.Message
 }
 
 // headers that shouldn't be forwarded when proxying. See
@@ -36,15 +42,11 @@ var dropHeaders = map[string]bool{
 
 type ResponseFilter func(*http.Response, error) (*http.Response, error)
 
-// Do sends a request, passes the result to the filter (if provided)
-// and then if the result is not suppressed by the filter, sends the
-// request to the ResponseWriter.  Returns true if a response was written,
-// false if not.
-func (p *proxy) Do(w http.ResponseWriter,
+// Forward a request to upstream service, and return response or error.
+func (p *proxy) Do(
        reqIn *http.Request,
        urlOut *url.URL,
-       client *http.Client,
-       filter ResponseFilter) bool {
+       client *http.Client) (*http.Response, error) {
 
        // Copy headers from incoming request, then add/replace proxy
        // headers like Via and X-Forwarded-For.
@@ -64,65 +66,35 @@ func (p *proxy) Do(w http.ResponseWriter,
        }
        hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
 
-       ctx := reqIn.Context()
-       if p.RequestTimeout > 0 {
-               var cancel context.CancelFunc
-               ctx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Duration(p.RequestTimeout)))
-               defer cancel()
-       }
-
        reqOut := (&http.Request{
                Method: reqIn.Method,
                URL:    urlOut,
                Host:   reqIn.Host,
                Header: hdrOut,
                Body:   reqIn.Body,
-       }).WithContext(ctx)
+       }).WithContext(reqIn.Context())
 
        resp, err := client.Do(reqOut)
-       if filter == nil && err != nil {
-               httpserver.Error(w, err.Error(), http.StatusBadGateway)
-               return true
-       }
-
-       // make sure original response body gets closed
-       var originalBody io.ReadCloser
-       if resp != nil {
-               originalBody = resp.Body
-               if originalBody != nil {
-                       defer originalBody.Close()
-               }
-       }
-
-       if filter != nil {
-               resp, err = filter(resp, err)
+       return resp, err
+}
 
-               if err != nil {
+// Copy a response (or error) to the downstream client
+func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
+       if err != nil {
+               if he, ok := err.(HTTPError); ok {
+                       httpserver.Error(w, he.Message, he.Code)
+               } else {
                        httpserver.Error(w, err.Error(), http.StatusBadGateway)
-                       return true
-               }
-               if resp == nil {
-                       // filter() returned a nil response, this means suppress
-                       // writing a response, for the case where there might
-                       // be multiple response writers.
-                       return false
-               }
-
-               // the filter gave us a new response body, make sure that gets closed too.
-               if resp.Body != originalBody {
-                       defer resp.Body.Close()
                }
+               return 0, nil
        }
 
+       defer resp.Body.Close()
        for k, v := range resp.Header {
                for _, v := range v {
                        w.Header().Add(k, v)
                }
        }
        w.WriteHeader(resp.StatusCode)
-       n, err := io.Copy(w, resp.Body)
-       if err != nil {
-               httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
-       }
-       return true
+       return io.Copy(w, resp.Body)
 }
index ec0239eb37bf0a45bb715b35eab757c6c94850d5..17cff235db82fba55fa12c6ff08fe0a114dff27b 100644 (file)
@@ -6,8 +6,10 @@ package arvados
 
 // APIClientAuthorization is an arvados#apiClientAuthorization resource.
 type APIClientAuthorization struct {
-       UUID     string `json:"uuid"`
-       APIToken string `json:"api_token"`
+       UUID      string   `json:"uuid,omitempty"`
+       APIToken  string   `json:"api_token,omitempty"`
+       ExpiresAt string   `json:"expires_at,omitempty"`
+       Scopes    []string `json:"scopes,omitempty"`
 }
 
 // APIClientAuthorizationList is an arvados#apiClientAuthorizationList resource.
index 2622c137030aada559bfc47f5768e7d918b6d816..b70b4ac917672f363096a810cd35e3689f5132f9 100644 (file)
@@ -56,6 +56,7 @@ type ContainerRequest struct {
        UseExisting             bool                   `json:"use_existing"`
        LogUUID                 string                 `json:"log_uuid"`
        OutputUUID              string                 `json:"output_uuid"`
+       RuntimeToken            string                 `json:"runtime_token"`
 }
 
 // Mount is special behavior to attach to a filesystem path or device.
index 114faf17b74e245aeaacf72aeaaf5bb6f8e5046a..e0f2483131a98a64856116bda8c14b4de7bd7051 100644 (file)
@@ -8,6 +8,7 @@ package arvadostest
 const (
        SpectatorToken          = "zw2f4gwx8hw8cjre7yp6v1zylhrhn3m5gvjq73rtpwhmknrybu"
        ActiveToken             = "3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi"
+       ActiveTokenUUID         = "zzzzz-gj3su-077z32aux8dg2s1"
        ActiveTokenV2           = "v2/zzzzz-gj3su-077z32aux8dg2s1/3kg6k6lzmp9kj5cpkcoxie963cmvjahbt2fod9zru30k1jqdmi"
        AdminToken              = "4axaw8zxe0qm22wa6urpp5nskcne8z88cvbupv653y1njyi05h"
        AnonymousToken          = "4kg6k6lzmp9kj4cpkcoxie964cmvjahbt4fod9zru44k4jqdmi"
index 6452136d85eede6896f1dca1648e00b4ba6ae8e7..14d89873b60f7d902a39a6b337eea78e8040d0c3 100644 (file)
@@ -12,6 +12,10 @@ import (
        "time"
 )
 
+const (
+       HeaderRequestID = "X-Request-Id"
+)
+
 // IDGenerator generates alphanumeric strings suitable for use as
 // unique IDs (a given IDGenerator will never return the same ID
 // twice).
@@ -44,11 +48,11 @@ func (g *IDGenerator) Next() string {
 func AddRequestIDs(h http.Handler) http.Handler {
        gen := &IDGenerator{Prefix: "req-"}
        return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-               if req.Header.Get("X-Request-Id") == "" {
+               if req.Header.Get(HeaderRequestID) == "" {
                        if req.Header == nil {
                                req.Header = http.Header{}
                        }
-                       req.Header.Set("X-Request-Id", gen.Next())
+                       req.Header.Set(HeaderRequestID, gen.Next())
                }
                h.ServeHTTP(w, req)
        })
index 718ffc0d0a51416440ff75ec98c442cfe64423b9..487043ee3549d8afe915f9abeeaeab2c8f252707 100644 (file)
@@ -496,7 +496,14 @@ class Collection < ArvadosModel
     if loc = Keep::Locator.parse(search_term)
       loc.strip_hints!
       coll_match = readable_by(*readers).where(portable_data_hash: loc.to_s).limit(1)
-      return get_compatible_images(readers, pattern, coll_match)
+      if coll_match.any? or Rails.configuration.remote_hosts.length == 0
+        return get_compatible_images(readers, pattern, coll_match)
+      else
+        # Allow bare pdh that doesn't exist in the local database so
+        # that federated container requests which refer to remotely
+        # stored containers will validate.
+        return [Collection.new(portable_data_hash: loc.to_s)]
+      end
     end
 
     if search_tag.nil? and (n = search_term.index(":"))
index 0d8453174e205e85ab3f79e01a32cc530478a4a1..cd763a8e7e1eb0d851f08517b730ddf9f230a113 100644 (file)
@@ -493,10 +493,14 @@ class Container < ArvadosModel
       return false
     end
 
-    if current_api_client_authorization.andand.uuid.andand == self.auth_uuid
-      # The contained process itself can update progress indicators,
-      # but can't change priority etc.
-      permitted = permitted & (progress_attrs + final_attrs + [:state] - [:log])
+    if self.state == Running &&
+       !current_api_client_authorization.nil? &&
+       (current_api_client_authorization.uuid == self.auth_uuid ||
+        current_api_client_authorization.token == self.runtime_token)
+      # The contained process itself can write final attrs but can't
+      # change priority or log.
+      permitted.push *final_attrs
+      permitted = permitted - [:log, :priority]
     elsif self.locked_by_uuid && self.locked_by_uuid != current_api_client_authorization.andand.uuid
       # When locked, progress fields cannot be updated by the wrong
       # dispatcher, even though it has admin privileges.
index 0e61db7bcd9d5cc0cb185c4766a2e597c6d6ed4a..44737524e5f583cb76bb62a6aa0ff8af5ca91319 100644 (file)
@@ -63,8 +63,8 @@ class RemoteUsersTest < ActionDispatch::IntegrationTest
     ready.pop
     @remote_server = srv
     @remote_host = "127.0.0.1:#{srv.config[:Port]}"
-    Rails.configuration.remote_hosts['zbbbb'] = @remote_host
-    Rails.configuration.remote_hosts['zbork'] = @remote_host
+    Rails.configuration.remote_hosts = Rails.configuration.remote_hosts.merge({'zbbbb' => @remote_host,
+                                                                               'zbork' => @remote_host})
     Arvados::V1::SchemaController.any_instance.stubs(:root_url).returns "https://#{@remote_host}"
     @stub_status = 200
     @stub_content = {
index 8ff216e28caf8a598c5b6fbbf46a9d342e4a7c35..0fafb990366de4c309fc67475542e7e570401005 100644 (file)
@@ -512,6 +512,12 @@ class ContainerRequestTest < ActiveSupport::TestCase
     end
   end
 
+  test "allow unrecognized container when there are remote_hosts" do
+    set_user_from_auth :active
+    Rails.configuration.remote_hosts = {"foooo" => "bar.com"}
+    Container.resolve_container_image('acbd18db4cc2f85cedef654fccc4a4d8+3')
+  end
+
   test "migrated docker image" do
     Rails.configuration.docker_image_formats = ['v2']
     add_docker19_migration_link
index 491022ad8d5a9cd6e47e1cf7727a5cba92d54ce4..90b4f13bf597b5b9ea306dec04b698e75fb98ae3 100644 (file)
@@ -777,25 +777,41 @@ class ContainerTest < ActiveSupport::TestCase
     assert_equal [logpdh_time2], Collection.where(uuid: [cr1log_uuid, cr2log_uuid]).to_a.collect(&:portable_data_hash).uniq
   end
 
-  test "auth_uuid can set output, progress, runtime_status, state on running container -- but not log" do
-    set_user_from_auth :active
-    c, _ = minimal_new
-    set_user_from_auth :dispatch1
-    c.lock
-    c.update_attributes! state: Container::Running
-
-    auth = ApiClientAuthorization.find_by_uuid(c.auth_uuid)
-    Thread.current[:api_client_authorization] = auth
-    Thread.current[:api_client] = auth.api_client
-    Thread.current[:token] = auth.token
-    Thread.current[:user] = auth.user
+  ["auth_uuid", "runtime_token"].each do |tok|
+    test "#{tok} can set output, progress, runtime_status, state on running container -- but not log" do
+      if tok == "runtime_token"
+        set_user_from_auth :spectator
+        c, _ = minimal_new(container_image: "9ae44d5792468c58bcf85ce7353c7027+124",
+                           runtime_token: api_client_authorizations(:active).token)
+      else
+        set_user_from_auth :active
+        c, _ = minimal_new
+      end
+      set_user_from_auth :dispatch1
+      c.lock
+      c.update_attributes! state: Container::Running
+
+      if tok == "runtime_token"
+        auth = ApiClientAuthorization.validate(token: c.runtime_token)
+        Thread.current[:api_client_authorization] = auth
+        Thread.current[:api_client] = auth.api_client
+        Thread.current[:token] = auth.token
+        Thread.current[:user] = auth.user
+      else
+        auth = ApiClientAuthorization.find_by_uuid(c.auth_uuid)
+        Thread.current[:api_client_authorization] = auth
+        Thread.current[:api_client] = auth.api_client
+        Thread.current[:token] = auth.token
+        Thread.current[:user] = auth.user
+      end
 
-    assert c.update_attributes(output: collections(:collection_owned_by_active).portable_data_hash)
-    assert c.update_attributes(runtime_status: {'warning' => 'something happened'})
-    assert c.update_attributes(progress: 0.5)
-    refute c.update_attributes(log: collections(:real_log_collection).portable_data_hash)
-    c.reload
-    assert c.update_attributes(state: Container::Complete, exit_code: 0)
+      assert c.update_attributes(output: collections(:collection_owned_by_active).portable_data_hash)
+      assert c.update_attributes(runtime_status: {'warning' => 'something happened'})
+      assert c.update_attributes(progress: 0.5)
+      refute c.update_attributes(log: collections(:real_log_collection).portable_data_hash)
+      c.reload
+      assert c.update_attributes(state: Container::Complete, exit_code: 0)
+    end
   end
 
   test "not allowed to set output that is not readable by current user" do
index fc6a97cf7480c645206c867e3449822bfcfa41a5..41e2adb9c3d35a2a6d52f9244b666913eff3e1d5 100644 (file)
@@ -127,6 +127,7 @@ class JobTest < ActiveSupport::TestCase
     'locator' => BAD_COLLECTION,
   }.each_pair do |spec_type, image_spec|
     test "Job validation fails with nonexistent Docker image #{spec_type}" do
+      Rails.configuration.remote_hosts = {}
       job = Job.new job_attrs(runtime_constraints:
                               {'docker_image' => image_spec})
       assert(job.invalid?, "nonexistent Docker image #{spec_type} was valid")
index 800556866a43e985c9107370ff83d9691158aef1..1deb74031667d7ade04968344d3b262b3ccf1dd1 100644 (file)
@@ -123,7 +123,7 @@ type ContainerRunner struct {
        SigChan         chan os.Signal
        ArvMountExit    chan error
        SecretMounts    map[string]arvados.Mount
-       MkArvClient     func(token string) (IArvadosClient, error)
+       MkArvClient     func(token string) (IArvadosClient, IKeepClient, error)
        finalState      string
        parentTemp      string
 
@@ -238,8 +238,17 @@ func (runner *ContainerRunner) LoadImage() (err error) {
 
        runner.CrunchLog.Printf("Fetching Docker image from collection '%s'", runner.Container.ContainerImage)
 
+       tok, err := runner.ContainerToken()
+       if err != nil {
+               return fmt.Errorf("While getting container token (LoadImage): %v", err)
+       }
+       arvClient, kc, err := runner.MkArvClient(tok)
+       if err != nil {
+               return fmt.Errorf("While creating arv client (LoadImage): %v", err)
+       }
+
        var collection arvados.Collection
-       err = runner.ArvClient.Get("collections", runner.Container.ContainerImage, nil, &collection)
+       err = arvClient.Get("collections", runner.Container.ContainerImage, nil, &collection)
        if err != nil {
                return fmt.Errorf("While getting container image collection: %v", err)
        }
@@ -260,7 +269,7 @@ func (runner *ContainerRunner) LoadImage() (err error) {
                runner.CrunchLog.Print("Loading Docker image from keep")
 
                var readCloser io.ReadCloser
-               readCloser, err = runner.Kc.ManifestFileReader(manifest, img)
+               readCloser, err = kc.ManifestFileReader(manifest, img)
                if err != nil {
                        return fmt.Errorf("While creating ManifestFileReader for container image: %v", err)
                }
@@ -282,7 +291,7 @@ func (runner *ContainerRunner) LoadImage() (err error) {
 
        runner.ContainerConfig.Image = imageID
 
-       runner.Kc.ClearBlockCache()
+       kc.ClearBlockCache()
 
        return nil
 }
@@ -1691,7 +1700,7 @@ func (runner *ContainerRunner) fetchContainerRecord() error {
                return fmt.Errorf("error getting container token: %v", err)
        }
 
-       containerClient, err := runner.MkArvClient(containerToken)
+       containerClient, _, err := runner.MkArvClient(containerToken)
        if err != nil {
                return fmt.Errorf("error creating container API client: %v", err)
        }
@@ -1731,13 +1740,17 @@ func NewContainerRunner(client *arvados.Client, api IArvadosClient, kc IKeepClie
                }
                return ps, nil
        }
-       cr.MkArvClient = func(token string) (IArvadosClient, error) {
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
                cl, err := arvadosclient.MakeArvadosClient()
                if err != nil {
-                       return nil, err
+                       return nil, nil, err
                }
                cl.ApiToken = token
-               return cl, nil
+               kc, err := keepclient.MakeKeepClient(cl)
+               if err != nil {
+                       return nil, nil, err
+               }
+               return cl, kc, nil
        }
        var err error
        cr.LogCollection, err = (&arvados.Collection{}).FileSystem(cr.client, cr.Kc)
index ab6afc77cbcf82855e4cd11920589d94d9968504..0df048cc8b95000fbb214dd88cabd83c6b9f71d1 100644 (file)
@@ -447,6 +447,10 @@ func (s *TestSuite) TestLoadImage(c *C) {
        cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
        c.Assert(err, IsNil)
 
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvTestClient{}, kc, nil
+       }
+
        _, err = cr.Docker.ImageRemove(nil, hwImageId, dockertypes.ImageRemoveOptions{})
        c.Check(err, IsNil)
 
@@ -492,6 +496,9 @@ func (ArvErrorTestClient) Create(resourceType string,
 }
 
 func (ArvErrorTestClient) Call(method, resourceType, uuid, action string, parameters arvadosclient.Dict, output interface{}) error {
+       if method == "GET" && resourceType == "containers" && action == "auth" {
+               return nil
+       }
        return errors.New("ArvError")
 }
 
@@ -556,9 +563,13 @@ func (s *TestSuite) TestLoadImageArvError(c *C) {
        // (1) Arvados error
        kc := &KeepTestClient{}
        defer kc.Close()
-       cr, err := NewContainerRunner(s.client, ArvErrorTestClient{}, kc, nil, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
+       cr, err := NewContainerRunner(s.client, &ArvErrorTestClient{}, kc, nil, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
        c.Assert(err, IsNil)
+
        cr.Container.ContainerImage = hwPDH
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvErrorTestClient{}, &KeepTestClient{}, nil
+       }
 
        err = cr.LoadImage()
        c.Check(err.Error(), Equals, "While getting container image collection: ArvError")
@@ -566,9 +577,13 @@ func (s *TestSuite) TestLoadImageArvError(c *C) {
 
 func (s *TestSuite) TestLoadImageKeepError(c *C) {
        // (2) Keep error
-       cr, err := NewContainerRunner(s.client, &ArvTestClient{}, &KeepErrorTestClient{}, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
+       kc := &KeepErrorTestClient{}
+       cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
        c.Assert(err, IsNil)
        cr.Container.ContainerImage = hwPDH
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvTestClient{}, kc, nil
+       }
 
        err = cr.LoadImage()
        c.Assert(err, NotNil)
@@ -577,9 +592,13 @@ func (s *TestSuite) TestLoadImageKeepError(c *C) {
 
 func (s *TestSuite) TestLoadImageCollectionError(c *C) {
        // (3) Collection doesn't contain image
-       cr, err := NewContainerRunner(s.client, &ArvTestClient{}, &KeepReadErrorTestClient{}, nil, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
+       kc := &KeepReadErrorTestClient{}
+       cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, nil, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
        c.Assert(err, IsNil)
        cr.Container.ContainerImage = otherPDH
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvTestClient{}, kc, nil
+       }
 
        err = cr.LoadImage()
        c.Check(err.Error(), Equals, "First file in the container image collection does not end in .tar")
@@ -587,9 +606,13 @@ func (s *TestSuite) TestLoadImageCollectionError(c *C) {
 
 func (s *TestSuite) TestLoadImageKeepReadError(c *C) {
        // (4) Collection doesn't contain image
-       cr, err := NewContainerRunner(s.client, &ArvTestClient{}, &KeepReadErrorTestClient{}, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
+       kc := &KeepReadErrorTestClient{}
+       cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
        c.Assert(err, IsNil)
        cr.Container.ContainerImage = hwPDH
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvTestClient{}, kc, nil
+       }
 
        err = cr.LoadImage()
        c.Check(err, NotNil)
@@ -637,6 +660,10 @@ func (s *TestSuite) TestRunContainer(c *C) {
        cr, err := NewContainerRunner(s.client, &ArvTestClient{}, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
        c.Assert(err, IsNil)
 
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvTestClient{}, kc, nil
+       }
+
        var logs TestLogs
        cr.NewLogWriter = logs.NewTestLoggingWriter
        cr.Container.ContainerImage = hwPDH
@@ -780,8 +807,8 @@ func (s *TestSuite) fullRunHelper(c *C, record string, extraMounts []string, exi
                }
                return d, err
        }
-       cr.MkArvClient = func(token string) (IArvadosClient, error) {
-               return &ArvTestClient{secretMounts: secretMounts}, nil
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvTestClient{secretMounts: secretMounts}, &KeepTestClient{}, nil
        }
 
        if extraMounts != nil && len(extraMounts) > 0 {
@@ -1077,8 +1104,8 @@ func (s *TestSuite) testStopContainer(c *C, setup func(cr *ContainerRunner)) {
        cr, err := NewContainerRunner(s.client, api, kc, s.docker, "zzzzz-zzzzz-zzzzzzzzzzzzzzz")
        c.Assert(err, IsNil)
        cr.RunArvMount = func([]string, string) (*exec.Cmd, error) { return nil, nil }
-       cr.MkArvClient = func(token string) (IArvadosClient, error) {
-               return &ArvTestClient{}, nil
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvTestClient{}, &KeepTestClient{}, nil
        }
        setup(cr)
 
@@ -1561,8 +1588,8 @@ func (s *TestSuite) stdoutErrorRunHelper(c *C, record string, fn func(t *TestDoc
        c.Assert(err, IsNil)
        am := &ArvMountCmdLine{}
        cr.RunArvMount = am.ArvMountTest
-       cr.MkArvClient = func(token string) (IArvadosClient, error) {
-               return &ArvTestClient{}, nil
+       cr.MkArvClient = func(token string) (IArvadosClient, IKeepClient, error) {
+               return &ArvTestClient{}, &KeepTestClient{}, nil
        }
 
        err = cr.Run()
index aa6b2d773dfb1e47969794dd816a81b179539163..9abb9bb15e0ae0824533c812f1302d93cf270722 100644 (file)
                        "revision": "d14ea06fba99483203c19d92cfcd13ebe73135f4",
                        "revisionTime": "2015-07-11T00:45:18Z"
                },
+               {
+                       "checksumSHA1": "khL6oKjx81rAZKW+36050b7f5As=",
+                       "path": "github.com/jmcvetta/randutil",
+                       "revision": "2bb1b664bcff821e02b2a0644cd29c7e824d54f8",
+                       "revisionTime": "2015-08-17T12:26:01Z"
+               },
                {
                        "checksumSHA1": "oX6jFQD74oOApvDIhOzW2dXpg5Q=",
                        "path": "github.com/kevinburke/ssh_config",