10467: local directory driver: skip disk IO if client disconnects before lock is...
authorTom Clegg <tom@curoverse.com>
Wed, 28 Dec 2016 21:44:32 +0000 (16:44 -0500)
committerTom Clegg <tom@curoverse.com>
Wed, 28 Dec 2016 21:44:32 +0000 (16:44 -0500)
services/keepstore/pipe_adapters.go [new file with mode: 0644]
services/keepstore/volume_unix.go
services/keepstore/volume_unix_test.go

diff --git a/services/keepstore/pipe_adapters.go b/services/keepstore/pipe_adapters.go
new file mode 100644 (file)
index 0000000..91aa270
--- /dev/null
@@ -0,0 +1,110 @@
+package main
+
+import (
+       "bytes"
+       "context"
+       "io"
+       "io/ioutil"
+       "sync"
+)
+
+// getWithPipe invokes getter and copies the resulting data into
+// buf. If ctx is done before all data is copied, getWithPipe closes
+// the pipe with an error, and returns early with an error.
+func getWithPipe(ctx context.Context, loc string, buf []byte, getter func(context.Context, string, *io.PipeWriter)) (int, error) {
+       piper, pipew := io.Pipe()
+       go getter(ctx, loc, pipew)
+       done := make(chan struct{})
+       var size int
+       var err error
+       go func() {
+               size, err = io.ReadFull(piper, buf)
+               if err == io.EOF || err == io.ErrUnexpectedEOF {
+                       err = nil
+               }
+               close(done)
+       }()
+       select {
+       case <-ctx.Done():
+               piper.CloseWithError(ctx.Err())
+               return 0, ctx.Err()
+       case <-done:
+               piper.Close()
+               return size, err
+       }
+}
+
+type errorReadCloser struct {
+       *io.PipeReader
+       err error
+       mtx sync.Mutex
+}
+
+func (erc *errorReadCloser) Close() error {
+       erc.mtx.Lock()
+       defer erc.mtx.Unlock()
+       erc.PipeReader.Close()
+       return erc.err
+}
+
+func (erc *errorReadCloser) SetError(err error) {
+       erc.mtx.Lock()
+       defer erc.mtx.Unlock()
+       erc.err = err
+}
+
+// putWithPipe invokes putter with a new pipe, and and copies data
+// from buf into the pipe. If ctx is done before all data is copied,
+// putWithPipe closes the pipe with an error, and returns early with
+// an error.
+func putWithPipe(ctx context.Context, loc string, buf []byte, putter func(context.Context, string, io.ReadCloser) error) error {
+       piper, pipew := io.Pipe()
+       copyErr := make(chan error)
+       go func() {
+               _, err := io.Copy(pipew, bytes.NewReader(buf))
+               copyErr <- err
+               close(copyErr)
+       }()
+
+       erc := errorReadCloser{
+               PipeReader: piper,
+               err:        nil,
+       }
+       putErr := make(chan error, 1)
+       go func() {
+               putErr <- putter(ctx, loc, &erc)
+               close(putErr)
+       }()
+
+       var err error
+       select {
+       case err = <-copyErr:
+       case err = <-putErr:
+       case <-ctx.Done():
+               err = ctx.Err()
+       }
+
+       // Ensure io.Copy goroutine isn't blocked writing to pipew
+       // (otherwise, io.Copy is still using buf so it isn't safe to
+       // return). This can cause pipew to receive corrupt data, so
+       // we first ensure putter() will get an error when calling
+       // erc.Close().
+       erc.SetError(err)
+       go pipew.CloseWithError(err)
+       go io.Copy(ioutil.Discard, piper)
+       <-copyErr
+
+       // Note: io.Copy() is finished now, but putter() might still
+       // be running. If we encounter an error before putter()
+       // returns, we return right away without waiting for putter().
+
+       if err != nil {
+               return err
+       }
+       select {
+       case <-ctx.Done():
+               return ctx.Err()
+       case err = <-putErr:
+               return err
+       }
+}
index fff02aac260f59a6fc46fc24cbebea57b27e5743..459e73a28f501da5145f46259657fcf5d33175ee 100644 (file)
@@ -160,10 +160,10 @@ func (v *UnixVolume) Touch(loc string) error {
                return err
        }
        defer f.Close()
