7121: Return DiskHashError instead of CollisionError from Compare() where appropriate.
authorTom Clegg <tom@curoverse.com>
Thu, 3 Sep 2015 17:42:32 +0000 (13:42 -0400)
committerTom Clegg <tom@curoverse.com>
Mon, 7 Sep 2015 20:42:36 +0000 (16:42 -0400)
services/keepstore/collision.go [new file with mode: 0644]
services/keepstore/collision_test.go [new file with mode: 0644]
services/keepstore/volume_unix.go

diff --git a/services/keepstore/collision.go b/services/keepstore/collision.go
new file mode 100644 (file)
index 0000000..210286a
--- /dev/null
@@ -0,0 +1,49 @@
+package main
+
+import (
+       "crypto/md5"
+       "fmt"
+       "io"
+)
+
+// Compute the MD5 digest of a data block (consisting of buf1 + buf2 +
+// all bytes readable from rdr). If all data is read successfully,
+// return DiskHashError or CollisionError depending on whether it
+// matches expectMD5. If an error occurs while reading, return that
+// error.
+//
+// "content has expected MD5" is called a collision because this
+// function is used in cases where we have another block in hand with
+// the given MD5 but different content.
+func collisionOrCorrupt(expectMD5 string, buf1, buf2 []byte, rdr io.Reader) error {
+       outcome := make(chan error)
+       data := make(chan []byte, 1)
+       go func() {
+               h := md5.New()
+               for b := range data {
+                       h.Write(b)
+               }
+               if fmt.Sprintf("%x", h.Sum(nil)) == expectMD5 {
+                       outcome <- CollisionError
+               } else {
+                       outcome <- DiskHashError
+               }
+       }()
+       data <- buf1
+       if buf2 != nil {
+               data <- buf2
+       }
+       var err error
+       for rdr != nil && err == nil {
+               buf := make([]byte, 1 << 18)
+               var n int
+               n, err = rdr.Read(buf)
+               data <- buf[:n]
+       }
+       close(data)
+       if rdr != nil && err != io.EOF {
+               <-outcome
+               return err
+       }
+       return <-outcome
+}
diff --git a/services/keepstore/collision_test.go b/services/keepstore/collision_test.go
new file mode 100644 (file)
index 0000000..e6cfd16
--- /dev/null
@@ -0,0 +1,45 @@
+package main
+
+import (
+       "bytes"
+       "testing"
+       "testing/iotest"
+
+       check "gopkg.in/check.v1"
+)
+
+// Gocheck boilerplate
+func Test(t *testing.T) {
+       check.TestingT(t)
+}
+
+var _ = check.Suite(&CollisionSuite{})
+
+type CollisionSuite struct{}
+
+func (s *CollisionSuite) TestCollisionOrCorrupt(c *check.C) {
+       fooMD5 := "acbd18db4cc2f85cedef654fccc4a4d8"
+
+       c.Check(collisionOrCorrupt(fooMD5, []byte{'f'}, []byte{'o'}, bytes.NewBufferString("o")),
+               check.Equals, CollisionError)
+       c.Check(collisionOrCorrupt(fooMD5, []byte{'f'}, nil, bytes.NewBufferString("oo")),
+               check.Equals, CollisionError)
+       c.Check(collisionOrCorrupt(fooMD5, []byte{'f'}, []byte{'o', 'o'}, nil),
+               check.Equals, CollisionError)
+       c.Check(collisionOrCorrupt(fooMD5, nil, []byte{}, bytes.NewBufferString("foo")),
+               check.Equals, CollisionError)
+       c.Check(collisionOrCorrupt(fooMD5, []byte{'f', 'o', 'o'}, nil, bytes.NewBufferString("")),
+               check.Equals, CollisionError)
+       c.Check(collisionOrCorrupt(fooMD5, nil, nil, iotest.NewReadLogger("foo: ", iotest.DataErrReader(iotest.OneByteReader(bytes.NewBufferString("foo"))))),
+               check.Equals, CollisionError)
+
+       c.Check(collisionOrCorrupt(fooMD5, []byte{'f', 'o', 'o'}, nil, bytes.NewBufferString("bar")),
+               check.Equals, DiskHashError)
+       c.Check(collisionOrCorrupt(fooMD5, []byte{'f', 'o'}, nil, nil),
+               check.Equals, DiskHashError)
+       c.Check(collisionOrCorrupt(fooMD5, []byte{}, nil, bytes.NewBufferString("")),
+               check.Equals, DiskHashError)
+
+       c.Check(collisionOrCorrupt(fooMD5, []byte{}, nil, iotest.TimeoutReader(iotest.OneByteReader(bytes.NewBufferString("foo")))),
+               check.Equals, iotest.ErrTimeout)
+}
index 2ffa8faa39e2193b6e23540e6897c254a6262aa7..368ddc55f07a9c8f4b78c09ba0cd3917bbdf091a 100644 (file)
@@ -1,5 +1,3 @@
-// A UnixVolume is a Volume backed by a locally mounted disk.
-//
 package main
 
 import (
@@ -111,12 +109,7 @@ func (v *UnixVolume) Get(loc string) ([]byte, error) {
 // Compare returns nil if Get(loc) would return the same content as
 // cmp. It is functionally equivalent to Get() followed by
 // bytes.Compare(), but uses less memory.
-//
-// TODO(TC): Before returning CollisionError, compute the MD5 digest
-// of the data on disk (i.e., known-to-be-equal data in cmp +
-// remaining data on disk) and return DiskHashError instead of
-// CollisionError if it doesn't equal loc[:32].
-func (v *UnixVolume) Compare(loc string, cmp []byte) error {
+func (v *UnixVolume) Compare(loc string, expect []byte) error {
        path := v.blockPath(loc)
        stat, err := v.stat(path)
        if err != nil {
@@ -126,6 +119,7 @@ func (v *UnixVolume) Compare(loc string, cmp []byte) error {
        if int64(bufLen) > stat.Size() {
                bufLen = int(stat.Size())
        }
+       cmp := expect
        buf := make([]byte, bufLen)
        return v.getFunc(path, func(rdr io.Reader) error {
                // Loop invariants: all data read so far matched what
@@ -134,17 +128,13 @@ func (v *UnixVolume) Compare(loc string, cmp []byte) error {
                // reader.
                for {
                        n, err := rdr.Read(buf)
-                       if n > len(cmp) {
-                               // file on disk is too long
-                               return CollisionError
-                       } else if n > 0 && bytes.Compare(cmp[:n], buf[:n]) != 0 {
-                               return CollisionError
+                       if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
+                               return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], buf[:n], rdr)
                        }
                        cmp = cmp[n:]
                        if err == io.EOF {
                                if len(cmp) != 0 {
-                                       // file on disk is too short
-                                       return CollisionError
+                                       return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], nil, nil)
                                }
                                return nil
                        } else if err != nil {