--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package keepstore
+
+import (
+ "fmt"
+ "hash"
+ "io"
+)
+
+type hashCheckWriter struct {
+ writer io.Writer
+ hash hash.Hash
+ expectSize int64
+ expectDigest string
+
+ offset int64
+}
+
+// newHashCheckWriter returns a writer that writes through to w, but
+// stops short if the written content reaches expectSize bytes and
+// does not match expectDigest according to the given hash
+// function.
+//
+// It returns a write error if more than expectSize bytes are written.
+//
+// Thus, in case of a hash mismatch, fewer than expectSize will be
+// written through.
+func newHashCheckWriter(writer io.Writer, hash hash.Hash, expectSize int64, expectDigest string) io.Writer {
+ return &hashCheckWriter{
+ writer: writer,
+ hash: hash,
+ expectSize: expectSize,
+ expectDigest: expectDigest,
+ }
+}
+
+func (hcw *hashCheckWriter) Write(p []byte) (int, error) {
+ if todo := hcw.expectSize - hcw.offset - int64(len(p)); todo < 0 {
+ // Writing beyond expected size returns a checksum
+ // error without even checking the hash.
+ return 0, errChecksum
+ } else if todo > 0 {
+ // This isn't the last write, so we pass it through.
+ _, err := hcw.hash.Write(p)
+ if err != nil {
+ return 0, err
+ }
+ n, err := hcw.writer.Write(p)
+ hcw.offset += int64(n)
+ return n, err
+ } else {
+ // This is the last write, so we check the hash before
+ // writing through.
+ _, err := hcw.hash.Write(p)
+ if err != nil {
+ return 0, err
+ }
+ if digest := fmt.Sprintf("%x", hcw.hash.Sum(nil)); digest != hcw.expectDigest {
+ return 0, errChecksum
+ }
+ // Ensure subsequent write will fail
+ hcw.offset = hcw.expectSize + 1
+ return hcw.writer.Write(p)
+ }
+}