X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/d81ea65da05119d5c6480d373b5d42bbee8ae1ad..6fc3e8d44c1bf825f7c3727bab1fef81d2518288:/sdk/go/keepclient/hashcheck.go?ds=sidebyside diff --git a/sdk/go/keepclient/hashcheck.go b/sdk/go/keepclient/hashcheck.go index 1f696d95b6..0966e072ea 100644 --- a/sdk/go/keepclient/hashcheck.go +++ b/sdk/go/keepclient/hashcheck.go @@ -1,8 +1,7 @@ -// Lightweight implementation of io.ReadCloser that checks the contents read -// from the underlying io.Reader a against checksum hash. To avoid reading the -// entire contents into a buffer up front, the hash is updated with each read, -// and the actual checksum is not checked until the underlying reader returns -// EOF. +// Copyright (C) The Arvados Authors. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + package keepclient import ( @@ -14,65 +13,75 @@ import ( var BadChecksum = errors.New("Reader failed checksum") +// HashCheckingReader is an io.ReadCloser that checks the contents +// read from the underlying io.Reader against the provided hash. type HashCheckingReader struct { // The underlying data source io.Reader - // The hashing function to use + // The hash function to use hash.Hash // The hash value to check against. Must be a hex-encoded lowercase string. Check string } -// Read from the underlying reader, update the hashing function, and pass the -// results through. Will return BadChecksum on the last read instead of EOF if -// the checksum doesn't match. -func (this HashCheckingReader) Read(p []byte) (n int, err error) { - n, err = this.Reader.Read(p) +// Reads from the underlying reader, update the hashing function, and +// pass the results through. Returns BadChecksum (instead of EOF) on +// the last read if the checksum doesn't match. +func (hcr HashCheckingReader) Read(p []byte) (n int, err error) { + n, err = hcr.Reader.Read(p) if n > 0 { - this.Hash.Write(p[:n]) + hcr.Hash.Write(p[:n]) } if err == io.EOF { - sum := this.Hash.Sum(make([]byte, 0, this.Hash.Size())) - if fmt.Sprintf("%x", sum) != this.Check { + sum := hcr.Hash.Sum(nil) + if fmt.Sprintf("%x", sum) != hcr.Check { err = BadChecksum } } return n, err } -// Write entire contents of this.Reader to 'dest'. Returns BadChecksum if the -// data written to 'dest' doesn't match the hash code of this.Check. -func (this HashCheckingReader) WriteTo(dest io.Writer) (written int64, err error) { - if writeto, ok := this.Reader.(io.WriterTo); ok { - written, err = writeto.WriteTo(io.MultiWriter(dest, this.Hash)) +// WriteTo writes the entire contents of hcr.Reader to dest. Returns +// BadChecksum if writing is successful but the checksum doesn't +// match. +func (hcr HashCheckingReader) WriteTo(dest io.Writer) (written int64, err error) { + if writeto, ok := hcr.Reader.(io.WriterTo); ok { + written, err = writeto.WriteTo(io.MultiWriter(dest, hcr.Hash)) } else { - written, err = io.Copy(io.MultiWriter(dest, this.Hash), this.Reader) + written, err = io.Copy(io.MultiWriter(dest, hcr.Hash), hcr.Reader) } - sum := this.Hash.Sum(make([]byte, 0, this.Hash.Size())) + if err != nil { + return written, err + } - if fmt.Sprintf("%x", sum) != this.Check { - err = BadChecksum + sum := hcr.Hash.Sum(nil) + if fmt.Sprintf("%x", sum) != hcr.Check { + return written, BadChecksum } - return written, err + return written, nil } -// Close() the underlying Reader if it is castable to io.ReadCloser. This will -// drain the underlying reader of any remaining data and check the checksum. -func (this HashCheckingReader) Close() (err error) { - _, err = io.Copy(this.Hash, this.Reader) +// Close reads all remaining data from the underlying Reader and +// returns BadChecksum if the checksum doesn't match. It also closes +// the underlying Reader if it implements io.ReadCloser. +func (hcr HashCheckingReader) Close() (err error) { + _, err = io.Copy(hcr.Hash, hcr.Reader) - if closer, ok := this.Reader.(io.ReadCloser); ok { - err = closer.Close() + if closer, ok := hcr.Reader.(io.Closer); ok { + closeErr := closer.Close() + if err == nil { + err = closeErr + } } - - sum := this.Hash.Sum(make([]byte, 0, this.Hash.Size())) - if fmt.Sprintf("%x", sum) != this.Check { - err = BadChecksum + if err != nil { + return err } - - return err + if fmt.Sprintf("%x", hcr.Hash.Sum(nil)) != hcr.Check { + return BadChecksum + } + return nil }