refactor as procedural
[arvados.git] / services / keepstore / collision.go
1 package main
2
3 import (
4         "bytes"
5         "context"
6         "crypto/md5"
7         "fmt"
8         "io"
9 )
10
11 // Compute the MD5 digest of a data block (consisting of buf1 + buf2 +
12 // all bytes readable from rdr). If all data is read successfully,
13 // return DiskHashError or CollisionError depending on whether it
14 // matches expectMD5. If an error occurs while reading, return that
15 // error.
16 //
17 // "content has expected MD5" is called a collision because this
18 // function is used in cases where we have another block in hand with
19 // the given MD5 but different content.
20 func collisionOrCorrupt(expectMD5 string, buf1, buf2 []byte, rdr io.Reader) error {
21         outcome := make(chan error)
22         data := make(chan []byte, 1)
23         go func() {
24                 h := md5.New()
25                 for b := range data {
26                         h.Write(b)
27                 }
28                 if fmt.Sprintf("%x", h.Sum(nil)) == expectMD5 {
29                         outcome <- CollisionError
30                 } else {
31                         outcome <- DiskHashError
32                 }
33         }()
34         data <- buf1
35         if buf2 != nil {
36                 data <- buf2
37         }
38         var err error
39         for rdr != nil && err == nil {
40                 buf := make([]byte, 1<<18)
41                 var n int
42                 n, err = rdr.Read(buf)
43                 data <- buf[:n]
44         }
45         close(data)
46         if rdr != nil && err != io.EOF {
47                 <-outcome
48                 return err
49         }
50         return <-outcome
51 }
52
53 func compareReaderWithBuf(ctx context.Context, rdr io.Reader, expect []byte, hash string) error {
54         bufLen := 1 << 20
55         if bufLen > len(expect) && len(expect) > 0 {
56                 // No need for bufLen to be longer than
57                 // expect, except that len(buf)==0 would
58                 // prevent us from handling empty readers the
59                 // same way as non-empty readers: reading 0
60                 // bytes at a time never reaches EOF.
61                 bufLen = len(expect)
62         }
63         buf := make([]byte, bufLen)
64         cmp := expect
65
66         // Loop invariants: all data read so far matched what
67         // we expected, and the first N bytes of cmp are
68         // expected to equal the next N bytes read from
69         // rdr.
70         for {
71                 ready := make(chan bool)
72                 var n int
73                 var err error
74                 go func() {
75                         n, err = rdr.Read(buf)
76                         close(ready)
77                 }()
78                 select {
79                 case <-ready:
80                 case <-ctx.Done():
81                         return ctx.Err()
82                 }
83                 if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
84                         return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], buf[:n], rdr)
85                 }
86                 cmp = cmp[n:]
87                 if err == io.EOF {
88                         if len(cmp) != 0 {
89                                 return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], nil, nil)
90                         }
91                         return nil
92                 } else if err != nil {
93                         return err
94                 }
95         }
96 }