21606: Add output buffer for webdav downloads.
authorTom Clegg <tom@curii.com>
Tue, 19 Mar 2024 19:37:32 +0000 (15:37 -0400)
committerTom Clegg <tom@curii.com>
Thu, 11 Apr 2024 19:46:10 +0000 (15:46 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/config/config.default.yml
lib/config/export.go
sdk/go/arvados/config.go
services/keep-web/handler.go
services/keep-web/writebuffer.go [new file with mode: 0644]
services/keep-web/writebuffer_test.go [new file with mode: 0644]

index a3ae4fd56bbc179f67fc1f21e6c9cdb2db5c43df..eb9173ec55e6d2853aa9f844bf33329186a1e042 100644 (file)
@@ -801,6 +801,10 @@ Clusters:
       # load on the API server and you don't need it.
       WebDAVLogEvents: true
 
+      # Per-connection output buffer for WebDAV downloads. May improve
+      # throughput for large files.
+      WebDAVOutputBuffer: 1M
+
     Login:
       # One of the following mechanisms (Google, PAM, LDAP, or
       # LoginCluster) should be enabled; see
index 4b6c142ff2e29f41bcf2b843ac6479b54dd436aa..f511ebbcb16b1a238f0a5b77fdc85c2de6518367 100644 (file)
@@ -122,6 +122,7 @@ var whitelist = map[string]bool{
        "Collections.TrustAllContent":              true,
        "Collections.WebDAVCache":                  false,
        "Collections.WebDAVLogEvents":              false,
+       "Collections.WebDAVOutputBuffer":           false,
        "Collections.WebDAVPermission":             false,
        "Containers":                               true,
        "Containers.AlwaysUsePreemptibleInstances": true,
index 698ee20d8c6bcc58119a02e0330f19ca0e7a64ee..116051b09e3717f0d24c86aa967ae9452eb7877e 100644 (file)
@@ -159,6 +159,7 @@ type Cluster struct {
                KeepproxyPermission UploadDownloadRolePermissions
                WebDAVPermission    UploadDownloadRolePermissions
                WebDAVLogEvents     bool
+               WebDAVOutputBuffer  ByteSize
        }
        Git struct {
                GitCommand   string
index e0da14e774525d9b860e6c92c62a010653e25d06..cdd51f0bb7c7fcee60dd21970192fbd3a65aa6d0 100644 (file)
@@ -178,7 +178,12 @@ func (h *handler) ServeHTTP(wOrig http.ResponseWriter, r *http.Request) {
                r.URL.Scheme = xfp
        }
 
-       w := httpserver.WrapResponseWriter(wOrig)
+       wbuffer := newWriteBuffer(wOrig, int(h.Cluster.Collections.WebDAVOutputBuffer))
+       defer wbuffer.Close()
+       w := httpserver.WrapResponseWriter(responseWriter{
+               Writer:         wbuffer,
+               ResponseWriter: wOrig,
+       })
 
        if r.Method == "OPTIONS" && ServeCORSPreflight(w, r.Header) {
                return
diff --git a/services/keep-web/writebuffer.go b/services/keep-web/writebuffer.go
new file mode 100644 (file)
index 0000000..f309b69
--- /dev/null
@@ -0,0 +1,141 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package keepweb
+
+import (
+       "errors"
+       "io"
+       "net/http"
+       "sync/atomic"
+)
+
+type writeBuffer struct {
+       out       io.Writer
+       buf       []byte
+       writesize int
+       wpos      atomic.Int64  // index in buf where writer (Write()) will write to next
+       wsignal   chan struct{} // receives a value after wpos or closed changes
+       rpos      atomic.Int64  // index in buf where reader (flush()) will read from next
+       rsignal   chan struct{} // receives a value after rpos or err changes
+       err       error         // error encountered by flush
+       closed    atomic.Bool
+       flushed   chan struct{} // closes when flush() is finished
+}
+
+func newWriteBuffer(w io.Writer, size int) *writeBuffer {
+       wb := &writeBuffer{
+               out:       w,
+               buf:       make([]byte, size),
+               writesize: (size + 63) / 64,
+               wsignal:   make(chan struct{}, 1),
+               rsignal:   make(chan struct{}, 1),
+               flushed:   make(chan struct{}),
+       }
+       go wb.flush()
+       return wb
+}
+
+func (wb *writeBuffer) Close() error {
+       if wb.closed.Load() {
+               return errors.New("writeBuffer: already closed")
+       }
+       wb.closed.Store(true)
+       // wake up flush()
+       select {
+       case wb.wsignal <- struct{}{}:
+       default:
+       }
+       // wait for flush() to finish
+       <-wb.flushed
+       return wb.err
+}
+
+func (wb *writeBuffer) Write(p []byte) (int, error) {
+       if len(wb.buf) < 2 {
+               // Our buffer logic doesn't work with size<2, and such
+               // a tiny buffer has no purpose anyway, so just write
+               // through unbuffered.
+               return wb.out.Write(p)
+       }
+       todo := p
+       wpos := int(wb.wpos.Load())
+       rpos := int(wb.rpos.Load())
+       for len(todo) > 0 {
+               for rpos == (wpos+1)%len(wb.buf) {
+                       select {
+                       case <-wb.flushed:
+                               if wb.err == nil {
+                                       return 0, errors.New("Write called on closed writeBuffer")
+                               }
+                               return 0, wb.err
+                       case <-wb.rsignal:
+                               rpos = int(wb.rpos.Load())
+                       }
+               }
+               var avail []byte
+               if rpos == 0 {
+                       avail = wb.buf[wpos : len(wb.buf)-1]
+               } else if wpos >= rpos {
+                       avail = wb.buf[wpos:]
+               } else {
+                       avail = wb.buf[wpos : rpos-1]
+               }
+               n := copy(avail, todo)
+               wpos = (wpos + n) % len(wb.buf)
+               wb.wpos.Store(int64(wpos))
+               // wake up flush()
+               select {
+               case wb.wsignal <- struct{}{}:
+               default:
+               }
+               todo = todo[n:]
+       }
+       return len(p), nil
+}
+
+func (wb *writeBuffer) flush() {
+       defer close(wb.flushed)
+       rpos := 0
+       wpos := 0
+       closed := false
+       for {
+               for rpos == wpos {
+                       if closed {
+                               return
+                       }
+                       <-wb.wsignal
+                       closed = wb.closed.Load()
+                       wpos = int(wb.wpos.Load())
+               }
+               var ready []byte
+               if rpos < wpos {
+                       ready = wb.buf[rpos:wpos]
+               } else {
+                       ready = wb.buf[rpos:]
+               }
+               if len(ready) > wb.writesize {
+                       ready = ready[:wb.writesize]
+               }
+               _, wb.err = wb.out.Write(ready)
+               if wb.err != nil {
+                       return
+               }
+               rpos = (rpos + len(ready)) % len(wb.buf)
+               wb.rpos.Store(int64(rpos))
+               select {
+               case wb.rsignal <- struct{}{}:
+               default:
+               }
+       }
+}
+
+type responseWriter struct {
+       io.Writer
+       http.ResponseWriter
+}
+
+func (rwc responseWriter) Write(p []byte) (int, error) {
+       return rwc.Writer.Write(p)
+}
diff --git a/services/keep-web/writebuffer_test.go b/services/keep-web/writebuffer_test.go
new file mode 100644 (file)
index 0000000..589dc24
--- /dev/null
@@ -0,0 +1,98 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package keepweb
+
+import (
+       "bytes"
+       "io"
+       "math/rand"
+       "time"
+
+       . "gopkg.in/check.v1"
+)
+
+var _ = Suite(&writeBufferSuite{})
+
+type writeBufferSuite struct {
+}
+
+// 1000 / 96.3 ns/op = 10.384 GB/s
+func (s *writeBufferSuite) Benchmark_1KBWrites(c *C) {
+       wb := newWriteBuffer(io.Discard, 1<<20)
+       in := make([]byte, 1000)
+       for i := 0; i < c.N; i++ {
+               wb.Write(in)
+       }
+       wb.Close()
+}
+
+func (s *writeBufferSuite) TestRandomizedSpeedsAndSizes(c *C) {
+       for i := 0; i < 20; i++ {
+               insize := rand.Intn(1 << 26)
+               bufsize := rand.Intn(1 << 26)
+               if i < 2 {
+                       // make sure to test edge cases
+                       bufsize = i
+               } else if insize/bufsize > 1000 {
+                       // don't waste too much time testing tiny
+                       // buffer / huge content
+                       insize = bufsize*1000 + 123
+               }
+               c.Logf("%s: insize %d bufsize %d", c.TestName(), insize, bufsize)
+
+               in := make([]byte, insize)
+               b := byte(0)
+               for i := range in {
+                       in[i] = b
+                       b++
+               }
+
+               out := &bytes.Buffer{}
+               done := make(chan struct{})
+               pr, pw := io.Pipe()
+               go func() {
+                       n, err := slowCopy(out, pr, rand.Intn(8192)+1)
+                       c.Check(err, IsNil)
+                       c.Check(n, Equals, int64(insize))
+                       close(done)
+               }()
+               wb := newWriteBuffer(pw, bufsize)
+               n, err := slowCopy(wb, bytes.NewBuffer(in), rand.Intn(8192)+1)
+               c.Check(err, IsNil)
+               c.Check(n, Equals, int64(insize))
+               c.Check(wb.Close(), IsNil)
+               c.Check(pw.Close(), IsNil)
+               <-done
+               c.Check(out.Len(), Equals, insize)
+               for i := 0; i < out.Len() && i < len(in); i++ {
+                       if out.Bytes()[i] != in[i] {
+                               c.Errorf("content mismatch at byte %d", i)
+                               break
+                       }
+               }
+       }
+}
+
+func slowCopy(dst io.Writer, src io.Reader, bufsize int) (int64, error) {
+       wrote := int64(0)
+       buf := make([]byte, bufsize)
+       for {
+               time.Sleep(time.Duration(rand.Intn(100) + 1))
+               n, err := src.Read(buf)
+               if n > 0 {
+                       n, err := dst.Write(buf[:n])
+                       wrote += int64(n)
+                       if err != nil {
+                               return wrote, err
+                       }
+               }
+               if err == io.EOF {
+                       return wrote, nil
+               }
+               if err != nil {
+                       return wrote, err
+               }
+       }
+}