16217: Refactor ws to use lib/service.
[arvados.git] / services / ws / event_source.go
index 3a82bf62b3e9351a95d2abe4c56ae942fededa4c..341464de500cf784399f8df17b6d42acf4c4ebd2 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "context"
@@ -16,12 +16,14 @@ import (
 
        "git.arvados.org/arvados.git/sdk/go/stats"
        "github.com/lib/pq"
+       "github.com/sirupsen/logrus"
 )
 
 type pgEventSource struct {
        DataSource   string
        MaxOpenConns int
        QueueSize    int
+       Logger       logrus.FieldLogger
 
        db         *sql.DB
        pqListener *pq.Listener
@@ -43,14 +45,14 @@ var _ debugStatuser = (*pgEventSource)(nil)
 
 func (ps *pgEventSource) listenerProblem(et pq.ListenerEventType, err error) {
        if et == pq.ListenerEventConnected {
-               logger(nil).Debug("pgEventSource connected")
+               ps.Logger.Debug("pgEventSource connected")
                return
        }
 
        // Until we have a mechanism for catching up on missed events,
        // we cannot recover from a dropped connection without
        // breaking our promises to clients.
-       logger(nil).
+       ps.Logger.
                WithField("eventType", et).
                WithError(err).
                Error("listener problem")
@@ -76,8 +78,8 @@ func (ps *pgEventSource) WaitReady() {
 // Run listens for event notifications on the "logs" channel and sends
 // them to all subscribers.
 func (ps *pgEventSource) Run() {
-       logger(nil).Debug("pgEventSource Run starting")
-       defer logger(nil).Debug("pgEventSource Run finished")
+       ps.Logger.Debug("pgEventSource Run starting")
+       defer ps.Logger.Debug("pgEventSource Run finished")
 
        ps.setupOnce.Do(ps.setup)
        ready := ps.ready
@@ -103,15 +105,15 @@ func (ps *pgEventSource) Run() {
 
        db, err := sql.Open("postgres", ps.DataSource)
        if err != nil {
-               logger(nil).WithError(err).Error("sql.Open failed")
+               ps.Logger.WithError(err).Error("sql.Open failed")
                return
        }
        if ps.MaxOpenConns <= 0 {
-               logger(nil).Warn("no database connection limit configured -- consider setting PostgresPool>0 in arvados-ws configuration file")
+               ps.Logger.Warn("no database connection limit configured -- consider setting PostgresPool>0 in arvados-ws configuration file")
        }
        db.SetMaxOpenConns(ps.MaxOpenConns)
        if err = db.Ping(); err != nil {
-               logger(nil).WithError(err).Error("db.Ping failed")
+               ps.Logger.WithError(err).Error("db.Ping failed")
                return
        }
        ps.db = db
@@ -119,11 +121,11 @@ func (ps *pgEventSource) Run() {
        ps.pqListener = pq.NewListener(ps.DataSource, time.Second, time.Minute, ps.listenerProblem)
        err = ps.pqListener.Listen("logs")
        if err != nil {
-               logger(nil).WithError(err).Error("pq Listen failed")
+               ps.Logger.WithError(err).Error("pq Listen failed")
                return
        }
        defer ps.pqListener.Close()
-       logger(nil).Debug("pq Listen setup done")
+       ps.Logger.Debug("pq Listen setup done")
 
        close(ready)
        // Avoid double-close in deferred func
@@ -141,7 +143,7 @@ func (ps *pgEventSource) Run() {
                        // client_count X client_queue_size.
                        e.Detail()
 
-                       logger(nil).
+                       ps.Logger.
                                WithField("serial", e.Serial).
                                WithField("detail", e.Detail()).
                                Debug("event ready")
@@ -163,11 +165,11 @@ func (ps *pgEventSource) Run() {
        for {
                select {
                case <-ctx.Done():
-                       logger(nil).Debug("ctx done")
+                       ps.Logger.Debug("ctx done")
                        return
 
                case <-ticker.C:
-                       logger(nil).Debug("listener ping")
+                       ps.Logger.Debug("listener ping")
                        err := ps.pqListener.Ping()
                        if err != nil {
                                ps.listenerProblem(-1, fmt.Errorf("pqListener ping failed: %s", err))
@@ -176,7 +178,7 @@ func (ps *pgEventSource) Run() {
 
                case pqEvent, ok := <-ps.pqListener.Notify:
                        if !ok {
-                               logger(nil).Error("pqListener Notify chan closed")
+                               ps.Logger.Error("pqListener Notify chan closed")
                                return
                        }
                        if pqEvent == nil {
@@ -188,12 +190,12 @@ func (ps *pgEventSource) Run() {
                                continue
                        }
                        if pqEvent.Channel != "logs" {
-                               logger(nil).WithField("pqEvent", pqEvent).Error("unexpected notify from wrong channel")
+                               ps.Logger.WithField("pqEvent", pqEvent).Error("unexpected notify from wrong channel")
                                continue
                        }
                        logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
                        if err != nil {
-                               logger(nil).WithField("pqEvent", pqEvent).Error("bad notify payload")
+                               ps.Logger.WithField("pqEvent", pqEvent).Error("bad notify payload")
                                continue
                        }
                        serial++
@@ -202,8 +204,9 @@ func (ps *pgEventSource) Run() {
                                Received: time.Now(),
                                Serial:   serial,
                                db:       ps.db,
+                               logger:   ps.Logger,
                        }
-                       logger(nil).WithField("event", e).Debug("incoming")
+                       ps.Logger.WithField("event", e).Debug("incoming")
                        atomic.AddUint64(&ps.eventsIn, 1)
                        ps.queue <- e
                        go e.Detail()
@@ -238,6 +241,9 @@ func (ps *pgEventSource) DB() *sql.DB {
 }
 
 func (ps *pgEventSource) DBHealth() error {
+       if ps.db == nil {
+               return errors.New("database not connected")
+       }
        ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
        defer cancel()
        var i int