21720:
[arvados.git] / services / keepstore / streamwriterat.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepstore
6
7 import (
8         "errors"
9         "fmt"
10         "io"
11         "sync"
12 )
13
14 // streamWriterAt translates random-access writes to sequential
15 // writes. The caller is expected to use an arbitrary sequence of
16 // non-overlapping WriteAt calls covering all positions between 0 and
17 // N, for any N < len(buf), then call Close.
18 //
19 // streamWriterAt writes the data to the provided io.Writer in
20 // sequential order.
21 //
22 // streamWriterAt can also be wrapped with an io.OffsetWriter to
23 // provide an asynchronous buffer: the caller can use the io.Writer
24 // interface to write into a memory buffer and return without waiting
25 // for the wrapped writer to catch up.
26 //
27 // Close returns when all data has been written through.
28 type streamWriterAt struct {
29         writer     io.Writer
30         buf        []byte
31         writepos   int         // target offset if Write is called
32         partsize   int         // size of each part written through to writer
33         endpos     int         // portion of buf actually used, judging by WriteAt calls so far
34         partfilled []int       // number of bytes written to each part so far
35         partready  chan []byte // parts of buf fully written / waiting for writer goroutine
36         partnext   int         // index of next part we will send to partready when it's ready
37         wroteAt    int         // bytes we copied to buf in WriteAt
38         wrote      int         // bytes successfully written through to writer
39         errWrite   chan error  // final outcome of writer goroutine
40         closed     bool        // streamWriterAt has been closed
41         mtx        sync.Mutex  // guard internal fields during concurrent calls to WriteAt and Close
42 }
43
44 // newStreamWriterAt creates a new streamWriterAt.
45 func newStreamWriterAt(w io.Writer, partsize int, buf []byte) *streamWriterAt {
46         if partsize == 0 {
47                 partsize = 65536
48         }
49         nparts := (len(buf) + partsize - 1) / partsize
50         swa := &streamWriterAt{
51                 writer:     w,
52                 partsize:   partsize,
53                 buf:        buf,
54                 partfilled: make([]int, nparts),
55                 partready:  make(chan []byte, nparts),
56                 errWrite:   make(chan error, 1),
57         }
58         go swa.writeToWriter()
59         return swa
60 }
61
62 // Wrote returns the number of bytes written through to the
63 // io.Writer.
64 //
65 // Wrote must not be called until after Close.
66 func (swa *streamWriterAt) Wrote() int {
67         return swa.wrote
68 }
69
70 // Wrote returns the number of bytes passed to WriteAt, regardless of
71 // whether they were written through to the io.Writer.
72 func (swa *streamWriterAt) WroteAt() int {
73         swa.mtx.Lock()
74         defer swa.mtx.Unlock()
75         return swa.wroteAt
76 }
77
78 func (swa *streamWriterAt) writeToWriter() {
79         defer close(swa.errWrite)
80         for p := range swa.partready {
81                 n, err := swa.writer.Write(p)
82                 if err != nil {
83                         swa.errWrite <- err
84                         return
85                 }
86                 swa.wrote += n
87         }
88 }
89
90 // WriteAt implements io.WriterAt. WriteAt is goroutine-safe.
91 func (swa *streamWriterAt) WriteAt(p []byte, offset int64) (int, error) {
92         pos := int(offset)
93         n := 0
94         if pos <= len(swa.buf) {
95                 n = copy(swa.buf[pos:], p)
96         }
97         if n < len(p) {
98                 return n, fmt.Errorf("write beyond end of buffer: offset %d len %d buf %d", offset, len(p), len(swa.buf))
99         }
100         endpos := pos + n
101
102         swa.mtx.Lock()
103         defer swa.mtx.Unlock()
104         swa.wroteAt += len(p)
105         if swa.endpos < endpos {
106                 swa.endpos = endpos
107         }
108         if swa.closed {
109                 return 0, errors.New("invalid use of closed streamWriterAt")
110         }
111         // Track the number of bytes that landed in each of our
112         // (output) parts.
113         for i := pos; i < endpos; {
114                 j := i + swa.partsize - (i % swa.partsize)
115                 if j > endpos {
116                         j = endpos
117                 }
118                 pf := swa.partfilled[i/swa.partsize]
119                 pf += j - i
120                 if pf > swa.partsize {
121                         return 0, errors.New("streamWriterAt: overlapping WriteAt calls")
122                 }
123                 swa.partfilled[i/swa.partsize] = pf
124                 i = j
125         }
126         // Flush filled parts to partready.
127         for swa.partnext < len(swa.partfilled) && swa.partfilled[swa.partnext] == swa.partsize {
128                 offset := swa.partnext * swa.partsize
129                 swa.partready <- swa.buf[offset : offset+swa.partsize]
130                 swa.partnext++
131         }
132         return len(p), nil
133 }
134
135 // Close flushes all buffered data through to the io.Writer.
136 func (swa *streamWriterAt) Close() error {
137         swa.mtx.Lock()
138         defer swa.mtx.Unlock()
139         if swa.closed {
140                 return errors.New("invalid use of closed streamWriterAt")
141         }
142         swa.closed = true
143         // Flush last part if needed. If the input doesn't end on a
144         // part boundary, the last part never appears "filled" when we
145         // check in WriteAt.  But here, we know endpos is the end of
146         // the stream, so we can check whether the last part is ready.
147         if offset := swa.partnext * swa.partsize; offset < swa.endpos && offset+swa.partfilled[swa.partnext] == swa.endpos {
148                 swa.partready <- swa.buf[offset:swa.endpos]
149                 swa.partnext++
150         }
151         close(swa.partready)
152         err := <-swa.errWrite
153         if err != nil {
154                 return err
155         }
156         if swa.wrote != swa.wroteAt {
157                 return fmt.Errorf("streamWriterAt: detected hole in input: wrote %d but flushed %d", swa.wroteAt, swa.wrote)
158         }
159         return nil
160 }