Merge branch 'main' from workbench2.git
[arvados.git] / sdk / go / keepclient / hashcheck.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package keepclient
6
7 import (
8         "errors"
9         "fmt"
10         "hash"
11         "io"
12 )
13
14 var BadChecksum = errors.New("Reader failed checksum")
15
16 // HashCheckingReader is an io.ReadCloser that checks the contents
17 // read from the underlying io.Reader against the provided hash.
18 type HashCheckingReader struct {
19         // The underlying data source
20         io.Reader
21
22         // The hash function to use
23         hash.Hash
24
25         // The hash value to check against.  Must be a hex-encoded lowercase string.
26         Check string
27 }
28
29 // Reads from the underlying reader, update the hashing function, and
30 // pass the results through. Returns BadChecksum (instead of EOF) on
31 // the last read if the checksum doesn't match.
32 func (hcr HashCheckingReader) Read(p []byte) (n int, err error) {
33         n, err = hcr.Reader.Read(p)
34         if n > 0 {
35                 hcr.Hash.Write(p[:n])
36         }
37         if err == io.EOF {
38                 sum := hcr.Hash.Sum(nil)
39                 if fmt.Sprintf("%x", sum) != hcr.Check {
40                         err = BadChecksum
41                 }
42         }
43         return n, err
44 }
45
46 // WriteTo writes the entire contents of hcr.Reader to dest. Returns
47 // BadChecksum if writing is successful but the checksum doesn't
48 // match.
49 func (hcr HashCheckingReader) WriteTo(dest io.Writer) (written int64, err error) {
50         if writeto, ok := hcr.Reader.(io.WriterTo); ok {
51                 written, err = writeto.WriteTo(io.MultiWriter(dest, hcr.Hash))
52         } else {
53                 written, err = io.Copy(io.MultiWriter(dest, hcr.Hash), hcr.Reader)
54         }
55
56         if err != nil {
57                 return written, err
58         }
59
60         sum := hcr.Hash.Sum(nil)
61         if fmt.Sprintf("%x", sum) != hcr.Check {
62                 return written, BadChecksum
63         }
64
65         return written, nil
66 }
67
68 // Close reads all remaining data from the underlying Reader and
69 // returns BadChecksum if the checksum doesn't match. It also closes
70 // the underlying Reader if it implements io.ReadCloser.
71 func (hcr HashCheckingReader) Close() (err error) {
72         _, err = io.Copy(hcr.Hash, hcr.Reader)
73
74         if closer, ok := hcr.Reader.(io.Closer); ok {
75                 closeErr := closer.Close()
76                 if err == nil {
77                         err = closeErr
78                 }
79         }
80         if err != nil {
81                 return err
82         }
83         if fmt.Sprintf("%x", hcr.Hash.Sum(nil)) != hcr.Check {
84                 return BadChecksum
85         }
86         return nil
87 }