7179: Tighten Put requirements when overwriting existing data.
[arvados.git] / services / keepstore / volume_unix.go
index 2ffa8faa39e2193b6e23540e6897c254a6262aa7..74bee52387c5f291f7edc8cc0c36e7ef5b48dd9e 100644 (file)
@@ -1,5 +1,3 @@
-// A UnixVolume is a Volume backed by a locally mounted disk.
-//
 package main
 
 import (
@@ -20,10 +18,12 @@ import (
 
 // A UnixVolume stores and retrieves blocks in a local directory.
 type UnixVolume struct {
-       root      string // path to the volume's root directory
-       serialize bool
-       readonly  bool
-       mutex     sync.Mutex
+       // path to the volume's root directory
+       root string
+       // something to lock during IO, typically a sync.Mutex (or nil
+       // to skip locking)
+       locker   sync.Locker
+       readonly bool
 }
 
 func (v *UnixVolume) Touch(loc string) error {
@@ -36,9 +36,9 @@ func (v *UnixVolume) Touch(loc string) error {
                return err
        }
        defer f.Close()
-       if v.serialize {
-               v.mutex.Lock()
-               defer v.mutex.Unlock()
+       if v.locker != nil {
+               v.locker.Lock()
+               defer v.locker.Unlock()
        }
        if e := lockfile(f); e != nil {
                return e
@@ -58,18 +58,18 @@ func (v *UnixVolume) Mtime(loc string) (time.Time, error) {
        }
 }
 
-// Open the given file, apply the serialize lock if enabled, and call
-// the given function if and when the file is ready to read.
+// Lock the locker (if one is in use), open the file for reading, and
+// call the given function if and when the file is ready to read.
 func (v *UnixVolume) getFunc(path string, fn func(io.Reader) error) error {
+       if v.locker != nil {
+               v.locker.Lock()
+               defer v.locker.Unlock()
+       }
        f, err := os.Open(path)
        if err != nil {
                return err
        }
        defer f.Close()
-       if v.serialize {
-               v.mutex.Lock()
-               defer v.mutex.Unlock()
-       }
        return fn(f)
 }
 
@@ -109,14 +109,9 @@ func (v *UnixVolume) Get(loc string) ([]byte, error) {
 }
 
 // Compare returns nil if Get(loc) would return the same content as
-// cmp. It is functionally equivalent to Get() followed by
+// expect. It is functionally equivalent to Get() followed by
 // bytes.Compare(), but uses less memory.
-//
-// TODO(TC): Before returning CollisionError, compute the MD5 digest
-// of the data on disk (i.e., known-to-be-equal data in cmp +
-// remaining data on disk) and return DiskHashError instead of
-// CollisionError if it doesn't equal loc[:32].
-func (v *UnixVolume) Compare(loc string, cmp []byte) error {
+func (v *UnixVolume) Compare(loc string, expect []byte) error {
        path := v.blockPath(loc)
        stat, err := v.stat(path)
        if err != nil {
@@ -126,6 +121,7 @@ func (v *UnixVolume) Compare(loc string, cmp []byte) error {
        if int64(bufLen) > stat.Size() {
                bufLen = int(stat.Size())
        }
+       cmp := expect
        buf := make([]byte, bufLen)
        return v.getFunc(path, func(rdr io.Reader) error {
                // Loop invariants: all data read so far matched what
@@ -134,17 +130,13 @@ func (v *UnixVolume) Compare(loc string, cmp []byte) error {
                // reader.
                for {
                        n, err := rdr.Read(buf)
-                       if n > len(cmp) {
-                               // file on disk is too long
-                               return CollisionError
-                       } else if n > 0 && bytes.Compare(cmp[:n], buf[:n]) != 0 {
-                               return CollisionError
+                       if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
+                               return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], buf[:n], rdr)
                        }
                        cmp = cmp[n:]
                        if err == io.EOF {
                                if len(cmp) != 0 {
-                                       // file on disk is too short
-                                       return CollisionError
+                                       return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], nil, nil)
                                }
                                return nil
                        } else if err != nil {
@@ -179,9 +171,9 @@ func (v *UnixVolume) Put(loc string, block []byte) error {
        }
        bpath := v.blockPath(loc)
 
-       if v.serialize {
-               v.mutex.Lock()
-               defer v.mutex.Unlock()
+       if v.locker != nil {
+               v.locker.Lock()
+               defer v.locker.Unlock()
        }
        if _, err := tmpfile.Write(block); err != nil {
                log.Printf("%s: writing to %s: %s\n", v, bpath, err)
@@ -308,9 +300,9 @@ func (v *UnixVolume) Delete(loc string) error {
        if v.readonly {
                return MethodDisabledError
        }
-       if v.serialize {
-               v.mutex.Lock()
-               defer v.mutex.Unlock()
+       if v.locker != nil {
+               v.locker.Lock()
+               defer v.locker.Unlock()
        }
        p := v.blockPath(loc)
        f, err := os.OpenFile(p, os.O_RDWR|os.O_APPEND, 0644)