Refactor the multi-host salt install page.
[arvados.git] / services / keepstore / collision.go
index be26514a00ce6ae9092bf12981f5f55818a083f1..16f2d0923244b138a64eb970fef1a70dc477532e 100644 (file)
@@ -1,6 +1,12 @@
-package main
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package keepstore
 
 import (
+       "bytes"
+       "context"
        "crypto/md5"
        "fmt"
        "io"
@@ -47,3 +53,48 @@ func collisionOrCorrupt(expectMD5 string, buf1, buf2 []byte, rdr io.Reader) erro
        }
        return <-outcome
 }
+
+func compareReaderWithBuf(ctx context.Context, rdr io.Reader, expect []byte, hash string) error {
+       bufLen := 1 << 20
+       if bufLen > len(expect) && len(expect) > 0 {
+               // No need for bufLen to be longer than
+               // expect, except that len(buf)==0 would
+               // prevent us from handling empty readers the
+               // same way as non-empty readers: reading 0
+               // bytes at a time never reaches EOF.
+               bufLen = len(expect)
+       }
+       buf := make([]byte, bufLen)
+       cmp := expect
+
+       // Loop invariants: all data read so far matched what
+       // we expected, and the first N bytes of cmp are
+       // expected to equal the next N bytes read from
+       // rdr.
+       for {
+               ready := make(chan bool)
+               var n int
+               var err error
+               go func() {
+                       n, err = rdr.Read(buf)
+                       close(ready)
+               }()
+               select {
+               case <-ready:
+               case <-ctx.Done():
+                       return ctx.Err()
+               }
+               if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
+                       return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], buf[:n], rdr)
+               }
+               cmp = cmp[n:]
+               if err == io.EOF {
+                       if len(cmp) != 0 {
+                               return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], nil, nil)
+                       }
+                       return nil
+               } else if err != nil {
+                       return err
+               }
+       }
+}