16217: Refactor ws to use lib/service.
authorTom Clegg <tom@tomclegg.ca>
Wed, 25 Mar 2020 15:00:03 +0000 (11:00 -0400)
committerTom Clegg <tom@tomclegg.ca>
Wed, 25 Mar 2020 15:00:03 +0000 (11:00 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@tomclegg.ca>

22 files changed:
build/run-build-packages.sh
cmd/arvados-server/arvados-ws.service [moved from services/ws/arvados-ws.service with 94% similarity]
cmd/arvados-server/cmd.go
go.mod
services/ws/doc.go
services/ws/event.go
services/ws/event_source.go
services/ws/event_source_test.go
services/ws/event_test.go
services/ws/gocheck_test.go
services/ws/handler.go
services/ws/main.go [deleted file]
services/ws/permission.go
services/ws/permission_test.go
services/ws/router.go
services/ws/server.go [deleted file]
services/ws/service.go [new file with mode: 0644]
services/ws/service_test.go [moved from services/ws/server_test.go with 68% similarity]
services/ws/session.go
services/ws/session_v0.go
services/ws/session_v0_test.go
services/ws/session_v1.go

index 4faa1c6b0d4b0e83d12d27b997615fbf78031284..3ba1dcc05e8776fc57a205e2deb79a0224a8e370 100755 (executable)
@@ -308,7 +308,7 @@ package_go_binary services/keepstore keepstore \
     "Keep storage daemon, accessible to clients on the LAN"
 package_go_binary services/keep-web keep-web \
     "Static web hosting service for user data stored in Arvados Keep"
-package_go_binary services/ws arvados-ws \
+package_go_binary cmd/arvados-server arvados-ws \
     "Arvados Websocket server"
 package_go_binary tools/sync-groups arvados-sync-groups \
     "Synchronize remote groups into Arvados from an external source"
similarity index 94%
rename from services/ws/arvados-ws.service
rename to cmd/arvados-server/arvados-ws.service
index 36624c78779c02cfde829323551ca9c2cb19eda3..aebc56a79f333b19f061f5f0aadce793e799529c 100644 (file)
@@ -6,6 +6,7 @@
 Description=Arvados websocket server
 Documentation=https://doc.arvados.org/
 After=network.target
+AssertPathExists=/etc/arvados/config.yml
 
 # systemd==229 (ubuntu:xenial) obeys StartLimitInterval in the [Unit] section
 StartLimitInterval=0
index a9d927d8734401f76fa173bff7214e0038fc4c68..80d43ad848fa0925327a2abe6fa785c968b536ae 100644 (file)
@@ -14,6 +14,7 @@ import (
        "git.arvados.org/arvados.git/lib/controller"
        "git.arvados.org/arvados.git/lib/crunchrun"
        "git.arvados.org/arvados.git/lib/dispatchcloud"
+       "git.arvados.org/arvados.git/services/ws"
 )
 
 var (
@@ -30,6 +31,7 @@ var (
                "controller":      controller.Command,
                "crunch-run":      crunchrun.Command,
                "dispatch-cloud":  dispatchcloud.Command,
+               "ws":              ws.Command,
        })
 )
 
