// Copyright (C) The Arvados Authors. All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0

package main

import (
	"bytes"
	"context"
	"crypto/md5"
	"fmt"
	"io"
)

// Compute the MD5 digest of a data block (consisting of buf1 + buf2 +
// all bytes readable from rdr). If all data is read successfully,
// return DiskHashError or CollisionError depending on whether it
// matches expectMD5. If an error occurs while reading, return that
// error.
//
// "content has expected MD5" is called a collision because this
// function is used in cases where we have another block in hand with
// the given MD5 but different content.
func collisionOrCorrupt(expectMD5 string, buf1, buf2 []byte, rdr io.Reader) error {
	outcome := make(chan error)
	data := make(chan []byte, 1)
	go func() {
		h := md5.New()
		for b := range data {
			h.Write(b)
		}
		if fmt.Sprintf("%x", h.Sum(nil)) == expectMD5 {
			outcome <- CollisionError
		} else {
			outcome <- DiskHashError
		}
	}()
	data <- buf1
	if buf2 != nil {
		data <- buf2
	}
	var err error
	for rdr != nil && err == nil {
		buf := make([]byte, 1<<18)
		var n int
		n, err = rdr.Read(buf)
		data <- buf[:n]
	}
	close(data)
	if rdr != nil && err != io.EOF {
		<-outcome
		return err
	}
	return <-outcome
}

func compareReaderWithBuf(ctx context.Context, rdr io.Reader, expect []byte, hash string) error {
	bufLen := 1 << 20
	if bufLen > len(expect) && len(expect) > 0 {
		// No need for bufLen to be longer than
		// expect, except that len(buf)==0 would
		// prevent us from handling empty readers the
		// same way as non-empty readers: reading 0
		// bytes at a time never reaches EOF.
		bufLen = len(expect)
	}
	buf := make([]byte, bufLen)
	cmp := expect

	// Loop invariants: all data read so far matched what
	// we expected, and the first N bytes of cmp are
	// expected to equal the next N bytes read from
	// rdr.
	for {
		ready := make(chan bool)
		var n int
		var err error
		go func() {
			n, err = rdr.Read(buf)
			close(ready)
		}()
		select {
		case <-ready:
		case <-ctx.Done():
			return ctx.Err()
		}
		if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
			return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], buf[:n], rdr)
		}
		cmp = cmp[n:]
		if err == io.EOF {
			if len(cmp) != 0 {
				return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], nil, nil)
			}
			return nil
		} else if err != nil {
			return err
		}
	}
}