// Copyright (C) The Arvados Authors. All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0

package keepweb

import (
	"errors"
	"io"
	"net/http"
	"sync/atomic"
)

// writeBuffer uses a ring buffer to implement an asynchronous write
// buffer.
//
// rpos==wpos means the buffer is empty.
//
// rpos==(wpos+1)%size means the buffer is full.
//
// size<2 means the buffer is always empty and full, so in this case
// writeBuffer writes through synchronously.
type writeBuffer struct {
	out       io.Writer
	buf       []byte
	writesize int           // max bytes flush() should write in a single out.Write()
	wpos      atomic.Int64  // index in buf where writer (Write()) will write to next
	wsignal   chan struct{} // receives a value after wpos or closed changes
	rpos      atomic.Int64  // index in buf where reader (flush()) will read from next
	rsignal   chan struct{} // receives a value after rpos or err changes
	err       error         // error encountered by flush
	closed    atomic.Bool
	flushed   chan struct{} // closes when flush() is finished
}

func newWriteBuffer(w io.Writer, size int) *writeBuffer {
	wb := &writeBuffer{
		out:       w,
		buf:       make([]byte, size),
		writesize: (size + 63) / 64,
		wsignal:   make(chan struct{}, 1),
		rsignal:   make(chan struct{}, 1),
		flushed:   make(chan struct{}),
	}
	go wb.flush()
	return wb
}

func (wb *writeBuffer) Close() error {
	if wb.closed.Load() {
		return errors.New("writeBuffer: already closed")
	}
	wb.closed.Store(true)
	// wake up flush()
	select {
	case wb.wsignal <- struct{}{}:
	default:
	}
	// wait for flush() to finish
	<-wb.flushed
	return wb.err
}

func (wb *writeBuffer) Write(p []byte) (int, error) {
	if len(wb.buf) < 2 {
		// Our buffer logic doesn't work with size<2, and such
		// a tiny buffer has no purpose anyway, so just write
		// through unbuffered.
		return wb.out.Write(p)
	}
	todo := p
	wpos := int(wb.wpos.Load())
	rpos := int(wb.rpos.Load())
	for len(todo) > 0 {
		// wait until the buffer is not full.
		for rpos == (wpos+1)%len(wb.buf) {
			select {
			case <-wb.flushed:
				if wb.err == nil {
					return 0, errors.New("Write called on closed writeBuffer")
				}
				return 0, wb.err
			case <-wb.rsignal:
				rpos = int(wb.rpos.Load())
			}
		}
		// determine next contiguous portion of buffer that is
		// available.
		var avail []byte
		if rpos == 0 {
			avail = wb.buf[wpos : len(wb.buf)-1]
		} else if wpos >= rpos {
			avail = wb.buf[wpos:]
		} else {
			avail = wb.buf[wpos : rpos-1]
		}
		n := copy(avail, todo)
		wpos = (wpos + n) % len(wb.buf)
		wb.wpos.Store(int64(wpos))
		// wake up flush()
		select {
		case wb.wsignal <- struct{}{}:
		default:
		}
		todo = todo[n:]
	}
	return len(p), nil
}

func (wb *writeBuffer) flush() {
	defer close(wb.flushed)
	rpos := 0
	wpos := 0
	closed := false
	for {
		// wait until buffer is not empty.
		for rpos == wpos {
			if closed {
				return
			}
			<-wb.wsignal
			closed = wb.closed.Load()
			wpos = int(wb.wpos.Load())
		}
		// determine next contiguous portion of buffer that is
		// ready to write through.
		var ready []byte
		if rpos < wpos {
			ready = wb.buf[rpos:wpos]
		} else {
			ready = wb.buf[rpos:]
		}
		if len(ready) > wb.writesize {
			ready = ready[:wb.writesize]
		}
		_, wb.err = wb.out.Write(ready)
		if wb.err != nil {
			return
		}
		rpos = (rpos + len(ready)) % len(wb.buf)
		wb.rpos.Store(int64(rpos))
		select {
		case wb.rsignal <- struct{}{}:
		default:
		}
	}
}

// responseWriter enables inserting an io.Writer-wrapper (like
// *writeBuffer) into an http.ResponseWriter stack.
//
// It passes Write() calls to an io.Writer, and all other calls to an
// http.ResponseWriter.
type responseWriter struct {
	io.Writer
	http.ResponseWriter
}

func (rwc responseWriter) Write(p []byte) (int, error) {
	return rwc.Writer.Write(p)
}