Merge branch '21356-clean-imports'
[arvados.git] / services / keepstore / hashcheckwriter.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepstore
6
7 import (
8         "fmt"
9         "hash"
10         "io"
11 )
12
13 type hashCheckWriter struct {
14         writer       io.Writer
15         hash         hash.Hash
16         expectSize   int64
17         expectDigest string
18
19         offset int64
20 }
21
22 // newHashCheckWriter returns a writer that writes through to w, but
23 // stops short if the written content reaches expectSize bytes and
24 // does not match expectDigest according to the given hash
25 // function.
26 //
27 // It returns a write error if more than expectSize bytes are written.
28 //
29 // Thus, in case of a hash mismatch, fewer than expectSize will be
30 // written through.
31 func newHashCheckWriter(writer io.Writer, hash hash.Hash, expectSize int64, expectDigest string) io.Writer {
32         return &hashCheckWriter{
33                 writer:       writer,
34                 hash:         hash,
35                 expectSize:   expectSize,
36                 expectDigest: expectDigest,
37         }
38 }
39
40 func (hcw *hashCheckWriter) Write(p []byte) (int, error) {
41         if todo := hcw.expectSize - hcw.offset - int64(len(p)); todo < 0 {
42                 // Writing beyond expected size returns a checksum
43                 // error without even checking the hash.
44                 return 0, errChecksum
45         } else if todo > 0 {
46                 // This isn't the last write, so we pass it through.
47                 _, err := hcw.hash.Write(p)
48                 if err != nil {
49                         return 0, err
50                 }
51                 n, err := hcw.writer.Write(p)
52                 hcw.offset += int64(n)
53                 return n, err
54         } else {
55                 // This is the last write, so we check the hash before
56                 // writing through.
57                 _, err := hcw.hash.Write(p)
58                 if err != nil {
59                         return 0, err
60                 }
61                 if digest := fmt.Sprintf("%x", hcw.hash.Sum(nil)); digest != hcw.expectDigest {
62                         return 0, errChecksum
63                 }
64                 // Ensure subsequent write will fail
65                 hcw.offset = hcw.expectSize + 1
66                 return hcw.writer.Write(p)
67         }
68 }