2960: Buffer reads when serialize enabled on unix volume.
[arvados.git] / services / keepstore / streamwriterat.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepstore
6
7 import (
8         "errors"
9         "fmt"
10         "io"
11         "sync"
12 )
13
14 // streamWriterAt translates random-access writes to sequential
15 // writes. The caller is expected to use an arbitrary sequence of
16 // non-overlapping WriteAt calls covering all positions between 0 and
17 // N, for any N < len(buf), then call Close.
18 //
19 // streamWriterAt writes the data to the provided io.Writer in
20 // sequential order.
21 //
22 // streamWriterAt can also be used as an asynchronous buffer: the
23 // caller can use the io.Writer interface to write into a memory
24 // buffer and return without waiting for the wrapped writer to catch
25 // up.
26 //
27 // Close returns when all data has been written through.
28 type streamWriterAt struct {
29         writer     io.Writer
30         buf        []byte
31         writepos   int         // target offset if Write is called
32         partsize   int         // size of each part written through to writer
33         endpos     int         // portion of buf actually used, judging by WriteAt calls so far
34         partfilled []int       // number of bytes written to each part so far
35         partready  chan []byte // parts of buf fully written / waiting for writer goroutine
36         partnext   int         // index of next part we will send to partready when it's ready
37         wroteAt    int         // bytes we copied to buf in WriteAt
38         wrote      int         // bytes successfully written through to writer
39         errWrite   chan error  // final outcome of writer goroutine
40         closed     bool        // streamWriterAt has been closed
41         mtx        sync.Mutex  // guard internal fields during concurrent calls to WriteAt and Close
42 }
43
44 // newStreamWriterAt creates a new streamWriterAt.
45 func newStreamWriterAt(w io.Writer, partsize int, buf []byte) *streamWriterAt {
46         if partsize == 0 {
47                 partsize = 65536
48         }
49         nparts := (len(buf) + partsize - 1) / partsize
50         swa := &streamWriterAt{
51                 writer:     w,
52                 partsize:   partsize,
53                 buf:        buf,
54                 partfilled: make([]int, nparts),
55                 partready:  make(chan []byte, nparts),
56                 errWrite:   make(chan error, 1),
57         }
58         go swa.writeToWriter()
59         return swa
60 }
61
62 // Wrote returns the number of bytes written through to the
63 // io.Writer.
64 //
65 // Wrote must not be called until after Close.
66 func (swa *streamWriterAt) Wrote() int {
67         return swa.wrote
68 }
69
70 // Wrote returns the number of bytes passed to WriteAt, regardless of
71 // whether they were written through to the io.Writer.
72 func (swa *streamWriterAt) WroteAt() int {
73         swa.mtx.Lock()
74         defer swa.mtx.Unlock()
75         return swa.wroteAt
76 }
77
78 func (swa *streamWriterAt) writeToWriter() {
79         defer close(swa.errWrite)
80         for p := range swa.partready {
81                 n, err := swa.writer.Write(p)
82                 if err != nil {
83                         swa.errWrite <- err
84                         return
85                 }
86                 swa.wrote += n
87         }
88 }
89
90 // Write implements io.Writer.
91 func (swa *streamWriterAt) Write(p []byte) (int, error) {
92         n, err := swa.WriteAt(p, int64(swa.writepos))
93         swa.writepos += n
94         return n, err
95 }
96
97 // WriteAt implements io.WriterAt.
98 func (swa *streamWriterAt) WriteAt(p []byte, offset int64) (int, error) {
99         pos := int(offset)
100         n := 0
101         if pos <= len(swa.buf) {
102                 n = copy(swa.buf[pos:], p)
103         }
104         if n < len(p) {
105                 return n, fmt.Errorf("write beyond end of buffer: offset %d len %d buf %d", offset, len(p), len(swa.buf))
106         }
107         endpos := pos + n
108
109         swa.mtx.Lock()
110         defer swa.mtx.Unlock()
111         swa.wroteAt += len(p)
112         if swa.endpos < endpos {
113                 swa.endpos = endpos
114         }
115         if swa.closed {
116                 return 0, errors.New("invalid use of closed streamWriterAt")
117         }
118         // Track the number of bytes that landed in each of our
119         // (output) parts.
120         for i := pos; i < endpos; {
121                 j := i + swa.partsize - (i % swa.partsize)
122                 if j > endpos {
123                         j = endpos
124                 }
125                 pf := swa.partfilled[i/swa.partsize]
126                 pf += j - i
127                 if pf > swa.partsize {
128                         return 0, errors.New("streamWriterAt: overlapping WriteAt calls")
129                 }
130                 swa.partfilled[i/swa.partsize] = pf
131                 i = j
132         }
133         // Flush filled parts to partready.
134         for swa.partnext < len(swa.partfilled) && swa.partfilled[swa.partnext] == swa.partsize {
135                 offset := swa.partnext * swa.partsize
136                 swa.partready <- swa.buf[offset : offset+swa.partsize]
137                 swa.partnext++
138         }
139         return len(p), nil
140 }
141
142 // Close flushes all buffered data through to the io.Writer.
143 func (swa *streamWriterAt) Close() error {
144         swa.mtx.Lock()
145         defer swa.mtx.Unlock()
146         if swa.closed {
147                 return errors.New("invalid use of closed streamWriterAt")
148         }
149         swa.closed = true
150         // Flush last part if needed. If the input doesn't end on a
151         // part boundary, the last part never appears "filled" when we
152         // check in WriteAt.  But here, we know endpos is the end of
153         // the stream, so we can check whether the last part is ready.
154         if offset := swa.partnext * swa.partsize; offset < swa.endpos && offset+swa.partfilled[swa.partnext] == swa.endpos {
155                 swa.partready <- swa.buf[offset:swa.endpos]
156                 swa.partnext++
157         }
158         close(swa.partready)
159         err := <-swa.errWrite
160         if err != nil {
161                 return err
162         }
163         if swa.wrote != swa.wroteAt {
164                 return fmt.Errorf("streamWriterAt: detected hole in input: wrote %d but flushed %d", swa.wroteAt, swa.wrote)
165         }
166         return nil
167 }