Merge branch '20559-dav-concurrent-writes'
authorTom Clegg <tom@curii.com>
Fri, 7 Jul 2023 13:49:24 +0000 (09:49 -0400)
committerTom Clegg <tom@curii.com>
Fri, 7 Jul 2023 13:49:24 +0000 (09:49 -0400)
refs #20559

Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

services/keep-web/cache.go
services/keep-web/handler.go
services/keep-web/handler_test.go

index c77a1b4bb6dca0390b79793225f713459db76ce3..604efd29d9139e2fe4963709f64503a0d8d7e6f3 100644 (file)
@@ -7,14 +7,13 @@ package keepweb
 import (
        "errors"
        "net/http"
+       "sort"
        "sync"
-       "sync/atomic"
        "time"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/arvadosclient"
        "git.arvados.org/arvados.git/sdk/go/keepclient"
-       lru "github.com/hashicorp/golang-lru"
        "github.com/prometheus/client_golang/prometheus"
        "github.com/sirupsen/logrus"
 )
@@ -26,8 +25,9 @@ type cache struct {
        logger    logrus.FieldLogger
        registry  *prometheus.Registry
        metrics   cacheMetrics
-       sessions  *lru.TwoQueueCache
+       sessions  map[string]*cachedSession
        setupOnce sync.Once
+       mtx       sync.Mutex
 
        chPruneSessions chan struct{}
 }
