Merge master to output-tags branch and resolve conflict
authorJiayong Li <jiayong@math.mit.edu>
Tue, 15 Nov 2016 19:59:55 +0000 (14:59 -0500)
committerJiayong Li <jiayong@math.mit.edu>
Tue, 15 Nov 2016 19:59:55 +0000 (14:59 -0500)
53 files changed:
apps/workbench/app/controllers/application_controller.rb
apps/workbench/app/controllers/collections_controller.rb
apps/workbench/app/controllers/projects_controller.rb
apps/workbench/app/controllers/work_unit_templates_controller.rb
apps/workbench/app/controllers/work_units_controller.rb
apps/workbench/app/models/arvados_base.rb
apps/workbench/app/views/projects/_show_workflows.html.erb [new file with mode: 0644]
apps/workbench/test/controllers/disabled_api_test.rb [new file with mode: 0644]
apps/workbench/test/unit/disabled_api_test.rb [new file with mode: 0644]
build/run-build-packages.sh
build/run-tests.sh
sdk/cli/bin/arv-run-pipeline-instance
sdk/cwl/arvados_cwl/__init__.py
sdk/cwl/arvados_cwl/arvcontainer.py
sdk/cwl/arvados_cwl/arvjob.py
sdk/cwl/arvados_cwl/arvworkflow.py
sdk/cwl/arvados_cwl/pathmapper.py
sdk/cwl/arvados_cwl/runner.py
sdk/cwl/setup.py
sdk/cwl/tests/test_container.py
sdk/cwl/tests/test_make_output.py
sdk/go/arvados/container.go
services/api/app/controllers/arvados/v1/collections_controller.rb
services/api/app/models/container.rb
services/api/app/models/container_request.rb
services/api/app/models/node.rb
services/api/db/migrate/20161111143147_add_scheduling_parameters_to_container.rb [new file with mode: 0644]
services/api/db/structure.sql
services/api/test/fixtures/nodes.yml
services/api/test/functional/arvados/v1/collections_controller_test.rb
services/api/test/unit/container_request_test.rb
services/api/test/unit/node_test.rb
services/crunch-dispatch-slurm/crunch-dispatch-slurm.go
services/crunch-dispatch-slurm/crunch-dispatch-slurm_test.go
services/keepstore/azure_blob_volume.go
services/keepstore/azure_blob_volume_test.go
services/keepstore/collision.go
services/keepstore/config.go
services/keepstore/config_test.go [new file with mode: 0644]
services/keepstore/handler_test.go
services/keepstore/handlers.go
services/keepstore/handlers_with_generic_volume_test.go
services/keepstore/keepstore_test.go
services/keepstore/pull_worker.go
services/keepstore/s3_volume.go
services/keepstore/s3_volume_test.go
services/keepstore/trash_worker_test.go
services/keepstore/volume.go
services/keepstore/volume_generic_test.go
services/keepstore/volume_test.go
services/keepstore/volume_unix.go
services/keepstore/volume_unix_test.go
tools/keep-exercise/keep-exercise.go

index f68250ba15bbf77d9e25821632f3d37a0154e25c..c9ce8ce0b748a9473d2cd5f80739d070f1f8aef5 100644 (file)
@@ -13,6 +13,7 @@ class ApplicationController < ActionController::Base
   # Methods that don't require login should
   #   skip_around_filter :require_thread_api_token
   around_filter :require_thread_api_token, except: ERROR_ACTIONS
+  before_filter :ensure_arvados_api_exists, only: [:index, :show]
   before_filter :set_cache_buster
   before_filter :accept_uuid_as_id_param, except: ERROR_ACTIONS
   before_filter :check_user_agreements, except: ERROR_ACTIONS
@@ -213,6 +214,13 @@ class ApplicationController < ActionController::Base
     end
   end
 
+  def ensure_arvados_api_exists
+    if model_class.is_a?(Class) && model_class < ArvadosBase && !model_class.api_exists?(params['action'].to_sym)
+      @errors = ["#{params['action']} method is not supported for #{params['controller']}"]
+      return render_error(status: 404)
+    end
+  end
+
   def index
     find_objects_for_index if !@objects
     render_index
