2960: Refactor keepstore into a streaming server.
[arvados.git] / services / keepstore / hashcheckwriter.go
diff --git a/services/keepstore/hashcheckwriter.go b/services/keepstore/hashcheckwriter.go
new file mode 100644 (file)
index 0000000..f191c98
--- /dev/null
@@ -0,0 +1,68 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package keepstore
+
+import (
+       "fmt"
+       "hash"
+       "io"
+)
+
+type hashCheckWriter struct {
+       writer       io.Writer
+       hash         hash.Hash
+       expectSize   int64
+       expectDigest string
+
+       offset int64
+}
+
+// newHashCheckWriter returns a writer that writes through to w, but
+// stops short if the written content reaches expectSize bytes and
+// does not match expectDigest according to the given hash
+// function.
+//
+// It returns a write error if more than expectSize bytes are written.
+//
+// Thus, in case of a hash mismatch, fewer than expectSize will be
+// written through.
+func newHashCheckWriter(writer io.Writer, hash hash.Hash, expectSize int64, expectDigest string) io.Writer {
+       return &hashCheckWriter{
+               writer:       writer,
+               hash:         hash,
+               expectSize:   expectSize,
+               expectDigest: expectDigest,
+       }
+}
+
+func (hcw *hashCheckWriter) Write(p []byte) (int, error) {
+       if todo := hcw.expectSize - hcw.offset - int64(len(p)); todo < 0 {
+               // Writing beyond expected size returns a checksum
+               // error without even checking the hash.
+               return 0, errChecksum
+       } else if todo > 0 {
+               // This isn't the last write, so we pass it through.
+               _, err := hcw.hash.Write(p)
+               if err != nil {
+                       return 0, err
+               }
+               n, err := hcw.writer.Write(p)
+               hcw.offset += int64(n)
+               return n, err
+       } else {
+               // This is the last write, so we check the hash before
+               // writing through.
+               _, err := hcw.hash.Write(p)
+               if err != nil {
+                       return 0, err
+               }
+               if digest := fmt.Sprintf("%x", hcw.hash.Sum(nil)); digest != hcw.expectDigest {
+                       return 0, errChecksum
+               }
+               // Ensure subsequent write will fail
+               hcw.offset = hcw.expectSize + 1
+               return hcw.writer.Write(p)
+       }
+}