Refactor the multi-host salt install page.
[arvados.git] / services / ws / event_source.go
index bb323745d50f56a79c8ce968644710adaae30a73..3593c3aebd58ceae6932e9667eca43aba8a8c0cf 100644 (file)
-package main
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package ws
 
 import (
+       "context"
        "database/sql"
+       "errors"
+       "fmt"
        "strconv"
-       "strings"
        "sync"
-       "sync/atomic"
        "time"
 
-       "git.curoverse.com/arvados.git/sdk/go/stats"
+       "git.arvados.org/arvados.git/sdk/go/stats"
        "github.com/lib/pq"
+       "github.com/prometheus/client_golang/prometheus"
+       "github.com/sirupsen/logrus"
 )
 
-type pgConfig map[string]string
-
-func (c pgConfig) ConnectionString() string {
-       s := ""
-       for k, v := range c {
-               s += k
-               s += "='"
-               s += strings.Replace(
-                       strings.Replace(v, `\`, `\\`, -1),
-                       `'`, `\'`, -1)
-               s += "' "
-       }
-       return s
-}
-
 type pgEventSource struct {
-       DataSource string
-       QueueSize  int
+       DataSource   string
+       MaxOpenConns int
+       QueueSize    int
+       Logger       logrus.FieldLogger
+       Reg          *prometheus.Registry
 
        db         *sql.DB
        pqListener *pq.Listener
        queue      chan *event
        sinks      map[*pgEventSink]bool
-       setupOnce  sync.Once
        mtx        sync.Mutex
-       shutdown   chan error
 
        lastQDelay time.Duration
-       eventsIn   uint64
-       eventsOut  uint64
+       eventsIn   prometheus.Counter
+       eventsOut  prometheus.Counter
+
+       cancel func()
+
+       setupOnce sync.Once
+       ready     chan bool
 }
 
-var _ DebugStatuser = (*pgEventSource)(nil)
+func (ps *pgEventSource) listenerProblem(et pq.ListenerEventType, err error) {
+       if et == pq.ListenerEventConnected {
+               ps.Logger.Debug("pgEventSource connected")
+               return
+       }
+
+       // Until we have a mechanism for catching up on missed events,
+       // we cannot recover from a dropped connection without
+       // breaking our promises to clients.
+       ps.Logger.
+               WithField("eventType", et).
+               WithError(err).
+               Error("listener problem")
+       ps.cancel()
+}
 
 func (ps *pgEventSource) setup() {
-       ps.shutdown = make(chan error, 1)
-       ps.sinks = make(map[*pgEventSink]bool)
+       ps.ready = make(chan bool)
+       ps.Reg.MustRegister(prometheus.NewGaugeFunc(
+               prometheus.GaugeOpts{
+                       Namespace: "arvados",
+                       Subsystem: "ws",
+                       Name:      "queue_len",
+                       Help:      "Current number of events in queue",
+               }, func() float64 { return float64(len(ps.queue)) }))
+       ps.Reg.MustRegister(prometheus.NewGaugeFunc(
+               prometheus.GaugeOpts{
+                       Namespace: "arvados",
+                       Subsystem: "ws",
+                       Name:      "queue_cap",
+                       Help:      "Event queue capacity",
+               }, func() float64 { return float64(cap(ps.queue)) }))
+       ps.Reg.MustRegister(prometheus.NewGaugeFunc(
+               prometheus.GaugeOpts{
+                       Namespace: "arvados",
+                       Subsystem: "ws",
+                       Name:      "queue_delay",
+                       Help:      "Queue delay of the last emitted event",
+               }, func() float64 { return ps.lastQDelay.Seconds() }))
+       ps.Reg.MustRegister(prometheus.NewGaugeFunc(
+               prometheus.GaugeOpts{
+                       Namespace: "arvados",
+                       Subsystem: "ws",
+                       Name:      "sinks",
+                       Help:      "Number of active sinks (connections)",
+               }, func() float64 { return float64(len(ps.sinks)) }))
+       ps.Reg.MustRegister(prometheus.NewGaugeFunc(
+               prometheus.GaugeOpts{
+                       Namespace: "arvados",
+                       Subsystem: "ws",
+                       Name:      "sinks_blocked",
+                       Help:      "Number of sinks (connections) that are busy and blocking the main event stream",
+               }, func() float64 {
+                       ps.mtx.Lock()
+                       defer ps.mtx.Unlock()
+                       blocked := 0
+                       for sink := range ps.sinks {
+                               blocked += len(sink.channel)
+                       }
+                       return float64(blocked)
+               }))
+       ps.eventsIn = prometheus.NewCounter(prometheus.CounterOpts{
+               Namespace: "arvados",
+               Subsystem: "ws",
+               Name:      "events_in",
+               Help:      "Number of events received from postgresql notify channel",
+       })
+       ps.Reg.MustRegister(ps.eventsIn)
+       ps.eventsOut = prometheus.NewCounter(prometheus.CounterOpts{
+               Namespace: "arvados",
+               Subsystem: "ws",
+               Name:      "events_out",
+               Help:      "Number of events sent to client sessions (before filtering)",
+       })
+       ps.Reg.MustRegister(ps.eventsOut)
+
+       maxConnections := prometheus.NewGauge(prometheus.GaugeOpts{
+               Namespace: "arvados",
+               Subsystem: "ws",
+               Name:      "db_max_connections",
+               Help:      "Maximum number of open connections to the database",
+       })
+       ps.Reg.MustRegister(maxConnections)
+       openConnections := prometheus.NewGaugeVec(prometheus.GaugeOpts{
+               Namespace: "arvados",
+               Subsystem: "ws",
+               Name:      "db_open_connections",
+               Help:      "Open connections to the database",
+       }, []string{"inuse"})
+       ps.Reg.MustRegister(openConnections)
+
+       updateDBStats := func() {
+               stats := ps.db.Stats()
+               maxConnections.Set(float64(stats.MaxOpenConnections))
+               openConnections.WithLabelValues("0").Set(float64(stats.Idle))
+               openConnections.WithLabelValues("1").Set(float64(stats.InUse))
+       }
+       go func() {
+               <-ps.ready
+               if ps.db == nil {
+                       return
+               }
+               updateDBStats()
+               for range time.Tick(time.Second) {
+                       updateDBStats()
+               }
+       }()
+}
+
+// Close stops listening for new events and disconnects all clients.
+func (ps *pgEventSource) Close() {
+       ps.WaitReady()
+       ps.cancel()
+}
+
+// WaitReady returns when the event listener is connected.
+func (ps *pgEventSource) WaitReady() {
+       ps.setupOnce.Do(ps.setup)
+       <-ps.ready
+}
+
+// Run listens for event notifications on the "logs" channel and sends
+// them to all subscribers.
+func (ps *pgEventSource) Run() {
+       ps.Logger.Debug("pgEventSource Run starting")
+       defer ps.Logger.Debug("pgEventSource Run finished")
+
+       ps.setupOnce.Do(ps.setup)
+       ready := ps.ready
+       defer func() {
+               if ready != nil {
+                       close(ready)
+               }
+       }()
+
+       ctx, cancel := context.WithCancel(context.Background())
+       ps.cancel = cancel
+       defer cancel()
+
+       defer func() {
+               // Disconnect all clients
+               ps.mtx.Lock()
+               for sink := range ps.sinks {
+                       close(sink.channel)
+               }
+               ps.sinks = nil
+               ps.mtx.Unlock()
+       }()
 
        db, err := sql.Open("postgres", ps.DataSource)
        if err != nil {
-               logger(nil).WithError(err).Fatal("sql.Open failed")
+               ps.Logger.WithError(err).Error("sql.Open failed")
+               return
        }
+       if ps.MaxOpenConns <= 0 {
+               ps.Logger.Warn("no database connection limit configured -- consider setting PostgreSQL.ConnectionPool>0 in arvados-ws configuration file")
+       }
+       db.SetMaxOpenConns(ps.MaxOpenConns)
        if err = db.Ping(); err != nil {
-               logger(nil).WithError(err).Fatal("db.Ping failed")
+               ps.Logger.WithError(err).Error("db.Ping failed")
+               return
        }
        ps.db = db
 
-       ps.pqListener = pq.NewListener(ps.DataSource, time.Second, time.Minute, func(ev pq.ListenerEventType, err error) {
-               if err != nil {
-                       // Until we have a mechanism for catching up
-                       // on missed events, we cannot recover from a
-                       // dropped connection without breaking our
-                       // promises to clients.
-                       logger(nil).WithError(err).Error("listener problem")
-                       ps.shutdown <- err
-               }
-       })
+       ps.pqListener = pq.NewListener(ps.DataSource, time.Second, time.Minute, ps.listenerProblem)
        err = ps.pqListener.Listen("logs")
        if err != nil {
-               logger(nil).WithError(err).Fatal("pq Listen failed")
+               ps.Logger.WithError(err).Error("pq Listen failed")
+               return
        }
-       logger(nil).Debug("pgEventSource listening")
+       defer ps.pqListener.Close()
+       ps.Logger.Debug("pq Listen setup done")
 
-       go ps.run()
-}
+       close(ready)
+       // Avoid double-close in deferred func
+       ready = nil
 
-func (ps *pgEventSource) run() {
        ps.queue = make(chan *event, ps.QueueSize)
+       defer close(ps.queue)
 
        go func() {
                for e := range ps.queue {
@@ -90,7 +231,7 @@ func (ps *pgEventSource) run() {
                        // client_count X client_queue_size.
                        e.Detail()
 
-                       logger(nil).
+                       ps.Logger.
                                WithField("serial", e.Serial).
                                WithField("detail", e.Detail()).
                                Debug("event ready")
@@ -98,9 +239,9 @@ func (ps *pgEventSource) run() {
                        ps.lastQDelay = e.Ready.Sub(e.Received)
 
                        ps.mtx.Lock()
-                       atomic.AddUint64(&ps.eventsOut, uint64(len(ps.sinks)))
                        for sink := range ps.sinks {
                                sink.channel <- e
+                               ps.eventsOut.Inc()
                        }
                        ps.mtx.Unlock()
                }
@@ -111,28 +252,38 @@ func (ps *pgEventSource) run() {
        defer ticker.Stop()
        for {
                select {
-               case err, ok := <-ps.shutdown:
-                       if ok {
-                               logger(nil).WithError(err).Info("shutdown")
-                       }
-                       close(ps.queue)
+               case <-ctx.Done():
+                       ps.Logger.Debug("ctx done")
                        return
 
                case <-ticker.C:
-                       logger(nil).Debug("listener ping")
-                       ps.pqListener.Ping()
+                       ps.Logger.Debug("listener ping")
+                       err := ps.pqListener.Ping()
+                       if err != nil {
+                               ps.listenerProblem(-1, fmt.Errorf("pqListener ping failed: %s", err))
+                               continue
+                       }
 
                case pqEvent, ok := <-ps.pqListener.Notify:
                        if !ok {
-                               close(ps.queue)
+                               ps.Logger.Error("pqListener Notify chan closed")
                                return
                        }
+                       if pqEvent == nil {
+                               // pq should call listenerProblem
+                               // itself in addition to sending us a
+                               // nil event, so this might be
+                               // superfluous:
+                               ps.listenerProblem(-1, errors.New("pqListener Notify chan received nil event"))
+                               continue
+                       }
                        if pqEvent.Channel != "logs" {
+                               ps.Logger.WithField("pqEvent", pqEvent).Error("unexpected notify from wrong channel")
                                continue
                        }
                        logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
                        if err != nil {
-                               logger(nil).WithField("pqEvent", pqEvent).Error("bad notify payload")
+                               ps.Logger.WithField("pqEvent", pqEvent).Error("bad notify payload")
                                continue
                        }
                        serial++
@@ -141,9 +292,10 @@ func (ps *pgEventSource) run() {
                                Received: time.Now(),
                                Serial:   serial,
                                db:       ps.db,
+                               logger:   ps.Logger,
                        }
-                       logger(nil).WithField("event", e).Debug("incoming")
-                       atomic.AddUint64(&ps.eventsIn, 1)
+                       ps.Logger.WithField("event", e).Debug("incoming")
+                       ps.eventsIn.Inc()
                        ps.queue <- e
                        go e.Detail()
                }
@@ -158,22 +310,34 @@ func (ps *pgEventSource) run() {
 // quickly as possible because when one sink stops being ready, all
 // other sinks block.
 func (ps *pgEventSource) NewSink() eventSink {
-       ps.setupOnce.Do(ps.setup)
        sink := &pgEventSink{
                channel: make(chan *event, 1),
                source:  ps,
        }
        ps.mtx.Lock()
+       if ps.sinks == nil {
+               ps.sinks = make(map[*pgEventSink]bool)
+       }
        ps.sinks[sink] = true
        ps.mtx.Unlock()
        return sink
 }
 
 func (ps *pgEventSource) DB() *sql.DB {
-       ps.setupOnce.Do(ps.setup)
+       ps.WaitReady()
        return ps.db
 }
 
+func (ps *pgEventSource) DBHealth() error {
+       if ps.db == nil {
+               return errors.New("database not connected")
+       }
+       ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
+       defer cancel()
+       var i int
+       return ps.db.QueryRowContext(ctx, "SELECT 1").Scan(&i)
+}
+
 func (ps *pgEventSource) DebugStatus() interface{} {
        ps.mtx.Lock()
        defer ps.mtx.Unlock()
@@ -182,13 +346,12 @@ func (ps *pgEventSource) DebugStatus() interface{} {
                blocked += len(sink.channel)
        }
        return map[string]interface{}{
-               "EventsIn":     atomic.LoadUint64(&ps.eventsIn),
-               "EventsOut":    atomic.LoadUint64(&ps.eventsOut),
                "Queue":        len(ps.queue),
                "QueueLimit":   cap(ps.queue),
                "QueueDelay":   stats.Duration(ps.lastQDelay),
                "Sinks":        len(ps.sinks),
                "SinksBlocked": blocked,
+               "DBStats":      ps.db.Stats(),
        }
 }
 
@@ -201,16 +364,19 @@ func (sink *pgEventSink) Channel() <-chan *event {
        return sink.channel
 }
 
+// Stop sending events to the sink's channel.
 func (sink *pgEventSink) Stop() {
        go func() {
                // Ensure this sink cannot fill up and block the
                // server-side queue (which otherwise could in turn
                // block our mtx.Lock() here)
-               for _ = range sink.channel {
+               for range sink.channel {
                }
        }()
        sink.source.mtx.Lock()
-       delete(sink.source.sinks, sink)
+       if _, ok := sink.source.sinks[sink]; ok {
+               delete(sink.source.sinks, sink)
+               close(sink.channel)
+       }
        sink.source.mtx.Unlock()
-       close(sink.channel)
 }