@@ -72,17 +72,69 @@ func (m *cacheMetrics) setup(reg *prometheus.Registry) {
 }
 
 type cachedSession struct {
+       cache         *cache
        expire        time.Time
-       fs            atomic.Value
        client        *arvados.Client
        arvadosclient *arvadosclient.ArvadosClient
        keepclient    *keepclient.KeepClient
-       user          atomic.Value
+
+       // Each session uses a system of three mutexes (plus the
+       // cache-wide mutex) to enable the following semantics:
+       //
+       // - There are never multiple sessions in use for a given
+       // token.
+       //
+       // - If the cached in-memory filesystems/user records are
+       // older than the configured cache TTL when a request starts,
+       // the request will use new ones.
+       //
+       // - Unused sessions are garbage-collected.
+       //
+       // In particular, when it is necessary to reset a session's
+       // filesystem/user record (to save memory or respect the
+       // configured cache TTL), any operations that are already
+       // using the existing filesystem/user record are allowed to
+       // finish before the new filesystem is constructed.
+       //
+       // The locks must be acquired in the following order:
+       // cache.mtx, session.mtx, session.refresh, session.inuse.
+
+       // mtx is RLocked while session is not safe to evict from
+       // cache -- i.e., a checkout() has decided to use it, and its
+       // caller is not finished with it. When locking or rlocking
+       // this mtx, the cache mtx MUST already be held.
+       //
+       // This mutex enables pruneSessions to detect when it is safe
+       // to completely remove the session entry from the cache.
+       mtx sync.RWMutex
+       // refresh must be locked in order to read or write the
+       // fs/user/userLoaded/lastuse fields. This mutex enables
+       // GetSession and pruneSessions to remove/replace fs and user
+       // values safely.
+       refresh sync.Mutex
+       // inuse must be RLocked while the session is in use by a
+       // caller. This mutex enables pruneSessions() to wait for all
+       // existing usage to finish by calling inuse.Lock().
+       inuse sync.RWMutex
+
+       fs         arvados.CustomFileSystem
+       user       arvados.User
+       userLoaded bool
+       lastuse    time.Time
+}
+
+func (sess *cachedSession) Release() {
+       sess.inuse.RUnlock()
+       sess.mtx.RUnlock()
+       select {
+       case sess.cache.chPruneSessions <- struct{}{}:
+       default:
+       }
 }
 
 func (c *cache) setup() {
        var err error
-       c.sessions, err = lru.New2Q(c.cluster.Collections.WebDAVCache.MaxSessions)
+       c.sessions = map[string]*cachedSession{}
        if err != nil {
                panic(err)
        }
@@ -106,132 +158,232 @@ func (c *cache) setup() {
 }
 
 func (c *cache) updateGauges() {
-       c.metrics.collectionBytes.Set(float64(c.collectionBytes()))
-       c.metrics.sessionEntries.Set(float64(c.sessions.Len()))
+       n, size := c.sessionsSize()
+       c.metrics.collectionBytes.Set(float64(size))
+       c.metrics.sessionEntries.Set(float64(n))
 }
 
 var selectPDH = map[string]interface{}{
        "select": []string{"portable_data_hash"},
 }
 
-// ResetSession unloads any potentially stale state. Should be called
-// after write operations, so subsequent reads don't return stale
-// data.
-func (c *cache) ResetSession(token string) {
-       c.setupOnce.Do(c.setup)
-       c.sessions.Remove(token)
-}
-
-// Get a long-lived CustomFileSystem suitable for doing a read operation
-// with the given token.
-func (c *cache) GetSession(token string) (arvados.CustomFileSystem, *cachedSession, *arvados.User, error) {
+func (c *cache) checkout(token string) (*cachedSession, error) {
        c.setupOnce.Do(c.setup)
-       now := time.Now()
-       ent, _ := c.sessions.Get(token)
-       sess, _ := ent.(*cachedSession)
-       expired := false
+       c.mtx.Lock()
+       defer c.mtx.Unlock()
+       sess := c.sessions[token]
        if sess == nil {
-               c.metrics.sessionMisses.Inc()
-               sess = &cachedSession{
-                       expire: now.Add(c.cluster.Collections.WebDAVCache.TTL.Duration()),
-               }
-               var err error
-               sess.client, err = arvados.NewClientFromConfig(c.cluster)
+               client, err := arvados.NewClientFromConfig(c.cluster)
                if err != nil {
-                       return nil, nil, nil, err
+                       return nil, err
                }
-               sess.client.AuthToken = token
-               sess.client.Timeout = time.Minute
+               client.AuthToken = token
+               client.Timeout = time.Minute
                // A non-empty origin header tells controller to
                // prioritize our traffic as interactive, which is
                // true most of the time.
                origin := c.cluster.Services.WebDAVDownload.ExternalURL
-               sess.client.SendHeader = http.Header{"Origin": {origin.Scheme + "://" + origin.Host}}
-               sess.arvadosclient, err = arvadosclient.New(sess.client)
+               client.SendHeader = http.Header{"Origin": {origin.Scheme + "://" + origin.Host}}
+               arvadosclient, err := arvadosclient.New(client)
                if err != nil {
-                       return nil, nil, nil, err
+                       return nil, err
                }
-               sess.keepclient = keepclient.New(sess.arvadosclient)
-               c.sessions.Add(token, sess)
-       } else if sess.expire.Before(now) {
-               c.metrics.sessionMisses.Inc()
-               expired = true
-       } else {
-               c.metrics.sessionHits.Inc()
-       }
-       select {
-       case c.chPruneSessions <- struct{}{}:
-       default:
+               sess = &cachedSession{
+                       cache:         c,
+                       client:        client,
+                       arvadosclient: arvadosclient,
+                       keepclient:    keepclient.New(arvadosclient),
+               }
+               c.sessions[token] = sess
        }
+       sess.mtx.RLock()
+       return sess, nil
+}
 
-       fs, _ := sess.fs.Load().(arvados.CustomFileSystem)
-       if fs == nil || expired {
-               fs = sess.client.SiteFileSystem(sess.keepclient)
-               fs.ForwardSlashNameSubstitution(c.cluster.Collections.ForwardSlashNameSubstitution)
-               sess.fs.Store(fs)
+// Get a long-lived CustomFileSystem suitable for doing a read or
+// write operation with the given token.
+//
+// If the returned error is nil, the caller must call Release() on the
+// returned session when finished using it.
+func (c *cache) GetSession(token string) (arvados.CustomFileSystem, *cachedSession, *arvados.User, error) {
+       sess, err := c.checkout(token)
+       if err != nil {
+               return nil, nil, nil, err
        }
+       sess.refresh.Lock()
+       defer sess.refresh.Unlock()
+       now := time.Now()
+       sess.lastuse = now
+       refresh := sess.expire.Before(now)
+       if sess.fs == nil || !sess.userLoaded || refresh {
+               // Wait for all active users to finish (otherwise they
+               // might make changes to an old fs after we start
+               // using the new fs).
+               sess.inuse.Lock()
+               if !sess.userLoaded || refresh {
+                       err := sess.client.RequestAndDecode(&sess.user, "GET", "/arvados/v1/users/current", nil, nil)
+                       if he := errorWithHTTPStatus(nil); errors.As(err, &he) && he.HTTPStatus() == http.StatusForbidden {
+                               // token is OK, but "get user id" api is out
+                               // of scope -- use existing/expired info if
+                               // any, or leave empty for unknown user
+                       } else if err != nil {
+                               sess.inuse.Unlock()
+                               sess.mtx.RUnlock()
+                               return nil, nil, nil, err
+                       }
+                       sess.userLoaded = true
+               }
 
-       user, _ := sess.user.Load().(*arvados.User)
-       if user == nil || expired {
-               user = new(arvados.User)
-               err := sess.client.RequestAndDecode(user, "GET", "/arvados/v1/users/current", nil, nil)
-               if he := errorWithHTTPStatus(nil); errors.As(err, &he) && he.HTTPStatus() == http.StatusForbidden {
-                       // token is OK, but "get user id" api is out
-                       // of scope -- return nil, signifying unknown
-                       // user
-               } else if err != nil {
-                       return nil, nil, nil, err
+               if sess.fs == nil || refresh {
+                       sess.fs = sess.client.SiteFileSystem(sess.keepclient)
+                       sess.fs.ForwardSlashNameSubstitution(c.cluster.Collections.ForwardSlashNameSubstitution)
+                       sess.expire = now.Add(c.cluster.Collections.WebDAVCache.TTL.Duration())
+                       c.metrics.sessionMisses.Inc()
+               } else {
+                       c.metrics.sessionHits.Inc()
                }
-               sess.user.Store(user)
+               sess.inuse.Unlock()
+       } else {
+               c.metrics.sessionHits.Inc()
        }
+       sess.inuse.RLock()
+       return sess.fs, sess, &sess.user, nil
+}
 
-       return fs, sess, user, nil
+type sessionSnapshot struct {
+       token   string
+       sess    *cachedSession
+       lastuse time.Time
+       fs      arvados.CustomFileSystem
+       size    int64
+       prune   bool
 }
 
-// Remove all expired session cache entries, then remove more entries
-// until approximate remaining size <= maxsize/2
+// Remove all expired idle session cache entries, and remove in-memory
+// filesystems until approximate remaining size <= maxsize
 func (c *cache) pruneSessions() {
        now := time.Now()
-       keys := c.sessions.Keys()
-       sizes := make([]int64, len(keys))
+       c.mtx.Lock()
+       snaps := make([]sessionSnapshot, 0, len(c.sessions))
+       for token, sess := range c.sessions {
+               snaps = append(snaps, sessionSnapshot{
+                       token: token,
+                       sess:  sess,
+               })
+       }
+       c.mtx.Unlock()
+
+       // Load lastuse/fs/expire data from sessions. Note we do this
+       // after unlocking c.mtx because sess.refresh.Lock sometimes
+       // waits for another goroutine to finish "[re]fetch user
+       // record".
+       for i := range snaps {
+               snaps[i].sess.refresh.Lock()
+               snaps[i].lastuse = snaps[i].sess.lastuse
+               snaps[i].fs = snaps[i].sess.fs
+               snaps[i].prune = snaps[i].sess.expire.Before(now)
+               snaps[i].sess.refresh.Unlock()
+       }
+
+       // Sort sessions with oldest first.
+       sort.Slice(snaps, func(i, j int) bool {
+               return snaps[i].lastuse.Before(snaps[j].lastuse)
+       })
+
+       // Add up size of sessions that aren't already marked for
+       // pruning based on expire time.
        var size int64
-       for i, token := range keys {
-               ent, ok := c.sessions.Peek(token)
-               if !ok {
-                       continue
+       for i, snap := range snaps {
+               if !snap.prune && snap.fs != nil {
+                       size := snap.fs.MemorySize()
+                       snaps[i].size = size
+                       size += size
+               }
+       }
+       // Mark more sessions for deletion until reaching desired
+       // memory size limit, starting with the oldest entries.
+       for i, snap := range snaps {
+               if size <= c.cluster.Collections.WebDAVCache.MaxCollectionBytes {
+                       break
                }
-               s := ent.(*cachedSession)
-               if s.expire.Before(now) {
-                       c.sessions.Remove(token)
+               if snap.prune {
                        continue
                }
-               if fs, ok := s.fs.Load().(arvados.CustomFileSystem); ok {
-                       sizes[i] = fs.MemorySize()
-                       size += sizes[i]
+               snaps[i].prune = true
+               size -= snap.size
+       }
+
+       // Mark more sessions for deletion until reaching desired
+       // session count limit.
+       mustprune := len(snaps) - c.cluster.Collections.WebDAVCache.MaxSessions
+       for i := range snaps {
+               if snaps[i].prune {
+                       mustprune--
                }
        }
-       // Remove tokens until reaching size limit, starting with the
-       // least frequently used entries (which Keys() returns last).
-       for i := len(keys) - 1; i >= 0 && size > c.cluster.Collections.WebDAVCache.MaxCollectionBytes; i-- {
-               if sizes[i] > 0 {
-                       c.sessions.Remove(keys[i])
-                       size -= sizes[i]
+       for i := range snaps {
+               if mustprune < 1 {
+                       break
+               } else if !snaps[i].prune {
+                       snaps[i].prune = true
+                       mustprune--
                }
        }
-}
 
-// collectionBytes returns the approximate combined memory size of the
-// collection cache and session filesystem cache.
-func (c *cache) collectionBytes() uint64 {
-       var size uint64
-       for _, token := range c.sessions.Keys() {
-               ent, ok := c.sessions.Peek(token)
-               if !ok {
+       c.mtx.Lock()
+       defer c.mtx.Unlock()
+       for _, snap := range snaps {
+               if !snap.prune {
                        continue
                }
-               if fs, ok := ent.(*cachedSession).fs.Load().(arvados.CustomFileSystem); ok {
-                       size += uint64(fs.MemorySize())
+               sess := snap.sess
+               if sess.mtx.TryLock() {
+                       delete(c.sessions, snap.token)
+                       continue
+               }
+               // We can't remove a session that's been checked out
+               // -- that would allow another session to be created
+               // for the same token using a different in-memory
+               // filesystem. Instead, we wait for active requests to
+               // finish and then "unload" it. After this, either the
+               // next GetSession will reload fs/user, or a
+               // subsequent pruneSessions will remove the session.
+               go func() {
+                       // Ensure nobody is mid-GetSession() (note we
+                       // already know nobody is mid-checkout()
+                       // because we have c.mtx locked)
+                       sess.refresh.Lock()
+                       defer sess.refresh.Unlock()
+                       // Wait for current usage to finish (i.e.,
+                       // anyone who has decided to use the current
+                       // values of sess.fs and sess.user, and hasn't
+                       // called Release() yet)
+                       sess.inuse.Lock()
+                       defer sess.inuse.Unlock()
+                       // Release memory
+                       sess.fs = nil
+                       // Next GetSession will make a new fs
+               }()
+       }
+}
+
+// sessionsSize returns the number and approximate total memory size
+// of all cached sessions.
+func (c *cache) sessionsSize() (n int, size int64) {
+       c.mtx.Lock()
+       n = len(c.sessions)
+       sessions := make([]*cachedSession, 0, n)
+       for _, sess := range c.sessions {
+               sessions = append(sessions, sess)
+       }
+       c.mtx.Unlock()
+       for _, sess := range sessions {
+               sess.refresh.Lock()
+               fs := sess.fs
+               sess.refresh.Unlock()
+               if fs != nil {
+                       size += fs.MemorySize()
                }
        }
-       return size
+       return
 }
index 3cdaf5d2b51c5e2d663fd8e53dc29739ffe76551..3af326a1ad451483da29bb6398624b508e0d1fe2 100644 (file)
@@ -18,6 +18,7 @@ import (
        "strconv"
        "strings"
        "sync"
+       "time"
 
        "git.arvados.org/arvados.git/lib/cmd"
        "git.arvados.org/arvados.git/lib/webdavfs"
@@ -35,6 +36,10 @@ type handler struct {
        Cache     cache
        Cluster   *arvados.Cluster
        setupOnce sync.Once
+
+       lockMtx    sync.Mutex
+       lock       map[string]*sync.RWMutex
+       lockTidied time.Time
 }
 
 var urlPDHDecoder = strings.NewReplacer(" ", "+", "-", "+")
@@ -406,16 +411,20 @@ func (h *handler) ServeHTTP(wOrig http.ResponseWriter, r *http.Request) {
                        // collection id is outside scope of supplied
                        // token
                        tokenScopeProblem = true
+                       sess.Release()
                        continue
                } else if os.IsNotExist(err) {
                        // collection does not exist or is not
                        // readable using this token
+                       sess.Release()
                        continue
                } else if err != nil {
                        http.Error(w, err.Error(), http.StatusInternalServerError)
+                       sess.Release()
                        return
                }
                defer f.Close()
+               defer sess.Release()
 
                collectionDir, sessionFS, session, tokenUser = f, fs, sess, user
                break
@@ -530,7 +539,11 @@ func (h *handler) ServeHTTP(wOrig http.ResponseWriter, r *http.Request) {
        }
        h.logUploadOrDownload(r, session.arvadosclient, sessionFS, fsprefix+strings.Join(targetPath, "/"), nil, tokenUser)
 
-       if writeMethod[r.Method] {
+       writing := writeMethod[r.Method]
+       locker := h.collectionLock(collectionID, writing)
+       defer locker.Unlock()
+
+       if writing {
                // Save the collection only if/when all
                // webdav->filesystem operations succeed --
                // and send a 500 error if the modified
@@ -942,6 +955,41 @@ func (h *handler) determineCollection(fs arvados.CustomFileSystem, path string)
        return nil, ""
 }
 
+var lockTidyInterval = time.Minute * 10
+
+// Lock the specified collection for reading or writing. Caller must
+// call Unlock() on the returned Locker when the operation is
+// finished.
+func (h *handler) collectionLock(collectionID string, writing bool) sync.Locker {
+       h.lockMtx.Lock()
+       defer h.lockMtx.Unlock()
+       if time.Since(h.lockTidied) > lockTidyInterval {
+               // Periodically delete all locks that aren't in use.
+               h.lockTidied = time.Now()
+               for id, locker := range h.lock {
+                       if locker.TryLock() {
+                               locker.Unlock()
+                               delete(h.lock, id)
+                       }
+               }
+       }
+       locker := h.lock[collectionID]
+       if locker == nil {
+               locker = new(sync.RWMutex)
+               if h.lock == nil {
+                       h.lock = map[string]*sync.RWMutex{}
+               }
+               h.lock[collectionID] = locker
+       }
+       if writing {
+               locker.Lock()
+               return locker
+       } else {
+               locker.RLock()
+               return locker.RLocker()
+       }
+}
+
 func ServeCORSPreflight(w http.ResponseWriter, header http.Header) bool {
        method := header.Get("Access-Control-Request-Method")
        if method == "" {
index c9b48f99a73c01407374c046fcde443ce11abe4b..4a76276392ca5b9772d922287620ac29c55774c4 100644 (file)
@@ -18,6 +18,7 @@ import (
        "path/filepath"
        "regexp"
        "strings"
+       "sync"
        "time"
 
        "git.arvados.org/arvados.git/lib/config"
@@ -1624,3 +1625,72 @@ func (s *IntegrationSuite) TestUploadLoggingPermission(c *check.C) {
                }
        }
 }
+
+func (s *IntegrationSuite) TestConcurrentWrites(c *check.C) {
+       s.handler.Cluster.Collections.WebDAVCache.TTL = arvados.Duration(time.Second * 2)
+       lockTidyInterval = time.Second
+       client := arvados.NewClientFromEnv()
+       client.AuthToken = arvadostest.ActiveTokenV2
+       // Start small, and increase concurrency (2^2, 4^2, ...)
+       // only until hitting failure. Avoids unnecessarily long
+       // failure reports.
+       for n := 2; n < 16 && !c.Failed(); n = n * 2 {
+               c.Logf("%s: n=%d", c.TestName(), n)
+
+               var coll arvados.Collection
+               err := client.RequestAndDecode(&coll, "POST", "arvados/v1/collections", nil, nil)
+               c.Assert(err, check.IsNil)
+               defer client.RequestAndDecode(&coll, "DELETE", "arvados/v1/collections/"+coll.UUID, nil, nil)
+
+               var wg sync.WaitGroup
+               for i := 0; i < n && !c.Failed(); i++ {
+                       i := i
+                       wg.Add(1)
+                       go func() {
+                               defer wg.Done()
+                               u := mustParseURL(fmt.Sprintf("http://%s.collections.example.com/i=%d", coll.UUID, i))
+                               resp := httptest.NewRecorder()
+                               req, err := http.NewRequest("MKCOL", u.String(), nil)
+                               c.Assert(err, check.IsNil)
+                               req.Header.Set("Authorization", "Bearer "+client.AuthToken)
+                               s.handler.ServeHTTP(resp, req)
+                               c.Assert(resp.Code, check.Equals, http.StatusCreated)
+                               for j := 0; j < n && !c.Failed(); j++ {
+                                       j := j
+                                       wg.Add(1)
+                                       go func() {
+                                               defer wg.Done()
+                                               content := fmt.Sprintf("i=%d/j=%d", i, j)
+                                               u := mustParseURL("http://" + coll.UUID + ".collections.example.com/" + content)
+
+                                               resp := httptest.NewRecorder()
+                                               req, err := http.NewRequest("PUT", u.String(), strings.NewReader(content))
+                                               c.Assert(err, check.IsNil)
+                                               req.Header.Set("Authorization", "Bearer "+client.AuthToken)
+                                               s.handler.ServeHTTP(resp, req)
+                                               c.Check(resp.Code, check.Equals, http.StatusCreated)
+
+                                               time.Sleep(time.Second)
+                                               resp = httptest.NewRecorder()
+                                               req, err = http.NewRequest("GET", u.String(), nil)
+                                               c.Assert(err, check.IsNil)
+                                               req.Header.Set("Authorization", "Bearer "+client.AuthToken)
+                                               s.handler.ServeHTTP(resp, req)
+                                               c.Check(resp.Code, check.Equals, http.StatusOK)
+                                               c.Check(resp.Body.String(), check.Equals, content)
+                                       }()
+                               }
+                       }()
+               }
+               wg.Wait()
+               for i := 0; i < n; i++ {
+                       u := mustParseURL(fmt.Sprintf("http://%s.collections.example.com/i=%d", coll.UUID, i))
+                       resp := httptest.NewRecorder()
+                       req, err := http.NewRequest("PROPFIND", u.String(), &bytes.Buffer{})
+                       c.Assert(err, check.IsNil)
+                       req.Header.Set("Authorization", "Bearer "+client.AuthToken)
+                       s.handler.ServeHTTP(resp, req)
+                       c.Assert(resp.Code, check.Equals, http.StatusMultiStatus)
+               }
+       }
+}