From: Tom Clegg Date: Sat, 5 Nov 2016 21:17:28 +0000 (-0400) Subject: 10467: Abort S3 and release buffer if caller disconnects during S3 PUT request. X-Git-Tag: 1.1.0~610^2~12 X-Git-Url: https://git.arvados.org/arvados.git/commitdiff_plain/da13bb400f87fdd4157146e2d0b171b730fa3208 10467: Abort S3 and release buffer if caller disconnects during S3 PUT request. --- diff --git a/services/keepstore/azure_blob_volume.go b/services/keepstore/azure_blob_volume.go index b21f68d683..542a9ca690 100644 --- a/services/keepstore/azure_blob_volume.go +++ b/services/keepstore/azure_blob_volume.go @@ -302,7 +302,7 @@ func (v *AzureBlobVolume) Compare(loc string, expect []byte) error { } // Put stores a Keep block as a block blob in the container. -func (v *AzureBlobVolume) Put(loc string, block []byte) error { +func (v *AzureBlobVolume) Put(ctx context.Context, loc string, block []byte) error { if v.ReadOnly { return MethodDisabledError } diff --git a/services/keepstore/azure_blob_volume_test.go b/services/keepstore/azure_blob_volume_test.go index bb57dcd266..0123bfba5d 100644 --- a/services/keepstore/azure_blob_volume_test.go +++ b/services/keepstore/azure_blob_volume_test.go @@ -455,7 +455,7 @@ func TestAzureBlobVolumeRangeFenceposts(t *testing.T) { data[i] = byte((i + 7) & 0xff) } hash := fmt.Sprintf("%x", md5.Sum(data)) - err := v.Put(hash, data) + err := v.Put(context.TODO(), hash, data) if err != nil { t.Error(err) } @@ -501,7 +501,7 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) { allDone := make(chan struct{}) v.azHandler.race = make(chan chan struct{}) go func() { - err := v.Put(TestHash, TestBlock) + err := v.Put(context.TODO(), TestHash, TestBlock) if err != nil { t.Error(err) } diff --git a/services/keepstore/handler_test.go b/services/keepstore/handler_test.go index e254853a0b..1821383c85 100644 --- a/services/keepstore/handler_test.go +++ b/services/keepstore/handler_test.go @@ -49,7 +49,7 @@ func TestGetHandler(t *testing.T) { defer KeepVM.Close() vols := KeepVM.AllWritable() - if err := vols[0].Put(TestHash, TestBlock); err != nil { + if err := vols[0].Put(context.TODO(), TestHash, TestBlock); err != nil { t.Error(err) } @@ -289,10 +289,10 @@ func TestIndexHandler(t *testing.T) { defer KeepVM.Close() vols := KeepVM.AllWritable() - vols[0].Put(TestHash, TestBlock) - vols[1].Put(TestHash2, TestBlock2) - vols[0].Put(TestHash+".meta", []byte("metadata")) - vols[1].Put(TestHash2+".meta", []byte("metadata")) + vols[0].Put(context.TODO(), TestHash, TestBlock) + vols[1].Put(context.TODO(), TestHash2, TestBlock2) + vols[0].Put(context.TODO(), TestHash+".meta", []byte("metadata")) + vols[1].Put(context.TODO(), TestHash2+".meta", []byte("metadata")) theConfig.systemAuthToken = "DATA MANAGER TOKEN" @@ -478,7 +478,7 @@ func TestDeleteHandler(t *testing.T) { defer KeepVM.Close() vols := KeepVM.AllWritable() - vols[0].Put(TestHash, TestBlock) + vols[0].Put(context.TODO(), TestHash, TestBlock) // Explicitly set the BlobSignatureTTL to 0 for these // tests, to ensure the MockVolume deletes the blocks @@ -573,7 +573,7 @@ func TestDeleteHandler(t *testing.T) { // A DELETE request on a block newer than BlobSignatureTTL // should return success but leave the block on the volume. - vols[0].Put(TestHash, TestBlock) + vols[0].Put(context.TODO(), TestHash, TestBlock) theConfig.BlobSignatureTTL = arvados.Duration(time.Hour) response = IssueRequest(superuserExistingBlockReq) @@ -941,7 +941,7 @@ func TestGetHandlerClientDisconnect(t *testing.T) { KeepVM = MakeTestVolumeManager(2) defer KeepVM.Close() - if err := KeepVM.AllWritable()[0].Put(TestHash, TestBlock); err != nil { + if err := KeepVM.AllWritable()[0].Put(context.TODO(), TestHash, TestBlock); err != nil { t.Error(err) } @@ -986,7 +986,7 @@ func TestGetHandlerNoBufferLeak(t *testing.T) { defer KeepVM.Close() vols := KeepVM.AllWritable() - if err := vols[0].Put(TestHash, TestBlock); err != nil { + if err := vols[0].Put(context.TODO(), TestHash, TestBlock); err != nil { t.Error(err) } @@ -1041,7 +1041,7 @@ func TestUntrashHandler(t *testing.T) { KeepVM = MakeTestVolumeManager(2) defer KeepVM.Close() vols := KeepVM.AllWritable() - vols[0].Put(TestHash, TestBlock) + vols[0].Put(context.TODO(), TestHash, TestBlock) theConfig.systemAuthToken = "DATA MANAGER TOKEN" diff --git a/services/keepstore/handlers.go b/services/keepstore/handlers.go index ac2d71228f..5dc68df4aa 100644 --- a/services/keepstore/handlers.go +++ b/services/keepstore/handlers.go @@ -157,6 +157,8 @@ func getBufferForResponseWriter(resp http.ResponseWriter, bufs *bufferPool, bufS // PutBlockHandler is a HandleFunc to address Put block requests. func PutBlockHandler(resp http.ResponseWriter, req *http.Request) { + ctx := contextForResponse(context.TODO(), resp) + hash := mux.Vars(req)["hash"] // Detect as many error conditions as possible before reading @@ -191,7 +193,7 @@ func PutBlockHandler(resp http.ResponseWriter, req *http.Request) { return } - replication, err := PutBlock(buf, hash) + replication, err := PutBlock(ctx, buf, hash) bufs.Put(buf) if err != nil { @@ -611,7 +613,7 @@ func GetBlock(ctx context.Context, hash string, buf []byte, resp http.ResponseWr // PutBlock Stores the BLOCK (identified by the content id HASH) in Keep. // -// PutBlock(block, hash) +// PutBlock(ctx, block, hash) // Stores the BLOCK (identified by the content id HASH) in Keep. // // The MD5 checksum of the block must be identical to the content id HASH. @@ -636,7 +638,7 @@ func GetBlock(ctx context.Context, hash string, buf []byte, resp http.ResponseWr // all writes failed). The text of the error message should // provide as much detail as possible. // -func PutBlock(block []byte, hash string) (int, error) { +func PutBlock(ctx context.Context, block []byte, hash string) (int, error) { // Check that BLOCK's checksum matches HASH. blockhash := fmt.Sprintf("%x", md5.Sum(block)) if blockhash != hash { @@ -654,7 +656,7 @@ func PutBlock(block []byte, hash string) (int, error) { // Choose a Keep volume to write to. // If this volume fails, try all of the volumes in order. if vol := KeepVM.NextWritable(); vol != nil { - if err := vol.Put(hash, block); err == nil { + if err := vol.Put(context.TODO(), hash, block); err == nil { return vol.Replication(), nil // success! } } @@ -667,7 +669,12 @@ func PutBlock(block []byte, hash string) (int, error) { allFull := true for _, vol := range writables { - err := vol.Put(hash, block) + err := vol.Put(ctx, hash, block) + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } if err == nil { return vol.Replication(), nil // success! } diff --git a/services/keepstore/handlers_with_generic_volume_test.go b/services/keepstore/handlers_with_generic_volume_test.go index 8abf8e07fc..2c273aec38 100644 --- a/services/keepstore/handlers_with_generic_volume_test.go +++ b/services/keepstore/handlers_with_generic_volume_test.go @@ -78,12 +78,12 @@ func testPutBlock(t TB, factory TestableVolumeManagerFactory, testHash string, t setupHandlersWithGenericVolumeTest(t, factory) // PutBlock - if _, err := PutBlock(testBlock, testHash); err != nil { + if _, err := PutBlock(context.TODO(), testBlock, testHash); err != nil { t.Fatalf("Error during PutBlock: %s", err) } // Check that PutBlock succeeds again even after CompareAndTouch - if _, err := PutBlock(testBlock, testHash); err != nil { + if _, err := PutBlock(context.TODO(), testBlock, testHash); err != nil { t.Fatalf("Error during PutBlock: %s", err) } @@ -107,7 +107,7 @@ func testPutBlockCorrupt(t TB, factory TestableVolumeManagerFactory, testableVolumes[1].PutRaw(testHash, badData) // Check that PutBlock with good data succeeds - if _, err := PutBlock(testBlock, testHash); err != nil { + if _, err := PutBlock(context.TODO(), testBlock, testHash); err != nil { t.Fatalf("Error during PutBlock for %q: %s", testHash, err) } diff --git a/services/keepstore/keepstore_test.go b/services/keepstore/keepstore_test.go index 8413b7d1c5..a2e8044837 100644 --- a/services/keepstore/keepstore_test.go +++ b/services/keepstore/keepstore_test.go @@ -62,7 +62,7 @@ func TestGetBlock(t *testing.T) { defer KeepVM.Close() vols := KeepVM.AllReadable() - if err := vols[1].Put(TestHash, TestBlock); err != nil { + if err := vols[1].Put(context.TODO(), TestHash, TestBlock); err != nil { t.Error(err) } @@ -107,7 +107,7 @@ func TestGetBlockCorrupt(t *testing.T) { defer KeepVM.Close() vols := KeepVM.AllReadable() - vols[0].Put(TestHash, BadBlock) + vols[0].Put(context.TODO(), TestHash, BadBlock) // Check that GetBlock returns failure. buf := make([]byte, BlockSize) @@ -132,7 +132,7 @@ func TestPutBlockOK(t *testing.T) { defer KeepVM.Close() // Check that PutBlock stores the data as expected. - if n, err := PutBlock(TestBlock, TestHash); err != nil || n < 1 { + if n, err := PutBlock(context.TODO(), TestBlock, TestHash); err != nil || n < 1 { t.Fatalf("PutBlock: n %d err %v", n, err) } @@ -163,7 +163,7 @@ func TestPutBlockOneVol(t *testing.T) { vols[0].(*MockVolume).Bad = true // Check that PutBlock stores the data as expected. - if n, err := PutBlock(TestBlock, TestHash); err != nil || n < 1 { + if n, err := PutBlock(context.TODO(), TestBlock, TestHash); err != nil || n < 1 { t.Fatalf("PutBlock: n %d err %v", n, err) } @@ -191,7 +191,7 @@ func TestPutBlockMD5Fail(t *testing.T) { // Check that PutBlock returns the expected error when the hash does // not match the block. - if _, err := PutBlock(BadBlock, TestHash); err != RequestHashError { + if _, err := PutBlock(context.TODO(), BadBlock, TestHash); err != RequestHashError { t.Errorf("Expected RequestHashError, got %v", err) } @@ -215,8 +215,8 @@ func TestPutBlockCorrupt(t *testing.T) { // Store a corrupted block under TestHash. vols := KeepVM.AllWritable() - vols[0].Put(TestHash, BadBlock) - if n, err := PutBlock(TestBlock, TestHash); err != nil || n < 1 { + vols[0].Put(context.TODO(), TestHash, BadBlock) + if n, err := PutBlock(context.TODO(), TestBlock, TestHash); err != nil || n < 1 { t.Errorf("PutBlock: n %d err %v", n, err) } @@ -247,10 +247,10 @@ func TestPutBlockCollision(t *testing.T) { // Store one block, then attempt to store the other. Confirm that // PutBlock reported a CollisionError. - if _, err := PutBlock(b1, locator); err != nil { + if _, err := PutBlock(context.TODO(), b1, locator); err != nil { t.Error(err) } - if _, err := PutBlock(b2, locator); err == nil { + if _, err := PutBlock(context.TODO(), b2, locator); err == nil { t.Error("PutBlock did not report a collision") } else if err != CollisionError { t.Errorf("PutBlock returned %v", err) @@ -272,7 +272,7 @@ func TestPutBlockTouchFails(t *testing.T) { // Store a block and then make the underlying volume bad, // so a subsequent attempt to update the file timestamp // will fail. - vols[0].Put(TestHash, BadBlock) + vols[0].Put(context.TODO(), TestHash, BadBlock) oldMtime, err := vols[0].Mtime(TestHash) if err != nil { t.Fatalf("vols[0].Mtime(%s): %s\n", TestHash, err) @@ -281,7 +281,7 @@ func TestPutBlockTouchFails(t *testing.T) { // vols[0].Touch will fail on the next call, so the volume // manager will store a copy on vols[1] instead. vols[0].(*MockVolume).Touchable = false - if n, err := PutBlock(TestBlock, TestHash); err != nil || n < 1 { + if n, err := PutBlock(context.TODO(), TestBlock, TestHash); err != nil || n < 1 { t.Fatalf("PutBlock: n %d err %v", n, err) } vols[0].(*MockVolume).Touchable = true @@ -401,11 +401,11 @@ func TestIndex(t *testing.T) { defer KeepVM.Close() vols := KeepVM.AllReadable() - vols[0].Put(TestHash, TestBlock) - vols[1].Put(TestHash2, TestBlock2) - vols[0].Put(TestHash3, TestBlock3) - vols[0].Put(TestHash+".meta", []byte("metadata")) - vols[1].Put(TestHash2+".meta", []byte("metadata")) + vols[0].Put(context.TODO(), TestHash, TestBlock) + vols[1].Put(context.TODO(), TestHash2, TestBlock2) + vols[0].Put(context.TODO(), TestHash3, TestBlock3) + vols[0].Put(context.TODO(), TestHash+".meta", []byte("metadata")) + vols[1].Put(context.TODO(), TestHash2+".meta", []byte("metadata")) buf := new(bytes.Buffer) vols[0].IndexTo("", buf) diff --git a/services/keepstore/pull_worker.go b/services/keepstore/pull_worker.go index d53d1060e7..e42b6e4b89 100644 --- a/services/keepstore/pull_worker.go +++ b/services/keepstore/pull_worker.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/rand" "fmt" "git.curoverse.com/arvados.git/sdk/go/keepclient" @@ -94,6 +95,6 @@ func GenerateRandomAPIToken() string { // Put block var PutContent = func(content []byte, locator string) (err error) { - _, err = PutBlock(content, locator) + _, err = PutBlock(context.TODO(), content, locator) return } diff --git a/services/keepstore/s3_volume.go b/services/keepstore/s3_volume.go index 6339cf8e28..7ef590f4a6 100644 --- a/services/keepstore/s3_volume.go +++ b/services/keepstore/s3_volume.go @@ -1,12 +1,14 @@ package main import ( + "bytes" "context" "encoding/base64" "encoding/hex" "flag" "fmt" "io" + "io/ioutil" "log" "net/http" "os" @@ -320,24 +322,64 @@ func (v *S3Volume) Compare(loc string, expect []byte) error { } // Put writes a block. -func (v *S3Volume) Put(loc string, block []byte) error { +func (v *S3Volume) Put(ctx context.Context, loc string, block []byte) error { if v.ReadOnly { return MethodDisabledError } var opts s3.Options - if len(block) > 0 { + size := len(block) + if size > 0 { md5, err := hex.DecodeString(loc) if err != nil { return err } opts.ContentMD5 = base64.StdEncoding.EncodeToString(md5) } - err := v.bucket.Put(loc, block, "application/octet-stream", s3ACL, opts) - if err != nil { - return v.translateError(err) + + // Send the block data through a pipe, so that (if we need to) + // we can close the pipe early and abandon our PutReader() + // goroutine, without worrying about PutReader() accessing our + // block buffer after we release it. + bufr, bufw := io.Pipe() + go func() { + io.Copy(bufw, bytes.NewReader(block)) + bufw.Close() + }() + + var err error + ready := make(chan bool) + go func() { + defer func() { + select { + case <-ctx.Done(): + theConfig.debugLogf("%s: abandoned PutReader goroutine finished with err: %s", v, err) + default: + } + }() + defer close(ready) + err = v.bucket.PutReader(loc, bufr, int64(size), "application/octet-stream", s3ACL, opts) + if err != nil { + err = v.translateError(err) + return + } + err = v.bucket.Put("recent/"+loc, nil, "application/octet-stream", s3ACL, s3.Options{}) + err = v.translateError(err) + }() + select { + case <-ctx.Done(): + theConfig.debugLogf("%s: taking PutReader's input away: %s", v, ctx.Err()) + // Our pipe might be stuck in Write(), waiting for + // io.Copy() to read. If so, un-stick it. This means + // PutReader will get corrupt data, but that's OK: the + // size and MD5 won't match, so the write will fail. + go io.Copy(ioutil.Discard, bufr) + // CloseWithError() will return once pending I/O is done. + bufw.CloseWithError(ctx.Err()) + theConfig.debugLogf("%s: abandoning PutReader goroutine", v) + return ctx.Err() + case <-ready: + return err } - err = v.bucket.Put("recent/"+loc, nil, "application/octet-stream", s3ACL, s3.Options{}) - return v.translateError(err) } // Touch sets the timestamp for the given locator to the current time. diff --git a/services/keepstore/s3_volume_test.go b/services/keepstore/s3_volume_test.go index db3f4c6f6b..b720777f76 100644 --- a/services/keepstore/s3_volume_test.go +++ b/services/keepstore/s3_volume_test.go @@ -270,7 +270,7 @@ func (s *StubbedS3Suite) TestBackendStates(c *check.C) { // Check for current Mtime after Put (applies to all // scenarios) loc, blk = setupScenario() - err = v.Put(loc, blk) + err = v.Put(context.TODO(), loc, blk) c.Check(err, check.IsNil) t, err := v.Mtime(loc) c.Check(err, check.IsNil) diff --git a/services/keepstore/trash_worker_test.go b/services/keepstore/trash_worker_test.go index 267175d6d2..857f86a790 100644 --- a/services/keepstore/trash_worker_test.go +++ b/services/keepstore/trash_worker_test.go @@ -220,15 +220,15 @@ func performTrashWorkerTest(testData TrashWorkerTestData, t *testing.T) { // Put test content vols := KeepVM.AllWritable() if testData.CreateData { - vols[0].Put(testData.Locator1, testData.Block1) - vols[0].Put(testData.Locator1+".meta", []byte("metadata")) + vols[0].Put(context.TODO(), testData.Locator1, testData.Block1) + vols[0].Put(context.TODO(), testData.Locator1+".meta", []byte("metadata")) if testData.CreateInVolume1 { - vols[0].Put(testData.Locator2, testData.Block2) - vols[0].Put(testData.Locator2+".meta", []byte("metadata")) + vols[0].Put(context.TODO(), testData.Locator2, testData.Block2) + vols[0].Put(context.TODO(), testData.Locator2+".meta", []byte("metadata")) } else { - vols[1].Put(testData.Locator2, testData.Block2) - vols[1].Put(testData.Locator2+".meta", []byte("metadata")) + vols[1].Put(context.TODO(), testData.Locator2, testData.Block2) + vols[1].Put(context.TODO(), testData.Locator2+".meta", []byte("metadata")) } } diff --git a/services/keepstore/volume.go b/services/keepstore/volume.go index 19a59960e1..01bb6e28ca 100644 --- a/services/keepstore/volume.go +++ b/services/keepstore/volume.go @@ -85,7 +85,7 @@ type Volume interface { // // Put should not verify that loc==hash(block): this is the // caller's responsibility. - Put(loc string, block []byte) error + Put(ctx context.Context, loc string, block []byte) error // Touch sets the timestamp for the given locator to the // current time. diff --git a/services/keepstore/volume_generic_test.go b/services/keepstore/volume_generic_test.go index a0fd3e1fa3..4c26335082 100644 --- a/services/keepstore/volume_generic_test.go +++ b/services/keepstore/volume_generic_test.go @@ -187,12 +187,12 @@ func testPutBlockWithSameContent(t TB, factory TestableVolumeFactory, testHash s return } - err := v.Put(testHash, testData) + err := v.Put(context.TODO(), testHash, testData) if err != nil { t.Errorf("Got err putting block %q: %q, expected nil", TestBlock, err) } - err = v.Put(testHash, testData) + err = v.Put(context.TODO(), testHash, testData) if err != nil { t.Errorf("Got err putting block second time %q: %q, expected nil", TestBlock, err) } @@ -210,7 +210,7 @@ func testPutBlockWithDifferentContent(t TB, factory TestableVolumeFactory, testH v.PutRaw(testHash, testDataA) - putErr := v.Put(testHash, testDataB) + putErr := v.Put(context.TODO(), testHash, testDataB) buf := make([]byte, BlockSize) n, getErr := v.Get(context.TODO(), testHash, buf) if putErr == nil { @@ -239,17 +239,17 @@ func testPutMultipleBlocks(t TB, factory TestableVolumeFactory) { return } - err := v.Put(TestHash, TestBlock) + err := v.Put(context.TODO(), TestHash, TestBlock) if err != nil { t.Errorf("Got err putting block %q: %q, expected nil", TestBlock, err) } - err = v.Put(TestHash2, TestBlock2) + err = v.Put(context.TODO(), TestHash2, TestBlock2) if err != nil { t.Errorf("Got err putting block %q: %q, expected nil", TestBlock2, err) } - err = v.Put(TestHash3, TestBlock3) + err = v.Put(context.TODO(), TestHash3, TestBlock3) if err != nil { t.Errorf("Got err putting block %q: %q, expected nil", TestBlock3, err) } @@ -295,7 +295,7 @@ func testPutAndTouch(t TB, factory TestableVolumeFactory) { return } - if err := v.Put(TestHash, TestBlock); err != nil { + if err := v.Put(context.TODO(), TestHash, TestBlock); err != nil { t.Error(err) } @@ -315,7 +315,7 @@ func testPutAndTouch(t TB, factory TestableVolumeFactory) { } // Write the same block again. - if err := v.Put(TestHash, TestBlock); err != nil { + if err := v.Put(context.TODO(), TestHash, TestBlock); err != nil { t.Error(err) } @@ -438,7 +438,7 @@ func testDeleteNewBlock(t TB, factory TestableVolumeFactory) { return } - v.Put(TestHash, TestBlock) + v.Put(context.TODO(), TestHash, TestBlock) if err := v.Trash(TestHash); err != nil { t.Error(err) @@ -464,7 +464,7 @@ func testDeleteOldBlock(t TB, factory TestableVolumeFactory) { return } - v.Put(TestHash, TestBlock) + v.Put(context.TODO(), TestHash, TestBlock) v.TouchWithDate(TestHash, time.Now().Add(-2*theConfig.BlobSignatureTTL.Duration())) if err := v.Trash(TestHash); err != nil { @@ -560,7 +560,7 @@ func testUpdateReadOnly(t TB, factory TestableVolumeFactory) { } // Put a new block to read-only volume should result in error - err = v.Put(TestHash2, TestBlock2) + err = v.Put(context.TODO(), TestHash2, TestBlock2) if err == nil { t.Errorf("Expected error when putting block in a read-only volume") } @@ -582,7 +582,7 @@ func testUpdateReadOnly(t TB, factory TestableVolumeFactory) { } // Overwriting an existing block in read-only volume should result in error - err = v.Put(TestHash, TestBlock) + err = v.Put(context.TODO(), TestHash, TestBlock) if err == nil { t.Errorf("Expected error when putting block in a read-only volume") } @@ -653,7 +653,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) { sem := make(chan int) go func(sem chan int) { - err := v.Put(TestHash, TestBlock) + err := v.Put(context.TODO(), TestHash, TestBlock) if err != nil { t.Errorf("err1: %v", err) } @@ -661,7 +661,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) { }(sem) go func(sem chan int) { - err := v.Put(TestHash2, TestBlock2) + err := v.Put(context.TODO(), TestHash2, TestBlock2) if err != nil { t.Errorf("err2: %v", err) } @@ -669,7 +669,7 @@ func testPutConcurrent(t TB, factory TestableVolumeFactory) { }(sem) go func(sem chan int) { - err := v.Put(TestHash3, TestBlock3) + err := v.Put(context.TODO(), TestHash3, TestBlock3) if err != nil { t.Errorf("err3: %v", err) } @@ -721,7 +721,7 @@ func testPutFullBlock(t TB, factory TestableVolumeFactory) { wdata[0] = 'a' wdata[BlockSize-1] = 'z' hash := fmt.Sprintf("%x", md5.Sum(wdata)) - err := v.Put(hash, wdata) + err := v.Put(context.TODO(), hash, wdata) if err != nil { t.Fatal(err) } diff --git a/services/keepstore/volume_test.go b/services/keepstore/volume_test.go index 917942e787..acbd7c9bd3 100644 --- a/services/keepstore/volume_test.go +++ b/services/keepstore/volume_test.go @@ -126,7 +126,7 @@ func (v *MockVolume) Get(ctx context.Context, loc string, buf []byte) (int, erro return 0, os.ErrNotExist } -func (v *MockVolume) Put(loc string, block []byte) error { +func (v *MockVolume) Put(ctx context.Context, loc string, block []byte) error { v.gotCall("Put") <-v.Gate if v.Bad { diff --git a/services/keepstore/volume_unix.go b/services/keepstore/volume_unix.go index 02f0f9f3d1..1c676b12e1 100644 --- a/services/keepstore/volume_unix.go +++ b/services/keepstore/volume_unix.go @@ -246,7 +246,7 @@ func (v *UnixVolume) Compare(loc string, expect []byte) error { // "loc". It returns nil on success. If the volume is full, it // returns a FullError. If the write fails due to some other error, // that error is returned. -func (v *UnixVolume) Put(loc string, block []byte) error { +func (v *UnixVolume) Put(ctx context.Context, loc string, block []byte) error { if v.ReadOnly { return MethodDisabledError } @@ -271,6 +271,11 @@ func (v *UnixVolume) Put(loc string, block []byte) error { v.locker.Lock() defer v.locker.Unlock() } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } if _, err := tmpfile.Write(block); err != nil { log.Printf("%s: writing to %s: %s\n", v, bpath, err) tmpfile.Close() diff --git a/services/keepstore/volume_unix_test.go b/services/keepstore/volume_unix_test.go index 72fa819fe8..fad1f12164 100644 --- a/services/keepstore/volume_unix_test.go +++ b/services/keepstore/volume_unix_test.go @@ -46,7 +46,7 @@ func (v *TestableUnixVolume) PutRaw(locator string, data []byte) { v.ReadOnly = orig }(v.ReadOnly) v.ReadOnly = false - err := v.Put(locator, data) + err := v.Put(context.TODO(), locator, data) if err != nil { v.t.Fatal(err) } @@ -118,7 +118,7 @@ func TestReplicationDefault1(t *testing.T) { func TestGetNotFound(t *testing.T) { v := NewTestableUnixVolume(t, false, false) defer v.Teardown() - v.Put(TestHash, TestBlock) + v.Put(context.TODO(), TestHash, TestBlock) buf := make([]byte, BlockSize) n, err := v.Get(context.TODO(), TestHash2, buf) @@ -136,7 +136,7 @@ func TestPut(t *testing.T) { v := NewTestableUnixVolume(t, false, false) defer v.Teardown() - err := v.Put(TestHash, TestBlock) + err := v.Put(context.TODO(), TestHash, TestBlock) if err != nil { t.Error(err) } @@ -154,7 +154,7 @@ func TestPutBadVolume(t *testing.T) { defer v.Teardown() os.Chmod(v.Root, 000) - err := v.Put(TestHash, TestBlock) + err := v.Put(context.TODO(), TestHash, TestBlock) if err == nil { t.Error("Write should have failed") } @@ -172,7 +172,7 @@ func TestUnixVolumeReadonly(t *testing.T) { t.Errorf("got err %v, expected nil", err) } - err = v.Put(TestHash, TestBlock) + err = v.Put(context.TODO(), TestHash, TestBlock) if err != MethodDisabledError { t.Errorf("got err %v, expected MethodDisabledError", err) } @@ -232,7 +232,7 @@ func TestUnixVolumeGetFuncWorkerError(t *testing.T) { v := NewTestableUnixVolume(t, false, false) defer v.Teardown() - v.Put(TestHash, TestBlock) + v.Put(context.TODO(), TestHash, TestBlock) mockErr := errors.New("Mock error") err := v.getFunc(v.blockPath(TestHash), func(rdr io.Reader) error { return mockErr @@ -263,7 +263,7 @@ func TestUnixVolumeGetFuncWorkerWaitsOnMutex(t *testing.T) { v := NewTestableUnixVolume(t, false, false) defer v.Teardown() - v.Put(TestHash, TestBlock) + v.Put(context.TODO(), TestHash, TestBlock) mtx := NewMockMutex() v.locker = mtx @@ -298,7 +298,7 @@ func TestUnixVolumeCompare(t *testing.T) { v := NewTestableUnixVolume(t, false, false) defer v.Teardown() - v.Put(TestHash, TestBlock) + v.Put(context.TODO(), TestHash, TestBlock) err := v.Compare(TestHash, TestBlock) if err != nil { t.Errorf("Got err %q, expected nil", err) @@ -309,7 +309,7 @@ func TestUnixVolumeCompare(t *testing.T) { t.Errorf("Got err %q, expected %q", err, CollisionError) } - v.Put(TestHash, []byte("baddata")) + v.Put(context.TODO(), TestHash, []byte("baddata")) err = v.Compare(TestHash, TestBlock) if err != DiskHashError { t.Errorf("Got err %q, expected %q", err, DiskHashError)