8460: Inject permChecker from main.
[arvados.git] / services / ws / pg.go
1 package main
2
3 import (
4         "database/sql"
5         "strconv"
6         "strings"
7         "sync"
8         "time"
9
10         "github.com/lib/pq"
11 )
12
13 type pgConfig map[string]string
14
15 func (c pgConfig) ConnectionString() string {
16         s := ""
17         for k, v := range c {
18                 s += k
19                 s += "='"
20                 s += strings.Replace(
21                         strings.Replace(v, `\`, `\\`, -1),
22                         `'`, `\'`, -1)
23                 s += "' "
24         }
25         return s
26 }
27
28 type pgEventSource struct {
29         DataSource string
30         QueueSize  int
31
32         db         *sql.DB
33         pqListener *pq.Listener
34         sinks      map[*pgEventSink]bool
35         setupOnce  sync.Once
36         mtx        sync.Mutex
37         shutdown   chan error
38 }
39
40 func (ps *pgEventSource) setup() {
41         ps.shutdown = make(chan error, 1)
42         ps.sinks = make(map[*pgEventSink]bool)
43
44         db, err := sql.Open("postgres", ps.DataSource)
45         if err != nil {
46                 logger(nil).WithError(err).Fatal("sql.Open failed")
47         }
48         if err = db.Ping(); err != nil {
49                 logger(nil).WithError(err).Fatal("db.Ping failed")
50         }
51         ps.db = db
52
53         ps.pqListener = pq.NewListener(ps.DataSource, time.Second, time.Minute, func(ev pq.ListenerEventType, err error) {
54                 if err != nil {
55                         // Until we have a mechanism for catching up
56                         // on missed events, we cannot recover from a
57                         // dropped connection without breaking our
58                         // promises to clients.
59                         logger(nil).WithError(err).Error("listener problem")
60                         ps.shutdown <- err
61                 }
62         })
63         err = ps.pqListener.Listen("logs")
64         if err != nil {
65                 logger(nil).WithError(err).Fatal("pq Listen failed")
66         }
67         logger(nil).Debug("pgEventSource listening")
68
69         go ps.run()
70 }
71
72 func (ps *pgEventSource) run() {
73         eventQueue := make(chan *event, ps.QueueSize)
74
75         go func() {
76                 for e := range eventQueue {
77                         // Wait for the "select ... from logs" call to
78                         // finish. This limits max concurrent queries
79                         // to ps.QueueSize. Without this, max
80                         // concurrent queries would be bounded by
81                         // client_count X client_queue_size.
82                         e.Detail()
83
84                         logger(nil).
85                                 WithField("serial", e.Serial).
86                                 WithField("detail", e.Detail()).
87                                 Debug("event ready")
88
89                         ps.mtx.Lock()
90                         for sink := range ps.sinks {
91                                 sink.channel <- e
92                         }
93                         ps.mtx.Unlock()
94                 }
95         }()
96
97         var serial uint64
98         ticker := time.NewTicker(time.Minute)
99         defer ticker.Stop()
100         for {
101                 select {
102                 case err, ok := <-ps.shutdown:
103                         if ok {
104                                 logger(nil).WithError(err).Info("shutdown")
105                         }
106                         close(eventQueue)
107                         return
108
109                 case <-ticker.C:
110                         logger(nil).Debug("listener ping")
111                         ps.pqListener.Ping()
112
113                 case pqEvent, ok := <-ps.pqListener.Notify:
114                         if !ok {
115                                 close(eventQueue)
116                                 return
117                         }
118                         if pqEvent.Channel != "logs" {
119                                 continue
120                         }
121                         logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
122                         if err != nil {
123                                 logger(nil).WithField("pqEvent", pqEvent).Error("bad notify payload")
124                                 continue
125                         }
126                         serial++
127                         e := &event{
128                                 LogID:    logID,
129                                 Received: time.Now(),
130                                 Serial:   serial,
131                                 db:       ps.db,
132                         }
133                         logger(nil).WithField("event", e).Debug("incoming")
134                         eventQueue <- e
135                         go e.Detail()
136                 }
137         }
138 }
139
140 // NewSink subscribes to the event source. NewSink returns an
141 // eventSink, whose Channel() method returns a channel: a pointer to
142 // each subsequent event will be sent to that channel.
143 //
144 // The caller must ensure events are received from the sink channel as
145 // quickly as possible because when one sink stops being ready, all
146 // other sinks block.
147 func (ps *pgEventSource) NewSink() eventSink {
148         ps.setupOnce.Do(ps.setup)
149         sink := &pgEventSink{
150                 channel: make(chan *event, 1),
151                 source:  ps,
152         }
153         ps.mtx.Lock()
154         ps.sinks[sink] = true
155         ps.mtx.Unlock()
156         return sink
157 }
158
159 func (ps *pgEventSource) DB() *sql.DB {
160         ps.setupOnce.Do(ps.setup)
161         return ps.db
162 }
163
164 type pgEventSink struct {
165         channel chan *event
166         source  *pgEventSource
167 }
168
169 func (sink *pgEventSink) Channel() <-chan *event {
170         return sink.channel
171 }
172
173 func (sink *pgEventSink) Stop() {
174         go func() {
175                 // Ensure this sink cannot fill up and block the
176                 // server-side queue (which otherwise could in turn
177                 // block our mtx.Lock() here)
178                 for _ = range sink.channel {
179                 }
180         }()
181         sink.source.mtx.Lock()
182         delete(sink.source.sinks, sink)
183         sink.source.mtx.Unlock()
184         close(sink.channel)
185 }