8460: Return recent events if last_log_id given.
[arvados.git] / services / ws / pg.go
1 package main
2
3 import (
4         "database/sql"
5         "fmt"
6         "log"
7         "strconv"
8         "strings"
9         "sync"
10         "time"
11
12         "github.com/lib/pq"
13 )
14
15 type pgConfig map[string]string
16
17 func (c pgConfig) ConnectionString() string {
18         s := ""
19         for k, v := range c {
20                 s += k
21                 s += "='"
22                 s += strings.Replace(
23                         strings.Replace(v, `\`, `\\`, -1),
24                         `'`, `\'`, -1)
25                 s += "' "
26         }
27         return s
28 }
29
30 type pgEventSource struct {
31         DataSource string
32         QueueSize  int
33
34         db         *sql.DB
35         pqListener *pq.Listener
36         sinks      map[*pgEventSink]bool
37         setupOnce  sync.Once
38         mtx        sync.Mutex
39         shutdown   chan error
40 }
41
42 func (ps *pgEventSource) setup() {
43         ps.shutdown = make(chan error, 1)
44         ps.sinks = make(map[*pgEventSink]bool)
45
46         db, err := sql.Open("postgres", ps.DataSource)
47         if err != nil {
48                 log.Fatalf("sql.Open: %s", err)
49         }
50         if err = db.Ping(); err != nil {
51                 log.Fatalf("db.Ping: %s", err)
52         }
53         ps.db = db
54
55         ps.pqListener = pq.NewListener(ps.DataSource, time.Second, time.Minute, func(ev pq.ListenerEventType, err error) {
56                 if err != nil {
57                         // Until we have a mechanism for catching up
58                         // on missed events, we cannot recover from a
59                         // dropped connection without breaking our
60                         // promises to clients.
61                         ps.shutdown <- fmt.Errorf("pgEventSource listener problem: %s", err)
62                 }
63         })
64         err = ps.pqListener.Listen("logs")
65         if err != nil {
66                 log.Fatal(err)
67         }
68         debugLogf("pgEventSource listening")
69
70         go ps.run()
71 }
72
73 func (ps *pgEventSource) run() {
74         eventQueue := make(chan *event, ps.QueueSize)
75
76         go func() {
77                 for e := range eventQueue {
78                         // Wait for the "select ... from logs" call to
79                         // finish. This limits max concurrent queries
80                         // to ps.QueueSize. Without this, max
81                         // concurrent queries would be bounded by
82                         // client_count X client_queue_size.
83                         e.Detail()
84                         debugLogf("event %d detail %+v", e.Serial, e.Detail())
85                         ps.mtx.Lock()
86                         for sink := range ps.sinks {
87                                 sink.channel <- e
88                         }
89                         ps.mtx.Unlock()
90                 }
91         }()
92
93         var serial uint64
94         ticker := time.NewTicker(time.Minute)
95         defer ticker.Stop()
96         for {
97                 select {
98                 case err, ok := <-ps.shutdown:
99                         if ok {
100                                 debugLogf("shutdown on error: %s", err)
101                         }
102                         close(eventQueue)
103                         return
104
105                 case <-ticker.C:
106                         debugLogf("pgEventSource listener ping")
107                         ps.pqListener.Ping()
108
109                 case pqEvent, ok := <-ps.pqListener.Notify:
110                         if !ok {
111                                 close(eventQueue)
112                                 return
113                         }
114                         if pqEvent.Channel != "logs" {
115                                 continue
116                         }
117                         logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
118                         if err != nil {
119                                 log.Printf("bad notify payload: %+v", pqEvent)
120                                 continue
121                         }
122                         serial++
123                         e := &event{
124                                 LogID:    logID,
125                                 Received: time.Now(),
126                                 Serial:   serial,
127                                 db:       ps.db,
128                         }
129                         debugLogf("event %d %+v", e.Serial, e)
130                         eventQueue <- e
131                         go e.Detail()
132                 }
133         }
134 }
135
136 // NewSink subscribes to the event source. NewSink returns an
137 // eventSink, whose Channel() method returns a channel: a pointer to
138 // each subsequent event will be sent to that channel.
139 //
140 // The caller must ensure events are received from the sink channel as
141 // quickly as possible because when one sink stops being ready, all
142 // other sinks block.
143 func (ps *pgEventSource) NewSink() eventSink {
144         ps.setupOnce.Do(ps.setup)
145         sink := &pgEventSink{
146                 channel: make(chan *event, 1),
147                 source:  ps,
148         }
149         ps.mtx.Lock()
150         ps.sinks[sink] = true
151         ps.mtx.Unlock()
152         return sink
153 }
154
155 func (ps *pgEventSource) DB() *sql.DB {
156         ps.setupOnce.Do(ps.setup)
157         return ps.db
158 }
159
160 type pgEventSink struct {
161         channel chan *event
162         source  *pgEventSource
163 }
164
165 func (sink *pgEventSink) Channel() <-chan *event {
166         return sink.channel
167 }
168
169 func (sink *pgEventSink) Stop() {
170         go func() {
171                 // Ensure this sink cannot fill up and block the
172                 // server-side queue (which otherwise could in turn
173                 // block our mtx.Lock() here)
174                 for _ = range sink.channel {
175                 }
176         }()
177         sink.source.mtx.Lock()
178         delete(sink.source.sinks, sink)
179         sink.source.mtx.Unlock()
180         close(sink.channel)
181 }