2960: Refactor keepstore into a streaming server.
[arvados.git] / services / keepstore / streamwriterat.go
diff --git a/services/keepstore/streamwriterat.go b/services/keepstore/streamwriterat.go
new file mode 100644 (file)
index 0000000..365b55f
--- /dev/null
@@ -0,0 +1,154 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package keepstore
+
+import (
+       "errors"
+       "fmt"
+       "io"
+       "sync"
+)
+
+// streamWriterAt translates random-access writes to sequential
+// writes. The caller is expected to use an arbitrary sequence of
+// non-overlapping WriteAt calls covering all positions between 0 and
+// N, for any N < len(buf), then call Close.
+//
+// streamWriterAt writes the data to the provided io.Writer in
+// sequential order.
+//
+// Close returns when all data has been written through.
+type streamWriterAt struct {
+       writer     io.Writer
+       buf        []byte
+       partsize   int         // size of each part written through to writer
+       endpos     int         // portion of buf actually used, judging by WriteAt calls so far
+       partfilled []int       // number of bytes written to each part so far
+       partready  chan []byte // parts of buf fully written / waiting for writer goroutine
+       partnext   int         // index of next part we will send to partready when it's ready
+       wroteAt    int         // bytes we copied to buf in WriteAt
+       wrote      int         // bytes successfully written through to writer
+       errWrite   chan error  // final outcome of writer goroutine
+       closed     bool        // streamWriterAt has been closed
+       mtx        sync.Mutex  // guard internal fields during concurrent calls to WriteAt and Close
+}
+
+// newStreamWriterAt creates a new streamWriterAt.
+func newStreamWriterAt(w io.Writer, partsize int, buf []byte) *streamWriterAt {
+       if partsize == 0 {
+               partsize = 65536
+       }
+       nparts := (len(buf) + partsize - 1) / partsize
+       swa := &streamWriterAt{
+               writer:     w,
+               partsize:   partsize,
+               buf:        buf,
+               partfilled: make([]int, nparts),
+               partready:  make(chan []byte, nparts),
+               errWrite:   make(chan error, 1),
+       }
+       go swa.writeToWriter()
+       return swa
+}
+
+// Wrote returns the number of bytes written through to the
+// io.Writer.
+//
+// Wrote must not be called until after Close.
+func (swa *streamWriterAt) Wrote() int {
+       return swa.wrote
+}
+
+// Wrote returns the number of bytes passed to WriteAt, regardless of
+// whether they were written through to the io.Writer.
+func (swa *streamWriterAt) WroteAt() int {
+       swa.mtx.Lock()
+       defer swa.mtx.Unlock()
+       return swa.wroteAt
+}
+
+func (swa *streamWriterAt) writeToWriter() {
+       defer close(swa.errWrite)
+       for p := range swa.partready {
+               n, err := swa.writer.Write(p)
+               if err != nil {
+                       swa.errWrite <- err
+                       return
+               }
+               swa.wrote += n
+       }
+}
+
+// WriteAt implements io.WriterAt.
+func (swa *streamWriterAt) WriteAt(p []byte, offset int64) (int, error) {
+       pos := int(offset)
+       n := 0
+       if pos <= len(swa.buf) {
+               n = copy(swa.buf[pos:], p)
+       }
+       if n < len(p) {
+               return n, fmt.Errorf("write beyond end of buffer: offset %d len %d buf %d", offset, len(p), len(swa.buf))
+       }
+       endpos := pos + n
+
+       swa.mtx.Lock()
+       defer swa.mtx.Unlock()
+       swa.wroteAt += len(p)
+       if swa.endpos < endpos {
+               swa.endpos = endpos
+       }
+       if swa.closed {
+               return 0, errors.New("invalid use of closed streamWriterAt")
+       }
+       // Track the number of bytes that landed in each of our
+       // (output) parts.
+       for i := pos; i < endpos; {
+               j := i + swa.partsize - (i % swa.partsize)
+               if j > endpos {
+                       j = endpos
+               }
+               pf := swa.partfilled[i/swa.partsize]
+               pf += j - i
+               if pf > swa.partsize {
+                       return 0, errors.New("streamWriterAt: overlapping WriteAt calls")
+               }
+               swa.partfilled[i/swa.partsize] = pf
+               i = j
+       }
+       // Flush filled parts to partready.
+       for swa.partnext < len(swa.partfilled) && swa.partfilled[swa.partnext] == swa.partsize {
+               offset := swa.partnext * swa.partsize
+               swa.partready <- swa.buf[offset : offset+swa.partsize]
+               swa.partnext++
+       }
+       return len(p), nil
+}
+
+// Close flushes all buffered data through to the io.Writer.
+func (swa *streamWriterAt) Close() error {
+       swa.mtx.Lock()
+       defer swa.mtx.Unlock()
+       if swa.closed {
+               return errors.New("invalid use of closed streamWriterAt")
+       }
+       swa.closed = true
+       // Flush last part if needed. If the input doesn't end on a
+       // part boundary, the last part never appears "filled" when we
+       // check in WriteAt.  But here, we know endpos is the end of
+       // the stream, so we can check whether the last part is ready.
+       if offset := swa.partnext * swa.partsize; offset < swa.endpos && offset+swa.partfilled[swa.partnext] == swa.endpos {
+               swa.partready <- swa.buf[offset:swa.endpos]
+               swa.partnext++
+       }
+       close(swa.partready)
+       err := <-swa.errWrite
+       if err != nil {
+               return err
+       }
+       if swa.wrote != swa.wroteAt {
+               return fmt.Errorf("streamWriterAt: detected hole in input: wrote %d but flushed %d", swa.wroteAt, swa.wrote)
+       }
+       return nil
+}