--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package keepweb
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "sync/atomic"
+)
+
+// writeBuffer uses a ring buffer to implement an asynchronous write
+// buffer.
+//
+// rpos==wpos means the buffer is empty.
+//
+// rpos==(wpos+1)%size means the buffer is full.
+//
+// size<2 means the buffer is always empty and full, so in this case
+// writeBuffer writes through synchronously.
+type writeBuffer struct {
+ out io.Writer
+ buf []byte
+ writesize int // max bytes flush() should write in a single out.Write()
+ wpos atomic.Int64 // index in buf where writer (Write()) will write to next
+ wsignal chan struct{} // receives a value after wpos or closed changes
+ rpos atomic.Int64 // index in buf where reader (flush()) will read from next
+ rsignal chan struct{} // receives a value after rpos or err changes
+ err error // error encountered by flush
+ closed atomic.Bool
+ flushed chan struct{} // closes when flush() is finished
+}
+
+func newWriteBuffer(w io.Writer, size int) *writeBuffer {
+ wb := &writeBuffer{
+ out: w,
+ buf: make([]byte, size),
+ writesize: (size + 63) / 64,
+ wsignal: make(chan struct{}, 1),
+ rsignal: make(chan struct{}, 1),
+ flushed: make(chan struct{}),
+ }
+ go wb.flush()
+ return wb
+}
+
+func (wb *writeBuffer) Close() error {
+ if wb.closed.Load() {
+ return errors.New("writeBuffer: already closed")
+ }
+ wb.closed.Store(true)
+ // wake up flush()
+ select {
+ case wb.wsignal <- struct{}{}:
+ default:
+ }
+ // wait for flush() to finish
+ <-wb.flushed
+ return wb.err
+}
+
+func (wb *writeBuffer) Write(p []byte) (int, error) {
+ if len(wb.buf) < 2 {
+ // Our buffer logic doesn't work with size<2, and such
+ // a tiny buffer has no purpose anyway, so just write
+ // through unbuffered.
+ return wb.out.Write(p)
+ }
+ todo := p
+ wpos := int(wb.wpos.Load())
+ rpos := int(wb.rpos.Load())
+ for len(todo) > 0 {
+ // wait until the buffer is not full.
+ for rpos == (wpos+1)%len(wb.buf) {
+ select {
+ case <-wb.flushed:
+ if wb.err == nil {
+ return 0, errors.New("Write called on closed writeBuffer")
+ }
+ return 0, wb.err
+ case <-wb.rsignal:
+ rpos = int(wb.rpos.Load())
+ }
+ }
+ // determine next contiguous portion of buffer that is
+ // available.
+ var avail []byte
+ if rpos == 0 {
+ avail = wb.buf[wpos : len(wb.buf)-1]
+ } else if wpos >= rpos {
+ avail = wb.buf[wpos:]
+ } else {
+ avail = wb.buf[wpos : rpos-1]
+ }
+ n := copy(avail, todo)
+ wpos = (wpos + n) % len(wb.buf)
+ wb.wpos.Store(int64(wpos))
+ // wake up flush()
+ select {
+ case wb.wsignal <- struct{}{}:
+ default:
+ }
+ todo = todo[n:]
+ }
+ return len(p), nil
+}
+
+func (wb *writeBuffer) flush() {
+ defer close(wb.flushed)
+ rpos := 0
+ wpos := 0
+ closed := false
+ for {
+ // wait until buffer is not empty.
+ for rpos == wpos {
+ if closed {
+ return
+ }
+ <-wb.wsignal
+ closed = wb.closed.Load()
+ wpos = int(wb.wpos.Load())
+ }
+ // determine next contiguous portion of buffer that is
+ // ready to write through.
+ var ready []byte
+ if rpos < wpos {
+ ready = wb.buf[rpos:wpos]
+ } else {
+ ready = wb.buf[rpos:]
+ }
+ if len(ready) > wb.writesize {
+ ready = ready[:wb.writesize]
+ }
+ _, wb.err = wb.out.Write(ready)
+ if wb.err != nil {
+ return
+ }
+ rpos = (rpos + len(ready)) % len(wb.buf)
+ wb.rpos.Store(int64(rpos))
+ select {
+ case wb.rsignal <- struct{}{}:
+ default:
+ }
+ }
+}
+
+// responseWriter enables inserting an io.Writer-wrapper (like
+// *writeBuffer) into an http.ResponseWriter stack.
+//
+// It passes Write() calls to an io.Writer, and all other calls to an
+// http.ResponseWriter.
+type responseWriter struct {
+ io.Writer
+ http.ResponseWriter
+}
+
+func (rwc responseWriter) Write(p []byte) (int, error) {
+ return rwc.Writer.Write(p)
+}