7179: Tighten Put requirements when overwriting existing data.
[arvados.git] / services / keepstore / volume_unix.go
index a7ad6f9e499c80439c27cb1beed33060674ed776..74bee52387c5f291f7edc8cc0c36e7ef5b48dd9e 100644 (file)
@@ -1,8 +1,7 @@
-// A UnixVolume is a Volume backed by a locally mounted disk.
-//
 package main
 
 import (
+       "bytes"
        "fmt"
        "io"
        "io/ioutil"
@@ -19,10 +18,12 @@ import (
 
 // A UnixVolume stores and retrieves blocks in a local directory.
 type UnixVolume struct {
-       root      string // path to the volume's root directory
-       serialize bool
-       readonly  bool
-       mutex     sync.Mutex
+       // path to the volume's root directory
+       root string
+       // something to lock during IO, typically a sync.Mutex (or nil
+       // to skip locking)
+       locker   sync.Locker
+       readonly bool
 }
 
 func (v *UnixVolume) Touch(loc string) error {
@@ -35,9 +36,9 @@ func (v *UnixVolume) Touch(loc string) error {
                return err
        }
        defer f.Close()
-       if v.serialize {
-               v.mutex.Lock()
-               defer v.mutex.Unlock()
+       if v.locker != nil {
+               v.locker.Lock()
+               defer v.locker.Unlock()
        }
        if e := lockfile(f); e != nil {
                return e
@@ -57,35 +58,49 @@ func (v *UnixVolume) Mtime(loc string) (time.Time, error) {
        }
 }
 
+// Lock the locker (if one is in use), open the file for reading, and
+// call the given function if and when the file is ready to read.
+func (v *UnixVolume) getFunc(path string, fn func(io.Reader) error) error {
+       if v.locker != nil {
+               v.locker.Lock()
+               defer v.locker.Unlock()
+       }
+       f, err := os.Open(path)
+       if err != nil {
+               return err
+       }
+       defer f.Close()
+       return fn(f)
+}
+
+// stat is os.Stat() with some extra sanity checks.
+func (v *UnixVolume) stat(path string) (os.FileInfo, error) {
+       stat, err := os.Stat(path)
+       if err == nil {
+               if stat.Size() < 0 {
+                       err = os.ErrInvalid
+               } else if stat.Size() > BLOCKSIZE {
+                       err = TooLongError
+               }
+       }
+       return stat, err
+}
+
 // Get retrieves a block identified by the locator string "loc", and
 // returns its contents as a byte slice.
 //
-// If the block could not be found, opened, or read, Get returns a nil
-// slice and whatever non-nil error was returned by Stat or ReadFile.
+// Get returns a nil buffer IFF it returns a non-nil error.
 func (v *UnixVolume) Get(loc string) ([]byte, error) {
        path := v.blockPath(loc)
-       stat, err := os.Stat(path)
+       stat, err := v.stat(path)
        if err != nil {
                return nil, err
        }
-       if stat.Size() < 0 {
-               return nil, os.ErrInvalid
-       } else if stat.Size() == 0 {
-               return bufs.Get(0), nil
-       } else if stat.Size() > BLOCKSIZE {
-               return nil, TooLongError
-       }
-       f, err := os.Open(path)
-       if err != nil {
-               return nil, err
-       }
-       defer f.Close()
        buf := bufs.Get(int(stat.Size()))
-       if v.serialize {
-               v.mutex.Lock()
-               defer v.mutex.Unlock()
-       }
-       _, err = io.ReadFull(f, buf)
+       err = v.getFunc(path, func(rdr io.Reader) error {
+               _, err = io.ReadFull(rdr, buf)
+               return err
+       })
        if err != nil {
                bufs.Put(buf)
                return nil, err
@@ -93,6 +108,44 @@ func (v *UnixVolume) Get(loc string) ([]byte, error) {
        return buf, nil
 }
 
+// Compare returns nil if Get(loc) would return the same content as
+// expect. It is functionally equivalent to Get() followed by
+// bytes.Compare(), but uses less memory.
+func (v *UnixVolume) Compare(loc string, expect []byte) error {
+       path := v.blockPath(loc)
+       stat, err := v.stat(path)
+       if err != nil {
+               return err
+       }
+       bufLen := 1 << 20
+       if int64(bufLen) > stat.Size() {
+               bufLen = int(stat.Size())
+       }
+       cmp := expect
+       buf := make([]byte, bufLen)
+       return v.getFunc(path, func(rdr io.Reader) error {
+               // Loop invariants: all data read so far matched what
+               // we expected, and the first N bytes of cmp are
+               // expected to equal the next N bytes read from
+               // reader.
+               for {
+                       n, err := rdr.Read(buf)
+                       if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
+                               return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], buf[:n], rdr)
+                       }
+                       cmp = cmp[n:]
+                       if err == io.EOF {
+                               if len(cmp) != 0 {
+                                       return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], nil, nil)
+                               }
+                               return nil
+                       } else if err != nil {
+                               return err
+                       }
+               }
+       })
+}
+
 // Put stores a block of data identified by the locator string
 // "loc".  It returns nil on success.  If the volume is full, it
 // returns a FullError.  If the write fails due to some other error,
@@ -118,9 +171,9 @@ func (v *UnixVolume) Put(loc string, block []byte) error {
        }
        bpath := v.blockPath(loc)
 
-       if v.serialize {
-               v.mutex.Lock()
-               defer v.mutex.Unlock()
+       if v.locker != nil {
+               v.locker.Lock()
+               defer v.locker.Unlock()
        }
        if _, err := tmpfile.Write(block); err != nil {
                log.Printf("%s: writing to %s: %s\n", v, bpath, err)
@@ -247,9 +300,9 @@ func (v *UnixVolume) Delete(loc string) error {
        if v.readonly {
                return MethodDisabledError
        }
-       if v.serialize {
-               v.mutex.Lock()
-               defer v.mutex.Unlock()
+       if v.locker != nil {
+               v.locker.Lock()
+               defer v.locker.Unlock()
        }
        p := v.blockPath(loc)
        f, err := os.OpenFile(p, os.O_RDWR|os.O_APPEND, 0644)