18071: Use dblock to avoid concurrent keep-balance ops.
authorTom Clegg <tom@curii.com>
Wed, 26 Oct 2022 21:00:42 +0000 (17:00 -0400)
committerTom Clegg <tom@curii.com>
Thu, 27 Oct 2022 15:03:26 +0000 (11:03 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/controller/dblock/dblock.go
lib/controller/trash.go
services/keep-balance/balance.go
services/keep-balance/balance_run_test.go
services/keep-balance/integration_test.go
services/keep-balance/main.go
services/keep-balance/server.go

index a46201bb45af793062126689d773be61f9bbe232..472633747c5d9ba6d832777de704fc3cd25fa387 100644 (file)
@@ -15,9 +15,11 @@ import (
 )
 
 var (
-       TrashSweep        = &DBLocker{key: 10001}
-       ContainerLogSweep = &DBLocker{key: 10002}
-       retryDelay        = 5 * time.Second
+       TrashSweep         = &DBLocker{key: 10001}
+       ContainerLogSweep  = &DBLocker{key: 10002}
+       KeepBalanceService = &DBLocker{key: 10003} // keep-balance service in periodic-sweep loop
+       KeepBalanceActive  = &DBLocker{key: 10004} // keep-balance sweep in progress (either -once=true or service loop)
+       retryDelay         = 5 * time.Second
 )
 
 // DBLocker uses pg_advisory_lock to maintain a cluster-wide lock for
@@ -31,7 +33,9 @@ type DBLocker struct {
 }
 
 // Lock acquires the advisory lock, waiting/reconnecting if needed.
-func (dbl *DBLocker) Lock(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) {
+//
+// Returns false if ctx is canceled before the lock is acquired.
+func (dbl *DBLocker) Lock(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) bool {
        logger := ctxlog.FromContext(ctx)
        for ; ; time.Sleep(retryDelay) {
                dbl.mtx.Lock()
@@ -41,21 +45,33 @@ func (dbl *DBLocker) Lock(ctx context.Context, getdb func(context.Context) (*sql
                        dbl.mtx.Unlock()
                        continue
                }
+               if ctx.Err() != nil {
+                       dbl.mtx.Unlock()
+                       return false
+               }
                db, err := getdb(ctx)
-               if err != nil {
+               if err == context.Canceled {
+                       dbl.mtx.Unlock()
+                       return false
+               } else if err != nil {
                        logger.WithError(err).Infof("error getting database pool")
                        dbl.mtx.Unlock()
                        continue
                }
                conn, err := db.Conn(ctx)
-               if err != nil {
+               if err == context.Canceled {
+                       dbl.mtx.Unlock()
+                       return false
+               } else if err != nil {
                        logger.WithError(err).Info("error getting database connection")
                        dbl.mtx.Unlock()
                        continue
                }
                var locked bool
                err = conn.QueryRowContext(ctx, `SELECT pg_try_advisory_lock($1)`, dbl.key).Scan(&locked)
-               if err != nil {
+               if err == context.Canceled {
+                       return false
+               } else if err != nil {
                        logger.WithError(err).Infof("error getting pg_try_advisory_lock %d", dbl.key)
                        conn.Close()
                        dbl.mtx.Unlock()
@@ -69,27 +85,33 @@ func (dbl *DBLocker) Lock(ctx context.Context, getdb func(context.Context) (*sql
                logger.Debugf("acquired pg_advisory_lock %d", dbl.key)
                dbl.ctx, dbl.getdb, dbl.conn = ctx, getdb, conn
                dbl.mtx.Unlock()
-               return
+               return true
        }
 }
 
 // Check confirms that the lock is still active (i.e., the session is
 // still alive), and re-acquires if needed. Panics if Lock is not
 // acquired first.
-func (dbl *DBLocker) Check() {
+//
+// Returns false if the context passed to Lock() is canceled before
+// the lock is confirmed or reacquired.
+func (dbl *DBLocker) Check() bool {
        dbl.mtx.Lock()
        err := dbl.conn.PingContext(dbl.ctx)
-       if err == nil {
+       if err == context.Canceled {
+               dbl.mtx.Unlock()
+               return false
+       } else if err == nil {
                ctxlog.FromContext(dbl.ctx).Debugf("pg_advisory_lock %d connection still alive", dbl.key)
                dbl.mtx.Unlock()
-               return
+               return true
        }
        ctxlog.FromContext(dbl.ctx).WithError(err).Info("database connection ping failed")
        dbl.conn.Close()
        dbl.conn = nil
        ctx, getdb := dbl.ctx, dbl.getdb
        dbl.mtx.Unlock()
-       dbl.Lock(ctx, getdb)
+       return dbl.Lock(ctx, getdb)
 }
 
 func (dbl *DBLocker) Unlock() {
index 9a7b0814cee7477dc7b506921aa99be0ef34ed77..afdf95b782647b2f5f2b372eaaadb7a888fb1936 100644 (file)
@@ -20,10 +20,16 @@ func (h *Handler) periodicWorker(workerName string, interval time.Duration, lock
                logger.Debugf("interval is %v, not running worker", interval)
                return
        }
-       locker.Lock(ctx, h.db)
+       if !locker.Lock(ctx, h.db) {
+               // context canceled
+               return
+       }
        defer locker.Unlock()
        for time.Sleep(interval); ctx.Err() == nil; time.Sleep(interval) {
-               locker.Check()
+               if !locker.Check() {
+                       // context canceled
+                       return
+               }
                err := run(ctx)
                if err != nil {
                        logger.WithError(err).Infof("%s failed", workerName)
index 1dedb409a4a2de5c4f414959b024e291007d42b1..50c4dae1886ca292e3904f813d356a18d8608ed0 100644 (file)
@@ -23,7 +23,9 @@ import (
        "syscall"
        "time"
 
+       "git.arvados.org/arvados.git/lib/controller/dblock"
        "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "git.arvados.org/arvados.git/sdk/go/keepclient"
        "github.com/jmoiron/sqlx"
        "github.com/sirupsen/logrus"
@@ -70,13 +72,20 @@ type Balancer struct {
 //
 // Typical usage:
 //
-//   runOptions, err = (&Balancer{}).Run(config, runOptions)
-func (bal *Balancer) Run(client *arvados.Client, cluster *arvados.Cluster, runOptions RunOptions) (nextRunOptions RunOptions, err error) {
+//     runOptions, err = (&Balancer{}).Run(config, runOptions)
+func (bal *Balancer) Run(ctx context.Context, client *arvados.Client, cluster *arvados.Cluster, runOptions RunOptions) (nextRunOptions RunOptions, err error) {
        nextRunOptions = runOptions
 
+       ctxlog.FromContext(ctx).Info("acquiring active lock")
+       if !dblock.KeepBalanceActive.Lock(ctx, func(context.Context) (*sqlx.DB, error) { return bal.DB, nil }) {
+               // context canceled
+               return
+       }
+       defer dblock.KeepBalanceActive.Unlock()
+
        defer bal.time("sweep", "wall clock time to run one full sweep")()
 
-       ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(cluster.Collections.BalanceTimeout.Duration()))
+       ctx, cancel := context.WithDeadline(ctx, time.Now().Add(cluster.Collections.BalanceTimeout.Duration()))
        defer cancel()
 
        var lbFile *os.File
index 2db7bea173c17dc41f6943b4fe579cbc7d15a24f..4772da55a2d6dddc79acff891dee1034781f7582 100644 (file)
@@ -6,6 +6,7 @@ package keepbalance
 
 import (
        "bytes"
+       "context"
        "encoding/json"
        "fmt"
        "io"
@@ -372,7 +373,7 @@ func (s *runSuite) TestRefuseZeroCollections(c *check.C) {
        trashReqs := s.stub.serveKeepstoreTrash()
        pullReqs := s.stub.serveKeepstorePull()
        srv := s.newServer(&opts)
-       _, err = srv.runOnce()
+       _, err = srv.runOnce(context.Background())
        c.Check(err, check.ErrorMatches, "received zero collections")
        c.Check(trashReqs.Count(), check.Equals, 4)
        c.Check(pullReqs.Count(), check.Equals, 0)
@@ -391,7 +392,7 @@ func (s *runSuite) TestRefuseNonAdmin(c *check.C) {
        trashReqs := s.stub.serveKeepstoreTrash()
        pullReqs := s.stub.serveKeepstorePull()
        srv := s.newServer(&opts)
-       _, err := srv.runOnce()
+       _, err := srv.runOnce(context.Background())
        c.Check(err, check.ErrorMatches, "current user .* is not .* admin user")
        c.Check(trashReqs.Count(), check.Equals, 0)
        c.Check(pullReqs.Count(), check.Equals, 0)
@@ -417,7 +418,7 @@ func (s *runSuite) TestRefuseSameDeviceDifferentVolumes(c *check.C) {
        trashReqs := s.stub.serveKeepstoreTrash()
        pullReqs := s.stub.serveKeepstorePull()
        srv := s.newServer(&opts)
-       _, err := srv.runOnce()
+       _, err := srv.runOnce(context.Background())
        c.Check(err, check.ErrorMatches, "cannot continue with config errors.*")
        c.Check(trashReqs.Count(), check.Equals, 0)
        c.Check(pullReqs.Count(), check.Equals, 0)
@@ -442,7 +443,7 @@ func (s *runSuite) TestWriteLostBlocks(c *check.C) {
        s.stub.serveKeepstorePull()
        srv := s.newServer(&opts)
        c.Assert(err, check.IsNil)
-       _, err = srv.runOnce()
+       _, err = srv.runOnce(context.Background())
        c.Check(err, check.IsNil)
        lost, err := ioutil.ReadFile(lostf.Name())
        c.Assert(err, check.IsNil)
@@ -463,7 +464,7 @@ func (s *runSuite) TestDryRun(c *check.C) {
        trashReqs := s.stub.serveKeepstoreTrash()
        pullReqs := s.stub.serveKeepstorePull()
        srv := s.newServer(&opts)
-       bal, err := srv.runOnce()
+       bal, err := srv.runOnce(context.Background())
        c.Check(err, check.IsNil)
        for _, req := range collReqs.reqs {
                c.Check(req.Form.Get("include_trash"), check.Equals, "true")
@@ -493,7 +494,7 @@ func (s *runSuite) TestCommit(c *check.C) {
        trashReqs := s.stub.serveKeepstoreTrash()
        pullReqs := s.stub.serveKeepstorePull()
        srv := s.newServer(&opts)
-       bal, err := srv.runOnce()
+       bal, err := srv.runOnce(context.Background())
        c.Check(err, check.IsNil)
        c.Check(trashReqs.Count(), check.Equals, 8)
        c.Check(pullReqs.Count(), check.Equals, 4)
@@ -533,13 +534,14 @@ func (s *runSuite) TestRunForever(c *check.C) {
        trashReqs := s.stub.serveKeepstoreTrash()
        pullReqs := s.stub.serveKeepstorePull()
 
-       stop := make(chan interface{})
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
        s.config.Collections.BalancePeriod = arvados.Duration(time.Millisecond)
        srv := s.newServer(&opts)
 
        done := make(chan bool)
        go func() {
-               srv.runForever(stop)
+               srv.runForever(ctx)
                close(done)
        }()
 
@@ -550,7 +552,7 @@ func (s *runSuite) TestRunForever(c *check.C) {
        for t0 := time.Now(); pullReqs.Count() < 16 && time.Since(t0) < 10*time.Second; {
                time.Sleep(time.Millisecond)
        }
-       stop <- true
+       cancel()
        <-done
        c.Check(pullReqs.Count() >= 16, check.Equals, true)
        c.Check(trashReqs.Count(), check.Equals, pullReqs.Count()+4)
index 3cfb5cdeda5039fb37f414f5cd0b095eea0e772d..42463a002a5ec73652f7f7ef6f00f8a8c4fb44a1 100644 (file)
@@ -6,6 +6,7 @@ package keepbalance
 
 import (
        "bytes"
+       "context"
        "io"
        "os"
        "strings"
@@ -97,7 +98,7 @@ func (s *integrationSuite) TestBalanceAPIFixtures(c *check.C) {
                        Logger:  logger,
                        Metrics: newMetrics(prometheus.NewRegistry()),
                }
-               nextOpts, err := bal.Run(s.client, s.config, opts)
+               nextOpts, err := bal.Run(context.Background(), s.client, s.config, opts)
                c.Check(err, check.IsNil)
                c.Check(nextOpts.SafeRendezvousState, check.Not(check.Equals), "")
                c.Check(nextOpts.CommitPulls, check.Equals, true)
index f0b0df5bd331d6a97a2cdaab0a8d968cfdbfc550..b016db22ffe67f6316f1e4f537bfa680f135ecad 100644 (file)
@@ -112,7 +112,7 @@ func (command) RunCommand(prog string, args []string, stdin io.Reader, stdout, s
                                Routes: health.Routes{"ping": srv.CheckHealth},
                        }
 
-                       go srv.run()
+                       go srv.run(ctx)
                        return srv
                }).RunCommand(prog, args, stdin, stdout, stderr)
 }
index e485f5b2061f28134306d1d897b22cb62e4190e9..fd53497f789ed4f5f1db458f99e69f8e7f10c1a7 100644 (file)
@@ -5,12 +5,14 @@
 package keepbalance
 
 import (
+       "context"
        "net/http"
        "os"
        "os/signal"
        "syscall"
        "time"
 
+       "git.arvados.org/arvados.git/lib/controller/dblock"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "github.com/jmoiron/sqlx"
        "github.com/sirupsen/logrus"
@@ -62,12 +64,12 @@ func (srv *Server) Done() <-chan struct{} {
        return nil
 }
 
-func (srv *Server) run() {
+func (srv *Server) run(ctx context.Context) {
        var err error
        if srv.RunOptions.Once {
-               _, err = srv.runOnce()
+               _, err = srv.runOnce(ctx)
        } else {
-               err = srv.runForever(nil)
+               err = srv.runForever(ctx)
        }
        if err != nil {
                srv.Logger.Error(err)
@@ -77,7 +79,7 @@ func (srv *Server) run() {
        }
 }
 
-func (srv *Server) runOnce() (*Balancer, error) {
+func (srv *Server) runOnce(ctx context.Context) (*Balancer, error) {
        bal := &Balancer{
                DB:             srv.DB,
                Logger:         srv.Logger,
@@ -86,13 +88,12 @@ func (srv *Server) runOnce() (*Balancer, error) {
                LostBlocksFile: srv.Cluster.Collections.BlobMissingReport,
        }
        var err error
-       srv.RunOptions, err = bal.Run(srv.ArvClient, srv.Cluster, srv.RunOptions)
+       srv.RunOptions, err = bal.Run(ctx, srv.ArvClient, srv.Cluster, srv.RunOptions)
        return bal, err
 }
 
-// RunForever runs forever, or (for testing purposes) until the given
-// stop channel is ready to receive.
-func (srv *Server) runForever(stop <-chan interface{}) error {
+// RunForever runs forever, or until ctx is cancelled.
+func (srv *Server) runForever(ctx context.Context) error {
        logger := srv.Logger
 
        ticker := time.NewTicker(time.Duration(srv.Cluster.Collections.BalancePeriod))
@@ -102,6 +103,10 @@ func (srv *Server) runForever(stop <-chan interface{}) error {
        sigUSR1 := make(chan os.Signal)
        signal.Notify(sigUSR1, syscall.SIGUSR1)
 
+       logger.Info("acquiring service lock")
+       dblock.KeepBalanceService.Lock(ctx, func(context.Context) (*sqlx.DB, error) { return srv.DB, nil })
+       defer dblock.KeepBalanceService.Unlock()
+
        logger.Printf("starting up: will scan every %v and on SIGUSR1", srv.Cluster.Collections.BalancePeriod)
 
        for {
@@ -110,7 +115,11 @@ func (srv *Server) runForever(stop <-chan interface{}) error {
                        logger.Print("=======  Consider using -commit-pulls and -commit-trash flags.")
                }
 
-               _, err := srv.runOnce()
+               if !dblock.KeepBalanceService.Check() {
+                       // context canceled
+                       return nil
+               }
+               _, err := srv.runOnce(ctx)
                if err != nil {
                        logger.Print("run failed: ", err)
                } else {
@@ -118,7 +127,7 @@ func (srv *Server) runForever(stop <-chan interface{}) error {
                }
 
                select {
-               case <-stop:
+               case <-ctx.Done():
                        signal.Stop(sigUSR1)
                        return nil
                case <-ticker.C: