Merge branch '10701-refactor-dispatch'
[arvados.git] / services / ws / event_source.go
1 package main
2
3 import (
4         "context"
5         "database/sql"
6         "strconv"
7         "strings"
8         "sync"
9         "sync/atomic"
10         "time"
11
12         "git.curoverse.com/arvados.git/sdk/go/stats"
13         "github.com/lib/pq"
14 )
15
16 type pgConfig map[string]string
17
18 func (c pgConfig) ConnectionString() string {
19         s := ""
20         for k, v := range c {
21                 s += k
22                 s += "='"
23                 s += strings.Replace(
24                         strings.Replace(v, `\`, `\\`, -1),
25                         `'`, `\'`, -1)
26                 s += "' "
27         }
28         return s
29 }
30
31 type pgEventSource struct {
32         DataSource string
33         QueueSize  int
34
35         db         *sql.DB
36         pqListener *pq.Listener
37         queue      chan *event
38         sinks      map[*pgEventSink]bool
39         mtx        sync.Mutex
40
41         lastQDelay time.Duration
42         eventsIn   uint64
43         eventsOut  uint64
44
45         cancel func()
46 }
47
48 var _ debugStatuser = (*pgEventSource)(nil)
49
50 func (ps *pgEventSource) listenerProblem(et pq.ListenerEventType, err error) {
51         if et == pq.ListenerEventConnected {
52                 logger(nil).Debug("pgEventSource connected")
53                 return
54         }
55
56         // Until we have a mechanism for catching up on missed events,
57         // we cannot recover from a dropped connection without
58         // breaking our promises to clients.
59         logger(nil).
60                 WithField("eventType", et).
61                 WithError(err).
62                 Error("listener problem")
63         ps.cancel()
64 }
65
66 // Run listens for event notifications on the "logs" channel and sends
67 // them to all subscribers.
68 func (ps *pgEventSource) Run() {
69         logger(nil).Debug("pgEventSource Run starting")
70         defer logger(nil).Debug("pgEventSource Run finished")
71
72         ctx, cancel := context.WithCancel(context.Background())
73         ps.cancel = cancel
74         defer cancel()
75
76         defer func() {
77                 // Disconnect all clients
78                 ps.mtx.Lock()
79                 for sink := range ps.sinks {
80                         close(sink.channel)
81                 }
82                 ps.sinks = nil
83                 ps.mtx.Unlock()
84         }()
85
86         db, err := sql.Open("postgres", ps.DataSource)
87         if err != nil {
88                 logger(nil).WithError(err).Fatal("sql.Open failed")
89                 return
90         }
91         if err = db.Ping(); err != nil {
92                 logger(nil).WithError(err).Fatal("db.Ping failed")
93                 return
94         }
95         ps.db = db
96
97         ps.pqListener = pq.NewListener(ps.DataSource, time.Second, time.Minute, ps.listenerProblem)
98         err = ps.pqListener.Listen("logs")
99         if err != nil {
100                 logger(nil).WithError(err).Fatal("pq Listen failed")
101         }
102         defer ps.pqListener.Close()
103         logger(nil).Debug("pq Listen setup done")
104
105         ps.queue = make(chan *event, ps.QueueSize)
106         defer close(ps.queue)
107
108         go func() {
109                 for e := range ps.queue {
110                         // Wait for the "select ... from logs" call to
111                         // finish. This limits max concurrent queries
112                         // to ps.QueueSize. Without this, max
113                         // concurrent queries would be bounded by
114                         // client_count X client_queue_size.
115                         e.Detail()
116
117                         logger(nil).
118                                 WithField("serial", e.Serial).
119                                 WithField("detail", e.Detail()).
120                                 Debug("event ready")
121                         e.Ready = time.Now()
122                         ps.lastQDelay = e.Ready.Sub(e.Received)
123
124                         ps.mtx.Lock()
125                         atomic.AddUint64(&ps.eventsOut, uint64(len(ps.sinks)))
126                         for sink := range ps.sinks {
127                                 sink.channel <- e
128                         }
129                         ps.mtx.Unlock()
130                 }
131         }()
132
133         var serial uint64
134         ticker := time.NewTicker(time.Minute)
135         defer ticker.Stop()
136         for {
137                 select {
138                 case <-ctx.Done():
139                         logger(nil).Debug("ctx done")
140                         return
141
142                 case <-ticker.C:
143                         logger(nil).Debug("listener ping")
144                         ps.pqListener.Ping()
145
146                 case pqEvent, ok := <-ps.pqListener.Notify:
147                         if !ok {
148                                 logger(nil).Debug("pqListener Notify chan closed")
149                                 return
150                         }
151                         if pqEvent == nil {
152                                 // pq should call listenerProblem
153                                 // itself in addition to sending us a
154                                 // nil event, so this might be
155                                 // superfluous:
156                                 ps.listenerProblem(-1, nil)
157                                 continue
158                         }
159                         if pqEvent.Channel != "logs" {
160                                 logger(nil).WithField("pqEvent", pqEvent).Error("unexpected notify from wrong channel")
161                                 continue
162                         }
163                         logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
164                         if err != nil {
165                                 logger(nil).WithField("pqEvent", pqEvent).Error("bad notify payload")
166                                 continue
167                         }
168                         serial++
169                         e := &event{
170                                 LogID:    logID,
171                                 Received: time.Now(),
172                                 Serial:   serial,
173                                 db:       ps.db,
174                         }
175                         logger(nil).WithField("event", e).Debug("incoming")
176                         atomic.AddUint64(&ps.eventsIn, 1)
177                         ps.queue <- e
178                         go e.Detail()
179                 }
180         }
181 }
182
183 // NewSink subscribes to the event source. NewSink returns an
184 // eventSink, whose Channel() method returns a channel: a pointer to
185 // each subsequent event will be sent to that channel.
186 //
187 // The caller must ensure events are received from the sink channel as
188 // quickly as possible because when one sink stops being ready, all
189 // other sinks block.
190 func (ps *pgEventSource) NewSink() eventSink {
191         sink := &pgEventSink{
192                 channel: make(chan *event, 1),
193                 source:  ps,
194         }
195         ps.mtx.Lock()
196         if ps.sinks == nil {
197                 ps.sinks = make(map[*pgEventSink]bool)
198         }
199         ps.sinks[sink] = true
200         ps.mtx.Unlock()
201         return sink
202 }
203
204 func (ps *pgEventSource) DB() *sql.DB {
205         return ps.db
206 }
207
208 func (ps *pgEventSource) DebugStatus() interface{} {
209         ps.mtx.Lock()
210         defer ps.mtx.Unlock()
211         blocked := 0
212         for sink := range ps.sinks {
213                 blocked += len(sink.channel)
214         }
215         return map[string]interface{}{
216                 "EventsIn":     atomic.LoadUint64(&ps.eventsIn),
217                 "EventsOut":    atomic.LoadUint64(&ps.eventsOut),
218                 "Queue":        len(ps.queue),
219                 "QueueLimit":   cap(ps.queue),
220                 "QueueDelay":   stats.Duration(ps.lastQDelay),
221                 "Sinks":        len(ps.sinks),
222                 "SinksBlocked": blocked,
223         }
224 }
225
226 type pgEventSink struct {
227         channel chan *event
228         source  *pgEventSource
229 }
230
231 func (sink *pgEventSink) Channel() <-chan *event {
232         return sink.channel
233 }
234
235 // Stop sending events to the sink's channel.
236 func (sink *pgEventSink) Stop() {
237         go func() {
238                 // Ensure this sink cannot fill up and block the
239                 // server-side queue (which otherwise could in turn
240                 // block our mtx.Lock() here)
241                 for _ = range sink.channel {
242                 }
243         }()
244         sink.source.mtx.Lock()
245         if _, ok := sink.source.sinks[sink]; ok {
246                 delete(sink.source.sinks, sink)
247                 close(sink.channel)
248         }
249         sink.source.mtx.Unlock()
250 }