18339: Extract dblocker to a package.
[arvados.git] / lib / controller / dblock / dblock.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dblock
6
7 import (
8         "context"
9         "database/sql"
10         "sync"
11         "time"
12
13         "git.arvados.org/arvados.git/sdk/go/ctxlog"
14         "github.com/jmoiron/sqlx"
15 )
16
17 var (
18         TrashSweep = &DBLocker{key: 10001}
19         retryDelay = 5 * time.Second
20 )
21
22 // DBLocker uses pg_advisory_lock to maintain a cluster-wide lock for
23 // a long-running task like "do X every N seconds".
24 type DBLocker struct {
25         key   int
26         mtx   sync.Mutex
27         ctx   context.Context
28         getdb func(context.Context) (*sqlx.DB, error)
29         conn  *sql.Conn // != nil if advisory lock has been acquired
30 }
31
32 // Lock acquires the advisory lock, waiting/reconnecting if needed.
33 func (dbl *DBLocker) Lock(ctx context.Context, getdb func(context.Context) (*sqlx.DB, error)) {
34         logger := ctxlog.FromContext(ctx)
35         for ; ; time.Sleep(retryDelay) {
36                 dbl.mtx.Lock()
37                 if dbl.conn != nil {
38                         // Already locked by another caller in this
39                         // process. Wait for them to release.
40                         dbl.mtx.Unlock()
41                         continue
42                 }
43                 db, err := getdb(ctx)
44                 if err != nil {
45                         logger.WithError(err).Infof("error getting database pool")
46                         dbl.mtx.Unlock()
47                         continue
48                 }
49                 conn, err := db.Conn(ctx)
50                 if err != nil {
51                         logger.WithError(err).Info("error getting database connection")
52                         dbl.mtx.Unlock()
53                         continue
54                 }
55                 _, err = conn.ExecContext(ctx, `SELECT pg_advisory_lock($1)`, dbl.key)
56                 if err != nil {
57                         logger.WithError(err).Infof("error getting pg_advisory_lock %d", dbl.key)
58                         conn.Close()
59                         dbl.mtx.Unlock()
60                         continue
61                 }
62                 logger.Debugf("acquired pg_advisory_lock %d", dbl.key)
63                 dbl.ctx, dbl.getdb, dbl.conn = ctx, getdb, conn
64                 dbl.mtx.Unlock()
65                 return
66         }
67 }
68
69 // Check confirms that the lock is still active (i.e., the session is
70 // still alive), and re-acquires if needed. Panics if Lock is not
71 // acquired first.
72 func (dbl *DBLocker) Check() {
73         dbl.mtx.Lock()
74         err := dbl.conn.PingContext(dbl.ctx)
75         if err == nil {
76                 ctxlog.FromContext(dbl.ctx).Debugf("pg_advisory_lock %d connection still alive", dbl.key)
77                 dbl.mtx.Unlock()
78                 return
79         }
80         ctxlog.FromContext(dbl.ctx).WithError(err).Info("database connection ping failed")
81         dbl.conn.Close()
82         dbl.conn = nil
83         ctx, getdb := dbl.ctx, dbl.getdb
84         dbl.mtx.Unlock()
85         dbl.Lock(ctx, getdb)
86 }
87
88 func (dbl *DBLocker) Unlock() {
89         dbl.mtx.Lock()
90         defer dbl.mtx.Unlock()
91         if dbl.conn != nil {
92                 _, err := dbl.conn.ExecContext(context.Background(), `SELECT pg_advisory_unlock($1)`, dbl.key)
93                 if err != nil {
94                         ctxlog.FromContext(dbl.ctx).WithError(err).Infof("error releasing pg_advisory_lock %d", dbl.key)
95                 } else {
96                         ctxlog.FromContext(dbl.ctx).Debugf("released pg_advisory_lock %d", dbl.key)
97                 }
98                 dbl.conn.Close()
99                 dbl.conn = nil
100         }
101 }