8460: Combine ping and notify goroutines.
[arvados.git] / services / ws / pg.go
1 package main
2
3 import (
4         "database/sql"
5         "log"
6         "strconv"
7         "strings"
8         "sync"
9         "time"
10
11         "github.com/lib/pq"
12 )
13
14 type pgConfig map[string]string
15
16 func (c pgConfig) ConnectionString() string {
17         s := ""
18         for k, v := range c {
19                 s += k
20                 s += "='"
21                 s += strings.Replace(
22                         strings.Replace(v, `\`, `\\`, -1),
23                         `'`, `\'`, -1)
24                 s += "' "
25         }
26         return s
27 }
28
29 type pgEventSource struct {
30         PgConfig  pgConfig
31         QueueSize int
32
33         pqListener *pq.Listener
34         sinks      map[*pgEventSink]bool
35         setupOnce  sync.Once
36         mtx        sync.Mutex
37 }
38
39 func (ps *pgEventSource) setup() {
40         ps.sinks = make(map[*pgEventSink]bool)
41         go ps.run()
42 }
43
44 func (ps *pgEventSource) run() {
45         db, err := sql.Open("postgres", ps.PgConfig.ConnectionString())
46         if err != nil {
47                 log.Fatal(err)
48         }
49
50         listener := pq.NewListener(ps.PgConfig.ConnectionString(), time.Second, time.Minute, func(ev pq.ListenerEventType, err error) {
51                 if err != nil {
52                         // Until we have a mechanism for catching up
53                         // on missed events, we cannot recover from a
54                         // dropped connection without breaking our
55                         // promises to clients.
56                         log.Fatalf("pgEventSource listener problem: %s", err)
57                 }
58         })
59         err = listener.Listen("logs")
60         if err != nil {
61                 log.Fatal(err)
62         }
63
64         debugLogf("pgEventSource listening")
65
66         eventQueue := make(chan *event, ps.QueueSize)
67
68         go func() {
69                 for e := range eventQueue {
70                         // Wait for the "select ... from logs" call to
71                         // finish. This limits max concurrent queries
72                         // to ps.QueueSize. Without this, max
73                         // concurrent queries would be bounded by
74                         // client_count X client_queue_size.
75                         e.Detail()
76                         debugLogf("event %d detail %+v", e.Serial, e.Detail())
77                         ps.mtx.Lock()
78                         for sink := range ps.sinks {
79                                 sink.channel <- e
80                         }
81                         ps.mtx.Unlock()
82                 }
83         }()
84
85         var serial uint64
86         ticker := time.NewTicker(time.Minute)
87         defer ticker.Stop()
88         for {
89                 select {
90                 case <-ticker.C:
91                         debugLogf("pgEventSource listener ping")
92                         listener.Ping()
93
94                 case pqEvent, ok := <-listener.Notify:
95                         if !ok {
96                                 close(eventQueue)
97                                 return
98                         }
99                         if pqEvent.Channel != "logs" {
100                                 continue
101                         }
102                         logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
103                         if err != nil {
104                                 log.Printf("bad notify payload: %+v", pqEvent)
105                                 continue
106                         }
107                         serial++
108                         e := &event{
109                                 LogID:    logID,
110                                 Received: time.Now(),
111                                 Serial:   serial,
112                                 db:       db,
113                         }
114                         debugLogf("event %d %+v", e.Serial, e)
115                         eventQueue <- e
116                         go e.Detail()
117                 }
118         }
119 }
120
121 // NewSink subscribes to the event source. NewSink returns an
122 // eventSink, whose Channel() method returns a channel: a pointer to
123 // each subsequent event will be sent to that channel.
124 //
125 // The caller must ensure events are received from the sink channel as
126 // quickly as possible because when one sink stops being ready, all
127 // other sinks block.
128 func (ps *pgEventSource) NewSink() eventSink {
129         ps.setupOnce.Do(ps.setup)
130         sink := &pgEventSink{
131                 channel: make(chan *event, 1),
132                 source:  ps,
133         }
134         ps.mtx.Lock()
135         ps.sinks[sink] = true
136         ps.mtx.Unlock()
137         return sink
138 }
139
140 type pgEventSink struct {
141         channel chan *event
142         source  *pgEventSource
143 }
144
145 func (sink *pgEventSink) Channel() <-chan *event {
146         return sink.channel
147 }
148
149 func (sink *pgEventSink) Stop() {
150         go func() {
151                 // Ensure this sink cannot fill up and block the
152                 // server-side queue (which otherwise could in turn
153                 // block our mtx.Lock() here)
154                 for _ = range sink.channel {
155                 }
156         }()
157         sink.source.mtx.Lock()
158         delete(sink.source.sinks, sink)
159         sink.source.mtx.Unlock()
160         close(sink.channel)
161 }