-       if v.locker != nil {
-               v.locker.Lock()
-               defer v.locker.Unlock()
+       if err := v.lock(context.TODO()); err != nil {
+               return err
        }
+       defer v.unlock()
        if e := lockfile(f); e != nil {
                return e
        }
@@ -185,13 +185,10 @@ func (v *UnixVolume) Mtime(loc string) (time.Time, error) {
 // Lock the locker (if one is in use), open the file for reading, and
 // call the given function if and when the file is ready to read.
 func (v *UnixVolume) getFunc(ctx context.Context, path string, fn func(io.Reader) error) error {
-       if v.locker != nil {
-               v.locker.Lock()
-               defer v.locker.Unlock()
-       }
-       if ctx.Err() != nil {
-               return ctx.Err()
+       if err := v.lock(ctx); err != nil {
+               return err
        }
+       defer v.unlock()
        f, err := os.Open(path)
        if err != nil {
                return err
@@ -216,21 +213,24 @@ func (v *UnixVolume) stat(path string) (os.FileInfo, error) {
 // Get retrieves a block, copies it to the given slice, and returns
 // the number of bytes copied.
 func (v *UnixVolume) Get(ctx context.Context, loc string, buf []byte) (int, error) {
+       return getWithPipe(ctx, loc, buf, v.get)
+}
+
+func (v *UnixVolume) get(ctx context.Context, loc string, w *io.PipeWriter) {
        path := v.blockPath(loc)
        stat, err := v.stat(path)
        if err != nil {
-               return 0, v.translateError(err)
+               w.CloseWithError(v.translateError(err))
+               return
        }
-       if stat.Size() > int64(len(buf)) {
-               return 0, TooLongError
-       }
-       var read int
-       size := int(stat.Size())
        err = v.getFunc(ctx, path, func(rdr io.Reader) error {
-               read, err = io.ReadFull(rdr, buf[:size])
+               n, err := io.Copy(w, rdr)
+               if err == nil && n != stat.Size() {
+                       err = io.ErrUnexpectedEOF
+               }
                return err
        })
-       return read, err
+       w.CloseWithError(err)
 }
 
 // Compare returns nil if Get(loc) would return the same content as
@@ -251,6 +251,10 @@ func (v *UnixVolume) Compare(ctx context.Context, loc string, expect []byte) err
 // returns a FullError.  If the write fails due to some other error,
 // that error is returned.
 func (v *UnixVolume) Put(ctx context.Context, loc string, block []byte) error {
+       return putWithPipe(ctx, loc, block, v.put)
+}
+
+func (v *UnixVolume) put(ctx context.Context, loc string, rdr io.ReadCloser) error {
        if v.ReadOnly {
                return MethodDisabledError
        }
@@ -269,23 +273,25 @@ func (v *UnixVolume) Put(ctx context.Context, loc string, block []byte) error {
                log.Printf("ioutil.TempFile(%s, tmp%s): %s", bdir, loc, tmperr)
                return tmperr
        }
+
        bpath := v.blockPath(loc)
 
-       if v.locker != nil {
-               v.locker.Lock()
-               defer v.locker.Unlock()
-       }
-       select {
-       case <-ctx.Done():
-               return ctx.Err()
-       default:
+       if err := v.lock(ctx); err != nil {
+               log.Println("lock err:", err)
+               return err
        }
-       if _, err := tmpfile.Write(block); err != nil {
+       defer v.unlock()
+       if _, err := io.Copy(tmpfile, rdr); err != nil {
                log.Printf("%s: writing to %s: %s\n", v, bpath, err)
                tmpfile.Close()
                os.Remove(tmpfile.Name())
                return err
        }
+       if err := rdr.Close(); err != nil {
+               tmpfile.Close()
+               os.Remove(tmpfile.Name())
+               return err
+       }
        if err := tmpfile.Close(); err != nil {
                log.Printf("closing %s: %s\n", tmpfile.Name(), err)
                os.Remove(tmpfile.Name())
@@ -418,10 +424,10 @@ func (v *UnixVolume) Trash(loc string) error {
        if v.ReadOnly {
                return MethodDisabledError
        }
-       if v.locker != nil {
-               v.locker.Lock()
-               defer v.locker.Unlock()
+       if err := v.lock(context.TODO()); err != nil {
+               return err
        }
+       defer v.unlock()
        p := v.blockPath(loc)
        f, err := os.OpenFile(p, os.O_RDWR|os.O_APPEND, 0644)
        if err != nil {
@@ -559,6 +565,42 @@ func (v *UnixVolume) Replication() int {
        return v.DirectoryReplication
 }
 
+// lock acquires the serialize lock, if one is in use. If ctx is done
+// before the lock is acquired, lock returns ctx.Err() instead of
+// acquiring the lock.
+func (v *UnixVolume) lock(ctx context.Context) error {
+       if v.locker == nil {
+               return nil
+       }
+       locked := make(chan struct{})
+       go func() {
+               v.locker.Lock()
+               close(locked)
+       }()
+       select {
+       case <-ctx.Done():
+               log.Print("ctx Done")
+               go func() {
+                       log.Print("waiting <-locked")
+                       <-locked
+                       log.Print("unlocking")
+                       v.locker.Unlock()
+               }()
+               return ctx.Err()
+       case <-locked:
+               log.Print("got lock")
+               return nil
+       }
+}
+
+// unlock releases the serialize lock, if one is in use.
+func (v *UnixVolume) unlock() {
+       if v.locker == nil {
+               return
+       }
+       v.locker.Unlock()
+}
+
 // lockfile and unlockfile use flock(2) to manage kernel file locks.
 func lockfile(f *os.File) error {
        return syscall.Flock(int(f.Fd()), syscall.LOCK_EX)
index 3021d6bd362724e7136d1054095e49bb53778199..7b02a15e0108c58c2318bcad141a44e39cc1310c 100644 (file)
@@ -323,14 +323,42 @@ func TestUnixVolumeCompare(t *testing.T) {
        }
 }
 
-// TODO(twp): show that the underlying Read/Write operations executed
-// serially and not concurrently. The easiest way to do this is
-// probably to activate verbose or debug logging, capture log output
-// and examine it to confirm that Reads and Writes did not overlap.
-//
-// TODO(twp): a proper test of I/O serialization requires that a
-// second request start while the first one is still underway.
-// Guaranteeing that the test behaves this way requires some tricky
-// synchronization and mocking.  For now we'll just launch a bunch of
-// requests simultaenously in goroutines and demonstrate that they
-// return accurate results.
+func TestUnixVolumeContextCancelPut(t *testing.T) {
+       v := NewTestableUnixVolume(t, true, false)
+       defer v.Teardown()
+       v.locker.Lock()
+       ctx, cancel := context.WithCancel(context.Background())
+       go func() {
+               time.Sleep(50 * time.Millisecond)
+               cancel()
+               time.Sleep(50 * time.Millisecond)
+               v.locker.Unlock()
+       }()
+       err := v.Put(ctx, TestHash, TestBlock)
+       if err != context.Canceled {
+               t.Errorf("Put() returned %s -- expected short read / canceled", err)
+       }
+}
+
+func TestUnixVolumeContextCancelGet(t *testing.T) {
+       v := NewTestableUnixVolume(t, false, false)
+       defer v.Teardown()
+       bpath := v.blockPath(TestHash)
+       v.PutRaw(TestHash, TestBlock)
+       os.Remove(bpath)
+       err := syscall.Mkfifo(bpath, 0600)
+       if err != nil {
+               t.Fatalf("Mkfifo %s: %s", bpath, err)
+       }
+       defer os.Remove(bpath)
+       ctx, cancel := context.WithCancel(context.Background())
+       go func() {
+               time.Sleep(50 * time.Millisecond)
+               cancel()
+       }()
+       buf := make([]byte, len(TestBlock))
+       n, err := v.Get(ctx, TestHash, buf)
+       if n == len(TestBlock) || err != context.Canceled {
+               t.Errorf("Get() returned %d, %s -- expected short read / canceled", n, err)
+       }
+}