@@ -760,7 +768,11 @@ class ApplicationController < ActionController::Base
   }
 
   @@notification_tests.push lambda { |controller, current_user|
-    PipelineInstance.limit(1).where(created_by: current_user.uuid).each do
+    if PipelineInstance.api_exists?(:index)
+      PipelineInstance.limit(1).where(created_by: current_user.uuid).each do
+        return nil
+      end
+    else
       return nil
     end
     return lambda { |view|
@@ -856,12 +868,14 @@ class ApplicationController < ActionController::Base
   def recent_processes lim
     lim = 12 if lim.nil?
 
-    cols = %w(uuid owner_uuid created_at modified_at pipeline_template_uuid name state started_at finished_at)
-    pipelines = PipelineInstance.select(cols).limit(lim).order(["created_at desc"])
+    procs = {}
+    if PipelineInstance.api_exists?(:index)
+      cols = %w(uuid owner_uuid created_at modified_at pipeline_template_uuid name state started_at finished_at)
+      pipelines = PipelineInstance.select(cols).limit(lim).order(["created_at desc"])
+      pipelines.results.each { |pi| procs[pi] = pi.created_at }
+    end
 
     crs = ContainerRequest.limit(lim).order(["created_at desc"]).filter([["requesting_container_uuid", "=", nil]])
-    procs = {}
-    pipelines.results.each { |pi| procs[pi] = pi.created_at }
     crs.results.each { |c| procs[c] = c.created_at }
 
     Hash[procs.sort_by {|key, value| value}].keys.reverse.first(lim)
index 20b227c3c7277d491c74b96c0f5de7bc415c0f4c..46dcab6dce38487a7bb938472e91f209f4e926e2 100644 (file)
@@ -239,12 +239,15 @@ class CollectionsController < ApplicationController
         render 'hash_matches'
         return
       else
-        jobs_with = lambda do |conds|
-          Job.limit(RELATION_LIMIT).where(conds)
-            .results.sort_by { |j| j.finished_at || j.created_at }
+        if Job.api_exists?(:index)
+          jobs_with = lambda do |conds|
+            Job.limit(RELATION_LIMIT).where(conds)
+              .results.sort_by { |j| j.finished_at || j.created_at }
+          end
+          @output_of = jobs_with.call(output: @object.portable_data_hash)
+          @log_of = jobs_with.call(log: @object.portable_data_hash)
         end
-        @output_of = jobs_with.call(output: @object.portable_data_hash)
-        @log_of = jobs_with.call(log: @object.portable_data_hash)
+
         @project_links = Link.limit(RELATION_LIMIT).order("modified_at DESC")
           .where(head_uuid: @object.uuid, link_class: 'name').results
         project_hash = Group.where(uuid: @project_links.map(&:tail_uuid)).to_hash
index 16212a8d0ad489b381aa3619d69d72443905cfcb..0a2044a0e23e96b741d77658dfa91057fe57bdfa 100644 (file)
@@ -53,6 +53,19 @@ class ProjectsController < ApplicationController
   # It also seems to me that something like these could be used to configure the contents of the panes.
   def show_pane_list
     pane_list = []
+
+    procs = ["arvados#containerRequest"]
+    if PipelineInstance.api_exists?(:index)
+      procs << "arvados#pipelineInstance"
+    end
+
+    workflows = ["arvados#workflow"]
+    workflows_pane_name = 'Workflows'
+    if PipelineTemplate.api_exists?(:index)
+      workflows << "arvados#pipelineTemplate"
+      workflows_pane_name = 'Pipeline_templates'
+    end
+
     if @object.uuid != current_user.andand.uuid
       pane_list << 'Description'
     end
@@ -64,12 +77,12 @@ class ProjectsController < ApplicationController
     pane_list <<
       {
         :name => 'Pipelines_and_processes',
-        :filters => [%w(uuid is_a) + [%w(arvados#containerRequest arvados#pipelineInstance)]]
+        :filters => [%w(uuid is_a) + [procs]]
       }
     pane_list <<
       {
-        :name => 'Pipeline_templates',
-        :filters => [%w(uuid is_a) + [%w(arvados#pipelineTemplate arvados#workflow)]]
+        :name => workflows_pane_name,
+        :filters => [%w(uuid is_a) + [workflows]]
       }
     pane_list <<
       {
@@ -213,6 +226,10 @@ class ProjectsController < ApplicationController
       @name_link_for = {}
       kind_filters.each do |attr,op,val|
         (val.is_a?(Array) ? val : [val]).each do |type|
+          klass = type.split('#')[-1]
+          klass[0] = klass[0].capitalize
+          next if(!Object.const_get(klass).api_exists?(:index))
+
           filters = @filters - kind_filters + [['uuid', 'is_a', type]]
           if type == 'arvados#containerRequest'
             filters = filters + [['container_requests.requesting_container_uuid', '=', nil]]
index 6b5f114a66fda426aa2353fda68c48daf87c7b3e..fe53ac403c3faccf50cc17e2159be32a120e585e 100644 (file)
@@ -6,8 +6,10 @@ class WorkUnitTemplatesController < ApplicationController
     @filters = @filters || []
 
     # get next page of pipeline_templates
-    filters = @filters + [["uuid", "is_a", ["arvados#pipelineTemplate"]]]
-    pipelines = PipelineTemplate.limit(@limit).order(["created_at desc"]).filter(filters)
+    if PipelineTemplate.api_exists?(:index)
+      filters = @filters + [["uuid", "is_a", ["arvados#pipelineTemplate"]]]
+      pipelines = PipelineTemplate.limit(@limit).order(["created_at desc"]).filter(filters)
+    end
 
     # get next page of workflows
     filters = @filters + [["uuid", "is_a", ["arvados#workflow"]]]
index fe6bff1cee4dfd7fa42a8b487376713808cbbd29..3b611aa25b74e28663d9b7ecc2b0647670f066c8 100644 (file)
@@ -14,12 +14,16 @@ class WorkUnitsController < ApplicationController
     @filters = @filters || []
 
     # get next page of pipeline_instances
-    filters = @filters + [["uuid", "is_a", ["arvados#pipelineInstance"]]]
-    pipelines = PipelineInstance.limit(@limit).order(["created_at desc"]).filter(filters)
+    if PipelineInstance.api_exists?(:index)
+      filters = @filters + [["uuid", "is_a", ["arvados#pipelineInstance"]]]
+      pipelines = PipelineInstance.limit(@limit).order(["created_at desc"]).filter(filters)
+    end
 
     # get next page of jobs
-    filters = @filters + [["uuid", "is_a", ["arvados#job"]]]
-    jobs = Job.limit(@limit).order(["created_at desc"]).filter(filters)
+    if Job.api_exists?(:index)
+      filters = @filters + [["uuid", "is_a", ["arvados#job"]]]
+      jobs = Job.limit(@limit).order(["created_at desc"]).filter(filters)
+    end
 
     # get next page of container_requests
     filters = @filters + [["uuid", "is_a", ["arvados#containerRequest"]]]
index b02db7a6b63b5fad8c75ea5f107baa74a58d7151..6250daa06a3d0c65d2233f51c33588a9de3855a5 100644 (file)
@@ -334,7 +334,7 @@ class ArvadosBase < ActiveRecord::Base
   end
 
   def self.creatable?
-    current_user.andand.is_active
+    current_user.andand.is_active && api_exists?(:create)
   end
 
   def self.goes_in_projects?
@@ -361,6 +361,10 @@ class ArvadosBase < ActiveRecord::Base
     editable?
   end
 
+  def self.api_exists?(method)
+    arvados_api_client.discovery[:resources][self.to_s.underscore.pluralize.to_sym].andand[:methods].andand[method]
+  end
+
   # Array of strings that are the names of attributes that can be edited
   # with X-Editable.
   def editable_attributes
diff --git a/apps/workbench/app/views/projects/_show_workflows.html.erb b/apps/workbench/app/views/projects/_show_workflows.html.erb
new file mode 100644 (file)
index 0000000..133fddc
--- /dev/null
@@ -0,0 +1,5 @@
+<%= render_pane 'tab_contents', to_string: true, locals: {
+    limit: 50,
+    filters: [['uuid', 'is_a', ["arvados#workflow"]]],
+       sortable_columns: { 'name' => 'workflows.name', 'description' => 'workflows.description' }
+    }.merge(local_assigns) %>
diff --git a/apps/workbench/test/controllers/disabled_api_test.rb b/apps/workbench/test/controllers/disabled_api_test.rb
new file mode 100644 (file)
index 0000000..a41d87f
--- /dev/null
@@ -0,0 +1,63 @@
+require 'test_helper'
+require 'helpers/share_object_helper'
+
+class DisabledApiTest < ActionController::TestCase
+  test "dashboard recent processes when pipeline_instance index API is disabled" do
+    @controller = ProjectsController.new
+
+    dd = ArvadosApiClient.new_or_current.discovery.deep_dup
+    dd[:resources][:pipeline_instances][:methods].delete(:index)
+    ArvadosApiClient.any_instance.stubs(:discovery).returns(dd)
+
+    get :index, {}, session_for(:active)
+    assert_includes @response.body, "zzzzz-xvhdp-cr4runningcntnr" # expect crs
+    assert_not_includes @response.body, "zzzzz-d1hrv-"   # expect no pipelines
+  end
+
+  [
+    [:jobs, JobsController.new],
+    [:job_tasks, JobTasksController.new],
+    [:pipeline_instances, PipelineInstancesController.new],
+    [:pipeline_templates, PipelineTemplatesController.new],
+  ].each do |ctrl_name, ctrl|
+    test "#{ctrl_name} index page when API is disabled" do
+      @controller = ctrl
+
+      dd = ArvadosApiClient.new_or_current.discovery.deep_dup
+      dd[:resources][ctrl_name][:methods].delete(:index)
+      ArvadosApiClient.any_instance.stubs(:discovery).returns(dd)
+
+      get :index, {}, session_for(:active)
+      assert_response 404
+    end
+  end
+
+  [
+    :active,
+    nil,
+  ].each do |user|
+    test "project tabs as user #{user} when pipeline related index APIs are disabled" do
+      @controller = ProjectsController.new
+
+      Rails.configuration.anonymous_user_token = api_fixture('api_client_authorizations')['anonymous']['api_token']
+
+      dd = ArvadosApiClient.new_or_current.discovery.deep_dup
+      dd[:resources][:pipeline_templates][:methods].delete(:index)
+      ArvadosApiClient.any_instance.stubs(:discovery).returns(dd)
+
+      proj_uuid = api_fixture('groups')['anonymously_accessible_project']['uuid']
+
+      if user
+        get(:show, {id: proj_uuid}, session_for(user))
+      else
+        get(:show, {id: proj_uuid})
+      end
+
+      resp = @response.body
+      assert_includes resp, "href=\"#Data_collections\""
+      assert_includes resp, "href=\"#Pipelines_and_processes\""
+      assert_includes resp, "href=\"#Workflows\""
+      assert_not_includes resp, "href=\"#Pipeline_templates\""
+    end
+  end
+end
diff --git a/apps/workbench/test/unit/disabled_api_test.rb b/apps/workbench/test/unit/disabled_api_test.rb
new file mode 100644 (file)
index 0000000..52e3bd1
--- /dev/null
@@ -0,0 +1,15 @@
+require 'test_helper'
+
+class DisabledApiTest < ActiveSupport::TestCase
+  test 'Job.creatable? reflects whether jobs.create API is enabled' do
+    use_token(:active) do
+      assert(Job.creatable?)
+    end
+    dd = ArvadosApiClient.new_or_current.discovery.deep_dup
+    dd[:resources][:jobs][:methods].delete(:create)
+    ArvadosApiClient.any_instance.stubs(:discovery).returns(dd)
+    use_token(:active) do
+      refute(Job.creatable?)
+    end
+  end
+end
index 12c92607de51e3fd4aa7b6ec129e438b9427b1ec..320f9d445c3a052a62bf5b8560b2080c98b06904 100755 (executable)
@@ -431,6 +431,8 @@ package_go_binary tools/keep-block-check keep-block-check \
     "Verify that all data from one set of Keep servers to another was copied"
 package_go_binary tools/keep-rsync keep-rsync \
     "Copy all data from one set of Keep servers to another"
+package_go_binary tools/keep-exercise keep-exercise \
+    "Performance testing tool for Arvados Keep"
 
 # The Python SDK
 # Please resist the temptation to add --no-python-fix-name to the fpm call here
@@ -476,7 +478,7 @@ fpm_build ruamel.yaml "" "" python 0.12.4 --python-setup-py-arguments "--single-
 fpm_build cwltest "" "" python 1.0.20160907111242
 
 # And for cwltool we have the same problem as for schema_salad. Ward, 2016-03-17
-fpm_build cwltool "" "" python 1.0.20161007181528
+fpm_build cwltool "" "" python 1.0.20161107145355
 
 # FPM eats the trailing .0 in the python-rdflib-jsonld package when built with 'rdflib-jsonld>=0.3.0'. Force the version. Ward, 2016-03-25
 fpm_build rdflib-jsonld "" "" python 0.3.0
index 2797ec31093fc5183289123aa916efcaf051533f..8959cfbe09c3ea7ac6ded2142b626259787d2121 100755 (executable)
@@ -93,6 +93,7 @@ sdk/go/streamer
 sdk/go/crunchrunner
 sdk/cwl
 tools/crunchstat-summary
+tools/keep-exercise
 tools/keep-rsync
 tools/keep-block-check
 
@@ -158,8 +159,8 @@ sanity_checks() {
     echo -n 'go: '
     go version \
         || fatal "No go binary. See http://golang.org/doc/install"
-    [[ $(go version) =~ go1.([0-9]+) ]] && [[ ${BASH_REMATCH[1]} -ge 6 ]] \
-        || fatal "Go >= 1.6 required. See http://golang.org/doc/install"
+    [[ $(go version) =~ go1.([0-9]+) ]] && [[ ${BASH_REMATCH[1]} -ge 7 ]] \
+        || fatal "Go >= 1.7 required. See http://golang.org/doc/install"
     echo -n 'gcc: '
     gcc --version | egrep ^gcc \
         || fatal "No gcc. Try: apt-get install build-essential"
@@ -764,8 +765,9 @@ gostuff=(
     services/crunch-dispatch-local
     services/crunch-dispatch-slurm
     services/crunch-run
-    tools/keep-rsync
     tools/keep-block-check
+    tools/keep-exercise
+    tools/keep-rsync
     )
 for g in "${gostuff[@]}"
 do
index bcb11d1d706d1fc6be68b340d0038daf6cc43266..960d7848de23b998ee4ce3d47edb38d35af54ea5 100755 (executable)
@@ -380,6 +380,8 @@ class WhRunPipelineInstance
           value = params[parametername.to_s]
         elsif parameter.has_key?(:default)
           value = parameter[:default]
+        elsif [false, 'false', 0, '0'].index(parameter[:required])
+          value = nil
         else
           errors << [componentname, parametername, "required parameter is missing"]
           next
index 6778eb0222a2fb404e0754578d698df91e792df0..b3d47dd8d05e5981ae4f645fa9c968ee7707e747 100644 (file)
@@ -202,14 +202,28 @@ class ArvCwlRunner(object):
 
         srccollections = {}
         for k,v in generatemapper.items():
+            if k.startswith("_:"):
+                if v.type == "Directory":
+                    continue
+                if v.type == "CreateFile":
+                    with final.open(v.target, "wb") as f:
+                        f.write(v.resolved.encode("utf-8"))
+                    continue
+
+            if not k.startswith("keep:"):
+                raise Exception("Output source is not in keep or a literal")
             sp = k.split("/")
             srccollection = sp[0][5:]
             if srccollection not in srccollections:
-                srccollections[srccollection] = arvados.collection.CollectionReader(
-                    srccollection,
-                    api_client=self.api,
-                    keep_client=self.keep_client,
-                    num_retries=self.num_retries)
+                try:
+                    srccollections[srccollection] = arvados.collection.CollectionReader(
+                        srccollection,
+                        api_client=self.api,
+                        keep_client=self.keep_client,
+                        num_retries=self.num_retries)
+                except arvados.errors.ArgumentError as e:
+                    logger.error("Creating CollectionReader for '%s' '%s': %s", k, v, e)
+                    raise
             reader = srccollections[srccollection]
             try:
                 srcpath = "/".join(sp[1:]) if len(sp) > 1 else "."
@@ -219,7 +233,7 @@ class ArvCwlRunner(object):
 
         def rewrite(fileobj):
             fileobj["location"] = generatemapper.mapper(fileobj["location"]).target
-            for k in ("basename", "size", "listing"):
+            for k in ("basename", "listing", "contents"):
                 if k in fileobj:
                     del fileobj[k]
 
@@ -242,7 +256,13 @@ class ArvCwlRunner(object):
                 "head_uuid": final_uuid, "link_class": "tag", "name": tag
                 }).execute(num_retries=self.num_retries)
 
-        self.final_output_collection = final
+        def finalcollection(fileobj):
+            fileobj["location"] = "keep:%s/%s" % (final.portable_data_hash(), fileobj["location"])
+
+        adjustDirObjs(outputObj, finalcollection)
+        adjustFileObjs(outputObj, finalcollection)
+
+        return (outputObj, final)
 
     def set_crunch_output(self):
         if self.work_api == "containers":
@@ -400,7 +420,7 @@ class ArvCwlRunner(object):
                 self.output_name = "Output of %s" % (shortname(tool.tool["id"]))
             if self.output_tags is None:
                 self.output_tags = ""
-            self.make_output_collection(self.output_name, self.output_tags, self.final_output)
+            self.final_output, self.final_output_collection = self.make_output_collection(self.output_name, self.output_tags, self.final_output)
             self.set_crunch_output()
 
         if self.final_status != "success":
index afcc29a21a424f16aaa3f7aad96f1644fd43b308..e7cd617baee8d063da303aac89a93f434234c551 100644 (file)
@@ -42,6 +42,7 @@ class ArvadosContainer(object):
                 "kind": "tmp"
             }
         }
+        scheduling_parameters = {}
 
         dirs = set()
         for f in self.pathmapper.files():
@@ -102,11 +103,12 @@ class ArvadosContainer(object):
 
         partition_req, _ = get_feature(self, "http://arvados.org/cwl#PartitionRequirement")
         if partition_req:
-            runtime_constraints["partition"] = aslist(partition_req["partition"])
+            scheduling_parameters["partitions"] = aslist(partition_req["partition"])
 
         container_request["mounts"] = mounts
         container_request["runtime_constraints"] = runtime_constraints
         container_request["use_existing"] = kwargs.get("enable_reuse", True)
+        container_request["scheduling_parameters"] = scheduling_parameters
 
         try:
             response = self.arvrunner.api.container_requests().create(
index f48d8bbe11c02a7293586840c144fa8b7558f9da..4db23b98a961904675727a13c33bf91cd3aa1f55 100644 (file)
@@ -85,6 +85,8 @@ class ArvadosJob(object):
         with Perf(metrics, "arv_docker_get_image %s" % self.name):
             (docker_req, docker_is_req) = get_feature(self, "DockerRequirement")
             if docker_req and kwargs.get("use_container") is not False:
+                if docker_req.get("dockerOutputDirectory"):
+                    raise UnsupportedRequirement("Option 'dockerOutputDirectory' of DockerRequirement not supported.")
                 runtime_constraints["docker_image"] = arv_docker_get_image(self.arvrunner.api, docker_req, pull_image, self.arvrunner.project_uuid)
             else:
                 runtime_constraints["docker_image"] = arvados_jobs_image(self.arvrunner)
index 8eb8fe6fee50e0722ee7066171ff7b7bbc4a10c5..ce633d43285a537268f3bc96dc446696d17d06a6 100644 (file)
@@ -87,6 +87,8 @@ class ArvadosWorkflow(Workflow):
                 joborder_keepmount = copy.deepcopy(joborder)
 
                 def keepmount(obj):
+                    if "location" not in obj:
+                        raise WorkflowException("%s object is missing required 'location' field: %s" % (obj["class"], obj))
                     if obj["location"].startswith("keep:"):
                         obj["location"] = "/keep/" + obj["location"][5:]
                         if "listing" in obj:
index 73c81ceb0fcdb033203c1b7e5425b3875ea121d6..58500d3a993ddb74327c419925c2aed2b769a1b6 100644 (file)
@@ -150,27 +150,31 @@ class ArvPathMapper(PathMapper):
         else:
             return super(ArvPathMapper, self).reversemap(target)
 
-class InitialWorkDirPathMapper(PathMapper):
+class StagingPathMapper(PathMapper):
+    _follow_dirs = True
 
     def visit(self, obj, stagedir, basedir, copy=False):
         # type: (Dict[unicode, Any], unicode, unicode, bool) -> None
         loc = obj["location"]
+        tgt = os.path.join(stagedir, obj["basename"])
         if obj["class"] == "Directory":
-            self._pathmap[loc] = MapperEnt(obj["location"], stagedir, "Directory")
-            self.visitlisting(obj.get("listing", []), stagedir, basedir)
+            self._pathmap[loc] = MapperEnt(loc, tgt, "Directory")
+            if loc.startswith("_:") or self._follow_dirs:
+                self.visitlisting(obj.get("listing", []), tgt, basedir)
         elif obj["class"] == "File":
             if loc in self._pathmap:
                 return
-            tgt = os.path.join(stagedir, obj["basename"])
-            if "contents" in obj and obj["location"].startswith("_:"):
+            if "contents" in obj and loc.startswith("_:"):
                 self._pathmap[loc] = MapperEnt(obj["contents"], tgt, "CreateFile")
             else:
                 if copy:
-                    self._pathmap[loc] = MapperEnt(obj["path"], tgt, "WritableFile")
+                    self._pathmap[loc] = MapperEnt(loc, tgt, "WritableFile")
                 else:
-                    self._pathmap[loc] = MapperEnt(obj["path"], tgt, "File")
+                    self._pathmap[loc] = MapperEnt(loc, tgt, "File")
                 self.visitlisting(obj.get("secondaryFiles", []), stagedir, basedir)
 
+
+class InitialWorkDirPathMapper(StagingPathMapper):
     def setup(self, referenced_files, basedir):
         # type: (List[Any], unicode) -> None
 
@@ -183,19 +187,8 @@ class InitialWorkDirPathMapper(PathMapper):
                 self._pathmap[path] = MapperEnt("$(task.keep)/%s" % ab[5:], tgt, type)
 
 
-class FinalOutputPathMapper(PathMapper):
-    def visit(self, obj, stagedir, basedir, copy=False):
-        # type: (Dict[unicode, Any], unicode, unicode, bool) -> None
-        loc = obj["location"]
-        if obj["class"] == "Directory":
-            self._pathmap[loc] = MapperEnt(loc, stagedir, "Directory")
-        elif obj["class"] == "File":
-            if loc in self._pathmap:
-                return
-            tgt = os.path.join(stagedir, obj["basename"])
-            self._pathmap[loc] = MapperEnt(loc, tgt, "File")
-            self.visitlisting(obj.get("secondaryFiles", []), stagedir, basedir)
-
+class FinalOutputPathMapper(StagingPathMapper):
+    _follow_dirs = False
     def setup(self, referenced_files, basedir):
         # type: (List[Any], unicode) -> None
         self.visitlisting(referenced_files, self.stagedir, basedir)
index 2b5d186843fc8b3162ffaf0f21ad2b07b68dfb86..5cc447e9a3bad9202d9e77fb53919dcc66b804c8 100644 (file)
@@ -9,9 +9,11 @@ from cStringIO import StringIO
 import cwltool.draft2tool
 from cwltool.draft2tool import CommandLineTool
 import cwltool.workflow
-from cwltool.process import get_feature, scandeps, UnsupportedRequirement, normalizeFilesDirs
+from cwltool.process import get_feature, scandeps, UnsupportedRequirement, normalizeFilesDirs, shortname
 from cwltool.load_tool import fetch_document
 from cwltool.pathmapper import adjustFileObjs, adjustDirObjs
+from cwltool.utils import aslist
+from cwltool.builder import substitute
 
 import arvados.collection
 import ruamel.yaml as yaml
@@ -108,6 +110,9 @@ def upload_docker(arvrunner, tool):
     if isinstance(tool, CommandLineTool):
         (docker_req, docker_is_req) = get_feature(tool, "DockerRequirement")
         if docker_req:
+            if docker_req.get("dockerOutputDirectory"):
+                # TODO: can be supported by containers API, but not jobs API.
+                raise UnsupportedRequirement("Option 'dockerOutputDirectory' of DockerRequirement not supported.")
             arv_docker_get_image(arvrunner.api, docker_req, True, arvrunner.project_uuid)
     elif isinstance(tool, cwltool.workflow.Workflow):
         for s in tool.steps:
@@ -116,6 +121,19 @@ def upload_docker(arvrunner, tool):
 def upload_instance(arvrunner, name, tool, job_order):
         upload_docker(arvrunner, tool)
 
+        for t in tool.tool["inputs"]:
+            def setSecondary(fileobj):
+                if isinstance(fileobj, dict) and fileobj.get("class") == "File":
+                    if "secondaryFiles" not in fileobj:
+                        fileobj["secondaryFiles"] = [{"location": substitute(fileobj["location"], sf), "class": "File"} for sf in t["secondaryFiles"]]
+
+                if isinstance(fileobj, list):
+                    for e in fileobj:
+                        setSecondary(e)
+
+            if shortname(t["id"]) in job_order and t.get("secondaryFiles"):
+                setSecondary(job_order[shortname(t["id"])])
+
         workflowmapper = upload_dependencies(arvrunner,
                                              name,
                                              tool.doc_loader,
index d1c8f9b567839bb6aaf1e78db2d6855b9a6038c2..9d9a1e1a7acf99f46d61d96de384681da114925a 100644 (file)
@@ -48,7 +48,7 @@ setup(name='arvados-cwl-runner',
       # Make sure to update arvados/build/run-build-packages.sh as well
       # when updating the cwltool version pin.
       install_requires=[
-          'cwltool==1.0.20161007181528',
+          'cwltool==1.0.20161107145355',
           'arvados-python-client>=0.1.20160826210445'
       ],
       data_files=[
index 93100ae9f76026c745fd5fdc3b011af550ac0079..bb4bac31dd1767081cdc12a313496a4bb13b4546 100644 (file)
@@ -68,7 +68,8 @@ class TestContainer(unittest.TestCase):
                         'output_path': '/var/spool/cwl',
                         'container_image': '99999999999999999999999999999993+99',
                         'command': ['ls', '/var/spool/cwl'],
-                        'cwd': '/var/spool/cwl'
+                        'cwd': '/var/spool/cwl',
+                        'scheduling_parameters': {}
                     })
 
     # The test passes some fields in builder.resources
@@ -113,8 +114,9 @@ class TestContainer(unittest.TestCase):
                              make_fs_access=make_fs_access, tmpdir="/tmp"):
             j.run()
 
-        runner.api.container_requests().create.assert_called_with(
-            body={
+        call_args, call_kwargs = runner.api.container_requests().create.call_args
+
+        call_body_expected = {
                 'environment': {
                     'HOME': '/var/spool/cwl',
                     'TMPDIR': '/tmp'
@@ -124,8 +126,7 @@ class TestContainer(unittest.TestCase):
                     'vcpus': 3,
                     'ram': 3145728000,
                     'keep_cache_ram': 512,
-                    'API': True,
-                    'partition': ['blurb']
+                    'API': True
                 },
                 'use_existing': True,
                 'priority': 1,
@@ -137,8 +138,16 @@ class TestContainer(unittest.TestCase):
                 'output_path': '/var/spool/cwl',
                 'container_image': '99999999999999999999999999999993+99',
                 'command': ['ls'],
-                'cwd': '/var/spool/cwl'
-            })
+                'cwd': '/var/spool/cwl',
+                'scheduling_parameters': {
+                    'partitions': ['blurb']
+                }
+        }
+
+        call_body = call_kwargs.get('body', None)
+        self.assertNotEqual(None, call_body)
+        for key in call_body:
+            self.assertEqual(call_body_expected.get(key), call_body.get(key))
 
     @mock.patch("arvados.collection.Collection")
     def test_done(self, col):
index a1cb605bfc83f78605c72b8eb10c47c73e1c95de..53f379f1a5ac0cc488af5157080e78933d543367 100644 (file)
@@ -35,7 +35,7 @@ class TestMakeOutput(unittest.TestCase):
         final.open.return_value = openmock
         openmock.__enter__.return_value = cwlout
 
-        runner.make_output_collection("Test output", "tag0,tag1,tag2", {
+        _, runner.final_output_collection = runner.make_output_collection("Test output", "tag0,tag1,tag2", {
             "foo": {
                 "class": "File",
                 "location": "keep:99999999999999999999999999999991+99/foo.txt",
@@ -45,7 +45,8 @@ class TestMakeOutput(unittest.TestCase):
             "bar": {
                 "class": "File",
                 "location": "keep:99999999999999999999999999999992+99/bar.txt",
-                "basename": "baz.txt"
+                "basename": "baz.txt",
+                "size": 4
             }
         })
 
@@ -55,11 +56,13 @@ class TestMakeOutput(unittest.TestCase):
         self.assertEqual("""{
     "bar": {
         "class": "File",
-        "location": "baz.txt"
+        "location": "baz.txt",
+        "size": 4
     },
     "foo": {
         "class": "File",
-        "location": "foo.txt"
+        "location": "foo.txt",
+        "size": 3
     }
 }""", cwlout.getvalue())
 
index 6a76f1f396a32c89544f55030cb586ae413d0c0b..61c14ea0b6c1d445bb2a26fb83a57614e0b240f9 100644 (file)
@@ -2,18 +2,19 @@ package arvados
 
 // Container is an arvados#container resource.
 type Container struct {
-       UUID               string             `json:"uuid"`
-       Command            []string           `json:"command"`
-       ContainerImage     string             `json:"container_image"`
-       Cwd                string             `json:"cwd"`
-       Environment        map[string]string  `json:"environment"`
-       LockedByUUID       string             `json:"locked_by_uuid"`
-       Mounts             map[string]Mount   `json:"mounts"`
-       Output             string             `json:"output"`
-       OutputPath         string             `json:"output_path"`
-       Priority           int                `json:"priority"`
-       RuntimeConstraints RuntimeConstraints `json:"runtime_constraints"`
-       State              ContainerState     `json:"state"`
+       UUID                 string               `json:"uuid"`
+       Command              []string             `json:"command"`
+       ContainerImage       string               `json:"container_image"`
+       Cwd                  string               `json:"cwd"`
+       Environment          map[string]string    `json:"environment"`
+       LockedByUUID         string               `json:"locked_by_uuid"`
+       Mounts               map[string]Mount     `json:"mounts"`
+       Output               string               `json:"output"`
+       OutputPath           string               `json:"output_path"`
+       Priority             int                  `json:"priority"`
+       RuntimeConstraints   RuntimeConstraints   `json:"runtime_constraints"`
+       State                ContainerState       `json:"state"`
+       SchedulingParameters SchedulingParameters `json:"scheduling_parameters"`
 }
 
 // Mount is special behavior to attach to a filesystem path or device.
@@ -31,10 +32,15 @@ type Mount struct {
 // CPU) and network connectivity.
 type RuntimeConstraints struct {
        API          *bool
-       RAM          int      `json:"ram"`
-       VCPUs        int      `json:"vcpus"`
-       KeepCacheRAM int      `json:"keep_cache_ram"`
-       Partition    []string `json:"partition"`
+       RAM          int `json:"ram"`
+       VCPUs        int `json:"vcpus"`
+       KeepCacheRAM int `json:"keep_cache_ram"`
+}
+
+// SchedulingParameters specify a container's scheduling parameters
+// such as Partitions
+type SchedulingParameters struct {
+       Partitions []string `json:"partitions"`
 }
 
 // ContainerList is an arvados#containerList resource.
index 44733cdfb82ff1c21c4ca379a723110ebcaf5721..922cf7dac16b87741013c23e4073d4070a6fbe43 100644 (file)
@@ -182,10 +182,10 @@ class Arvados::V1::CollectionsController < ApplicationController
   protected
 
   def load_limit_offset_order_params *args
+    super
     if action_name == 'index'
       # Omit manifest_text from index results unless expressly selected.
       @select ||= model_class.selectable_attributes - ["manifest_text"]
     end
-    super
   end
 end
index b1ea9bd230a47e2382dbb12ac0c0d6bee6929588..52f1cba723ed5744af3bf226265f9bb600d4f61f 100644 (file)
@@ -11,6 +11,7 @@ class Container < ArvadosModel
   serialize :mounts, Hash
   serialize :runtime_constraints, Hash
   serialize :command, Array
+  serialize :scheduling_parameters, Hash
 
   before_validation :fill_field_defaults, :if => :new_record?
   before_validation :set_timestamps
@@ -44,6 +45,7 @@ class Container < ArvadosModel
     t.add :started_at
     t.add :state
     t.add :auth_uuid
+    t.add :scheduling_parameters
   end
 
   # Supported states for a container
@@ -180,6 +182,7 @@ class Container < ArvadosModel
     self.mounts ||= {}
     self.cwd ||= "."
     self.priority ||= 1
+    self.scheduling_parameters ||= {}
   end
 
   def permission_to_create
@@ -222,7 +225,7 @@ class Container < ArvadosModel
     if self.new_record?
       permitted.push(:owner_uuid, :command, :container_image, :cwd,
                      :environment, :mounts, :output_path, :priority,
-                     :runtime_constraints)
+                     :runtime_constraints, :scheduling_parameters)
     end
 
     case self.state
@@ -326,6 +329,9 @@ class Container < ArvadosModel
     if self.runtime_constraints_changed?
       self.runtime_constraints = self.class.deep_sort_hash(self.runtime_constraints)
     end
+    if self.scheduling_parameters_changed?
+      self.scheduling_parameters = self.class.deep_sort_hash(self.scheduling_parameters)
+    end
   end
 
   def handle_completed
@@ -348,7 +354,8 @@ class Container < ArvadosModel
             output_path: self.output_path,
             container_image: self.container_image,
             mounts: self.mounts,
-            runtime_constraints: self.runtime_constraints
+            runtime_constraints: self.runtime_constraints,
+            scheduling_parameters: self.scheduling_parameters
           }
           c = Container.create! c_attrs
           retryable_requests.each do |cr|
index 05738de81e50627654e62e3f3a34d6cc46754460..7dcfbe378b6700a88b41ae11b6a15e8a28a6fe20 100644 (file)
@@ -11,9 +11,11 @@ class ContainerRequest < ArvadosModel
   serialize :mounts, Hash
   serialize :runtime_constraints, Hash
   serialize :command, Array
+  serialize :scheduling_parameters, Hash
 
   before_validation :fill_field_defaults, :if => :new_record?
   before_validation :validate_runtime_constraints
+  before_validation :validate_scheduling_parameters
   before_validation :set_container
   validates :command, :container_image, :output_path, :cwd, :presence => true
   validate :validate_state_change
@@ -42,6 +44,7 @@ class ContainerRequest < ArvadosModel
     t.add :runtime_constraints
     t.add :state
     t.add :use_existing
+    t.add :scheduling_parameters
   end
 
   # Supported states for a container request
@@ -105,6 +108,7 @@ class ContainerRequest < ArvadosModel
     self.mounts ||= {}
     self.cwd ||= "."
     self.container_count_max ||= Rails.configuration.container_count_max
+    self.scheduling_parameters ||= {}
   end
 
   # Create a new container (or find an existing one) to satisfy this
@@ -126,6 +130,7 @@ class ContainerRequest < ArvadosModel
       if not reusable.nil?
         reusable
       else
+        c_attrs[:scheduling_parameters] = self.scheduling_parameters
         Container.create!(c_attrs)
       end
     end
@@ -234,6 +239,17 @@ class ContainerRequest < ArvadosModel
     end
   end
 
+  def validate_scheduling_parameters
+    if self.state == Committed
+      if scheduling_parameters.include? 'partitions' and
+         (!scheduling_parameters['partitions'].is_a?(Array) ||
+          scheduling_parameters['partitions'].reject{|x| !x.is_a?(String)}.size !=
+            scheduling_parameters['partitions'].size)
+            errors.add :scheduling_parameters, "partitions must be an array of strings"
+      end
+    end
+  end
+
   def validate_change
     permitted = [:owner_uuid]
 
@@ -244,7 +260,7 @@ class ContainerRequest < ArvadosModel
                      :container_image, :cwd, :description, :environment,
                      :filters, :mounts, :name, :output_path, :priority,
                      :properties, :requesting_container_uuid, :runtime_constraints,
-                     :state, :container_uuid, :use_existing
+                     :state, :container_uuid, :use_existing, :scheduling_parameters
 
     when Committed
       if container_uuid.nil?
@@ -263,7 +279,7 @@ class ContainerRequest < ArvadosModel
         permitted.push :command, :container_image, :cwd, :description, :environment,
                        :filters, :mounts, :name, :output_path, :properties,
                        :requesting_container_uuid, :runtime_constraints,
-                       :state, :container_uuid
+                       :state, :container_uuid, :scheduling_parameters
       end
 
     when Final
index abb46fdc661128f5321a55b186d54afd142ed5f3..e470e4c2bd9c47a45b395a4c90f4814edf89a417 100644 (file)
@@ -13,6 +13,8 @@ class Node < ArvadosModel
   belongs_to(:job, foreign_key: :job_uuid, primary_key: :uuid)
   attr_accessor :job_readable
 
+  UNUSED_NODE_IP = '127.40.4.0'
+
   api_accessible :user, :extend => :common do |t|
     t.add :hostname
     t.add :domain
@@ -133,20 +135,22 @@ class Node < ArvadosModel
   end
 
   def dns_server_update
-    if self.hostname_changed? or self.ip_address_changed?
-      if not self.ip_address.nil?
-        stale_conflicting_nodes = Node.where('id != ? and ip_address = ? and last_ping_at < ?',self.id,self.ip_address,10.minutes.ago)
-        if not stale_conflicting_nodes.empty?
-          # One or more stale compute node records have the same IP address as the new node.
-          # Clear the ip_address field on the stale nodes.
-          stale_conflicting_nodes.each do |stale_node|
-            stale_node.ip_address = nil
-            stale_node.save!
-          end
+    if hostname_changed? && hostname_was
+      self.class.dns_server_update(hostname_was, UNUSED_NODE_IP)
+    end
+    if hostname_changed? or ip_address_changed?
+      if ip_address
+        Node.where('id != ? and ip_address = ? and last_ping_at < ?',
+                   id, ip_address, 10.minutes.ago).each do |stale_node|
+          # One or more stale compute node records have the same IP
+          # address as the new node.  Clear the ip_address field on
+          # the stale nodes.
+          stale_node.ip_address = nil
+          stale_node.save!
         end
       end
-      if self.hostname and self.ip_address
-        self.class.dns_server_update(self.hostname, self.ip_address)
+      if hostname
+        self.class.dns_server_update(hostname, ip_address || UNUSED_NODE_IP)
       end
     end
   end
@@ -225,7 +229,7 @@ class Node < ArvadosModel
       if !File.exists? hostfile
         n = Node.where(:slot_number => slot_number).first
         if n.nil? or n.ip_address.nil?
-          dns_server_update(hostname, '127.40.4.0')
+          dns_server_update(hostname, UNUSED_NODE_IP)
         else
           dns_server_update(hostname, n.ip_address)
         end
diff --git a/services/api/db/migrate/20161111143147_add_scheduling_parameters_to_container.rb b/services/api/db/migrate/20161111143147_add_scheduling_parameters_to_container.rb
new file mode 100644 (file)
index 0000000..1b317cf
--- /dev/null
@@ -0,0 +1,6 @@
+class AddSchedulingParametersToContainer < ActiveRecord::Migration
+  def change
+    add_column :containers, :scheduling_parameters, :text
+    add_column :container_requests, :scheduling_parameters, :text
+  end
+end
index 0db782af69484e6a8e0c476620891702055f36c7..1d3d238c837611e2858dbfd0959cfc7373a41917 100644 (file)
@@ -291,7 +291,8 @@ CREATE TABLE container_requests (
     filters text,
     updated_at timestamp without time zone NOT NULL,
     container_count integer DEFAULT 0,
-    use_existing boolean DEFAULT true
+    use_existing boolean DEFAULT true,
+    scheduling_parameters text
 );
 
 
@@ -343,7 +344,8 @@ CREATE TABLE containers (
     updated_at timestamp without time zone NOT NULL,
     exit_code integer,
     auth_uuid character varying(255),
-    locked_by_uuid character varying(255)
+    locked_by_uuid character varying(255),
+    scheduling_parameters text
 );
 
 
@@ -2694,4 +2696,6 @@ INSERT INTO schema_migrations (version) VALUES ('20160909181442');
 
 INSERT INTO schema_migrations (version) VALUES ('20160926194129');
 
-INSERT INTO schema_migrations (version) VALUES ('20161019171346');
\ No newline at end of file
+INSERT INTO schema_migrations (version) VALUES ('20161019171346');
+
+INSERT INTO schema_migrations (version) VALUES ('20161111143147');
\ No newline at end of file
index 489bb1d6605f86d622c824260d96ef89f63bd026..c5516cc38b34dc9329fc759b1c61b2877681f1af 100644 (file)
@@ -47,7 +47,7 @@ was_idle_now_down:
   hostname: compute3
   slot_number: ~
   domain: ""
-  ip_address: 172.17.2.173
+  ip_address: 172.17.2.174
   last_ping_at: <%= 1.hour.ago.to_s(:db) %>
   first_ping_at: <%= 23.hour.ago.to_s(:db) %>
   job_uuid: ~
@@ -62,7 +62,7 @@ new_with_no_hostname:
   owner_uuid: zzzzz-tpzed-000000000000000
   hostname: ~
   slot_number: ~
-  ip_address: 172.17.2.173
+  ip_address: 172.17.2.175
   last_ping_at: ~
   first_ping_at: ~
   job_uuid: ~
@@ -74,7 +74,7 @@ new_with_custom_hostname:
   owner_uuid: zzzzz-tpzed-000000000000000
   hostname: custom1
   slot_number: 23
-  ip_address: 172.17.2.173
+  ip_address: 172.17.2.176
   last_ping_at: ~
   first_ping_at: ~
   job_uuid: ~
index a8583be12bb70d915585c8c48aba0bc06aa32d3e..c85cc1979f99482ff36ba6dc38ba5790ec7bf591 100644 (file)
@@ -46,6 +46,49 @@ class Arvados::V1::CollectionsControllerTest < ActionController::TestCase
     end
   end
 
+  test 'index without select returns everything except manifest' do
+    authorize_with :active
+    get :index
+    assert_response :success
+    assert json_response['items'].any?
+    json_response['items'].each do |coll|
+      assert_includes(coll.keys, 'uuid')
+      assert_includes(coll.keys, 'name')
+      assert_includes(coll.keys, 'created_at')
+      refute_includes(coll.keys, 'manifest_text')
+    end
+  end
+
+  ['', nil, false, 'null'].each do |select|
+    test "index with select=#{select.inspect} returns everything except manifest" do
+      authorize_with :active
+      get :index, select: select
+      assert_response :success
+      assert json_response['items'].any?
+      json_response['items'].each do |coll|
+        assert_includes(coll.keys, 'uuid')
+        assert_includes(coll.keys, 'name')
+        assert_includes(coll.keys, 'created_at')
+        refute_includes(coll.keys, 'manifest_text')
+      end
+    end
+  end
+
+  [["uuid"],
+   ["uuid", "manifest_text"],
+   '["uuid"]',
+   '["uuid", "manifest_text"]'].each do |select|
+    test "index with select=#{select.inspect} returns no name" do
+      authorize_with :active
+      get :index, select: select
+      assert_response :success
+      assert json_response['items'].any?
+      json_response['items'].each do |coll|
+        refute_includes(coll.keys, 'name')
+      end
+    end
+  end
+
   [0,1,2].each do |limit|
     test "get index with limit=#{limit}" do
       authorize_with :active
index 34aa442c0938381e5fb56ebb2c6379616ad93976..1465c7180ad7b59cb947b4cfb825d2957ea17844 100644 (file)
@@ -552,4 +552,36 @@ class ContainerRequestTest < ActiveSupport::TestCase
       end
     end
   end
+
+  [
+    [{"partitions" => ["fastcpu","vfastcpu", 100]}, ContainerRequest::Committed, ActiveRecord::RecordInvalid],
+    [{"partitions" => ["fastcpu","vfastcpu", 100]}, ContainerRequest::Uncommitted],
+    [{"partitions" => "fastcpu"}, ContainerRequest::Committed, ActiveRecord::RecordInvalid],
+    [{"partitions" => "fastcpu"}, ContainerRequest::Uncommitted],
+    [{"partitions" => ["fastcpu","vfastcpu"]}, ContainerRequest::Committed],
+  ].each do |sp, state, expected|
+    test "create container request with scheduling_parameters #{sp} in state #{state} and verify #{expected}" do
+      common_attrs = {cwd: "test",
+                      priority: 1,
+                      command: ["echo", "hello"],
+                      output_path: "test",
+                      scheduling_parameters: sp,
+                      mounts: {"test" => {"kind" => "json"}}}
+      set_user_from_auth :active
+
+      if expected == ActiveRecord::RecordInvalid
+        assert_raises(ActiveRecord::RecordInvalid) do
+          create_minimal_req!(common_attrs.merge({state: state}))
+        end
+      else
+        cr = create_minimal_req!(common_attrs.merge({state: state}))
+        assert_equal sp, cr.scheduling_parameters
+
+        if state == ContainerRequest::Committed
+          c = Container.find_by_uuid(cr.container_uuid)
+          assert_equal sp, c.scheduling_parameters
+        end
+      end
+    end
+  end
 end
index e5b88354fb128e1308c1a00a7c9e297928f191dd..6eb1df56d129f0279c2e86323b865d13fd09817c 100644 (file)
@@ -125,4 +125,31 @@ class NodeTest < ActiveSupport::TestCase
     refute_nil node2.slot_number
     assert_equal "custom1", node2.hostname
   end
+
+  test "update dns when nodemanager clears hostname and ip_address" do
+    act_as_system_user do
+      node = ping_node(:new_with_custom_hostname, {})
+      Node.expects(:dns_server_update).with(node.hostname, Node::UNUSED_NODE_IP)
+      node.update_attributes(hostname: nil, ip_address: nil)
+    end
+  end
+
+  test "update dns when hostname changes" do
+    act_as_system_user do
+      node = ping_node(:new_with_custom_hostname, {})
+
+      Node.expects(:dns_server_update).with(node.hostname, Node::UNUSED_NODE_IP)
+      Node.expects(:dns_server_update).with('foo0', node.ip_address)
+      node.update_attributes!(hostname: 'foo0')
+
+      Node.expects(:dns_server_update).with('foo0', Node::UNUSED_NODE_IP)
+      node.update_attributes!(hostname: nil, ip_address: nil)
+
+      Node.expects(:dns_server_update).with('foo0', '10.11.12.13')
+      node.update_attributes!(hostname: 'foo0', ip_address: '10.11.12.13')
+
+      Node.expects(:dns_server_update).with('foo0', '10.11.12.14')
+      node.update_attributes!(hostname: 'foo0', ip_address: '10.11.12.14')
+    end
+  end
 end
index 0c1ce49592a6b08223271d440dca41f3a5d8fd46..f28d4c2826dcd9aff2069d749dab0c706520590d 100644 (file)
@@ -127,8 +127,8 @@ func sbatchFunc(container arvados.Container) *exec.Cmd {
        sbatchArgs = append(sbatchArgs, fmt.Sprintf("--job-name=%s", container.UUID))
        sbatchArgs = append(sbatchArgs, fmt.Sprintf("--mem-per-cpu=%d", int(memPerCPU)))
        sbatchArgs = append(sbatchArgs, fmt.Sprintf("--cpus-per-task=%d", container.RuntimeConstraints.VCPUs))
-       if container.RuntimeConstraints.Partition != nil {
-               sbatchArgs = append(sbatchArgs, fmt.Sprintf("--partition=%s", strings.Join(container.RuntimeConstraints.Partition, ",")))
+       if container.SchedulingParameters.Partitions != nil {
+               sbatchArgs = append(sbatchArgs, fmt.Sprintf("--partition=%s", strings.Join(container.SchedulingParameters.Partitions, ",")))
        }
 
        return exec.Command("sbatch", sbatchArgs...)
index c9208a6943924a1604c7b15735536229cde68104..fbea48e548a59f78718cb0afa419b5a84a1cd89b 100644 (file)
@@ -318,7 +318,7 @@ func testSbatchFuncWithArgs(c *C, args []string) {
 
 func (s *MockArvadosServerSuite) TestSbatchPartition(c *C) {
        theConfig.SbatchArguments = nil
-       container := arvados.Container{UUID: "123", RuntimeConstraints: arvados.RuntimeConstraints{RAM: 250000000, VCPUs: 1, Partition: []string{"blurb", "b2"}}}
+       container := arvados.Container{UUID: "123", RuntimeConstraints: arvados.RuntimeConstraints{RAM: 250000000, VCPUs: 1}, SchedulingParameters: arvados.SchedulingParameters{Partitions: []string{"blurb", "b2"}}}
        sbatchCmd := sbatchFunc(container)
 
        var expected []string
index d2163f6b490376768383b260444d6be90a9ca1ed..6ca31c38329ec7347631a03c802b4216bd05f167 100644 (file)
@@ -2,12 +2,14 @@ package main
 
 import (
        "bytes"
+       "context"
        "errors"
        "flag"
        "fmt"
        "io"
        "io/ioutil"
        "log"
+       "net/http"
        "os"
        "regexp"
        "strconv"
@@ -15,9 +17,12 @@ import (
        "sync"
        "time"
 
+       "git.curoverse.com/arvados.git/sdk/go/arvados"
        "github.com/curoverse/azure-sdk-for-go/storage"
 )
 
+const azureDefaultRequestTimeout = arvados.Duration(10 * time.Minute)
+
 var (
        azureMaxGetBytes           int
        azureStorageAccountName    string
@@ -95,6 +100,7 @@ type AzureBlobVolume struct {
        ContainerName         string
        AzureReplication      int
        ReadOnly              bool
+       RequestTimeout        arvados.Duration
 
        azClient storage.Client
        bsClient storage.BlobStorageClient
@@ -108,6 +114,7 @@ func (*AzureBlobVolume) Examples() []Volume {
                        StorageAccountKeyFile: "/etc/azure_storage_account_key.txt",
                        ContainerName:         "example-container-name",
                        AzureReplication:      3,
+                       RequestTimeout:        azureDefaultRequestTimeout,
                },
        }
 }
@@ -133,6 +140,13 @@ func (v *AzureBlobVolume) Start() error {
        if err != nil {
                return fmt.Errorf("creating Azure storage client: %s", err)
        }
+
+       if v.RequestTimeout == 0 {
+               v.RequestTimeout = azureDefaultRequestTimeout
+       }
+       v.azClient.HTTPClient = &http.Client{
+               Timeout: time.Duration(v.RequestTimeout),
+       }
        v.bsClient = v.azClient.GetBlobService()
 
        ok, err := v.bsClient.ContainerExists(v.ContainerName)
@@ -163,7 +177,7 @@ func (v *AzureBlobVolume) checkTrashed(loc string) (bool, map[string]string, err
 // If the block is younger than azureWriteRaceInterval and is
 // unexpectedly empty, assume a PutBlob operation is in progress, and
 // wait for it to finish writing.
-func (v *AzureBlobVolume) Get(loc string, buf []byte) (int, error) {
+func (v *AzureBlobVolume) Get(ctx context.Context, loc string, buf []byte) (int, error) {
        trashed, _, err := v.checkTrashed(loc)
        if err != nil {
                return 0, err
@@ -271,7 +285,7 @@ func (v *AzureBlobVolume) get(loc string, buf []byte) (int, error) {
 }
 
 // Compare the given data with existing stored data.
-func (v *AzureBlobVolume) Compare(loc string, expect []byte) error {
+func (v *AzureBlobVolume) Compare(ctx context.Context, loc string, expect []byte) error {
        trashed, _, err := v.checkTrashed(loc)
        if err != nil {
                return err
@@ -284,11 +298,11 @@ func (v *AzureBlobVolume) Compare(loc string, expect []byte) error {
                return v.translateError(err)
        }
        defer rdr.Close()
-       return compareReaderWithBuf(rdr, expect, loc[:32])
+       return compareReaderWithBuf(ctx, rdr, expect, loc[:32])
 }
 
 // Put stores a Keep block as a block blob in the container.
-func (v *AzureBlobVolume) Put(loc string, block []byte) error {
+func (v *AzureBlobVolume) Put(ctx context.Context, loc string, block []byte) error {
        if v.ReadOnly {
                return MethodDisabledError
        }
index c8c898fe2da3957e3efdf069c5370e146ac5d693..d636a5ee86887806372a14e2f291e5c4f2c11b33 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "crypto/md5"
        "encoding/base64"
        "encoding/xml"
@@ -454,12 +455,12 @@ func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
                        data[i] = byte((i + 7) & 0xff)
                }
                hash := fmt.Sprintf("%x", md5.Sum(data))
-               err := v.Put(hash, data)
+               err := v.Put(context.Background(), hash, data)
                if err != nil {
                        t.Error(err)
                }
                gotData := make([]byte, len(data))
-               gotLen, err := v.Get(hash, gotData)
+               gotLen, err := v.Get(context.Background(), hash, gotData)
                if err != nil {
                        t.Error(err)
                }
@@ -500,7 +501,7 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
        allDone := make(chan struct{})
        v.azHandler.race = make(chan chan struct{})
        go func() {
-               err := v.Put(TestHash, TestBlock)
+               err := v.Put(context.Background(), TestHash, TestBlock)
                if err != nil {
                        t.Error(err)
                }
@@ -510,7 +511,7 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
        v.azHandler.race <- continuePut
        go func() {
                buf := make([]byte, len(TestBlock))
-               _, err := v.Get(TestHash, buf)
+               _, err := v.Get(context.Background(), TestHash, buf)
                if err != nil {
                        t.Error(err)
                }
@@ -553,7 +554,7 @@ func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
        go func() {
                defer close(allDone)
                buf := make([]byte, BlockSize)
-               n, err := v.Get(TestHash, buf)
+               n, err := v.Get(context.Background(), TestHash, buf)
                if err != nil {
                        t.Error(err)
                        return
index a4af563729b3cf0e72686a2913fd3664e497e24d..82cb789eb954c289172ac478dc04392a14c6a989 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "crypto/md5"
        "fmt"
        "io"
@@ -49,7 +50,7 @@ func collisionOrCorrupt(expectMD5 string, buf1, buf2 []byte, rdr io.Reader) erro
        return <-outcome
 }
 
-func compareReaderWithBuf(rdr io.Reader, expect []byte, hash string) error {
+func compareReaderWithBuf(ctx context.Context, rdr io.Reader, expect []byte, hash string) error {
        bufLen := 1 << 20
        if bufLen > len(expect) && len(expect) > 0 {
                // No need for bufLen to be longer than
@@ -67,7 +68,18 @@ func compareReaderWithBuf(rdr io.Reader, expect []byte, hash string) error {
        // expected to equal the next N bytes read from
        // rdr.
        for {
-               n, err := rdr.Read(buf)
+               ready := make(chan bool)
+               var n int
+               var err error
+               go func() {
+                       n, err = rdr.Read(buf)
+                       close(ready)
+               }()
+               select {
+               case <-ready:
+               case <-ctx.Done():
+                       return ctx.Err()
+               }
                if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
                        return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], buf[:n], rdr)
                }
index 9c318d1245abcb285634ed46eaf17ddcabe300b2..dc06ef549877ba0316294e4a0e7767393ef4436d 100644 (file)
@@ -13,6 +13,7 @@ import (
 )
 
 type Config struct {
+       Debug  bool
        Listen string
 
        PIDFile string
@@ -32,6 +33,7 @@ type Config struct {
 
        blobSigningKey  []byte
        systemAuthToken string
+       debugLogf       func(string, ...interface{})
 }
 
 var theConfig = DefaultConfig()
@@ -52,6 +54,13 @@ func DefaultConfig() *Config {
 // Start should be called exactly once: after setting all public
 // fields, and before using the config.
 func (cfg *Config) Start() error {
+       if cfg.Debug {
+               cfg.debugLogf = log.Printf
+               cfg.debugLogf("debugging enabled")
+       } else {
+               cfg.debugLogf = func(string, ...interface{}) {}
+       }
+
        if cfg.MaxBuffers < 0 {
                return fmt.Errorf("MaxBuffers must be greater than zero")
        }
diff --git a/services/keepstore/config_test.go b/services/keepstore/config_test.go
new file mode 100644 (file)
index 0000000..eaa0904
--- /dev/null
@@ -0,0 +1,9 @@
+package main
+
+import (
+       "log"
+)
+
+func init() {
+       theConfig.debugLogf = log.Printf
+}
index dc9bcb117f0508e748a97ff3cb2a736aa5c00178..9708b4e6be32f96645d500dfcd4319972f213d47 100644 (file)
@@ -11,6 +11,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "encoding/json"
        "fmt"
        "net/http"
@@ -48,7 +49,7 @@ func TestGetHandler(t *testing.T) {
        defer KeepVM.Close()
 
        vols := KeepVM.AllWritable()
-       if err := vols[0].Put(TestHash, TestBlock); err != nil {
+       if err := vols[0].Put(context.Background(), TestHash, TestBlock); err != nil {
                t.Error(err)
        }
 
@@ -288,10 +289,10 @@ func TestIndexHandler(t *testing.T) {
        defer KeepVM.Close()
 
        vols := KeepVM.AllWritable()
-       vols[0].Put(TestHash, TestBlock)
-       vols[1].Put(TestHash2, TestBlock2)
-       vols[0].Put(TestHash+".meta", []byte("metadata"))
-       vols[1].Put(TestHash2+".meta", []byte("metadata"))
+       vols[0].Put(context.Background(), TestHash, TestBlock)
+       vols[1].Put(context.Background(), TestHash2, TestBlock2)
+       vols[0].Put(context.Background(), TestHash+".meta", []byte("metadata"))
+       vols[1].Put(context.Background(), TestHash2+".meta", []byte("metadata"))
 
        theConfig.systemAuthToken = "DATA MANAGER TOKEN"
 
@@ -477,7 +478,7 @@ func TestDeleteHandler(t *testing.T) {
        defer KeepVM.Close()
 
        vols := KeepVM.AllWritable()
-       vols[0].Put(TestHash, TestBlock)
+       vols[0].Put(context.Background(), TestHash, TestBlock)
 
        // Explicitly set the BlobSignatureTTL to 0 for these
        // tests, to ensure the MockVolume deletes the blocks
@@ -564,7 +565,7 @@ func TestDeleteHandler(t *testing.T) {
        }
        // Confirm the block has been deleted
        buf := make([]byte, BlockSize)
-       _, err := vols[0].Get(TestHash, buf)
+       _, err := vols[0].Get(context.Background(), TestHash, buf)
        var blockDeleted = os.IsNotExist(err)
        if !blockDeleted {
                t.Error("superuserExistingBlockReq: block not deleted")
@@ -572,7 +573,7 @@ func TestDeleteHandler(t *testing.T) {
 
        // A DELETE request on a block newer than BlobSignatureTTL
        // should return success but leave the block on the volume.
-       vols[0].Put(TestHash, TestBlock)
+       vols[0].Put(context.Background(), TestHash, TestBlock)
        theConfig.BlobSignatureTTL = arvados.Duration(time.Hour)
 
        response = IssueRequest(superuserExistingBlockReq)
@@ -588,7 +589,7 @@ func TestDeleteHandler(t *testing.T) {
                        expectedDc, responseDc)
        }
        // Confirm the block has NOT been deleted.
-       _, err = vols[0].Get(TestHash, buf)
+       _, err = vols[0].Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Errorf("testing delete on new block: %s\n", err)
        }
@@ -940,7 +941,7 @@ func TestGetHandlerClientDisconnect(t *testing.T) {
        KeepVM = MakeTestVolumeManager(2)
        defer KeepVM.Close()
 
-       if err := KeepVM.AllWritable()[0].Put(TestHash, TestBlock); err != nil {
+       if err := KeepVM.AllWritable()[0].Put(context.Background(), TestHash, TestBlock); err != nil {
                t.Error(err)
        }
 
@@ -985,7 +986,7 @@ func TestGetHandlerNoBufferLeak(t *testing.T) {
        defer KeepVM.Close()
 
        vols := KeepVM.AllWritable()
-       if err := vols[0].Put(TestHash, TestBlock); err != nil {
+       if err := vols[0].Put(context.Background(), TestHash, TestBlock); err != nil {
                t.Error(err)
        }
 
@@ -1040,7 +1041,7 @@ func TestUntrashHandler(t *testing.T) {
        KeepVM = MakeTestVolumeManager(2)
        defer KeepVM.Close()
        vols := KeepVM.AllWritable()
-       vols[0].Put(TestHash, TestBlock)
+       vols[0].Put(context.Background(), TestHash, TestBlock)
 
        theConfig.systemAuthToken = "DATA MANAGER TOKEN"
 
index 54b8b485e1dc99d491bd94d4b5b888b60b990b13..289dce15a06168572f5269d7fed82bdb31a75075 100644 (file)
@@ -9,6 +9,7 @@ package main
 
 import (
        "container/list"
+       "context"
        "crypto/md5"
        "encoding/json"
        "fmt"
@@ -71,6 +72,9 @@ func BadRequestHandler(w http.ResponseWriter, r *http.Request) {
 
 // GetBlockHandler is a HandleFunc to address Get block requests.
 func GetBlockHandler(resp http.ResponseWriter, req *http.Request) {
+       ctx, cancel := contextForResponse(context.TODO(), resp)
+       defer cancel()
+
        if theConfig.RequireSignatures {
                locator := req.URL.Path[1:] // strip leading slash
                if err := VerifySignature(locator, GetAPIToken(req)); err != nil {
@@ -86,14 +90,14 @@ func GetBlockHandler(resp http.ResponseWriter, req *http.Request) {
        // isn't here, we can return 404 now instead of waiting for a
        // buffer.
 
-       buf, err := getBufferForResponseWriter(resp, bufs, BlockSize)
+       buf, err := getBufferWithContext(ctx, bufs, BlockSize)
        if err != nil {
                http.Error(resp, err.Error(), http.StatusServiceUnavailable)
                return
        }
        defer bufs.Put(buf)
 
-       size, err := GetBlock(mux.Vars(req)["hash"], buf, resp)
+       size, err := GetBlock(ctx, mux.Vars(req)["hash"], buf, resp)
        if err != nil {
                code := http.StatusInternalServerError
                if err, ok := err.(*KeepError); ok {
@@ -108,24 +112,33 @@ func GetBlockHandler(resp http.ResponseWriter, req *http.Request) {
        resp.Write(buf[:size])
 }
 
+// Return a new context that gets cancelled by resp's CloseNotifier.
+func contextForResponse(parent context.Context, resp http.ResponseWriter) (context.Context, context.CancelFunc) {
+       ctx, cancel := context.WithCancel(parent)
+       if cn, ok := resp.(http.CloseNotifier); ok {
+               go func(c <-chan bool) {
+                       select {
+                       case <-c:
+                               theConfig.debugLogf("cancel context")
+                               cancel()
+                       case <-ctx.Done():
+                       }
+               }(cn.CloseNotify())
+       }
+       return ctx, cancel
+}
+
 // Get a buffer from the pool -- but give up and return a non-nil
-// error if resp implements http.CloseNotifier and tells us that the
-// client has disconnected before we get a buffer.
-func getBufferForResponseWriter(resp http.ResponseWriter, bufs *bufferPool, bufSize int) ([]byte, error) {
-       var closeNotifier <-chan bool
-       if resp, ok := resp.(http.CloseNotifier); ok {
-               closeNotifier = resp.CloseNotify()
-       }
-       var buf []byte
+// error if ctx ends before we get a buffer.
+func getBufferWithContext(ctx context.Context, bufs *bufferPool, bufSize int) ([]byte, error) {
        bufReady := make(chan []byte)
        go func() {
                bufReady <- bufs.Get(bufSize)
-               close(bufReady)
        }()
        select {
-       case buf = <-bufReady:
+       case buf := <-bufReady:
                return buf, nil
-       case <-closeNotifier:
+       case <-ctx.Done():
                go func() {
                        // Even if closeNotifier happened first, we
                        // need to keep waiting for our buf so we can
@@ -138,6 +151,9 @@ func getBufferForResponseWriter(resp http.ResponseWriter, bufs *bufferPool, bufS
 
 // PutBlockHandler is a HandleFunc to address Put block requests.
 func PutBlockHandler(resp http.ResponseWriter, req *http.Request) {
+       ctx, cancel := contextForResponse(context.TODO(), resp)
+       defer cancel()
+
        hash := mux.Vars(req)["hash"]
 
        // Detect as many error conditions as possible before reading
@@ -159,7 +175,7 @@ func PutBlockHandler(resp http.ResponseWriter, req *http.Request) {
                return
        }
 
-       buf, err := getBufferForResponseWriter(resp, bufs, int(req.ContentLength))
+       buf, err := getBufferWithContext(ctx, bufs, int(req.ContentLength))
        if err != nil {
                http.Error(resp, err.Error(), http.StatusServiceUnavailable)
                return
@@ -172,12 +188,15 @@ func PutBlockHandler(resp http.ResponseWriter, req *http.Request) {
                return
        }
 
-       replication, err := PutBlock(buf, hash)
+       replication, err := PutBlock(ctx, buf, hash)
        bufs.Put(buf)
 
        if err != nil {
-               ke := err.(*KeepError)
-               http.Error(resp, ke.Error(), ke.HTTPCode)
+               code := http.StatusInternalServerError
+               if err, ok := err.(*KeepError); ok {
+                       code = err.HTTPCode
+               }
+               http.Error(resp, err.Error(), code)
                return
        }
 
@@ -548,12 +567,17 @@ func UntrashHandler(resp http.ResponseWriter, req *http.Request) {
 // If the block found does not have the correct MD5 hash, returns
 // DiskHashError.
 //
-func GetBlock(hash string, buf []byte, resp http.ResponseWriter) (int, error) {
+func GetBlock(ctx context.Context, hash string, buf []byte, resp http.ResponseWriter) (int, error) {
        // Attempt to read the requested hash from a keep volume.
        errorToCaller := NotFoundError
 
        for _, vol := range KeepVM.AllReadable() {
-               size, err := vol.Get(hash, buf)
+               size, err := vol.Get(ctx, hash, buf)
+               select {
+               case <-ctx.Done():
+                       return 0, ErrClientDisconnect
+               default:
+               }
                if err != nil {
                        // IsNotExist is an expected error and may be
                        // ignored. All other errors are logged. In
@@ -587,7 +611,7 @@ func GetBlock(hash string, buf []byte, resp http.ResponseWriter) (int, error) {
 
 // PutBlock Stores the BLOCK (identified by the content id HASH) in Keep.
 //
-// PutBlock(block, hash)
+// PutBlock(ctx, block, hash)
 //   Stores the BLOCK (identified by the content id HASH) in Keep.
 //
 //   The MD5 checksum of the block must be identical to the content id HASH.
@@ -612,7 +636,7 @@ func GetBlock(hash string, buf []byte, resp http.ResponseWriter) (int, error) {
 //          all writes failed). The text of the error message should
 //          provide as much detail as possible.
 //
-func PutBlock(block []byte, hash string) (int, error) {
+func PutBlock(ctx context.Context, block []byte, hash string) (int, error) {
        // Check that BLOCK's checksum matches HASH.
        blockhash := fmt.Sprintf("%x", md5.Sum(block))
        if blockhash != hash {
@@ -623,16 +647,21 @@ func PutBlock(block []byte, hash string) (int, error) {
        // If we already have this data, it's intact on disk, and we
        // can update its timestamp, return success. If we have
        // different data with the same hash, return failure.
-       if n, err := CompareAndTouch(hash, block); err == nil || err == CollisionError {
+       if n, err := CompareAndTouch(ctx, hash, block); err == nil || err == CollisionError {
                return n, err
+       } else if ctx.Err() != nil {
+               return 0, ErrClientDisconnect
        }
 
        // Choose a Keep volume to write to.
        // If this volume fails, try all of the volumes in order.
        if vol := KeepVM.NextWritable(); vol != nil {
-               if err := vol.Put(hash, block); err == nil {
+               if err := vol.Put(ctx, hash, block); err == nil {
                        return vol.Replication(), nil // success!
                }
+               if ctx.Err() != nil {
+                       return 0, ErrClientDisconnect
+               }
        }
 
        writables := KeepVM.AllWritable()
@@ -643,7 +672,10 @@ func PutBlock(block []byte, hash string) (int, error) {
 
        allFull := true
        for _, vol := range writables {
-               err := vol.Put(hash, block)
+               err := vol.Put(ctx, hash, block)
+               if ctx.Err() != nil {
+                       return 0, ErrClientDisconnect
+               }
                if err == nil {
                        return vol.Replication(), nil // success!
                }
@@ -669,10 +701,13 @@ func PutBlock(block []byte, hash string) (int, error) {
 // the relevant block's modification time in order to protect it from
 // premature garbage collection. Otherwise, it returns a non-nil
 // error.
-func CompareAndTouch(hash string, buf []byte) (int, error) {
+func CompareAndTouch(ctx context.Context, hash string, buf []byte) (int, error) {
        var bestErr error = NotFoundError
        for _, vol := range KeepVM.AllWritable() {
-               if err := vol.Compare(hash, buf); err == CollisionError {
+               err := vol.Compare(ctx, hash, buf)
+               if ctx.Err() != nil {
+                       return 0, ctx.Err()
+               } else if err == CollisionError {
                        // Stop if we have a block with same hash but
                        // different content. (It will be impossible
                        // to tell which one is wanted if we have
index dda7edcec3509e683465a93d5eb775bf18f16d19..181e651d3b4bbef40fa146b8fe06e6b93c206b05 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "context"
 )
 
 // A TestableVolumeManagerFactory creates a volume manager with at least two TestableVolume instances.
@@ -46,7 +47,7 @@ func testGetBlock(t TB, factory TestableVolumeManagerFactory, testHash string, t
 
        // Get should pass
        buf := make([]byte, len(testBlock))
-       n, err := GetBlock(testHash, buf, nil)
+       n, err := GetBlock(context.Background(), testHash, buf, nil)
        if err != nil {
                t.Fatalf("Error while getting block %s", err)
        }
@@ -66,7 +67,7 @@ func testPutRawBadDataGetBlock(t TB, factory TestableVolumeManagerFactory,
 
        // Get should fail
        buf := make([]byte, BlockSize)
-       size, err := GetBlock(testHash, buf, nil)
+       size, err := GetBlock(context.Background(), testHash, buf, nil)
        if err == nil {
                t.Fatalf("Got %+q, expected error while getting corrupt block %v", buf[:size], testHash)
        }
@@ -77,18 +78,18 @@ func testPutBlock(t TB, factory TestableVolumeManagerFactory, testHash string, t
        setupHandlersWithGenericVolumeTest(t, factory)
 
        // PutBlock
-       if _, err := PutBlock(testBlock, testHash); err != nil {
+       if _, err := PutBlock(context.Background(), testBlock, testHash); err != nil {
                t.Fatalf("Error during PutBlock: %s", err)
        }
 
        // Check that PutBlock succeeds again even after CompareAndTouch
-       if _, err := PutBlock(testBlock, testHash); err != nil {
+       if _, err := PutBlock(context.Background(), testBlock, testHash); err != nil {
                t.Fatalf("Error during PutBlock: %s", err)
        }
 
        // Check that PutBlock stored the data as expected
        buf := make([]byte, BlockSize)
-       size, err := GetBlock(testHash, buf, nil)
+       size, err := GetBlock(context.Background(), testHash, buf, nil)
        if err != nil {
                t.Fatalf("Error during GetBlock for %q: %s", testHash, err)
        } else if bytes.Compare(buf[:size], testBlock) != 0 {
@@ -106,14 +107,14 @@ func testPutBlockCorrupt(t TB, factory TestableVolumeManagerFactory,
        testableVolumes[1].PutRaw(testHash, badData)
 
        // Check that PutBlock with good data succeeds
-       if _, err := PutBlock(testBlock, testHash); err != nil {
+       if _, err := PutBlock(context.Background(), testBlock, testHash); err != nil {
                t.Fatalf("Error during PutBlock for %q: %s", testHash, err)
        }
 
        // Put succeeded and overwrote the badData in one volume,
        // and Get should return the testBlock now, ignoring the bad data.
        buf := make([]byte, BlockSize)
-       size, err := GetBlock(testHash, buf, nil)
+       size, err := GetBlock(context.Background(), testHash, buf, nil)
        if err != nil {
                t.Fatalf("Error during GetBlock for %q: %s", testHash, err)
        } else if bytes.Compare(buf[:size], testBlock) != 0 {
index dc6af0fa0d651c79cd6694a006fd3ad83ba2d677..e1d1dc5cb3cf2eb6ed3bf0b1da0a18b154f03328 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "fmt"
        "io/ioutil"
        "os"
@@ -61,13 +62,13 @@ func TestGetBlock(t *testing.T) {
        defer KeepVM.Close()
 
        vols := KeepVM.AllReadable()
-       if err := vols[1].Put(TestHash, TestBlock); err != nil {
+       if err := vols[1].Put(context.Background(), TestHash, TestBlock); err != nil {
                t.Error(err)
        }
 
        // Check that GetBlock returns success.
        buf := make([]byte, BlockSize)
-       size, err := GetBlock(TestHash, buf, nil)
+       size, err := GetBlock(context.Background(), TestHash, buf, nil)
        if err != nil {
                t.Errorf("GetBlock error: %s", err)
        }
@@ -88,7 +89,7 @@ func TestGetBlockMissing(t *testing.T) {
 
        // Check that GetBlock returns failure.
        buf := make([]byte, BlockSize)
-       size, err := GetBlock(TestHash, buf, nil)
+       size, err := GetBlock(context.Background(), TestHash, buf, nil)
        if err != NotFoundError {
                t.Errorf("Expected NotFoundError, got %v, err %v", buf[:size], err)
        }
@@ -106,11 +107,11 @@ func TestGetBlockCorrupt(t *testing.T) {
        defer KeepVM.Close()
 
        vols := KeepVM.AllReadable()
-       vols[0].Put(TestHash, BadBlock)
+       vols[0].Put(context.Background(), TestHash, BadBlock)
 
        // Check that GetBlock returns failure.
        buf := make([]byte, BlockSize)
-       size, err := GetBlock(TestHash, buf, nil)
+       size, err := GetBlock(context.Background(), TestHash, buf, nil)
        if err != DiskHashError {
                t.Errorf("Expected DiskHashError, got %v (buf: %v)", err, buf[:size])
        }
@@ -131,13 +132,13 @@ func TestPutBlockOK(t *testing.T) {
        defer KeepVM.Close()
 
        // Check that PutBlock stores the data as expected.
-       if n, err := PutBlock(TestBlock, TestHash); err != nil || n < 1 {
+       if n, err := PutBlock(context.Background(), TestBlock, TestHash); err != nil || n < 1 {
                t.Fatalf("PutBlock: n %d err %v", n, err)
        }
 
        vols := KeepVM.AllReadable()
        buf := make([]byte, BlockSize)
-       n, err := vols[1].Get(TestHash, buf)
+       n, err := vols[1].Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Fatalf("Volume #0 Get returned error: %v", err)
        }
@@ -162,12 +163,12 @@ func TestPutBlockOneVol(t *testing.T) {
        vols[0].(*MockVolume).Bad = true
 
        // Check that PutBlock stores the data as expected.
-       if n, err := PutBlock(TestBlock, TestHash); err != nil || n < 1 {
+       if n, err := PutBlock(context.Background(), TestBlock, TestHash); err != nil || n < 1 {
                t.Fatalf("PutBlock: n %d err %v", n, err)
        }
 
        buf := make([]byte, BlockSize)
-       size, err := GetBlock(TestHash, buf, nil)
+       size, err := GetBlock(context.Background(), TestHash, buf, nil)
        if err != nil {
                t.Fatalf("GetBlock: %v", err)
        }
@@ -190,12 +191,12 @@ func TestPutBlockMD5Fail(t *testing.T) {
 
        // Check that PutBlock returns the expected error when the hash does
        // not match the block.
-       if _, err := PutBlock(BadBlock, TestHash); err != RequestHashError {
+       if _, err := PutBlock(context.Background(), BadBlock, TestHash); err != RequestHashError {
                t.Errorf("Expected RequestHashError, got %v", err)
        }
 
        // Confirm that GetBlock fails to return anything.
-       if result, err := GetBlock(TestHash, make([]byte, BlockSize), nil); err != NotFoundError {
+       if result, err := GetBlock(context.Background(), TestHash, make([]byte, BlockSize), nil); err != NotFoundError {
                t.Errorf("GetBlock succeeded after a corrupt block store (result = %s, err = %v)",
                        string(result), err)
        }
@@ -214,14 +215,14 @@ func TestPutBlockCorrupt(t *testing.T) {
 
        // Store a corrupted block under TestHash.
        vols := KeepVM.AllWritable()
-       vols[0].Put(TestHash, BadBlock)
-       if n, err := PutBlock(TestBlock, TestHash); err != nil || n < 1 {
+       vols[0].Put(context.Background(), TestHash, BadBlock)
+       if n, err := PutBlock(context.Background(), TestBlock, TestHash); err != nil || n < 1 {
                t.Errorf("PutBlock: n %d err %v", n, err)
        }
 
        // The block on disk should now match TestBlock.
        buf := make([]byte, BlockSize)
-       if size, err := GetBlock(TestHash, buf, nil); err != nil {
+       if size, err := GetBlock(context.Background(), TestHash, buf, nil); err != nil {
                t.Errorf("GetBlock: %v", err)
        } else if bytes.Compare(buf[:size], TestBlock) != 0 {
                t.Errorf("Got %+q, expected %+q", buf[:size], TestBlock)
@@ -246,10 +247,10 @@ func TestPutBlockCollision(t *testing.T) {
 
        // Store one block, then attempt to store the other. Confirm that
        // PutBlock reported a CollisionError.
-       if _, err := PutBlock(b1, locator); err != nil {
+       if _, err := PutBlock(context.Background(), b1, locator); err != nil {
                t.Error(err)
        }
-       if _, err := PutBlock(b2, locator); err == nil {
+       if _, err := PutBlock(context.Background(), b2, locator); err == nil {
                t.Error("PutBlock did not report a collision")
        } else if err != CollisionError {
                t.Errorf("PutBlock returned %v", err)
@@ -271,7 +272,7 @@ func TestPutBlockTouchFails(t *testing.T) {
        // Store a block and then make the underlying volume bad,
        // so a subsequent attempt to update the file timestamp
        // will fail.
-       vols[0].Put(TestHash, BadBlock)
+       vols[0].Put(context.Background(), TestHash, BadBlock)
        oldMtime, err := vols[0].Mtime(TestHash)
        if err != nil {
                t.Fatalf("vols[0].Mtime(%s): %s\n", TestHash, err)
@@ -280,7 +281,7 @@ func TestPutBlockTouchFails(t *testing.T) {
        // vols[0].Touch will fail on the next call, so the volume
        // manager will store a copy on vols[1] instead.
        vols[0].(*MockVolume).Touchable = false
-       if n, err := PutBlock(TestBlock, TestHash); err != nil || n < 1 {
+       if n, err := PutBlock(context.Background(), TestBlock, TestHash); err != nil || n < 1 {
                t.Fatalf("PutBlock: n %d err %v", n, err)
        }
        vols[0].(*MockVolume).Touchable = true
@@ -296,7 +297,7 @@ func TestPutBlockTouchFails(t *testing.T) {
                        oldMtime, newMtime)
        }
        buf := make([]byte, BlockSize)
-       n, err := vols[1].Get(TestHash, buf)
+       n, err := vols[1].Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Fatalf("vols[1]: %v", err)
        }
@@ -400,11 +401,11 @@ func TestIndex(t *testing.T) {
        defer KeepVM.Close()
 
        vols := KeepVM.AllReadable()
-       vols[0].Put(TestHash, TestBlock)
-       vols[1].Put(TestHash2, TestBlock2)
-       vols[0].Put(TestHash3, TestBlock3)
-       vols[0].Put(TestHash+".meta", []byte("metadata"))
-       vols[1].Put(TestHash2+".meta", []byte("metadata"))
+       vols[0].Put(context.Background(), TestHash, TestBlock)
+       vols[1].Put(context.Background(), TestHash2, TestBlock2)
+       vols[0].Put(context.Background(), TestHash3, TestBlock3)
+       vols[0].Put(context.Background(), TestHash+".meta", []byte("metadata"))
+       vols[1].Put(context.Background(), TestHash2+".meta", []byte("metadata"))
 
        buf := new(bytes.Buffer)
        vols[0].IndexTo("", buf)
index d53d1060e743e07d9d2bfba6b90c67376a1006ab..12860bb662d91a1e31191fed2bccace97bc2ac30 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "context"
        "crypto/rand"
        "fmt"
        "git.curoverse.com/arvados.git/sdk/go/keepclient"
@@ -94,6 +95,6 @@ func GenerateRandomAPIToken() string {
 
 // Put block
 var PutContent = func(content []byte, locator string) (err error) {
-       _, err = PutBlock(content, locator)
+       _, err = PutBlock(context.Background(), content, locator)
        return
 }
index caed35b670e9484e9978e771e0e8c2aea2532e52..17923f807dc8a8f11bc77ce8dc0732001a4a8ba8 100644 (file)
@@ -1,11 +1,14 @@
 package main
 
 import (
+       "bytes"
+       "context"
        "encoding/base64"
        "encoding/hex"
        "flag"
        "fmt"
        "io"
+       "io/ioutil"
        "log"
        "net/http"
        "os"
@@ -19,6 +22,11 @@ import (
        "github.com/AdRoll/goamz/s3"
 )
 
+const (
+       s3DefaultReadTimeout    = arvados.Duration(10 * time.Minute)
+       s3DefaultConnectTimeout = arvados.Duration(time.Minute)
+)
+
 var (
        // ErrS3TrashDisabled is returned by Trash if that operation
        // is impossible with the current config.
@@ -134,6 +142,8 @@ type S3Volume struct {
        LocationConstraint bool
        IndexPageSize      int
        S3Replication      int
+       ConnectTimeout     arvados.Duration
+       ReadTimeout        arvados.Duration
        RaceWindow         arvados.Duration
        ReadOnly           bool
        UnsafeDelete       bool
@@ -147,24 +157,28 @@ type S3Volume struct {
 func (*S3Volume) Examples() []Volume {
        return []Volume{
                &S3Volume{
-                       AccessKeyFile: "/etc/aws_s3_access_key.txt",
-                       SecretKeyFile: "/etc/aws_s3_secret_key.txt",
-                       Endpoint:      "",
-                       Region:        "us-east-1",
-                       Bucket:        "example-bucket-name",
-                       IndexPageSize: 1000,
-                       S3Replication: 2,
-                       RaceWindow:    arvados.Duration(24 * time.Hour),
+                       AccessKeyFile:  "/etc/aws_s3_access_key.txt",
+                       SecretKeyFile:  "/etc/aws_s3_secret_key.txt",
+                       Endpoint:       "",
+                       Region:         "us-east-1",
+                       Bucket:         "example-bucket-name",
+                       IndexPageSize:  1000,
+                       S3Replication:  2,
+                       RaceWindow:     arvados.Duration(24 * time.Hour),
+                       ConnectTimeout: arvados.Duration(time.Minute),
+                       ReadTimeout:    arvados.Duration(5 * time.Minute),
                },
                &S3Volume{
-                       AccessKeyFile: "/etc/gce_s3_access_key.txt",
-                       SecretKeyFile: "/etc/gce_s3_secret_key.txt",
-                       Endpoint:      "https://storage.googleapis.com",
-                       Region:        "",
-                       Bucket:        "example-bucket-name",
-                       IndexPageSize: 1000,
-                       S3Replication: 2,
-                       RaceWindow:    arvados.Duration(24 * time.Hour),
+                       AccessKeyFile:  "/etc/gce_s3_access_key.txt",
+                       SecretKeyFile:  "/etc/gce_s3_secret_key.txt",
+                       Endpoint:       "https://storage.googleapis.com",
+                       Region:         "",
+                       Bucket:         "example-bucket-name",
+                       IndexPageSize:  1000,
+                       S3Replication:  2,
+                       RaceWindow:     arvados.Duration(24 * time.Hour),
+                       ConnectTimeout: arvados.Duration(time.Minute),
+                       ReadTimeout:    arvados.Duration(5 * time.Minute),
                },
        }
 }
@@ -203,13 +217,47 @@ func (v *S3Volume) Start() error {
        if err != nil {
                return err
        }
+
+       // Zero timeouts mean "wait forever", which is a bad
+       // default. Default to long timeouts instead.
+       if v.ConnectTimeout == 0 {
+               v.ConnectTimeout = s3DefaultConnectTimeout
+       }
+       if v.ReadTimeout == 0 {
+               v.ReadTimeout = s3DefaultReadTimeout
+       }
+
+       client := s3.New(auth, region)
+       client.ConnectTimeout = time.Duration(v.ConnectTimeout)
+       client.ReadTimeout = time.Duration(v.ReadTimeout)
        v.bucket = &s3.Bucket{
-               S3:   s3.New(auth, region),
+               S3:   client,
                Name: v.Bucket,
        }
        return nil
 }
 
+func (v *S3Volume) getReaderWithContext(ctx context.Context, loc string) (rdr io.ReadCloser, err error) {
+       ready := make(chan bool)
+       go func() {
+               rdr, err = v.getReader(loc)
+               close(ready)
+       }()
+       select {
+       case <-ready:
+               return
+       case <-ctx.Done():
+               theConfig.debugLogf("s3: abandoning getReader(): %s", ctx.Err())
+               go func() {
+                       <-ready
+                       if err == nil {
+                               rdr.Close()
+                       }
+               }()
+               return nil, ctx.Err()
+       }
+}
+
 // getReader wraps (Bucket)GetReader.
 //
 // In situations where (Bucket)GetReader would fail because the block
@@ -242,50 +290,106 @@ func (v *S3Volume) getReader(loc string) (rdr io.ReadCloser, err error) {
 
 // Get a block: copy the block data into buf, and return the number of
 // bytes copied.
-func (v *S3Volume) Get(loc string, buf []byte) (int, error) {
-       rdr, err := v.getReader(loc)
+func (v *S3Volume) Get(ctx context.Context, loc string, buf []byte) (int, error) {
+       rdr, err := v.getReaderWithContext(ctx, loc)
        if err != nil {
                return 0, err
        }
-       defer rdr.Close()
-       n, err := io.ReadFull(rdr, buf)
-       switch err {
-       case nil, io.EOF, io.ErrUnexpectedEOF:
-               return n, nil
-       default:
-               return 0, v.translateError(err)
+
+       var n int
+       ready := make(chan bool)
+       go func() {
+               defer close(ready)
+
+               defer rdr.Close()
+               n, err = io.ReadFull(rdr, buf)
+
+               switch err {
+               case nil, io.EOF, io.ErrUnexpectedEOF:
+                       err = nil
+               default:
+                       err = v.translateError(err)
+               }
+       }()
+       select {
+       case <-ctx.Done():
+               theConfig.debugLogf("s3: interrupting ReadFull() with Close() because %s", ctx.Err())
+               rdr.Close()
+               // Must wait for ReadFull to return, to ensure it
+               // doesn't write to buf after we return.
+               theConfig.debugLogf("s3: waiting for ReadFull() to fail")
+               <-ready
+               return 0, ctx.Err()
+       case <-ready:
+               return n, err
        }
 }
 
 // Compare the given data with the stored data.
-func (v *S3Volume) Compare(loc string, expect []byte) error {
-       rdr, err := v.getReader(loc)
+func (v *S3Volume) Compare(ctx context.Context, loc string, expect []byte) error {
+       rdr, err := v.getReaderWithContext(ctx, loc)
        if err != nil {
                return err
        }
        defer rdr.Close()
-       return v.translateError(compareReaderWithBuf(rdr, expect, loc[:32]))
+       return v.translateError(compareReaderWithBuf(ctx, rdr, expect, loc[:32]))
 }
 
 // Put writes a block.
-func (v *S3Volume) Put(loc string, block []byte) error {
+func (v *S3Volume) Put(ctx context.Context, loc string, block []byte) error {
        if v.ReadOnly {
                return MethodDisabledError
        }
        var opts s3.Options
-       if len(block) > 0 {
+       size := len(block)
+       if size > 0 {
                md5, err := hex.DecodeString(loc)
                if err != nil {
                        return err
                }
                opts.ContentMD5 = base64.StdEncoding.EncodeToString(md5)
        }
-       err := v.bucket.Put(loc, block, "application/octet-stream", s3ACL, opts)
-       if err != nil {
+
+       // Send the block data through a pipe, so that (if we need to)
+       // we can close the pipe early and abandon our PutReader()
+       // goroutine, without worrying about PutReader() accessing our
+       // block buffer after we release it.
+       bufr, bufw := io.Pipe()
+       go func() {
+               io.Copy(bufw, bytes.NewReader(block))
+               bufw.Close()
+       }()
+
+       var err error
+       ready := make(chan bool)
+       go func() {
+               defer func() {
+                       if ctx.Err() != nil {
+                               theConfig.debugLogf("%s: abandoned PutReader goroutine finished with err: %s", v, err)
+                       }
+               }()
+               defer close(ready)
+               err = v.bucket.PutReader(loc, bufr, int64(size), "application/octet-stream", s3ACL, opts)
+               if err != nil {
+                       return
+               }
+               err = v.bucket.Put("recent/"+loc, nil, "application/octet-stream", s3ACL, s3.Options{})
+       }()
+       select {
+       case <-ctx.Done():
+               theConfig.debugLogf("%s: taking PutReader's input away: %s", v, ctx.Err())
+               // Our pipe might be stuck in Write(), waiting for
+               // io.Copy() to read. If so, un-stick it. This means
+               // PutReader will get corrupt data, but that's OK: the
+               // size and MD5 won't match, so the write will fail.
+               go io.Copy(ioutil.Discard, bufr)
+               // CloseWithError() will return once pending I/O is done.
+               bufw.CloseWithError(ctx.Err())
+               theConfig.debugLogf("%s: abandoning PutReader goroutine", v)
+               return ctx.Err()
+       case <-ready:
                return v.translateError(err)
        }
-       err = v.bucket.Put("recent/"+loc, nil, "application/octet-stream", s3ACL, s3.Options{})
-       return v.translateError(err)
 }
 
 // Touch sets the timestamp for the given locator to the current time.
index 76dcbc9f9ea2f8fb680a25a31b84735f991b1b51..63b186220c30a562e900722d63d38be50bde05d6 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "crypto/md5"
        "fmt"
        "io/ioutil"
@@ -223,7 +224,7 @@ func (s *StubbedS3Suite) TestBackendStates(c *check.C) {
                // Check canGet
                loc, blk := setupScenario()
                buf := make([]byte, len(blk))
-               _, err := v.Get(loc, buf)
+               _, err := v.Get(context.Background(), loc, buf)
                c.Check(err == nil, check.Equals, scenario.canGet)
                if err != nil {
                        c.Check(os.IsNotExist(err), check.Equals, true)
@@ -233,7 +234,7 @@ func (s *StubbedS3Suite) TestBackendStates(c *check.C) {
                loc, blk = setupScenario()
                err = v.Trash(loc)
                c.Check(err == nil, check.Equals, scenario.canTrash)
-               _, err = v.Get(loc, buf)
+               _, err = v.Get(context.Background(), loc, buf)
                c.Check(err == nil, check.Equals, scenario.canGetAfterTrash)
                if err != nil {
                        c.Check(os.IsNotExist(err), check.Equals, true)
@@ -248,7 +249,7 @@ func (s *StubbedS3Suite) TestBackendStates(c *check.C) {
                        // should be able to Get after Untrash --
                        // regardless of timestamps, errors, race
                        // conditions, etc.
-                       _, err = v.Get(loc, buf)
+                       _, err = v.Get(context.Background(), loc, buf)
                        c.Check(err, check.IsNil)
                }
 
@@ -269,7 +270,7 @@ func (s *StubbedS3Suite) TestBackendStates(c *check.C) {
                // Check for current Mtime after Put (applies to all
                // scenarios)
                loc, blk = setupScenario()
-               err = v.Put(loc, blk)
+               err = v.Put(context.Background(), loc, blk)
                c.Check(err, check.IsNil)
                t, err := v.Mtime(loc)
                c.Check(err, check.IsNil)
index 5ec413d1bde899d606a6792f40ffd3afe65f3615..04b034a97976980c5ce66d59c361b44936dabbca 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "container/list"
+       "context"
        "testing"
        "time"
 )
@@ -219,15 +220,15 @@ func performTrashWorkerTest(testData TrashWorkerTestData, t *testing.T) {
        // Put test content
        vols := KeepVM.AllWritable()
        if testData.CreateData {
-               vols[0].Put(testData.Locator1, testData.Block1)
-               vols[0].Put(testData.Locator1+".meta", []byte("metadata"))
+               vols[0].Put(context.Background(), testData.Locator1, testData.Block1)
+               vols[0].Put(context.Background(), testData.Locator1+".meta", []byte("metadata"))
 
                if testData.CreateInVolume1 {
-                       vols[0].Put(testData.Locator2, testData.Block2)
-                       vols[0].Put(testData.Locator2+".meta", []byte("metadata"))
+                       vols[0].Put(context.Background(), testData.Locator2, testData.Block2)
+                       vols[0].Put(context.Background(), testData.Locator2+".meta", []byte("metadata"))
                } else {
-                       vols[1].Put(testData.Locator2, testData.Block2)
-                       vols[1].Put(testData.Locator2+".meta", []byte("metadata"))
+                       vols[1].Put(context.Background(), testData.Locator2, testData.Block2)
+                       vols[1].Put(context.Background(), testData.Locator2+".meta", []byte("metadata"))
                }
        }
 
@@ -291,7 +292,7 @@ func performTrashWorkerTest(testData TrashWorkerTestData, t *testing.T) {
 
        // Verify Locator1 to be un/deleted as expected
        buf := make([]byte, BlockSize)
-       size, err := GetBlock(testData.Locator1, buf, nil)
+       size, err := GetBlock(context.Background(), testData.Locator1, buf, nil)
        if testData.ExpectLocator1 {
                if size == 0 || err != nil {
                        t.Errorf("Expected Locator1 to be still present: %s", testData.Locator1)
@@ -304,7 +305,7 @@ func performTrashWorkerTest(testData TrashWorkerTestData, t *testing.T) {
 
        // Verify Locator2 to be un/deleted as expected
        if testData.Locator1 != testData.Locator2 {
-               size, err = GetBlock(testData.Locator2, buf, nil)
+               size, err = GetBlock(context.Background(), testData.Locator2, buf, nil)
                if testData.ExpectLocator2 {
                        if size == 0 || err != nil {
                                t.Errorf("Expected Locator2 to be still present: %s", testData.Locator2)
@@ -323,7 +324,7 @@ func performTrashWorkerTest(testData TrashWorkerTestData, t *testing.T) {
                locatorFoundIn := 0
                for _, volume := range KeepVM.AllReadable() {
                        buf := make([]byte, BlockSize)
-                       if _, err := volume.Get(testData.Locator1, buf); err == nil {
+                       if _, err := volume.Get(context.Background(), testData.Locator1, buf); err == nil {
                                locatorFoundIn = locatorFoundIn + 1
                        }
                }
index 6e01e75b879b339232603d38f93cb040ecc6d86c..57e18aba9f691ceb43f32a928a0e3a95e9d505ec 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "context"
        "io"
        "sync/atomic"
        "time"
@@ -47,14 +48,14 @@ type Volume interface {
        // any of the data.
        //
        // len(buf) will not exceed BlockSize.
-       Get(loc string, buf []byte) (int, error)
+       Get(ctx context.Context, loc string, buf []byte) (int, error)
 
        // Compare the given data with the stored data (i.e., what Get
        // would return). If equal, return nil. If not, return
        // CollisionError or DiskHashError (depending on whether the
        // data on disk matches the expected hash), or whatever error
        // was encountered opening/reading the stored data.
-       Compare(loc string, data []byte) error
+       Compare(ctx context.Context, loc string, data []byte) error
 
        // Put writes a block to an underlying storage device.
        //
@@ -84,7 +85,7 @@ type Volume interface {
        //
        // Put should not verify that loc==hash(block): this is the
        // caller's responsibility.
-       Put(loc string, block []byte) error
+       Put(ctx context.Context, loc string, block []byte) error
 
        // Touch sets the timestamp for the given locator to the
        // current time.
index 1738fe9b513bb4d86482ceede86a04539d29d418..7e72a8f246ee60410e7110417de3284ae4263ca3 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "crypto/md5"
        "fmt"
        "os"
@@ -92,7 +93,7 @@ func testGet(t TB, factory TestableVolumeFactory) {
        v.PutRaw(TestHash, TestBlock)
 
        buf := make([]byte, BlockSize)
-       n, err := v.Get(TestHash, buf)
+       n, err := v.Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Fatal(err)
        }
@@ -109,7 +110,7 @@ func testGetNoSuchBlock(t TB, factory TestableVolumeFactory) {
        defer v.Teardown()
 
        buf := make([]byte, BlockSize)
-       if _, err := v.Get(TestHash2, buf); err == nil {
+       if _, err := v.Get(context.Background(), TestHash2, buf); err == nil {
                t.Errorf("Expected error while getting non-existing block %v", TestHash2)
        }
 }
@@ -121,7 +122,7 @@ func testCompareNonexistent(t TB, factory TestableVolumeFactory) {
        v := factory(t)
        defer v.Teardown()
 
-       err := v.Compare(TestHash, TestBlock)
+       err := v.Compare(context.Background(), TestHash, TestBlock)
        if err != os.ErrNotExist {
                t.Errorf("Got err %T %q, expected os.ErrNotExist", err, err)
        }
@@ -136,7 +137,7 @@ func testCompareSameContent(t TB, factory TestableVolumeFactory, testHash string
        v.PutRaw(testHash, testData)
 
        // Compare the block locator with same content
-       err := v.Compare(testHash, testData)
+       err := v.Compare(context.Background(), testHash, testData)
        if err != nil {
                t.Errorf("Got err %q, expected nil", err)
        }
@@ -154,7 +155,7 @@ func testCompareWithCollision(t TB, factory TestableVolumeFactory, testHash stri
        v.PutRaw(testHash, testDataA)
 
        // Compare the block locator with different content; collision
-       err := v.Compare(TestHash, testDataB)
+       err := v.Compare(context.Background(), TestHash, testDataB)
        if err == nil {
                t.Errorf("Got err nil, expected error due to collision")
        }
@@ -170,7 +171,7 @@ func testCompareWithCorruptStoredData(t TB, factory TestableVolumeFactory, testH
 
        v.PutRaw(TestHash, testDataB)
 
-       err := v.Compare(testHash, testDataA)
+       err := v.Compare(context.Background(), testHash, testDataA)
        if err == nil || err == CollisionError {
                t.Errorf("Got err %+v, expected non-collision error", err)
        }
@@ -186,12 +187,12 @@ func testPutBlockWithSameContent(t TB, factory TestableVolumeFactory, testHash s
                return
        }
 
-       err := v.Put(testHash, testData)
+       err := v.Put(context.Background(), testHash, testData)
        if err != nil {
                t.Errorf("Got err putting block %q: %q, expected nil", TestBlock, err)
        }
 
-       err = v.Put(testHash, testData)
+       err = v.Put(context.Background(), testHash, testData)
        if err != nil {
                t.Errorf("Got err putting block second time %q: %q, expected nil", TestBlock, err)
        }
@@ -209,9 +210,9 @@ func testPutBlockWithDifferentContent(t TB, factory TestableVolumeFactory, testH
 
        v.PutRaw(testHash, testDataA)
 
-       putErr := v.Put(testHash, testDataB)
+       putErr := v.Put(context.Background(), testHash, testDataB)
        buf := make([]byte, BlockSize)
-       n, getErr := v.Get(testHash, buf)
+       n, getErr := v.Get(context.Background(), testHash, buf)
        if putErr == nil {
                // Put must not return a nil error unless it has
                // overwritten the existing data.
@@ -238,23 +239,23 @@ func testPutMultipleBlocks(t TB, factory TestableVolumeFactory) {
                return
        }
 
-       err := v.Put(TestHash, TestBlock)
+       err := v.Put(context.Background(), TestHash, TestBlock)
        if err != nil {
                t.Errorf("Got err putting block %q: %q, expected nil", TestBlock, err)
        }
 
-       err = v.Put(TestHash2, TestBlock2)
+       err = v.Put(context.Background(), TestHash2, TestBlock2)
        if err != nil {
                t.Errorf("Got err putting block %q: %q, expected nil", TestBlock2, err)
        }
 
-       err = v.Put(TestHash3, TestBlock3)
+       err = v.Put(context.Background(), TestHash3, TestBlock3)
        if err != nil {
                t.Errorf("Got err putting block %q: %q, expected nil", TestBlock3, err)
        }
 
        data := make([]byte, BlockSize)
-       n, err := v.Get(TestHash, data)
+       n, err := v.Get(context.Background(), TestHash, data)
        if err != nil {
                t.Error(err)
        } else {
@@ -263,7 +264,7 @@ func testPutMultipleBlocks(t TB, factory TestableVolumeFactory) {
                }
        }
 
-       n, err = v.Get(TestHash2, data)
+       n, err = v.Get(context.Background(), TestHash2, data)
        if err != nil {
                t.Error(err)
        } else {
@@ -272,7 +273,7 @@ func testPutMultipleBlocks(t TB, factory TestableVolumeFactory) {
                }
        }
 
-       n, err = v.Get(TestHash3, data)
+       n, err = v.Get(context.Background(), TestHash3, data)
        if err != nil {
                t.Error(err)
        } else {
@@ -294,7 +295,7 @@ func testPutAndTouch(t TB, factory TestableVolumeFactory) {
                return
        }
 
-       if err := v.Put(TestHash, TestBlock); err != nil {
+       if err := v.Put(context.Background(), TestHash, TestBlock); err != nil {
                t.Error(err)
        }
 
@@ -314,7 +315,7 @@ func testPutAndTouch(t TB, factory TestableVolumeFactory) {
        }
 
        // Write the same block again.
-       if err := v.Put(TestHash, TestBlock); err != nil {
+       if err := v.Put(context.Background(), TestHash, TestBlock); err != nil {
                t.Error(err)
        }
 
@@ -437,13 +438,13 @@ func testDeleteNewBlock(t TB, factory TestableVolumeFactory) {
                return
        }
 
-       v.Put(TestHash, TestBlock)
+       v.Put(context.Background(), TestHash, TestBlock)
 
        if err := v.Trash(TestHash); err != nil {
                t.Error(err)
        }
        data := make([]byte, BlockSize)
-       n, err := v.Get(TestHash, data)
+       n, err := v.Get(context.Background(), TestHash, data)
        if err != nil {
                t.Error(err)
        } else if bytes.Compare(data[:n], TestBlock) != 0 {
@@ -463,14 +464,14 @@ func testDeleteOldBlock(t TB, factory TestableVolumeFactory) {
                return
        }
 
-       v.Put(TestHash, TestBlock)
+       v.Put(context.Background(), TestHash, TestBlock)
        v.TouchWithDate(TestHash, time.Now().Add(-2*theConfig.BlobSignatureTTL.Duration()))
 
        if err := v.Trash(TestHash); err != nil {
                t.Error(err)
        }
        data := make([]byte, BlockSize)
-       if _, err := v.Get(TestHash, data); err == nil || !os.IsNotExist(err) {
+       if _, err := v.Get(context.Background(), TestHash, data); err == nil || !os.IsNotExist(err) {
                t.Errorf("os.IsNotExist(%v) should have been true", err)
        }
 
@@ -479,7 +480,7 @@ func testDeleteOldBlock(t TB, factory TestableVolumeFactory) {
                t.Fatalf("os.IsNotExist(%v) should have been true", err)
        }
 
-       err = v.Compare(TestHash, TestBlock)
+       err = v.Compare(context.Background(), TestHash, TestBlock)
        if err == nil || !os.IsNotExist(err) {
                t.Fatalf("os.IsNotExist(%v) should have been true", err)
        }
@@ -553,17 +554,17 @@ func testUpdateReadOnly(t TB, factory TestableVolumeFactory) {
        buf := make([]byte, BlockSize)
 
        // Get from read-only volume should succeed
-       _, err := v.Get(TestHash, buf)
+       _, err := v.Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Errorf("got err %v, expected nil", err)
        }
 
        // Put a new block to read-only volume should result in error
-       err = v.Put(TestHash2, TestBlock2)
+       err = v.Put(context.Background(), TestHash2, TestBlock2)
        if err == nil {
                t.Errorf("Expected error when putting block in a read-only volume")
        }
-       _, err = v.Get(TestHash2, buf)
+       _, err = v.Get(context.Background(), TestHash2, buf)
        if err == nil {
                t.Errorf("Expected error when getting block whose put in read-only volume failed")
        }
@@ -581,7 +582,7 @@ func testUpdateReadOnly(t TB, factory TestableVolumeFactory) {
        }
 
        // Overwriting an existing block in read-only volume should result in error
-       err = v.Put(TestHash, TestBlock)
+       err = v.Put(context.Background(), TestHash, TestBlock)
        if err == nil {
                t.Errorf("Expected error when putting block in a read-only volume")
        }
@@ -600,7 +601,7 @@ func testGetConcurrent(t TB, factory TestableVolumeFactory) {
        sem := make(chan int)
        go func() {
                buf := make([]byte, BlockSize)
-               n, err := v.Get(TestHash, buf)
+               n, err := v.Get(context.Background(), TestHash, buf)
                if err != nil {
                        t.Errorf("err1: %v", err)
                }
@@ -612,7 +613,7 @@ func testGetConcurrent(t TB, factory TestableVolumeFactory) {
 
        go func() {
                buf := make([]byte, BlockSize)
-               n, err := v.Get(TestHash2, buf)
+               n, err := v.Get(context.Background(), TestHash2, buf)
                if err != nil {
                        t.Errorf("err2: %v", err)
                }
@@ -624,7 +625,7 @@ func testGetConcurrent(t TB, factory TestableVolumeFactory) {
 
        go func() {
                buf := make([]byte, BlockSize)
-               n, err := v.Get(TestHash3, buf)
+               n, err := v.Get(context.Background(), TestHash3, buf)
                if err != nil {
                        t.Errorf("err3: %v", err)
                }
@@ -652,7 +653,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) {
 
        sem := make(chan int)
        go func(sem chan int) {
-               err := v.Put(TestHash, TestBlock)
+               err := v.Put(context.Background(), TestHash, TestBlock)
                if err != nil {
                        t.Errorf("err1: %v", err)
                }
@@ -660,7 +661,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) {
        }(sem)
 
        go func(sem chan int) {
-               err := v.Put(TestHash2, TestBlock2)
+               err := v.Put(context.Background(), TestHash2, TestBlock2)
                if err != nil {
                        t.Errorf("err2: %v", err)
                }
@@ -668,7 +669,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) {
        }(sem)
 
        go func(sem chan int) {
-               err := v.Put(TestHash3, TestBlock3)
+               err := v.Put(context.Background(), TestHash3, TestBlock3)
                if err != nil {
                        t.Errorf("err3: %v", err)
                }
@@ -682,7 +683,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) {
 
        // Double check that we actually wrote the blocks we expected to write.
        buf := make([]byte, BlockSize)
-       n, err := v.Get(TestHash, buf)
+       n, err := v.Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Errorf("Get #1: %v", err)
        }
@@ -690,7 +691,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) {
                t.Errorf("Get #1: expected %s, got %s", string(TestBlock), string(buf[:n]))
        }
 
-       n, err = v.Get(TestHash2, buf)
+       n, err = v.Get(context.Background(), TestHash2, buf)
        if err != nil {
                t.Errorf("Get #2: %v", err)
        }
@@ -698,7 +699,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) {
                t.Errorf("Get #2: expected %s, got %s", string(TestBlock2), string(buf[:n]))
        }
 
-       n, err = v.Get(TestHash3, buf)
+       n, err = v.Get(context.Background(), TestHash3, buf)
        if err != nil {
                t.Errorf("Get #3: %v", err)
        }
@@ -720,12 +721,12 @@ func testPutFullBlock(t TB, factory TestableVolumeFactory) {
        wdata[0] = 'a'
        wdata[BlockSize-1] = 'z'
        hash := fmt.Sprintf("%x", md5.Sum(wdata))
-       err := v.Put(hash, wdata)
+       err := v.Put(context.Background(), hash, wdata)
        if err != nil {
                t.Fatal(err)
        }
        buf := make([]byte, BlockSize)
-       n, err := v.Get(hash, buf)
+       n, err := v.Get(context.Background(), hash, buf)
        if err != nil {
                t.Error(err)
        }
@@ -752,7 +753,7 @@ func testTrashUntrash(t TB, factory TestableVolumeFactory) {
        v.TouchWithDate(TestHash, time.Now().Add(-2*theConfig.BlobSignatureTTL.Duration()))
 
        buf := make([]byte, BlockSize)
-       n, err := v.Get(TestHash, buf)
+       n, err := v.Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Fatal(err)
        }
@@ -771,7 +772,7 @@ func testTrashUntrash(t TB, factory TestableVolumeFactory) {
                        t.Fatal(err)
                }
        } else {
-               _, err = v.Get(TestHash, buf)
+               _, err = v.Get(context.Background(), TestHash, buf)
                if err == nil || !os.IsNotExist(err) {
                        t.Errorf("os.IsNotExist(%v) should have been true", err)
                }
@@ -784,7 +785,7 @@ func testTrashUntrash(t TB, factory TestableVolumeFactory) {
        }
 
        // Get the block - after trash and untrash sequence
-       n, err = v.Get(TestHash, buf)
+       n, err = v.Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Fatal(err)
        }
@@ -802,7 +803,7 @@ func testTrashEmptyTrashUntrash(t TB, factory TestableVolumeFactory) {
 
        checkGet := func() error {
                buf := make([]byte, BlockSize)
-               n, err := v.Get(TestHash, buf)
+               n, err := v.Get(context.Background(), TestHash, buf)
                if err != nil {
                        return err
                }
@@ -815,7 +816,7 @@ func testTrashEmptyTrashUntrash(t TB, factory TestableVolumeFactory) {
                        return err
                }
 
-               err = v.Compare(TestHash, TestBlock)
+               err = v.Compare(context.Background(), TestHash, TestBlock)
                if err != nil {
                        return err
                }
index 6ab386aec4fcc7774af90c4fa5ca879258ac9404..931c10e69044c4715eb35ccab4d33872a848db5d 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "crypto/md5"
        "errors"
        "fmt"
@@ -95,7 +96,7 @@ func (v *MockVolume) gotCall(method string) {
        }
 }
 
-func (v *MockVolume) Compare(loc string, buf []byte) error {
+func (v *MockVolume) Compare(ctx context.Context, loc string, buf []byte) error {
        v.gotCall("Compare")
        <-v.Gate
        if v.Bad {
@@ -113,7 +114,7 @@ func (v *MockVolume) Compare(loc string, buf []byte) error {
        }
 }
 
-func (v *MockVolume) Get(loc string, buf []byte) (int, error) {
+func (v *MockVolume) Get(ctx context.Context, loc string, buf []byte) (int, error) {
        v.gotCall("Get")
        <-v.Gate
        if v.Bad {
@@ -125,7 +126,7 @@ func (v *MockVolume) Get(loc string, buf []byte) (int, error) {
        return 0, os.ErrNotExist
 }
 
-func (v *MockVolume) Put(loc string, block []byte) error {
+func (v *MockVolume) Put(ctx context.Context, loc string, block []byte) error {
        v.gotCall("Put")
        <-v.Gate
        if v.Bad {
index b5753dec04638927162a328d2a43f2fd4e567a50..5239ed37402c93f25af0d6c65c03f6a953597cda 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bufio"
+       "context"
        "flag"
        "fmt"
        "io"
@@ -182,11 +183,14 @@ func (v *UnixVolume) Mtime(loc string) (time.Time, error) {
 
 // Lock the locker (if one is in use), open the file for reading, and
 // call the given function if and when the file is ready to read.
-func (v *UnixVolume) getFunc(path string, fn func(io.Reader) error) error {
+func (v *UnixVolume) getFunc(ctx context.Context, path string, fn func(io.Reader) error) error {
        if v.locker != nil {
                v.locker.Lock()
                defer v.locker.Unlock()
        }
+       if ctx.Err() != nil {
+               return ctx.Err()
+       }
        f, err := os.Open(path)
        if err != nil {
                return err
@@ -210,7 +214,7 @@ func (v *UnixVolume) stat(path string) (os.FileInfo, error) {
 
 // Get retrieves a block, copies it to the given slice, and returns
 // the number of bytes copied.
-func (v *UnixVolume) Get(loc string, buf []byte) (int, error) {
+func (v *UnixVolume) Get(ctx context.Context, loc string, buf []byte) (int, error) {
        path := v.blockPath(loc)
        stat, err := v.stat(path)
        if err != nil {
@@ -221,7 +225,7 @@ func (v *UnixVolume) Get(loc string, buf []byte) (int, error) {
        }
        var read int
        size := int(stat.Size())
-       err = v.getFunc(path, func(rdr io.Reader) error {
+       err = v.getFunc(ctx, path, func(rdr io.Reader) error {
                read, err = io.ReadFull(rdr, buf[:size])
                return err
        })
@@ -231,13 +235,13 @@ func (v *UnixVolume) Get(loc string, buf []byte) (int, error) {
 // Compare returns nil if Get(loc) would return the same content as
 // expect. It is functionally equivalent to Get() followed by
 // bytes.Compare(), but uses less memory.
-func (v *UnixVolume) Compare(loc string, expect []byte) error {
+func (v *UnixVolume) Compare(ctx context.Context, loc string, expect []byte) error {
        path := v.blockPath(loc)
        if _, err := v.stat(path); err != nil {
                return v.translateError(err)
        }
-       return v.getFunc(path, func(rdr io.Reader) error {
-               return compareReaderWithBuf(rdr, expect, loc[:32])
+       return v.getFunc(ctx, path, func(rdr io.Reader) error {
+               return compareReaderWithBuf(ctx, rdr, expect, loc[:32])
        })
 }
 
@@ -245,7 +249,7 @@ func (v *UnixVolume) Compare(loc string, expect []byte) error {
 // "loc".  It returns nil on success.  If the volume is full, it
 // returns a FullError.  If the write fails due to some other error,
 // that error is returned.
-func (v *UnixVolume) Put(loc string, block []byte) error {
+func (v *UnixVolume) Put(ctx context.Context, loc string, block []byte) error {
        if v.ReadOnly {
                return MethodDisabledError
        }
@@ -270,6 +274,11 @@ func (v *UnixVolume) Put(loc string, block []byte) error {
                v.locker.Lock()
                defer v.locker.Unlock()
        }
+       select {
+       case <-ctx.Done():
+               return ctx.Err()
+       default:
+       }
        if _, err := tmpfile.Write(block); err != nil {
                log.Printf("%s: writing to %s: %s\n", v, bpath, err)
                tmpfile.Close()
index 887247d3c3956e9475edf8437c913b3e1fc922c9..3021d6bd362724e7136d1054095e49bb53778199 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "errors"
        "fmt"
        "io"
@@ -45,7 +46,7 @@ func (v *TestableUnixVolume) PutRaw(locator string, data []byte) {
                v.ReadOnly = orig
        }(v.ReadOnly)
        v.ReadOnly = false
-       err := v.Put(locator, data)
+       err := v.Put(context.Background(), locator, data)
        if err != nil {
                v.t.Fatal(err)
        }
@@ -117,10 +118,10 @@ func TestReplicationDefault1(t *testing.T) {
 func TestGetNotFound(t *testing.T) {
        v := NewTestableUnixVolume(t, false, false)
        defer v.Teardown()
-       v.Put(TestHash, TestBlock)
+       v.Put(context.Background(), TestHash, TestBlock)
 
        buf := make([]byte, BlockSize)
-       n, err := v.Get(TestHash2, buf)
+       n, err := v.Get(context.Background(), TestHash2, buf)
        switch {
        case os.IsNotExist(err):
                break
@@ -135,7 +136,7 @@ func TestPut(t *testing.T) {
        v := NewTestableUnixVolume(t, false, false)
        defer v.Teardown()
 
-       err := v.Put(TestHash, TestBlock)
+       err := v.Put(context.Background(), TestHash, TestBlock)
        if err != nil {
                t.Error(err)
        }
@@ -153,7 +154,7 @@ func TestPutBadVolume(t *testing.T) {
        defer v.Teardown()
 
        os.Chmod(v.Root, 000)
-       err := v.Put(TestHash, TestBlock)
+       err := v.Put(context.Background(), TestHash, TestBlock)
        if err == nil {
                t.Error("Write should have failed")
        }
@@ -166,12 +167,12 @@ func TestUnixVolumeReadonly(t *testing.T) {
        v.PutRaw(TestHash, TestBlock)
 
        buf := make([]byte, BlockSize)
-       _, err := v.Get(TestHash, buf)
+       _, err := v.Get(context.Background(), TestHash, buf)
        if err != nil {
                t.Errorf("got err %v, expected nil", err)
        }
 
-       err = v.Put(TestHash, TestBlock)
+       err = v.Put(context.Background(), TestHash, TestBlock)
        if err != MethodDisabledError {
                t.Errorf("got err %v, expected MethodDisabledError", err)
        }
@@ -231,9 +232,9 @@ func TestUnixVolumeGetFuncWorkerError(t *testing.T) {
        v := NewTestableUnixVolume(t, false, false)
        defer v.Teardown()
 
-       v.Put(TestHash, TestBlock)
+       v.Put(context.Background(), TestHash, TestBlock)
        mockErr := errors.New("Mock error")
-       err := v.getFunc(v.blockPath(TestHash), func(rdr io.Reader) error {
+       err := v.getFunc(context.Background(), v.blockPath(TestHash), func(rdr io.Reader) error {
                return mockErr
        })
        if err != mockErr {
@@ -246,7 +247,7 @@ func TestUnixVolumeGetFuncFileError(t *testing.T) {
        defer v.Teardown()
 
        funcCalled := false
-       err := v.getFunc(v.blockPath(TestHash), func(rdr io.Reader) error {
+       err := v.getFunc(context.Background(), v.blockPath(TestHash), func(rdr io.Reader) error {
                funcCalled = true
                return nil
        })
@@ -262,13 +263,13 @@ func TestUnixVolumeGetFuncWorkerWaitsOnMutex(t *testing.T) {
        v := NewTestableUnixVolume(t, false, false)
        defer v.Teardown()
 
-       v.Put(TestHash, TestBlock)
+       v.Put(context.Background(), TestHash, TestBlock)
 
        mtx := NewMockMutex()
        v.locker = mtx
 
        funcCalled := make(chan struct{})
-       go v.getFunc(v.blockPath(TestHash), func(rdr io.Reader) error {
+       go v.getFunc(context.Background(), v.blockPath(TestHash), func(rdr io.Reader) error {
                funcCalled <- struct{}{}
                return nil
        })
@@ -297,26 +298,26 @@ func TestUnixVolumeCompare(t *testing.T) {
        v := NewTestableUnixVolume(t, false, false)
        defer v.Teardown()
 
-       v.Put(TestHash, TestBlock)
-       err := v.Compare(TestHash, TestBlock)
+       v.Put(context.Background(), TestHash, TestBlock)
+       err := v.Compare(context.Background(), TestHash, TestBlock)
        if err != nil {
                t.Errorf("Got err %q, expected nil", err)
        }
 
-       err = v.Compare(TestHash, []byte("baddata"))
+       err = v.Compare(context.Background(), TestHash, []byte("baddata"))
        if err != CollisionError {
                t.Errorf("Got err %q, expected %q", err, CollisionError)
        }
 
-       v.Put(TestHash, []byte("baddata"))
-       err = v.Compare(TestHash, TestBlock)
+       v.Put(context.Background(), TestHash, []byte("baddata"))
+       err = v.Compare(context.Background(), TestHash, TestBlock)
        if err != DiskHashError {
                t.Errorf("Got err %q, expected %q", err, DiskHashError)
        }
 
        p := fmt.Sprintf("%s/%s/%s", v.Root, TestHash[:3], TestHash)
        os.Chmod(p, 000)
-       err = v.Compare(TestHash, TestBlock)
+       err = v.Compare(context.Background(), TestHash, TestBlock)
        if err == nil || strings.Index(err.Error(), "permission denied") < 0 {
                t.Errorf("Got err %q, expected %q", err, "permission denied")
        }
index 9dc8f9425a8e4707bc4538842911511928428095..6d791bf9876a5b84a2b1b642025b730771f76da2 100644 (file)
@@ -47,7 +47,7 @@ func main() {
        if err != nil {
                log.Fatal(err)
        }
-       kc, err := keepclient.MakeKeepClient(&arv)
+       kc, err := keepclient.MakeKeepClient(arv)
        if err != nil {
                log.Fatal(err)
        }
@@ -56,11 +56,11 @@ func main() {
 
        overrideServices(kc)
 
-       nextBuf := make(chan []byte, *WriteThreads)
        nextLocator := make(chan string, *ReadThreads+*WriteThreads)
 
        go countBeans(nextLocator)
        for i := 0; i < *WriteThreads; i++ {
+               nextBuf := make(chan []byte, 1)
                go makeBufs(nextBuf, i)
                go doWrites(kc, nextBuf, nextLocator)
        }
@@ -106,23 +106,28 @@ func countBeans(nextLocator chan string) {
        }
 }
 
-func makeBufs(nextBuf chan []byte, threadID int) {
+func makeBufs(nextBuf chan<- []byte, threadID int) {
        buf := make([]byte, *BlockSize)
        if *VaryThread {
                binary.PutVarint(buf, int64(threadID))
        }
+       randSize := 524288
+       if randSize > *BlockSize {
+               randSize = *BlockSize
+       }
        for {
                if *VaryRequest {
-                       buf = make([]byte, *BlockSize)
-                       if _, err := io.ReadFull(rand.Reader, buf); err != nil {
+                       rnd := make([]byte, randSize)
+                       if _, err := io.ReadFull(rand.Reader, rnd); err != nil {
                                log.Fatal(err)
                        }
+                       buf = append(rnd, buf[randSize:]...)
                }
                nextBuf <- buf
        }
 }
 
-func doWrites(kc *keepclient.KeepClient, nextBuf chan []byte, nextLocator chan string) {
+func doWrites(kc *keepclient.KeepClient, nextBuf <-chan []byte, nextLocator chan<- string) {
        for buf := range nextBuf {
                locator, _, err := kc.PutB(buf)
                if err != nil {
@@ -139,7 +144,7 @@ func doWrites(kc *keepclient.KeepClient, nextBuf chan []byte, nextLocator chan s
        }
 }
 
-func doReads(kc *keepclient.KeepClient, nextLocator chan string) {
+func doReads(kc *keepclient.KeepClient, nextLocator <-chan string) {
        for locator := range nextLocator {
                rdr, size, url, err := kc.Get(locator)
                if err != nil {