From: Tom Clegg Date: Wed, 25 Mar 2020 15:00:03 +0000 (-0400) Subject: 16217: Refactor ws to use lib/service. X-Git-Tag: 2.1.0~260^2~9 X-Git-Url: https://git.arvados.org/arvados.git/commitdiff_plain/ba418300c50e1375ca9938562579b7bd6bf9490d 16217: Refactor ws to use lib/service. Arvados-DCO-1.1-Signed-off-by: Tom Clegg --- diff --git a/build/run-build-packages.sh b/build/run-build-packages.sh index 4faa1c6b0d..3ba1dcc05e 100755 --- a/build/run-build-packages.sh +++ b/build/run-build-packages.sh @@ -308,7 +308,7 @@ package_go_binary services/keepstore keepstore \ "Keep storage daemon, accessible to clients on the LAN" package_go_binary services/keep-web keep-web \ "Static web hosting service for user data stored in Arvados Keep" -package_go_binary services/ws arvados-ws \ +package_go_binary cmd/arvados-server arvados-ws \ "Arvados Websocket server" package_go_binary tools/sync-groups arvados-sync-groups \ "Synchronize remote groups into Arvados from an external source" diff --git a/services/ws/arvados-ws.service b/cmd/arvados-server/arvados-ws.service similarity index 94% rename from services/ws/arvados-ws.service rename to cmd/arvados-server/arvados-ws.service index 36624c7877..aebc56a79f 100644 --- a/services/ws/arvados-ws.service +++ b/cmd/arvados-server/arvados-ws.service @@ -6,6 +6,7 @@ Description=Arvados websocket server Documentation=https://doc.arvados.org/ After=network.target +AssertPathExists=/etc/arvados/config.yml # systemd==229 (ubuntu:xenial) obeys StartLimitInterval in the [Unit] section StartLimitInterval=0 diff --git a/cmd/arvados-server/cmd.go b/cmd/arvados-server/cmd.go index a9d927d873..80d43ad848 100644 --- a/cmd/arvados-server/cmd.go +++ b/cmd/arvados-server/cmd.go @@ -14,6 +14,7 @@ import ( "git.arvados.org/arvados.git/lib/controller" "git.arvados.org/arvados.git/lib/crunchrun" "git.arvados.org/arvados.git/lib/dispatchcloud" + "git.arvados.org/arvados.git/services/ws" ) var ( @@ -30,6 +31,7 @@ var ( "controller": controller.Command, "crunch-run": crunchrun.Command, "dispatch-cloud": dispatchcloud.Command, + "ws": ws.Command, }) ) diff --git a/go.mod b/go.mod index 2cc5e89eb1..48b1c725a5 100644 --- a/go.mod +++ b/go.mod @@ -52,7 +52,7 @@ require ( golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 golang.org/x/net v0.0.0-20190620200207-3b0461eec859 golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 - golang.org/x/sys v0.0.0-20191105231009-c1f44814a5cd // indirect + golang.org/x/sys v0.0.0-20191105231009-c1f44814a5cd google.golang.org/api v0.13.0 gopkg.in/check.v1 v1.0.0-20161208181325-20d25e280405 gopkg.in/square/go-jose.v2 v2.3.1 diff --git a/services/ws/doc.go b/services/ws/doc.go index 806c3355da..6a86cbe7a8 100644 --- a/services/ws/doc.go +++ b/services/ws/doc.go @@ -13,36 +13,4 @@ // Developer info // // See https://dev.arvados.org/projects/arvados/wiki/Hacking_websocket_server. -// -// Usage -// -// arvados-ws [-legacy-ws-config /etc/arvados/ws/ws.yml] [-dump-config] -// -// Options -// -// -legacy-ws-config path -// -// Load legacy configuration from the given file instead of the default -// /etc/arvados/ws/ws.yml, legacy config overrides the clusterwide config.yml. -// -// -dump-config -// -// Print the loaded configuration to stdout and exit. -// -// Logs -// -// Logs are printed to stderr, formatted as JSON. -// -// A log is printed each time a client connects or disconnects. -// -// Enable additional logs by configuring: -// -// LogLevel: debug -// -// Runtime status -// -// GET /debug.json responds with debug stats. -// -// GET /status.json responds with health check results and -// activity/usage metrics. -package main +package ws diff --git a/services/ws/event.go b/services/ws/event.go index ae545c092c..c989c0ca55 100644 --- a/services/ws/event.go +++ b/services/ws/event.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "database/sql" @@ -11,6 +11,7 @@ import ( "git.arvados.org/arvados.git/sdk/go/arvados" "github.com/ghodss/yaml" + "github.com/sirupsen/logrus" ) type eventSink interface { @@ -31,6 +32,7 @@ type event struct { Serial uint64 db *sql.DB + logger logrus.FieldLogger logRow *arvados.Log err error mtx sync.Mutex @@ -57,12 +59,12 @@ func (e *event) Detail() *arvados.Log { &logRow.CreatedAt, &propYAML) if e.err != nil { - logger(nil).WithField("LogID", e.LogID).WithError(e.err).Error("QueryRow failed") + e.logger.WithField("LogID", e.LogID).WithError(e.err).Error("QueryRow failed") return nil } e.err = yaml.Unmarshal(propYAML, &logRow.Properties) if e.err != nil { - logger(nil).WithField("LogID", e.LogID).WithError(e.err).Error("yaml decode failed") + e.logger.WithField("LogID", e.LogID).WithError(e.err).Error("yaml decode failed") return nil } e.logRow = &logRow diff --git a/services/ws/event_source.go b/services/ws/event_source.go index 3a82bf62b3..341464de50 100644 --- a/services/ws/event_source.go +++ b/services/ws/event_source.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "context" @@ -16,12 +16,14 @@ import ( "git.arvados.org/arvados.git/sdk/go/stats" "github.com/lib/pq" + "github.com/sirupsen/logrus" ) type pgEventSource struct { DataSource string MaxOpenConns int QueueSize int + Logger logrus.FieldLogger db *sql.DB pqListener *pq.Listener @@ -43,14 +45,14 @@ var _ debugStatuser = (*pgEventSource)(nil) func (ps *pgEventSource) listenerProblem(et pq.ListenerEventType, err error) { if et == pq.ListenerEventConnected { - logger(nil).Debug("pgEventSource connected") + ps.Logger.Debug("pgEventSource connected") return } // Until we have a mechanism for catching up on missed events, // we cannot recover from a dropped connection without // breaking our promises to clients. - logger(nil). + ps.Logger. WithField("eventType", et). WithError(err). Error("listener problem") @@ -76,8 +78,8 @@ func (ps *pgEventSource) WaitReady() { // Run listens for event notifications on the "logs" channel and sends // them to all subscribers. func (ps *pgEventSource) Run() { - logger(nil).Debug("pgEventSource Run starting") - defer logger(nil).Debug("pgEventSource Run finished") + ps.Logger.Debug("pgEventSource Run starting") + defer ps.Logger.Debug("pgEventSource Run finished") ps.setupOnce.Do(ps.setup) ready := ps.ready @@ -103,15 +105,15 @@ func (ps *pgEventSource) Run() { db, err := sql.Open("postgres", ps.DataSource) if err != nil { - logger(nil).WithError(err).Error("sql.Open failed") + ps.Logger.WithError(err).Error("sql.Open failed") return } if ps.MaxOpenConns <= 0 { - logger(nil).Warn("no database connection limit configured -- consider setting PostgresPool>0 in arvados-ws configuration file") + ps.Logger.Warn("no database connection limit configured -- consider setting PostgresPool>0 in arvados-ws configuration file") } db.SetMaxOpenConns(ps.MaxOpenConns) if err = db.Ping(); err != nil { - logger(nil).WithError(err).Error("db.Ping failed") + ps.Logger.WithError(err).Error("db.Ping failed") return } ps.db = db @@ -119,11 +121,11 @@ func (ps *pgEventSource) Run() { ps.pqListener = pq.NewListener(ps.DataSource, time.Second, time.Minute, ps.listenerProblem) err = ps.pqListener.Listen("logs") if err != nil { - logger(nil).WithError(err).Error("pq Listen failed") + ps.Logger.WithError(err).Error("pq Listen failed") return } defer ps.pqListener.Close() - logger(nil).Debug("pq Listen setup done") + ps.Logger.Debug("pq Listen setup done") close(ready) // Avoid double-close in deferred func @@ -141,7 +143,7 @@ func (ps *pgEventSource) Run() { // client_count X client_queue_size. e.Detail() - logger(nil). + ps.Logger. WithField("serial", e.Serial). WithField("detail", e.Detail()). Debug("event ready") @@ -163,11 +165,11 @@ func (ps *pgEventSource) Run() { for { select { case <-ctx.Done(): - logger(nil).Debug("ctx done") + ps.Logger.Debug("ctx done") return case <-ticker.C: - logger(nil).Debug("listener ping") + ps.Logger.Debug("listener ping") err := ps.pqListener.Ping() if err != nil { ps.listenerProblem(-1, fmt.Errorf("pqListener ping failed: %s", err)) @@ -176,7 +178,7 @@ func (ps *pgEventSource) Run() { case pqEvent, ok := <-ps.pqListener.Notify: if !ok { - logger(nil).Error("pqListener Notify chan closed") + ps.Logger.Error("pqListener Notify chan closed") return } if pqEvent == nil { @@ -188,12 +190,12 @@ func (ps *pgEventSource) Run() { continue } if pqEvent.Channel != "logs" { - logger(nil).WithField("pqEvent", pqEvent).Error("unexpected notify from wrong channel") + ps.Logger.WithField("pqEvent", pqEvent).Error("unexpected notify from wrong channel") continue } logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64) if err != nil { - logger(nil).WithField("pqEvent", pqEvent).Error("bad notify payload") + ps.Logger.WithField("pqEvent", pqEvent).Error("bad notify payload") continue } serial++ @@ -202,8 +204,9 @@ func (ps *pgEventSource) Run() { Received: time.Now(), Serial: serial, db: ps.db, + logger: ps.Logger, } - logger(nil).WithField("event", e).Debug("incoming") + ps.Logger.WithField("event", e).Debug("incoming") atomic.AddUint64(&ps.eventsIn, 1) ps.queue <- e go e.Detail() @@ -238,6 +241,9 @@ func (ps *pgEventSource) DB() *sql.DB { } func (ps *pgEventSource) DBHealth() error { + if ps.db == nil { + return errors.New("database not connected") + } ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) defer cancel() var i int diff --git a/services/ws/event_source_test.go b/services/ws/event_source_test.go index 98a9e8b978..dd40835b6e 100644 --- a/services/ws/event_source_test.go +++ b/services/ws/event_source_test.go @@ -2,17 +2,16 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "database/sql" "fmt" - "os" - "path/filepath" "sync" "time" "git.arvados.org/arvados.git/sdk/go/arvados" + "git.arvados.org/arvados.git/sdk/go/ctxlog" check "gopkg.in/check.v1" ) @@ -21,7 +20,7 @@ var _ = check.Suite(&eventSourceSuite{}) type eventSourceSuite struct{} func testDBConfig() arvados.PostgreSQLConnection { - cfg, err := arvados.GetConfig(filepath.Join(os.Getenv("WORKSPACE"), "tmp", "arvados.yml")) + cfg, err := arvados.GetConfig(arvados.DefaultConfigFile) if err != nil { panic(err) } @@ -46,6 +45,7 @@ func (*eventSourceSuite) TestEventSource(c *check.C) { pges := &pgEventSource{ DataSource: cfg.String(), QueueSize: 4, + Logger: ctxlog.TestLogger(c), } go pges.Run() sinks := make([]eventSink, 18) diff --git a/services/ws/event_test.go b/services/ws/event_test.go index dc324464ec..4665dfcd9e 100644 --- a/services/ws/event_test.go +++ b/services/ws/event_test.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import check "gopkg.in/check.v1" diff --git a/services/ws/gocheck_test.go b/services/ws/gocheck_test.go index ea8dfc30c9..df1ca7ab31 100644 --- a/services/ws/gocheck_test.go +++ b/services/ws/gocheck_test.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "testing" @@ -13,3 +13,7 @@ import ( func TestGocheck(t *testing.T) { check.TestingT(t) } + +func init() { + testMode = true +} diff --git a/services/ws/handler.go b/services/ws/handler.go index 913b1ee800..912643ad97 100644 --- a/services/ws/handler.go +++ b/services/ws/handler.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "context" @@ -12,6 +12,7 @@ import ( "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/stats" + "github.com/sirupsen/logrus" ) type handler struct { @@ -31,12 +32,11 @@ type handlerStats struct { EventCount uint64 } -func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) { +func (h *handler) Handle(ws wsConn, logger logrus.FieldLogger, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) { h.setupOnce.Do(h.setup) ctx, cancel := context.WithCancel(ws.Request().Context()) defer cancel() - log := logger(ctx) incoming := eventSource.NewSink() defer incoming.Stop() @@ -53,7 +53,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC sess, err := newSession(ws, queue) if err != nil { - log.WithError(err).Error("newSession failed") + logger.WithError(err).Error("newSession failed") return } @@ -71,19 +71,19 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour)) n, err := ws.Read(buf) buf := buf[:n] - log.WithField("frame", string(buf[:n])).Debug("received frame") + logger.WithField("frame", string(buf[:n])).Debug("received frame") if err == nil && n == cap(buf) { err = errFrameTooBig } if err != nil { if err != io.EOF && ctx.Err() == nil { - log.WithError(err).Info("read error") + logger.WithError(err).Info("read error") } return } err = sess.Receive(buf) if err != nil { - log.WithError(err).Error("sess.Receive() failed") + logger.WithError(err).Error("sess.Receive() failed") return } } @@ -108,38 +108,38 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC var e *event var buf []byte var err error - log := log + logger := logger switch data := data.(type) { case []byte: buf = data case *event: e = data - log = log.WithField("serial", e.Serial) + logger = logger.WithField("serial", e.Serial) buf, err = sess.EventMessage(e) if err != nil { - log.WithError(err).Error("EventMessage failed") + logger.WithError(err).Error("EventMessage failed") return } else if len(buf) == 0 { - log.Debug("skip") + logger.Debug("skip") continue } default: - log.WithField("data", data).Error("bad object in client queue") + logger.WithField("data", data).Error("bad object in client queue") continue } - log.WithField("frame", string(buf)).Debug("send event") + logger.WithField("frame", string(buf)).Debug("send event") ws.SetWriteDeadline(time.Now().Add(h.PingTimeout)) t0 := time.Now() _, err = ws.Write(buf) if err != nil { if ctx.Err() == nil { - log.WithError(err).Error("write failed") + logger.WithError(err).Error("write failed") } return } - log.Debug("sent") + logger.Debug("sent") if e != nil { hStats.QueueDelayNs += t0.Sub(e.Ready) @@ -189,7 +189,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC select { case queue <- e: default: - log.WithError(errQueueFull).Error("terminate") + logger.WithError(errQueueFull).Error("terminate") return } } diff --git a/services/ws/main.go b/services/ws/main.go deleted file mode 100644 index 5b42c44210..0000000000 --- a/services/ws/main.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (C) The Arvados Authors. All rights reserved. -// -// SPDX-License-Identifier: AGPL-3.0 - -package main - -import ( - "flag" - "fmt" - "os" - - "git.arvados.org/arvados.git/lib/config" - "git.arvados.org/arvados.git/sdk/go/arvados" - "git.arvados.org/arvados.git/sdk/go/ctxlog" - "github.com/ghodss/yaml" - "github.com/sirupsen/logrus" -) - -var logger = ctxlog.FromContext -var version = "dev" - -func configure(log logrus.FieldLogger, args []string) *arvados.Cluster { - flags := flag.NewFlagSet(args[0], flag.ExitOnError) - dumpConfig := flags.Bool("dump-config", false, "show current configuration and exit") - getVersion := flags.Bool("version", false, "Print version information and exit.") - - loader := config.NewLoader(nil, log) - loader.SetupFlags(flags) - args = loader.MungeLegacyConfigArgs(log, args[1:], "-legacy-ws-config") - - flags.Parse(args) - - // Print version information if requested - if *getVersion { - fmt.Printf("arvados-ws %s\n", version) - return nil - } - - cfg, err := loader.Load() - if err != nil { - log.Fatal(err) - } - - cluster, err := cfg.GetCluster("") - if err != nil { - log.Fatal(err) - } - - ctxlog.SetLevel(cluster.SystemLogs.LogLevel) - ctxlog.SetFormat(cluster.SystemLogs.Format) - - if *dumpConfig { - out, err := yaml.Marshal(cfg) - if err != nil { - log.Fatal(err) - } - _, err = os.Stdout.Write(out) - if err != nil { - log.Fatal(err) - } - return nil - } - return cluster -} - -func main() { - log := logger(nil) - - cluster := configure(log, os.Args) - if cluster == nil { - return - } - - log.Printf("arvados-ws %s started", version) - srv := &server{cluster: cluster} - log.Fatal(srv.Run()) -} diff --git a/services/ws/permission.go b/services/ws/permission.go index 745d28f952..ac895f80e5 100644 --- a/services/ws/permission.go +++ b/services/ws/permission.go @@ -2,14 +2,16 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( + "context" "net/http" "net/url" "time" "git.arvados.org/arvados.git/sdk/go/arvados" + "git.arvados.org/arvados.git/sdk/go/ctxlog" ) const ( @@ -19,7 +21,7 @@ const ( type permChecker interface { SetToken(token string) - Check(uuid string) (bool, error) + Check(ctx context.Context, uuid string) (bool, error) } func newPermChecker(ac arvados.Client) permChecker { @@ -54,9 +56,9 @@ func (pc *cachingPermChecker) SetToken(token string) { pc.cache = make(map[string]cacheEnt) } -func (pc *cachingPermChecker) Check(uuid string) (bool, error) { +func (pc *cachingPermChecker) Check(ctx context.Context, uuid string) (bool, error) { pc.nChecks++ - logger := logger(nil). + logger := ctxlog.FromContext(ctx). WithField("token", pc.Client.AuthToken). WithField("uuid", uuid) pc.tidy() diff --git a/services/ws/permission_test.go b/services/ws/permission_test.go index 5f972551ff..023656c01f 100644 --- a/services/ws/permission_test.go +++ b/services/ws/permission_test.go @@ -2,9 +2,11 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( + "context" + "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/arvadostest" check "gopkg.in/check.v1" @@ -22,19 +24,19 @@ func (s *permSuite) TestCheck(c *check.C) { } wantError := func(uuid string) { c.Log(uuid) - ok, err := pc.Check(uuid) + ok, err := pc.Check(context.Background(), uuid) c.Check(ok, check.Equals, false) c.Check(err, check.NotNil) } wantYes := func(uuid string) { c.Log(uuid) - ok, err := pc.Check(uuid) + ok, err := pc.Check(context.Background(), uuid) c.Check(ok, check.Equals, true) c.Check(err, check.IsNil) } wantNo := func(uuid string) { c.Log(uuid) - ok, err := pc.Check(uuid) + ok, err := pc.Check(context.Background(), uuid) c.Check(ok, check.Equals, false) c.Check(err, check.IsNil) } @@ -67,7 +69,7 @@ func (s *permSuite) TestCheck(c *check.C) { pc.SetToken(arvadostest.ActiveToken) c.Log("...network error") - pc.Client.APIHost = "127.0.0.1:discard" + pc.Client.APIHost = "127.0.0.1:9" wantError(arvadostest.UserAgreementCollection) wantError(arvadostest.FooBarDirCollection) diff --git a/services/ws/router.go b/services/ws/router.go index f8c273c514..b1764c156c 100644 --- a/services/ws/router.go +++ b/services/ws/router.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "encoding/json" @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "git.arvados.org/arvados.git/lib/cmd" "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/ctxlog" "git.arvados.org/arvados.git/sdk/go/health" @@ -28,7 +29,7 @@ type wsConn interface { } type router struct { - client arvados.Client + client *arvados.Client cluster *arvados.Cluster eventSource eventSource newPermChecker func() permChecker @@ -71,7 +72,7 @@ func (rtr *router) setup() { }, Log: func(r *http.Request, err error) { if err != nil { - logger(r.Context()).WithError(err).Error("error") + ctxlog.FromContext(r.Context()).WithError(err).Error("error") } }, }) @@ -84,15 +85,15 @@ func (rtr *router) makeServer(newSession sessionFactory) *websocket.Server { }, Handler: websocket.Handler(func(ws *websocket.Conn) { t0 := time.Now() - log := logger(ws.Request().Context()) - log.Info("connected") + logger := ctxlog.FromContext(ws.Request().Context()) + logger.Info("connected") - stats := rtr.handler.Handle(ws, rtr.eventSource, + stats := rtr.handler.Handle(ws, logger, rtr.eventSource, func(ws wsConn, sendq chan<- interface{}) (session, error) { - return newSession(ws, sendq, rtr.eventSource.DB(), rtr.newPermChecker(), &rtr.client) + return newSession(ws, sendq, rtr.eventSource.DB(), rtr.newPermChecker(), rtr.client) }) - log.WithFields(logrus.Fields{ + logger.WithFields(logrus.Fields{ "elapsed": time.Now().Sub(t0).Seconds(), "stats": stats, }).Info("disconnect") @@ -125,7 +126,7 @@ func (rtr *router) DebugStatus() interface{} { func (rtr *router) Status() interface{} { return map[string]interface{}{ "Clients": atomic.LoadInt64(&rtr.status.ReqsActive), - "Version": version, + "Version": cmd.Version.String(), } } @@ -135,7 +136,7 @@ func (rtr *router) ServeHTTP(resp http.ResponseWriter, req *http.Request) { atomic.AddInt64(&rtr.status.ReqsActive, 1) defer atomic.AddInt64(&rtr.status.ReqsActive, -1) - logger := logger(req.Context()). + logger := ctxlog.FromContext(req.Context()). WithField("RequestID", rtr.newReqID()) ctx := ctxlog.Context(req.Context(), logger) req = req.WithContext(ctx) @@ -148,7 +149,7 @@ func (rtr *router) ServeHTTP(resp http.ResponseWriter, req *http.Request) { func (rtr *router) jsonHandler(fn func() interface{}) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger := logger(r.Context()) + logger := ctxlog.FromContext(r.Context()) w.Header().Set("Content-Type", "application/json") enc := json.NewEncoder(w) err := enc.Encode(fn()) @@ -159,3 +160,8 @@ func (rtr *router) jsonHandler(fn func() interface{}) http.Handler { } }) } + +func (rtr *router) CheckHealth() error { + rtr.setupOnce.Do(rtr.setup) + return rtr.eventSource.DBHealth() +} diff --git a/services/ws/server.go b/services/ws/server.go deleted file mode 100644 index 9747ea1b85..0000000000 --- a/services/ws/server.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (C) The Arvados Authors. All rights reserved. -// -// SPDX-License-Identifier: AGPL-3.0 - -package main - -import ( - "net" - "net/http" - "sync" - "time" - - "git.arvados.org/arvados.git/sdk/go/arvados" - "github.com/coreos/go-systemd/daemon" -) - -type server struct { - httpServer *http.Server - listener net.Listener - cluster *arvados.Cluster - eventSource *pgEventSource - setupOnce sync.Once -} - -func (srv *server) Close() { - srv.WaitReady() - srv.eventSource.Close() - srv.httpServer.Close() - srv.listener.Close() -} - -func (srv *server) WaitReady() { - srv.setupOnce.Do(srv.setup) - srv.eventSource.WaitReady() -} - -func (srv *server) Run() error { - srv.setupOnce.Do(srv.setup) - return srv.httpServer.Serve(srv.listener) -} - -func (srv *server) setup() { - log := logger(nil) - - var listen arvados.URL - for listen, _ = range srv.cluster.Services.Websocket.InternalURLs { - break - } - ln, err := net.Listen("tcp", listen.Host) - if err != nil { - log.WithField("Listen", listen).Fatal(err) - } - log.WithField("Listen", ln.Addr().String()).Info("listening") - - client := arvados.Client{} - client.APIHost = srv.cluster.Services.Controller.ExternalURL.Host - client.AuthToken = srv.cluster.SystemRootToken - client.Insecure = srv.cluster.TLS.Insecure - - srv.listener = ln - srv.eventSource = &pgEventSource{ - DataSource: srv.cluster.PostgreSQL.Connection.String(), - MaxOpenConns: srv.cluster.PostgreSQL.ConnectionPool, - QueueSize: srv.cluster.API.WebsocketServerEventQueue, - } - - srv.httpServer = &http.Server{ - Addr: listen.Host, - ReadTimeout: time.Minute, - WriteTimeout: time.Minute, - MaxHeaderBytes: 1 << 20, - Handler: &router{ - cluster: srv.cluster, - client: client, - eventSource: srv.eventSource, - newPermChecker: func() permChecker { return newPermChecker(client) }, - }, - } - - go func() { - srv.eventSource.Run() - log.Info("event source stopped") - srv.Close() - }() - - if _, err := daemon.SdNotify(false, "READY=1"); err != nil { - log.WithError(err).Warn("error notifying init daemon") - } -} diff --git a/services/ws/service.go b/services/ws/service.go new file mode 100644 index 0000000000..fb313bb799 --- /dev/null +++ b/services/ws/service.go @@ -0,0 +1,52 @@ +// Copyright (C) The Arvados Authors. All rights reserved. +// +// SPDX-License-Identifier: AGPL-3.0 + +package ws + +import ( + "context" + "fmt" + "os" + + "git.arvados.org/arvados.git/lib/cmd" + "git.arvados.org/arvados.git/lib/service" + "git.arvados.org/arvados.git/sdk/go/arvados" + "git.arvados.org/arvados.git/sdk/go/ctxlog" + "github.com/prometheus/client_golang/prometheus" +) + +var testMode = false + +var Command cmd.Handler = service.Command(arvados.ServiceNameWebsocket, newHandler) + +func newHandler(ctx context.Context, cluster *arvados.Cluster, token string, reg *prometheus.Registry) service.Handler { + client, err := arvados.NewClientFromConfig(cluster) + if err != nil { + return service.ErrorHandler(ctx, cluster, fmt.Errorf("error initializing client from cluster config: %s", err)) + } + eventSource := &pgEventSource{ + DataSource: cluster.PostgreSQL.Connection.String(), + MaxOpenConns: cluster.PostgreSQL.ConnectionPool, + QueueSize: cluster.API.WebsocketServerEventQueue, + Logger: ctxlog.FromContext(ctx), + } + go func() { + eventSource.Run() + ctxlog.FromContext(ctx).Error("event source stopped") + if !testMode { + os.Exit(1) + } + }() + eventSource.WaitReady() + if err := eventSource.DBHealth(); err != nil { + return service.ErrorHandler(ctx, cluster, err) + } + rtr := &router{ + cluster: cluster, + client: client, + eventSource: eventSource, + newPermChecker: func() permChecker { return newPermChecker(*client) }, + } + return rtr +} diff --git a/services/ws/server_test.go b/services/ws/service_test.go similarity index 68% rename from services/ws/server_test.go rename to services/ws/service_test.go index 88279ec9b2..1afd8e0064 100644 --- a/services/ws/server_test.go +++ b/services/ws/service_test.go @@ -2,39 +2,57 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( + "bytes" + "context" "encoding/json" + "flag" "io/ioutil" "net/http" + "net/http/httptest" "os" "sync" "time" "git.arvados.org/arvados.git/lib/config" + "git.arvados.org/arvados.git/lib/service" "git.arvados.org/arvados.git/sdk/go/arvados" "git.arvados.org/arvados.git/sdk/go/arvadostest" "git.arvados.org/arvados.git/sdk/go/ctxlog" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" check "gopkg.in/check.v1" ) -var _ = check.Suite(&serverSuite{}) +var _ = check.Suite(&serviceSuite{}) -type serverSuite struct { +type serviceSuite struct { + handler service.Handler + srv *httptest.Server cluster *arvados.Cluster - srv *server wg sync.WaitGroup } -func (s *serverSuite) SetUpTest(c *check.C) { +func (s *serviceSuite) SetUpTest(c *check.C) { var err error s.cluster, err = s.testConfig(c) c.Assert(err, check.IsNil) - s.srv = &server{cluster: s.cluster} } -func (*serverSuite) testConfig(c *check.C) (*arvados.Cluster, error) { +func (s *serviceSuite) start() { + s.handler = newHandler(context.Background(), s.cluster, "", prometheus.NewRegistry()) + s.srv = httptest.NewServer(s.handler) +} + +func (s *serviceSuite) TearDownTest(c *check.C) { + if s.srv != nil { + s.srv.Close() + } +} + +func (*serviceSuite) testConfig(c *check.C) (*arvados.Cluster, error) { ldr := config.NewLoader(nil, ctxlog.TestLogger(c)) cfg, err := ldr.Load() if err != nil { @@ -54,42 +72,24 @@ func (*serverSuite) testConfig(c *check.C) (*arvados.Cluster, error) { return cluster, nil } -// TestBadDB ensures Run() returns an error (instead of panicking or -// deadlocking) if it can't connect to the database server at startup. -func (s *serverSuite) TestBadDB(c *check.C) { +// TestBadDB ensures the server returns an error (instead of panicking +// or deadlocking) if it can't connect to the database server at +// startup. +func (s *serviceSuite) TestBadDB(c *check.C) { s.cluster.PostgreSQL.Connection["password"] = "1234" - - var wg sync.WaitGroup - wg.Add(1) - go func() { - err := s.srv.Run() - c.Check(err, check.NotNil) - wg.Done() - }() - wg.Add(1) - go func() { - s.srv.WaitReady() - wg.Done() - }() - - done := make(chan bool) - go func() { - wg.Wait() - close(done) - }() - select { - case <-done: - case <-time.After(10 * time.Second): - c.Fatal("timeout") - } + s.start() + resp, err := http.Get(s.srv.URL) + c.Check(err, check.IsNil) + c.Check(resp.StatusCode, check.Equals, http.StatusInternalServerError) + c.Check(s.handler.CheckHealth(), check.ErrorMatches, "database not connected") + c.Check(err, check.IsNil) + c.Check(resp.StatusCode, check.Equals, http.StatusInternalServerError) } -func (s *serverSuite) TestHealth(c *check.C) { - go s.srv.Run() - defer s.srv.Close() - s.srv.WaitReady() +func (s *serviceSuite) TestHealth(c *check.C) { + s.start() for _, token := range []string{"", "foo", s.cluster.ManagementToken} { - req, err := http.NewRequest("GET", "http://"+s.srv.listener.Addr().String()+"/_health/ping", nil) + req, err := http.NewRequest("GET", s.srv.URL+"/_health/ping", nil) c.Assert(err, check.IsNil) if token != "" { req.Header.Add("Authorization", "Bearer "+token) @@ -107,11 +107,9 @@ func (s *serverSuite) TestHealth(c *check.C) { } } -func (s *serverSuite) TestStatus(c *check.C) { - go s.srv.Run() - defer s.srv.Close() - s.srv.WaitReady() - req, err := http.NewRequest("GET", "http://"+s.srv.listener.Addr().String()+"/status.json", nil) +func (s *serviceSuite) TestStatus(c *check.C) { + s.start() + req, err := http.NewRequest("GET", s.srv.URL+"/status.json", nil) c.Assert(err, check.IsNil) resp, err := http.DefaultClient.Do(req) c.Check(err, check.IsNil) @@ -122,15 +120,11 @@ func (s *serverSuite) TestStatus(c *check.C) { c.Check(status["Version"], check.Not(check.Equals), "") } -func (s *serverSuite) TestHealthDisabled(c *check.C) { +func (s *serviceSuite) TestHealthDisabled(c *check.C) { s.cluster.ManagementToken = "" - - go s.srv.Run() - defer s.srv.Close() - s.srv.WaitReady() - + s.start() for _, token := range []string{"", "foo", arvadostest.ManagementToken} { - req, err := http.NewRequest("GET", "http://"+s.srv.listener.Addr().String()+"/_health/ping", nil) + req, err := http.NewRequest("GET", s.srv.URL+"/_health/ping", nil) c.Assert(err, check.IsNil) req.Header.Add("Authorization", "Bearer "+token) resp, err := http.DefaultClient.Do(req) @@ -139,7 +133,7 @@ func (s *serverSuite) TestHealthDisabled(c *check.C) { } } -func (s *serverSuite) TestLoadLegacyConfig(c *check.C) { +func (s *serviceSuite) TestLoadLegacyConfig(c *check.C) { content := []byte(` Client: APIHost: example.com @@ -175,7 +169,14 @@ ManagementToken: qqqqq c.Error(err) } - cluster := configure(logger(nil), []string{"arvados-ws", "-config", tmpfile.Name()}) + ldr := config.NewLoader(&bytes.Buffer{}, logrus.New()) + flagset := flag.NewFlagSet("", flag.ContinueOnError) + ldr.SetupFlags(flagset) + flagset.Parse(ldr.MungeLegacyConfigArgs(ctxlog.TestLogger(c), []string{"-config", tmpfile.Name()}, "-legacy-ws-config")) + cfg, err := ldr.Load() + c.Check(err, check.IsNil) + cluster, err := cfg.GetCluster("") + c.Check(err, check.IsNil) c.Check(cluster, check.NotNil) c.Check(cluster.Services.Controller.ExternalURL, check.Equals, arvados.URL{Scheme: "https", Host: "example.com"}) diff --git a/services/ws/session.go b/services/ws/session.go index 53b02146d5..c0cfbd6d02 100644 --- a/services/ws/session.go +++ b/services/ws/session.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "database/sql" diff --git a/services/ws/session_v0.go b/services/ws/session_v0.go index b0f40371ff..309352b39e 100644 --- a/services/ws/session_v0.go +++ b/services/ws/session_v0.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "database/sql" @@ -14,6 +14,7 @@ import ( "time" "git.arvados.org/arvados.git/sdk/go/arvados" + "git.arvados.org/arvados.git/sdk/go/ctxlog" "github.com/sirupsen/logrus" ) @@ -59,7 +60,7 @@ func newSessionV0(ws wsConn, sendq chan<- interface{}, db *sql.DB, pc permChecke db: db, ac: ac, permChecker: pc, - log: logger(ws.Request().Context()), + log: ctxlog.FromContext(ws.Request().Context()), } err := ws.Request().ParseForm() @@ -128,7 +129,7 @@ func (sess *v0session) EventMessage(e *event) ([]byte, error) { } else { permTarget = detail.ObjectUUID } - ok, err := sess.permChecker.Check(permTarget) + ok, err := sess.permChecker.Check(sess.ws.Request().Context(), permTarget) if err != nil || !ok { return nil, err } diff --git a/services/ws/session_v0_test.go b/services/ws/session_v0_test.go index bd70b44459..45baaa334b 100644 --- a/services/ws/session_v0_test.go +++ b/services/ws/session_v0_test.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "bytes" @@ -11,6 +11,7 @@ import ( "io" "net/url" "os" + "strings" "sync" "time" @@ -30,17 +31,16 @@ func init() { var _ = check.Suite(&v0Suite{}) type v0Suite struct { - serverSuite serverSuite - token string - toDelete []string - wg sync.WaitGroup - ignoreLogID uint64 + serviceSuite serviceSuite + token string + toDelete []string + wg sync.WaitGroup + ignoreLogID uint64 } func (s *v0Suite) SetUpTest(c *check.C) { - s.serverSuite.SetUpTest(c) - go s.serverSuite.srv.Run() - s.serverSuite.srv.WaitReady() + s.serviceSuite.SetUpTest(c) + s.serviceSuite.start() s.token = arvadostest.ActiveToken s.ignoreLogID = s.lastLogID(c) @@ -48,7 +48,7 @@ func (s *v0Suite) SetUpTest(c *check.C) { func (s *v0Suite) TearDownTest(c *check.C) { s.wg.Wait() - s.serverSuite.srv.Close() + s.serviceSuite.TearDownTest(c) } func (s *v0Suite) TearDownSuite(c *check.C) { @@ -353,8 +353,8 @@ func (s *v0Suite) expectLog(c *check.C, r *json.Decoder) *arvados.Log { } func (s *v0Suite) testClient() (*websocket.Conn, *json.Decoder, *json.Encoder) { - srv := s.serverSuite.srv - conn, err := websocket.Dial("ws://"+srv.listener.Addr().String()+"/websocket?api_token="+s.token, "", "http://"+srv.listener.Addr().String()) + srv := s.serviceSuite.srv + conn, err := websocket.Dial(strings.Replace(srv.URL, "http", "ws", 1)+"/websocket?api_token="+s.token, "", srv.URL) if err != nil { panic(err) } diff --git a/services/ws/session_v1.go b/services/ws/session_v1.go index 58f77df430..60b980d58e 100644 --- a/services/ws/session_v1.go +++ b/services/ws/session_v1.go @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0 -package main +package ws import ( "database/sql"