Merge branch 'main' into 15814-wb2-secrets
[arvados.git] / services / keep-web / writebuffer.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepweb
6
7 import (
8         "errors"
9         "io"
10         "net/http"
11         "sync/atomic"
12 )
13
14 // writeBuffer uses a ring buffer to implement an asynchronous write
15 // buffer.
16 //
17 // rpos==wpos means the buffer is empty.
18 //
19 // rpos==(wpos+1)%size means the buffer is full.
20 //
21 // size<2 means the buffer is always empty and full, so in this case
22 // writeBuffer writes through synchronously.
23 type writeBuffer struct {
24         out       io.Writer
25         buf       []byte
26         writesize int           // max bytes flush() should write in a single out.Write()
27         wpos      atomic.Int64  // index in buf where writer (Write()) will write to next
28         wsignal   chan struct{} // receives a value after wpos or closed changes
29         rpos      atomic.Int64  // index in buf where reader (flush()) will read from next
30         rsignal   chan struct{} // receives a value after rpos or err changes
31         err       error         // error encountered by flush
32         closed    atomic.Bool
33         flushed   chan struct{} // closes when flush() is finished
34 }
35
36 func newWriteBuffer(w io.Writer, size int) *writeBuffer {
37         wb := &writeBuffer{
38                 out:       w,
39                 buf:       make([]byte, size),
40                 writesize: (size + 63) / 64,
41                 wsignal:   make(chan struct{}, 1),
42                 rsignal:   make(chan struct{}, 1),
43                 flushed:   make(chan struct{}),
44         }
45         go wb.flush()
46         return wb
47 }
48
49 func (wb *writeBuffer) Close() error {
50         if wb.closed.Load() {
51                 return errors.New("writeBuffer: already closed")
52         }
53         wb.closed.Store(true)
54         // wake up flush()
55         select {
56         case wb.wsignal <- struct{}{}:
57         default:
58         }
59         // wait for flush() to finish
60         <-wb.flushed
61         return wb.err
62 }
63
64 func (wb *writeBuffer) Write(p []byte) (int, error) {
65         if len(wb.buf) < 2 {
66                 // Our buffer logic doesn't work with size<2, and such
67                 // a tiny buffer has no purpose anyway, so just write
68                 // through unbuffered.
69                 return wb.out.Write(p)
70         }
71         todo := p
72         wpos := int(wb.wpos.Load())
73         rpos := int(wb.rpos.Load())
74         for len(todo) > 0 {
75                 // wait until the buffer is not full.
76                 for rpos == (wpos+1)%len(wb.buf) {
77                         select {
78                         case <-wb.flushed:
79                                 if wb.err == nil {
80                                         return 0, errors.New("Write called on closed writeBuffer")
81                                 }
82                                 return 0, wb.err
83                         case <-wb.rsignal:
84                                 rpos = int(wb.rpos.Load())
85                         }
86                 }
87                 // determine next contiguous portion of buffer that is
88                 // available.
89                 var avail []byte
90                 if rpos == 0 {
91                         avail = wb.buf[wpos : len(wb.buf)-1]
92                 } else if wpos >= rpos {
93                         avail = wb.buf[wpos:]
94                 } else {
95                         avail = wb.buf[wpos : rpos-1]
96                 }
97                 n := copy(avail, todo)
98                 wpos = (wpos + n) % len(wb.buf)
99                 wb.wpos.Store(int64(wpos))
100                 // wake up flush()
101                 select {
102                 case wb.wsignal <- struct{}{}:
103                 default:
104                 }
105                 todo = todo[n:]
106         }
107         return len(p), nil
108 }
109
110 func (wb *writeBuffer) flush() {
111         defer close(wb.flushed)
112         rpos := 0
113         wpos := 0
114         closed := false
115         for {
116                 // wait until buffer is not empty.
117                 for rpos == wpos {
118                         if closed {
119                                 return
120                         }
121                         <-wb.wsignal
122                         closed = wb.closed.Load()
123                         wpos = int(wb.wpos.Load())
124                 }
125                 // determine next contiguous portion of buffer that is
126                 // ready to write through.
127                 var ready []byte
128                 if rpos < wpos {
129                         ready = wb.buf[rpos:wpos]
130                 } else {
131                         ready = wb.buf[rpos:]
132                 }
133                 if len(ready) > wb.writesize {
134                         ready = ready[:wb.writesize]
135                 }
136                 _, wb.err = wb.out.Write(ready)
137                 if wb.err != nil {
138                         return
139                 }
140                 rpos = (rpos + len(ready)) % len(wb.buf)
141                 wb.rpos.Store(int64(rpos))
142                 select {
143                 case wb.rsignal <- struct{}{}:
144                 default:
145                 }
146         }
147 }
148
149 // responseWriter enables inserting an io.Writer-wrapper (like
150 // *writeBuffer) into an http.ResponseWriter stack.
151 //
152 // It passes Write() calls to an io.Writer, and all other calls to an
153 // http.ResponseWriter.
154 type responseWriter struct {
155         io.Writer
156         http.ResponseWriter
157 }
158
159 func (rwc responseWriter) Write(p []byte) (int, error) {
160         return rwc.Writer.Write(p)
161 }