8460: Refactor session logic (subscription protocol) out of handler (queueing and...
[arvados.git] / services / ws / pg.go
1 package main
2
3 import (
4         "database/sql"
5         "log"
6         "strconv"
7         "strings"
8         "sync"
9         "time"
10
11         "github.com/lib/pq"
12 )
13
14 type pgConfig map[string]string
15
16 func (c pgConfig) ConnectionString() string {
17         s := ""
18         for k, v := range c {
19                 s += k
20                 s += "='"
21                 s += strings.Replace(
22                         strings.Replace(v, `\`, `\\`, -1),
23                         `'`, `\'`, -1)
24                 s += "' "
25         }
26         return s
27 }
28
29 type pgEventSource struct {
30         DataSource string
31         QueueSize  int
32
33         pqListener *pq.Listener
34         sinks      map[*pgEventSink]bool
35         setupOnce  sync.Once
36         mtx        sync.Mutex
37 }
38
39 func (ps *pgEventSource) setup() {
40         ps.sinks = make(map[*pgEventSink]bool)
41         go ps.run()
42 }
43
44 func (ps *pgEventSource) run() {
45         db, err := sql.Open("postgres", ps.DataSource)
46         if err != nil {
47                 log.Fatalf("sql.Open: %s", err)
48         }
49         if err = db.Ping(); err != nil {
50                 log.Fatalf("db.Ping: %s", err)
51         }
52
53         listener := pq.NewListener(ps.DataSource, time.Second, time.Minute, func(ev pq.ListenerEventType, err error) {
54                 if err != nil {
55                         // Until we have a mechanism for catching up
56                         // on missed events, we cannot recover from a
57                         // dropped connection without breaking our
58                         // promises to clients.
59                         log.Fatalf("pgEventSource listener problem: %s", err)
60                 }
61         })
62         err = listener.Listen("logs")
63         if err != nil {
64                 log.Fatal(err)
65         }
66
67         debugLogf("pgEventSource listening")
68
69         eventQueue := make(chan *event, ps.QueueSize)
70
71         go func() {
72                 for e := range eventQueue {
73                         // Wait for the "select ... from logs" call to
74                         // finish. This limits max concurrent queries
75                         // to ps.QueueSize. Without this, max
76                         // concurrent queries would be bounded by
77                         // client_count X client_queue_size.
78                         e.Detail()
79                         debugLogf("event %d detail %+v", e.Serial, e.Detail())
80                         ps.mtx.Lock()
81                         for sink := range ps.sinks {
82                                 sink.channel <- e
83                         }
84                         ps.mtx.Unlock()
85                 }
86         }()
87
88         var serial uint64
89         ticker := time.NewTicker(time.Minute)
90         defer ticker.Stop()
91         for {
92                 select {
93                 case <-ticker.C:
94                         debugLogf("pgEventSource listener ping")
95                         listener.Ping()
96
97                 case pqEvent, ok := <-listener.Notify:
98                         if !ok {
99                                 close(eventQueue)
100                                 return
101                         }
102                         if pqEvent.Channel != "logs" {
103                                 continue
104                         }
105                         logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
106                         if err != nil {
107                                 log.Printf("bad notify payload: %+v", pqEvent)
108                                 continue
109                         }
110                         serial++
111                         e := &event{
112                                 LogID:    logID,
113                                 Received: time.Now(),
114                                 Serial:   serial,
115                                 db:       db,
116                         }
117                         debugLogf("event %d %+v", e.Serial, e)
118                         eventQueue <- e
119                         go e.Detail()
120                 }
121         }
122 }
123
124 // NewSink subscribes to the event source. NewSink returns an
125 // eventSink, whose Channel() method returns a channel: a pointer to
126 // each subsequent event will be sent to that channel.
127 //
128 // The caller must ensure events are received from the sink channel as
129 // quickly as possible because when one sink stops being ready, all
130 // other sinks block.
131 func (ps *pgEventSource) NewSink() eventSink {
132         ps.setupOnce.Do(ps.setup)
133         sink := &pgEventSink{
134                 channel: make(chan *event, 1),
135                 source:  ps,
136         }
137         ps.mtx.Lock()
138         ps.sinks[sink] = true
139         ps.mtx.Unlock()
140         return sink
141 }
142
143 type pgEventSink struct {
144         channel chan *event
145         source  *pgEventSource
146 }
147
148 func (sink *pgEventSink) Channel() <-chan *event {
149         return sink.channel
150 }
151
152 func (sink *pgEventSink) Stop() {
153         go func() {
154                 // Ensure this sink cannot fill up and block the
155                 // server-side queue (which otherwise could in turn
156                 // block our mtx.Lock() here)
157                 for _ = range sink.channel {
158                 }
159         }()
160         sink.source.mtx.Lock()
161         delete(sink.source.sinks, sink)
162         sink.source.mtx.Unlock()
163         close(sink.channel)
164 }