// Copyright (C) The Arvados Authors. All rights reserved. // // SPDX-License-Identifier: AGPL-3.0 package keepstore import ( "fmt" "hash" "io" ) type hashCheckWriter struct { writer io.Writer hash hash.Hash expectSize int64 expectDigest string offset int64 } // newHashCheckWriter returns a writer that writes through to w, but // stops short if the written content reaches expectSize bytes and // does not match expectDigest according to the given hash // function. // // It returns a write error if more than expectSize bytes are written. // // Thus, in case of a hash mismatch, fewer than expectSize will be // written through. func newHashCheckWriter(writer io.Writer, hash hash.Hash, expectSize int64, expectDigest string) io.Writer { return &hashCheckWriter{ writer: writer, hash: hash, expectSize: expectSize, expectDigest: expectDigest, } } func (hcw *hashCheckWriter) Write(p []byte) (int, error) { if todo := hcw.expectSize - hcw.offset - int64(len(p)); todo < 0 { // Writing beyond expected size returns a checksum // error without even checking the hash. return 0, errChecksum } else if todo > 0 { // This isn't the last write, so we pass it through. _, err := hcw.hash.Write(p) if err != nil { return 0, err } n, err := hcw.writer.Write(p) hcw.offset += int64(n) return n, err } else { // This is the last write, so we check the hash before // writing through. _, err := hcw.hash.Write(p) if err != nil { return 0, err } if digest := fmt.Sprintf("%x", hcw.hash.Sum(nil)); digest != hcw.expectDigest { return 0, errChecksum } // Ensure subsequent write will fail hcw.offset = hcw.expectSize + 1 return hcw.writer.Write(p) } }