diff --git a/go.mod b/go.mod
index 2cc5e89eb1fe68c88335e2ba2e2906e1bb2d9c33..48b1c725a5986bfe249d3444ed5698a9ad193bd3 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -52,7 +52,7 @@ require (
        golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
        golang.org/x/net v0.0.0-20190620200207-3b0461eec859
        golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
-       golang.org/x/sys v0.0.0-20191105231009-c1f44814a5cd // indirect
+       golang.org/x/sys v0.0.0-20191105231009-c1f44814a5cd
        google.golang.org/api v0.13.0
        gopkg.in/check.v1 v1.0.0-20161208181325-20d25e280405
        gopkg.in/square/go-jose.v2 v2.3.1
index 806c3355da6c693350493a7471bc59e270bfb1e3..6a86cbe7a8307e1683dbd09ea506bc8cd79f52e3 100644 (file)
 // Developer info
 //
 // See https://dev.arvados.org/projects/arvados/wiki/Hacking_websocket_server.
-//
-// Usage
-//
-//     arvados-ws [-legacy-ws-config /etc/arvados/ws/ws.yml] [-dump-config]
-//
-// Options
-//
-// -legacy-ws-config path
-//
-// Load legacy configuration from the given file instead of the default
-// /etc/arvados/ws/ws.yml, legacy config overrides the clusterwide config.yml.
-//
-// -dump-config
-//
-// Print the loaded configuration to stdout and exit.
-//
-// Logs
-//
-// Logs are printed to stderr, formatted as JSON.
-//
-// A log is printed each time a client connects or disconnects.
-//
-// Enable additional logs by configuring:
-//
-//     LogLevel: debug
-//
-// Runtime status
-//
-// GET /debug.json responds with debug stats.
-//
-// GET /status.json responds with health check results and
-// activity/usage metrics.
-package main
+package ws
index ae545c092cf8ddece45cfbebdddb542e08de16b4..c989c0ca559b1a1cff472b2cc1bdb95b4fd021ce 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "database/sql"
@@ -11,6 +11,7 @@ import (
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "github.com/ghodss/yaml"
+       "github.com/sirupsen/logrus"
 )
 
 type eventSink interface {
@@ -31,6 +32,7 @@ type event struct {
        Serial   uint64
 
        db     *sql.DB
+       logger logrus.FieldLogger
        logRow *arvados.Log
        err    error
        mtx    sync.Mutex
@@ -57,12 +59,12 @@ func (e *event) Detail() *arvados.Log {
                &logRow.CreatedAt,
                &propYAML)
        if e.err != nil {
-               logger(nil).WithField("LogID", e.LogID).WithError(e.err).Error("QueryRow failed")
+               e.logger.WithField("LogID", e.LogID).WithError(e.err).Error("QueryRow failed")
                return nil
        }
        e.err = yaml.Unmarshal(propYAML, &logRow.Properties)
        if e.err != nil {
-               logger(nil).WithField("LogID", e.LogID).WithError(e.err).Error("yaml decode failed")
+               e.logger.WithField("LogID", e.LogID).WithError(e.err).Error("yaml decode failed")
                return nil
        }
        e.logRow = &logRow
index 3a82bf62b3e9351a95d2abe4c56ae942fededa4c..341464de500cf784399f8df17b6d42acf4c4ebd2 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "context"
@@ -16,12 +16,14 @@ import (
 
        "git.arvados.org/arvados.git/sdk/go/stats"
        "github.com/lib/pq"
+       "github.com/sirupsen/logrus"
 )
 
 type pgEventSource struct {
        DataSource   string
        MaxOpenConns int
        QueueSize    int
+       Logger       logrus.FieldLogger
 
        db         *sql.DB
        pqListener *pq.Listener
@@ -43,14 +45,14 @@ var _ debugStatuser = (*pgEventSource)(nil)
 
 func (ps *pgEventSource) listenerProblem(et pq.ListenerEventType, err error) {
        if et == pq.ListenerEventConnected {
-               logger(nil).Debug("pgEventSource connected")
+               ps.Logger.Debug("pgEventSource connected")
                return
        }
 
        // Until we have a mechanism for catching up on missed events,
        // we cannot recover from a dropped connection without
        // breaking our promises to clients.
-       logger(nil).
+       ps.Logger.
                WithField("eventType", et).
                WithError(err).
                Error("listener problem")
@@ -76,8 +78,8 @@ func (ps *pgEventSource) WaitReady() {
 // Run listens for event notifications on the "logs" channel and sends
 // them to all subscribers.
 func (ps *pgEventSource) Run() {
-       logger(nil).Debug("pgEventSource Run starting")
-       defer logger(nil).Debug("pgEventSource Run finished")
+       ps.Logger.Debug("pgEventSource Run starting")
+       defer ps.Logger.Debug("pgEventSource Run finished")
 
        ps.setupOnce.Do(ps.setup)
        ready := ps.ready
@@ -103,15 +105,15 @@ func (ps *pgEventSource) Run() {
 
        db, err := sql.Open("postgres", ps.DataSource)
        if err != nil {
-               logger(nil).WithError(err).Error("sql.Open failed")
+               ps.Logger.WithError(err).Error("sql.Open failed")
                return
        }
        if ps.MaxOpenConns <= 0 {
-               logger(nil).Warn("no database connection limit configured -- consider setting PostgresPool>0 in arvados-ws configuration file")
+               ps.Logger.Warn("no database connection limit configured -- consider setting PostgresPool>0 in arvados-ws configuration file")
        }
        db.SetMaxOpenConns(ps.MaxOpenConns)
        if err = db.Ping(); err != nil {
-               logger(nil).WithError(err).Error("db.Ping failed")
+               ps.Logger.WithError(err).Error("db.Ping failed")
                return
        }
        ps.db = db
@@ -119,11 +121,11 @@ func (ps *pgEventSource) Run() {
        ps.pqListener = pq.NewListener(ps.DataSource, time.Second, time.Minute, ps.listenerProblem)
        err = ps.pqListener.Listen("logs")
        if err != nil {
-               logger(nil).WithError(err).Error("pq Listen failed")
+               ps.Logger.WithError(err).Error("pq Listen failed")
                return
        }
        defer ps.pqListener.Close()
-       logger(nil).Debug("pq Listen setup done")
+       ps.Logger.Debug("pq Listen setup done")
 
        close(ready)
        // Avoid double-close in deferred func
@@ -141,7 +143,7 @@ func (ps *pgEventSource) Run() {
                        // client_count X client_queue_size.
                        e.Detail()
 
-                       logger(nil).
+                       ps.Logger.
                                WithField("serial", e.Serial).
                                WithField("detail", e.Detail()).
                                Debug("event ready")
@@ -163,11 +165,11 @@ func (ps *pgEventSource) Run() {
        for {
                select {
                case <-ctx.Done():
-                       logger(nil).Debug("ctx done")
+                       ps.Logger.Debug("ctx done")
                        return
 
                case <-ticker.C:
-                       logger(nil).Debug("listener ping")
+                       ps.Logger.Debug("listener ping")
                        err := ps.pqListener.Ping()
                        if err != nil {
                                ps.listenerProblem(-1, fmt.Errorf("pqListener ping failed: %s", err))
@@ -176,7 +178,7 @@ func (ps *pgEventSource) Run() {
 
                case pqEvent, ok := <-ps.pqListener.Notify:
                        if !ok {
-                               logger(nil).Error("pqListener Notify chan closed")
+                               ps.Logger.Error("pqListener Notify chan closed")
                                return
                        }
                        if pqEvent == nil {
@@ -188,12 +190,12 @@ func (ps *pgEventSource) Run() {
                                continue
                        }
                        if pqEvent.Channel != "logs" {
-                               logger(nil).WithField("pqEvent", pqEvent).Error("unexpected notify from wrong channel")
+                               ps.Logger.WithField("pqEvent", pqEvent).Error("unexpected notify from wrong channel")
                                continue
                        }
                        logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
                        if err != nil {
-                               logger(nil).WithField("pqEvent", pqEvent).Error("bad notify payload")
+                               ps.Logger.WithField("pqEvent", pqEvent).Error("bad notify payload")
                                continue
                        }
                        serial++
@@ -202,8 +204,9 @@ func (ps *pgEventSource) Run() {
                                Received: time.Now(),
                                Serial:   serial,
                                db:       ps.db,
+                               logger:   ps.Logger,
                        }
-                       logger(nil).WithField("event", e).Debug("incoming")
+                       ps.Logger.WithField("event", e).Debug("incoming")
                        atomic.AddUint64(&ps.eventsIn, 1)
                        ps.queue <- e
                        go e.Detail()
@@ -238,6 +241,9 @@ func (ps *pgEventSource) DB() *sql.DB {
 }
 
 func (ps *pgEventSource) DBHealth() error {
+       if ps.db == nil {
+               return errors.New("database not connected")
+       }
        ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
        defer cancel()
        var i int
index 98a9e8b9785b40dbd8f5314bcedb98bd083efe44..dd40835b6e56f5a3ed5ff07c91e32fbe4b920882 100644 (file)
@@ -2,17 +2,16 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "database/sql"
        "fmt"
-       "os"
-       "path/filepath"
        "sync"
        "time"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
        check "gopkg.in/check.v1"
 )
 
@@ -21,7 +20,7 @@ var _ = check.Suite(&eventSourceSuite{})
 type eventSourceSuite struct{}
 
 func testDBConfig() arvados.PostgreSQLConnection {
-       cfg, err := arvados.GetConfig(filepath.Join(os.Getenv("WORKSPACE"), "tmp", "arvados.yml"))
+       cfg, err := arvados.GetConfig(arvados.DefaultConfigFile)
        if err != nil {
                panic(err)
        }
@@ -46,6 +45,7 @@ func (*eventSourceSuite) TestEventSource(c *check.C) {
        pges := &pgEventSource{
                DataSource: cfg.String(),
                QueueSize:  4,
+               Logger:     ctxlog.TestLogger(c),
        }
        go pges.Run()
        sinks := make([]eventSink, 18)
index dc324464ec3d15f4b473b5d9b91f3557c7a90abd..4665dfcd9ee9208fcb71794189ba115d0285fa55 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import check "gopkg.in/check.v1"
 
index ea8dfc30c94c94e19308192c8c6713f745ce3a9b..df1ca7ab31c292280ab8a72c2f56155ef4c68e84 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "testing"
@@ -13,3 +13,7 @@ import (
 func TestGocheck(t *testing.T) {
        check.TestingT(t)
 }
+
+func init() {
+       testMode = true
+}
index 913b1ee8000cbd274039483df70bad7896d52df5..912643ad97c6374006b3fd4b00f90d340157d687 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "context"
@@ -12,6 +12,7 @@ import (
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/stats"
+       "github.com/sirupsen/logrus"
 )
 
 type handler struct {
@@ -31,12 +32,11 @@ type handlerStats struct {
        EventCount   uint64
 }
 
-func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
+func (h *handler) Handle(ws wsConn, logger logrus.FieldLogger, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
        h.setupOnce.Do(h.setup)
 
        ctx, cancel := context.WithCancel(ws.Request().Context())
        defer cancel()
-       log := logger(ctx)
 
        incoming := eventSource.NewSink()
        defer incoming.Stop()
@@ -53,7 +53,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
 
        sess, err := newSession(ws, queue)
        if err != nil {
-               log.WithError(err).Error("newSession failed")
+               logger.WithError(err).Error("newSession failed")
                return
        }
 
@@ -71,19 +71,19 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                        ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour))
                        n, err := ws.Read(buf)
                        buf := buf[:n]
-                       log.WithField("frame", string(buf[:n])).Debug("received frame")
+                       logger.WithField("frame", string(buf[:n])).Debug("received frame")
                        if err == nil && n == cap(buf) {
                                err = errFrameTooBig
                        }
                        if err != nil {
                                if err != io.EOF && ctx.Err() == nil {
-                                       log.WithError(err).Info("read error")
+                                       logger.WithError(err).Info("read error")
                                }
                                return
                        }
                        err = sess.Receive(buf)
                        if err != nil {
-                               log.WithError(err).Error("sess.Receive() failed")
+                               logger.WithError(err).Error("sess.Receive() failed")
                                return
                        }
                }
@@ -108,38 +108,38 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                        var e *event
                        var buf []byte
                        var err error
-                       log := log
+                       logger := logger
 
                        switch data := data.(type) {
                        case []byte:
                                buf = data
                        case *event:
                                e = data
-                               log = log.WithField("serial", e.Serial)
+                               logger = logger.WithField("serial", e.Serial)
                                buf, err = sess.EventMessage(e)
                                if err != nil {
-                                       log.WithError(err).Error("EventMessage failed")
+                                       logger.WithError(err).Error("EventMessage failed")
                                        return
                                } else if len(buf) == 0 {
-                                       log.Debug("skip")
+                                       logger.Debug("skip")
                                        continue
                                }
                        default:
-                               log.WithField("data", data).Error("bad object in client queue")
+                               logger.WithField("data", data).Error("bad object in client queue")
                                continue
                        }
 
-                       log.WithField("frame", string(buf)).Debug("send event")
+                       logger.WithField("frame", string(buf)).Debug("send event")
                        ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
                        t0 := time.Now()
                        _, err = ws.Write(buf)
                        if err != nil {
                                if ctx.Err() == nil {
-                                       log.WithError(err).Error("write failed")
+                                       logger.WithError(err).Error("write failed")
                                }
                                return
                        }
-                       log.Debug("sent")
+                       logger.Debug("sent")
 
                        if e != nil {
                                hStats.QueueDelayNs += t0.Sub(e.Ready)
@@ -189,7 +189,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                                select {
                                case queue <- e:
                                default:
-                                       log.WithError(errQueueFull).Error("terminate")
+                                       logger.WithError(errQueueFull).Error("terminate")
                                        return
                                }
                        }
diff --git a/services/ws/main.go b/services/ws/main.go
deleted file mode 100644 (file)
index 5b42c44..0000000
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright (C) The Arvados Authors. All rights reserved.
-//
-// SPDX-License-Identifier: AGPL-3.0
-
-package main
-
-import (
-       "flag"
-       "fmt"
-       "os"
-
-       "git.arvados.org/arvados.git/lib/config"
-       "git.arvados.org/arvados.git/sdk/go/arvados"
-       "git.arvados.org/arvados.git/sdk/go/ctxlog"
-       "github.com/ghodss/yaml"
-       "github.com/sirupsen/logrus"
-)
-
-var logger = ctxlog.FromContext
-var version = "dev"
-
-func configure(log logrus.FieldLogger, args []string) *arvados.Cluster {
-       flags := flag.NewFlagSet(args[0], flag.ExitOnError)
-       dumpConfig := flags.Bool("dump-config", false, "show current configuration and exit")
-       getVersion := flags.Bool("version", false, "Print version information and exit.")
-
-       loader := config.NewLoader(nil, log)
-       loader.SetupFlags(flags)
-       args = loader.MungeLegacyConfigArgs(log, args[1:], "-legacy-ws-config")
-
-       flags.Parse(args)
-
-       // Print version information if requested
-       if *getVersion {
-               fmt.Printf("arvados-ws %s\n", version)
-               return nil
-       }
-
-       cfg, err := loader.Load()
-       if err != nil {
-               log.Fatal(err)
-       }
-
-       cluster, err := cfg.GetCluster("")
-       if err != nil {
-               log.Fatal(err)
-       }
-
-       ctxlog.SetLevel(cluster.SystemLogs.LogLevel)
-       ctxlog.SetFormat(cluster.SystemLogs.Format)
-
-       if *dumpConfig {
-               out, err := yaml.Marshal(cfg)
-               if err != nil {
-                       log.Fatal(err)
-               }
-               _, err = os.Stdout.Write(out)
-               if err != nil {
-                       log.Fatal(err)
-               }
-               return nil
-       }
-       return cluster
-}
-
-func main() {
-       log := logger(nil)
-
-       cluster := configure(log, os.Args)
-       if cluster == nil {
-               return
-       }
-
-       log.Printf("arvados-ws %s started", version)
-       srv := &server{cluster: cluster}
-       log.Fatal(srv.Run())
-}
index 745d28f9523f36ca83afa0b29e9511e6f98176f9..ac895f80e5fd7ae7933558fbfa6e6acb97a6c7b0 100644 (file)
@@ -2,14 +2,16 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
+       "context"
        "net/http"
        "net/url"
        "time"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
 )
 
 const (
@@ -19,7 +21,7 @@ const (
 
 type permChecker interface {
        SetToken(token string)
-       Check(uuid string) (bool, error)
+       Check(ctx context.Context, uuid string) (bool, error)
 }
 
 func newPermChecker(ac arvados.Client) permChecker {
@@ -54,9 +56,9 @@ func (pc *cachingPermChecker) SetToken(token string) {
        pc.cache = make(map[string]cacheEnt)
 }
 
-func (pc *cachingPermChecker) Check(uuid string) (bool, error) {
+func (pc *cachingPermChecker) Check(ctx context.Context, uuid string) (bool, error) {
        pc.nChecks++
-       logger := logger(nil).
+       logger := ctxlog.FromContext(ctx).
                WithField("token", pc.Client.AuthToken).
                WithField("uuid", uuid)
        pc.tidy()
index 5f972551ffe8ffeaa4e11ec81573ae46425591d3..023656c01fd93dc3a912283682ffc9eda59c7e6b 100644 (file)
@@ -2,9 +2,11 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
+       "context"
+
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/arvadostest"
        check "gopkg.in/check.v1"
@@ -22,19 +24,19 @@ func (s *permSuite) TestCheck(c *check.C) {
        }
        wantError := func(uuid string) {
                c.Log(uuid)
-               ok, err := pc.Check(uuid)
+               ok, err := pc.Check(context.Background(), uuid)
                c.Check(ok, check.Equals, false)
                c.Check(err, check.NotNil)
        }
        wantYes := func(uuid string) {
                c.Log(uuid)
-               ok, err := pc.Check(uuid)
+               ok, err := pc.Check(context.Background(), uuid)
                c.Check(ok, check.Equals, true)
                c.Check(err, check.IsNil)
        }
        wantNo := func(uuid string) {
                c.Log(uuid)
-               ok, err := pc.Check(uuid)
+               ok, err := pc.Check(context.Background(), uuid)
                c.Check(ok, check.Equals, false)
                c.Check(err, check.IsNil)
        }
@@ -67,7 +69,7 @@ func (s *permSuite) TestCheck(c *check.C) {
        pc.SetToken(arvadostest.ActiveToken)
 
        c.Log("...network error")
-       pc.Client.APIHost = "127.0.0.1:discard"
+       pc.Client.APIHost = "127.0.0.1:9"
        wantError(arvadostest.UserAgreementCollection)
        wantError(arvadostest.FooBarDirCollection)
 
index f8c273c5141b6f76f73b28c3c2c5d995f0df94dd..b1764c156cad44c5954ca542fc8178b1cc182e1b 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "encoding/json"
@@ -13,6 +13,7 @@ import (
        "sync/atomic"
        "time"
 
+       "git.arvados.org/arvados.git/lib/cmd"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "git.arvados.org/arvados.git/sdk/go/health"
@@ -28,7 +29,7 @@ type wsConn interface {
 }
 
 type router struct {
-       client         arvados.Client
+       client         *arvados.Client
        cluster        *arvados.Cluster
        eventSource    eventSource
        newPermChecker func() permChecker
@@ -71,7 +72,7 @@ func (rtr *router) setup() {
                },
                Log: func(r *http.Request, err error) {
                        if err != nil {
-                               logger(r.Context()).WithError(err).Error("error")
+                               ctxlog.FromContext(r.Context()).WithError(err).Error("error")
                        }
                },
        })
@@ -84,15 +85,15 @@ func (rtr *router) makeServer(newSession sessionFactory) *websocket.Server {
                },
                Handler: websocket.Handler(func(ws *websocket.Conn) {
                        t0 := time.Now()
-                       log := logger(ws.Request().Context())
-                       log.Info("connected")
+                       logger := ctxlog.FromContext(ws.Request().Context())
+                       logger.Info("connected")
 
-                       stats := rtr.handler.Handle(ws, rtr.eventSource,
+                       stats := rtr.handler.Handle(ws, logger, rtr.eventSource,
                                func(ws wsConn, sendq chan<- interface{}) (session, error) {
-                                       return newSession(ws, sendq, rtr.eventSource.DB(), rtr.newPermChecker(), &rtr.client)
+                                       return newSession(ws, sendq, rtr.eventSource.DB(), rtr.newPermChecker(), rtr.client)
                                })
 
-                       log.WithFields(logrus.Fields{
+                       logger.WithFields(logrus.Fields{
                                "elapsed": time.Now().Sub(t0).Seconds(),
                                "stats":   stats,
                        }).Info("disconnect")
@@ -125,7 +126,7 @@ func (rtr *router) DebugStatus() interface{} {
 func (rtr *router) Status() interface{} {
        return map[string]interface{}{
                "Clients": atomic.LoadInt64(&rtr.status.ReqsActive),
-               "Version": version,
+               "Version": cmd.Version.String(),
        }
 }
 
@@ -135,7 +136,7 @@ func (rtr *router) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
        atomic.AddInt64(&rtr.status.ReqsActive, 1)
        defer atomic.AddInt64(&rtr.status.ReqsActive, -1)
 
-       logger := logger(req.Context()).
+       logger := ctxlog.FromContext(req.Context()).
                WithField("RequestID", rtr.newReqID())
        ctx := ctxlog.Context(req.Context(), logger)
        req = req.WithContext(ctx)
@@ -148,7 +149,7 @@ func (rtr *router) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
 
 func (rtr *router) jsonHandler(fn func() interface{}) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               logger := logger(r.Context())
+               logger := ctxlog.FromContext(r.Context())
                w.Header().Set("Content-Type", "application/json")
                enc := json.NewEncoder(w)
                err := enc.Encode(fn())
@@ -159,3 +160,8 @@ func (rtr *router) jsonHandler(fn func() interface{}) http.Handler {
                }
        })
 }
+
+func (rtr *router) CheckHealth() error {
+       rtr.setupOnce.Do(rtr.setup)
+       return rtr.eventSource.DBHealth()
+}
diff --git a/services/ws/server.go b/services/ws/server.go
deleted file mode 100644 (file)
index 9747ea1..0000000
+++ /dev/null
@@ -1,89 +0,0 @@
-// Copyright (C) The Arvados Authors. All rights reserved.
-//
-// SPDX-License-Identifier: AGPL-3.0
-
-package main
-
-import (
-       "net"
-       "net/http"
-       "sync"
-       "time"
-
-       "git.arvados.org/arvados.git/sdk/go/arvados"
-       "github.com/coreos/go-systemd/daemon"
-)
-
-type server struct {
-       httpServer  *http.Server
-       listener    net.Listener
-       cluster     *arvados.Cluster
-       eventSource *pgEventSource
-       setupOnce   sync.Once
-}
-
-func (srv *server) Close() {
-       srv.WaitReady()
-       srv.eventSource.Close()
-       srv.httpServer.Close()
-       srv.listener.Close()
-}
-
-func (srv *server) WaitReady() {
-       srv.setupOnce.Do(srv.setup)
-       srv.eventSource.WaitReady()
-}
-
-func (srv *server) Run() error {
-       srv.setupOnce.Do(srv.setup)
-       return srv.httpServer.Serve(srv.listener)
-}
-
-func (srv *server) setup() {
-       log := logger(nil)
-
-       var listen arvados.URL
-       for listen, _ = range srv.cluster.Services.Websocket.InternalURLs {
-               break
-       }
-       ln, err := net.Listen("tcp", listen.Host)
-       if err != nil {
-               log.WithField("Listen", listen).Fatal(err)
-       }
-       log.WithField("Listen", ln.Addr().String()).Info("listening")
-
-       client := arvados.Client{}
-       client.APIHost = srv.cluster.Services.Controller.ExternalURL.Host
-       client.AuthToken = srv.cluster.SystemRootToken
-       client.Insecure = srv.cluster.TLS.Insecure
-
-       srv.listener = ln
-       srv.eventSource = &pgEventSource{
-               DataSource:   srv.cluster.PostgreSQL.Connection.String(),
-               MaxOpenConns: srv.cluster.PostgreSQL.ConnectionPool,
-               QueueSize:    srv.cluster.API.WebsocketServerEventQueue,
-       }
-
-       srv.httpServer = &http.Server{
-               Addr:           listen.Host,
-               ReadTimeout:    time.Minute,
-               WriteTimeout:   time.Minute,
-               MaxHeaderBytes: 1 << 20,
-               Handler: &router{
-                       cluster:        srv.cluster,
-                       client:         client,
-                       eventSource:    srv.eventSource,
-                       newPermChecker: func() permChecker { return newPermChecker(client) },
-               },
-       }
-
-       go func() {
-               srv.eventSource.Run()
-               log.Info("event source stopped")
-               srv.Close()
-       }()
-
-       if _, err := daemon.SdNotify(false, "READY=1"); err != nil {
-               log.WithError(err).Warn("error notifying init daemon")
-       }
-}
diff --git a/services/ws/service.go b/services/ws/service.go
new file mode 100644 (file)
index 0000000..fb313bb
--- /dev/null
@@ -0,0 +1,52 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package ws
+
+import (
+       "context"
+       "fmt"
+       "os"
+
+       "git.arvados.org/arvados.git/lib/cmd"
+       "git.arvados.org/arvados.git/lib/service"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "github.com/prometheus/client_golang/prometheus"
+)
+
+var testMode = false
+
+var Command cmd.Handler = service.Command(arvados.ServiceNameWebsocket, newHandler)
+
+func newHandler(ctx context.Context, cluster *arvados.Cluster, token string, reg *prometheus.Registry) service.Handler {
+       client, err := arvados.NewClientFromConfig(cluster)
+       if err != nil {
+               return service.ErrorHandler(ctx, cluster, fmt.Errorf("error initializing client from cluster config: %s", err))
+       }
+       eventSource := &pgEventSource{
+               DataSource:   cluster.PostgreSQL.Connection.String(),
+               MaxOpenConns: cluster.PostgreSQL.ConnectionPool,
+               QueueSize:    cluster.API.WebsocketServerEventQueue,
+               Logger:       ctxlog.FromContext(ctx),
+       }
+       go func() {
+               eventSource.Run()
+               ctxlog.FromContext(ctx).Error("event source stopped")
+               if !testMode {
+                       os.Exit(1)
+               }
+       }()
+       eventSource.WaitReady()
+       if err := eventSource.DBHealth(); err != nil {
+               return service.ErrorHandler(ctx, cluster, err)
+       }
+       rtr := &router{
+               cluster:        cluster,
+               client:         client,
+               eventSource:    eventSource,
+               newPermChecker: func() permChecker { return newPermChecker(*client) },
+       }
+       return rtr
+}
similarity index 68%
rename from services/ws/server_test.go
rename to services/ws/service_test.go
index 88279ec9b2de83cd28bc191815bd1fa274cfec80..1afd8e006496a88ec30f7edd9f990d9f03e0809d 100644 (file)
@@ -2,39 +2,57 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
+       "bytes"
+       "context"
        "encoding/json"
+       "flag"
        "io/ioutil"
        "net/http"
+       "net/http/httptest"
        "os"
        "sync"
        "time"
 
        "git.arvados.org/arvados.git/lib/config"
+       "git.arvados.org/arvados.git/lib/service"
        "git.arvados.org/arvados.git/sdk/go/arvados"
        "git.arvados.org/arvados.git/sdk/go/arvadostest"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "github.com/prometheus/client_golang/prometheus"
+       "github.com/sirupsen/logrus"
        check "gopkg.in/check.v1"
 )
 
-var _ = check.Suite(&serverSuite{})
+var _ = check.Suite(&serviceSuite{})
 
-type serverSuite struct {
+type serviceSuite struct {
+       handler service.Handler
+       srv     *httptest.Server
        cluster *arvados.Cluster
-       srv     *server
        wg      sync.WaitGroup
 }
 
-func (s *serverSuite) SetUpTest(c *check.C) {
+func (s *serviceSuite) SetUpTest(c *check.C) {
        var err error
        s.cluster, err = s.testConfig(c)
        c.Assert(err, check.IsNil)
-       s.srv = &server{cluster: s.cluster}
 }
 
-func (*serverSuite) testConfig(c *check.C) (*arvados.Cluster, error) {
+func (s *serviceSuite) start() {
+       s.handler = newHandler(context.Background(), s.cluster, "", prometheus.NewRegistry())
+       s.srv = httptest.NewServer(s.handler)
+}
+
+func (s *serviceSuite) TearDownTest(c *check.C) {
+       if s.srv != nil {
+               s.srv.Close()
+       }
+}
+
+func (*serviceSuite) testConfig(c *check.C) (*arvados.Cluster, error) {
        ldr := config.NewLoader(nil, ctxlog.TestLogger(c))
        cfg, err := ldr.Load()
        if err != nil {
@@ -54,42 +72,24 @@ func (*serverSuite) testConfig(c *check.C) (*arvados.Cluster, error) {
        return cluster, nil
 }
 
-// TestBadDB ensures Run() returns an error (instead of panicking or
-// deadlocking) if it can't connect to the database server at startup.
-func (s *serverSuite) TestBadDB(c *check.C) {
+// TestBadDB ensures the server returns an error (instead of panicking
+// or deadlocking) if it can't connect to the database server at
+// startup.
+func (s *serviceSuite) TestBadDB(c *check.C) {
        s.cluster.PostgreSQL.Connection["password"] = "1234"
-
-       var wg sync.WaitGroup
-       wg.Add(1)
-       go func() {
-               err := s.srv.Run()
-               c.Check(err, check.NotNil)
-               wg.Done()
-       }()
-       wg.Add(1)
-       go func() {
-               s.srv.WaitReady()
-               wg.Done()
-       }()
-
-       done := make(chan bool)
-       go func() {
-               wg.Wait()
-               close(done)
-       }()
-       select {
-       case <-done:
-       case <-time.After(10 * time.Second):
-               c.Fatal("timeout")
-       }
+       s.start()
+       resp, err := http.Get(s.srv.URL)
+       c.Check(err, check.IsNil)
+       c.Check(resp.StatusCode, check.Equals, http.StatusInternalServerError)
+       c.Check(s.handler.CheckHealth(), check.ErrorMatches, "database not connected")
+       c.Check(err, check.IsNil)
+       c.Check(resp.StatusCode, check.Equals, http.StatusInternalServerError)
 }
 
-func (s *serverSuite) TestHealth(c *check.C) {
-       go s.srv.Run()
-       defer s.srv.Close()
-       s.srv.WaitReady()
+func (s *serviceSuite) TestHealth(c *check.C) {
+       s.start()
        for _, token := range []string{"", "foo", s.cluster.ManagementToken} {
-               req, err := http.NewRequest("GET", "http://"+s.srv.listener.Addr().String()+"/_health/ping", nil)
+               req, err := http.NewRequest("GET", s.srv.URL+"/_health/ping", nil)
                c.Assert(err, check.IsNil)
                if token != "" {
                        req.Header.Add("Authorization", "Bearer "+token)
@@ -107,11 +107,9 @@ func (s *serverSuite) TestHealth(c *check.C) {
        }
 }
 
-func (s *serverSuite) TestStatus(c *check.C) {
-       go s.srv.Run()
-       defer s.srv.Close()
-       s.srv.WaitReady()
-       req, err := http.NewRequest("GET", "http://"+s.srv.listener.Addr().String()+"/status.json", nil)
+func (s *serviceSuite) TestStatus(c *check.C) {
+       s.start()
+       req, err := http.NewRequest("GET", s.srv.URL+"/status.json", nil)
        c.Assert(err, check.IsNil)
        resp, err := http.DefaultClient.Do(req)
        c.Check(err, check.IsNil)
@@ -122,15 +120,11 @@ func (s *serverSuite) TestStatus(c *check.C) {
        c.Check(status["Version"], check.Not(check.Equals), "")
 }
 
-func (s *serverSuite) TestHealthDisabled(c *check.C) {
+func (s *serviceSuite) TestHealthDisabled(c *check.C) {
        s.cluster.ManagementToken = ""
-
-       go s.srv.Run()
-       defer s.srv.Close()
-       s.srv.WaitReady()
-
+       s.start()
        for _, token := range []string{"", "foo", arvadostest.ManagementToken} {
-               req, err := http.NewRequest("GET", "http://"+s.srv.listener.Addr().String()+"/_health/ping", nil)
+               req, err := http.NewRequest("GET", s.srv.URL+"/_health/ping", nil)
                c.Assert(err, check.IsNil)
                req.Header.Add("Authorization", "Bearer "+token)
                resp, err := http.DefaultClient.Do(req)
@@ -139,7 +133,7 @@ func (s *serverSuite) TestHealthDisabled(c *check.C) {
        }
 }
 
-func (s *serverSuite) TestLoadLegacyConfig(c *check.C) {
+func (s *serviceSuite) TestLoadLegacyConfig(c *check.C) {
        content := []byte(`
 Client:
   APIHost: example.com
@@ -175,7 +169,14 @@ ManagementToken: qqqqq
                c.Error(err)
 
        }
-       cluster := configure(logger(nil), []string{"arvados-ws", "-config", tmpfile.Name()})
+       ldr := config.NewLoader(&bytes.Buffer{}, logrus.New())
+       flagset := flag.NewFlagSet("", flag.ContinueOnError)
+       ldr.SetupFlags(flagset)
+       flagset.Parse(ldr.MungeLegacyConfigArgs(ctxlog.TestLogger(c), []string{"-config", tmpfile.Name()}, "-legacy-ws-config"))
+       cfg, err := ldr.Load()
+       c.Check(err, check.IsNil)
+       cluster, err := cfg.GetCluster("")
+       c.Check(err, check.IsNil)
        c.Check(cluster, check.NotNil)
 
        c.Check(cluster.Services.Controller.ExternalURL, check.Equals, arvados.URL{Scheme: "https", Host: "example.com"})
index 53b02146d560fe3eb4d045227277d60a8c6e072b..c0cfbd6d02f6ff37083f426c85084effae45f212 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "database/sql"
index b0f40371ffeb0ba12c5d3d1e1326d320fb6dbb51..309352b39edbd329aa031ec0c6194791341acec9 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "database/sql"
@@ -14,6 +14,7 @@ import (
        "time"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/ctxlog"
        "github.com/sirupsen/logrus"
 )
 
@@ -59,7 +60,7 @@ func newSessionV0(ws wsConn, sendq chan<- interface{}, db *sql.DB, pc permChecke
                db:          db,
                ac:          ac,
                permChecker: pc,
-               log:         logger(ws.Request().Context()),
+               log:         ctxlog.FromContext(ws.Request().Context()),
        }
 
        err := ws.Request().ParseForm()
@@ -128,7 +129,7 @@ func (sess *v0session) EventMessage(e *event) ([]byte, error) {
        } else {
                permTarget = detail.ObjectUUID
        }
-       ok, err := sess.permChecker.Check(permTarget)
+       ok, err := sess.permChecker.Check(sess.ws.Request().Context(), permTarget)
        if err != nil || !ok {
                return nil, err
        }
index bd70b44459dd79b5f22b0c08074b2d4bf480d76f..45baaa334bc002786da4848d8dafcaa74a84adce 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "bytes"
@@ -11,6 +11,7 @@ import (
        "io"
        "net/url"
        "os"
+       "strings"
        "sync"
        "time"
 
@@ -30,17 +31,16 @@ func init() {
 var _ = check.Suite(&v0Suite{})
 
 type v0Suite struct {
-       serverSuite serverSuite
-       token       string
-       toDelete    []string
-       wg          sync.WaitGroup
-       ignoreLogID uint64
+       serviceSuite serviceSuite
+       token        string
+       toDelete     []string
+       wg           sync.WaitGroup
+       ignoreLogID  uint64
 }
 
 func (s *v0Suite) SetUpTest(c *check.C) {
-       s.serverSuite.SetUpTest(c)
-       go s.serverSuite.srv.Run()
-       s.serverSuite.srv.WaitReady()
+       s.serviceSuite.SetUpTest(c)
+       s.serviceSuite.start()
 
        s.token = arvadostest.ActiveToken
        s.ignoreLogID = s.lastLogID(c)
@@ -48,7 +48,7 @@ func (s *v0Suite) SetUpTest(c *check.C) {
 
 func (s *v0Suite) TearDownTest(c *check.C) {
        s.wg.Wait()
-       s.serverSuite.srv.Close()
+       s.serviceSuite.TearDownTest(c)
 }
 
 func (s *v0Suite) TearDownSuite(c *check.C) {
@@ -353,8 +353,8 @@ func (s *v0Suite) expectLog(c *check.C, r *json.Decoder) *arvados.Log {
 }
 
 func (s *v0Suite) testClient() (*websocket.Conn, *json.Decoder, *json.Encoder) {
-       srv := s.serverSuite.srv
-       conn, err := websocket.Dial("ws://"+srv.listener.Addr().String()+"/websocket?api_token="+s.token, "", "http://"+srv.listener.Addr().String())
+       srv := s.serviceSuite.srv
+       conn, err := websocket.Dial(strings.Replace(srv.URL, "http", "ws", 1)+"/websocket?api_token="+s.token, "", srv.URL)
        if err != nil {
                panic(err)
        }
index 58f77df430201f79e71f66209711a740dff8a016..60b980d58e2f8f8a9acc67362deb7d7beff21350 100644 (file)
@@ -2,7 +2,7 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package ws
 
 import (
        "database/sql"