--- /dev/null
+package main
+
+import (
+ "crypto/md5"
+ "fmt"
+ "io"
+)
+
+// Compute the MD5 digest of a data block (consisting of buf1 + buf2 +
+// all bytes readable from rdr). If all data is read successfully,
+// return DiskHashError or CollisionError depending on whether it
+// matches expectMD5. If an error occurs while reading, return that
+// error.
+//
+// "content has expected MD5" is called a collision because this
+// function is used in cases where we have another block in hand with
+// the given MD5 but different content.
+func collisionOrCorrupt(expectMD5 string, buf1, buf2 []byte, rdr io.Reader) error {
+ outcome := make(chan error)
+ data := make(chan []byte, 1)
+ go func() {
+ h := md5.New()
+ for b := range data {
+ h.Write(b)
+ }
+ if fmt.Sprintf("%x", h.Sum(nil)) == expectMD5 {
+ outcome <- CollisionError
+ } else {
+ outcome <- DiskHashError
+ }
+ }()
+ data <- buf1
+ if buf2 != nil {
+ data <- buf2
+ }
+ var err error
+ for rdr != nil && err == nil {
+ buf := make([]byte, 1 << 18)
+ var n int
+ n, err = rdr.Read(buf)
+ data <- buf[:n]
+ }
+ close(data)
+ if rdr != nil && err != io.EOF {
+ <-outcome
+ return err
+ }
+ return <-outcome
+}
--- /dev/null
+package main
+
+import (
+ "bytes"
+ "testing"
+ "testing/iotest"
+
+ check "gopkg.in/check.v1"
+)
+
+// Gocheck boilerplate
+func Test(t *testing.T) {
+ check.TestingT(t)
+}
+
+var _ = check.Suite(&CollisionSuite{})
+
+type CollisionSuite struct{}
+
+func (s *CollisionSuite) TestCollisionOrCorrupt(c *check.C) {
+ fooMD5 := "acbd18db4cc2f85cedef654fccc4a4d8"
+
+ c.Check(collisionOrCorrupt(fooMD5, []byte{'f'}, []byte{'o'}, bytes.NewBufferString("o")),
+ check.Equals, CollisionError)
+ c.Check(collisionOrCorrupt(fooMD5, []byte{'f'}, nil, bytes.NewBufferString("oo")),
+ check.Equals, CollisionError)
+ c.Check(collisionOrCorrupt(fooMD5, []byte{'f'}, []byte{'o', 'o'}, nil),
+ check.Equals, CollisionError)
+ c.Check(collisionOrCorrupt(fooMD5, nil, []byte{}, bytes.NewBufferString("foo")),
+ check.Equals, CollisionError)
+ c.Check(collisionOrCorrupt(fooMD5, []byte{'f', 'o', 'o'}, nil, bytes.NewBufferString("")),
+ check.Equals, CollisionError)
+ c.Check(collisionOrCorrupt(fooMD5, nil, nil, iotest.NewReadLogger("foo: ", iotest.DataErrReader(iotest.OneByteReader(bytes.NewBufferString("foo"))))),
+ check.Equals, CollisionError)
+
+ c.Check(collisionOrCorrupt(fooMD5, []byte{'f', 'o', 'o'}, nil, bytes.NewBufferString("bar")),
+ check.Equals, DiskHashError)
+ c.Check(collisionOrCorrupt(fooMD5, []byte{'f', 'o'}, nil, nil),
+ check.Equals, DiskHashError)
+ c.Check(collisionOrCorrupt(fooMD5, []byte{}, nil, bytes.NewBufferString("")),
+ check.Equals, DiskHashError)
+
+ c.Check(collisionOrCorrupt(fooMD5, []byte{}, nil, iotest.TimeoutReader(iotest.OneByteReader(bytes.NewBufferString("foo")))),
+ check.Equals, iotest.ErrTimeout)
+}
-// A UnixVolume is a Volume backed by a locally mounted disk.
-//
package main
import (
// Compare returns nil if Get(loc) would return the same content as
// cmp. It is functionally equivalent to Get() followed by
// bytes.Compare(), but uses less memory.
-//
-// TODO(TC): Before returning CollisionError, compute the MD5 digest
-// of the data on disk (i.e., known-to-be-equal data in cmp +
-// remaining data on disk) and return DiskHashError instead of
-// CollisionError if it doesn't equal loc[:32].
-func (v *UnixVolume) Compare(loc string, cmp []byte) error {
+func (v *UnixVolume) Compare(loc string, expect []byte) error {
path := v.blockPath(loc)
stat, err := v.stat(path)
if err != nil {
if int64(bufLen) > stat.Size() {
bufLen = int(stat.Size())
}
+ cmp := expect
buf := make([]byte, bufLen)
return v.getFunc(path, func(rdr io.Reader) error {
// Loop invariants: all data read so far matched what
// reader.
for {
n, err := rdr.Read(buf)
- if n > len(cmp) {
- // file on disk is too long
- return CollisionError
- } else if n > 0 && bytes.Compare(cmp[:n], buf[:n]) != 0 {
- return CollisionError
+ if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
+ return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], buf[:n], rdr)
}
cmp = cmp[n:]
if err == io.EOF {
if len(cmp) != 0 {
- // file on disk is too short
- return CollisionError
+ return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], nil, nil)
}
return nil
} else if err != nil {