-package main
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package ws
import (
- "encoding/json"
+ "context"
"io"
- "log"
- "net/http"
+ "sync"
"time"
- "git.curoverse.com/arvados.git/sdk/go/arvados"
+ "git.arvados.org/arvados.git/sdk/go/arvados"
+ "git.arvados.org/arvados.git/sdk/go/stats"
+ "github.com/sirupsen/logrus"
)
-type wsConn interface {
- io.ReadWriter
- Request() *http.Request
- SetReadDeadline(time.Time) error
- SetWriteDeadline(time.Time) error
-}
-
type handler struct {
Client arvados.Client
PingTimeout time.Duration
QueueSize int
- NewSession func(wsConn, arvados.Client) (session, error)
+
+ mtx sync.Mutex
+ lastDelay map[chan interface{}]stats.Duration
+ setupOnce sync.Once
}
-func (h *handler) Handle(ws wsConn, events <-chan *event) {
- sess, err := h.NewSession(ws, h.Client)
+type handlerStats struct {
+ QueueDelayNs time.Duration
+ WriteDelayNs time.Duration
+ EventBytes uint64
+ EventCount uint64
+}
+
+func (h *handler) Handle(ws wsConn, logger logrus.FieldLogger, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
+ h.setupOnce.Do(h.setup)
+
+ ctx, cancel := context.WithCancel(ws.Request().Context())
+ defer cancel()
+
+ incoming := eventSource.NewSink()
+ defer incoming.Stop()
+
+ queue := make(chan interface{}, h.QueueSize)
+ h.mtx.Lock()
+ h.lastDelay[queue] = 0
+ h.mtx.Unlock()
+ defer func() {
+ h.mtx.Lock()
+ delete(h.lastDelay, queue)
+ h.mtx.Unlock()
+ }()
+
+ sess, err := newSession(ws, queue)
if err != nil {
- log.Printf("%s NewSession: %s", ws.Request().RemoteAddr, err)
+ logger.WithError(err).Error("newSession failed")
return
}
- queue := make(chan *event, h.QueueSize)
-
- stopped := make(chan struct{})
- stop := make(chan error, 5)
-
+ // Receive websocket frames from the client and pass them to
+ // sess.Receive().
go func() {
+ defer cancel()
buf := make([]byte, 2<<20)
for {
select {
- case <-stopped:
+ case <-ctx.Done():
return
default:
}
ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour))
n, err := ws.Read(buf)
- sess.debugLogf("received frame: %q", buf[:n])
- if err == nil && n == len(buf) {
+ buf := buf[:n]
+ logger.WithField("frame", string(buf[:n])).Debug("received frame")
+ if err == nil && n == cap(buf) {
err = errFrameTooBig
}
if err != nil {
- if err != io.EOF {
- sess.debugLogf("handler: read: %s", err)
+ if err != io.EOF && ctx.Err() == nil {
+ logger.WithError(err).Info("read error")
}
- stop <- err
return
}
- msg := make(map[string]interface{})
- err = json.Unmarshal(buf[:n], &msg)
+ err = sess.Receive(buf)
if err != nil {
- sess.debugLogf("handler: unmarshal: %s", err)
- stop <- err
+ logger.WithError(err).Error("sess.Receive() failed")
return
}
- sess.Receive(msg, buf[:n])
}
}()
+ // Take items from the outgoing queue, serialize them using
+ // sess.EventMessage() as needed, and send them to the client
+ // as websocket frames.
go func() {
- for e := range queue {
- if e == nil {
- ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
- _, err := ws.Write([]byte("{}"))
- if err != nil {
- sess.debugLogf("handler: write {}: %s", err)
- stop <- err
- break
+ defer cancel()
+ for {
+ var ok bool
+ var data interface{}
+ select {
+ case <-ctx.Done():
+ return
+ case data, ok = <-queue:
+ if !ok {
+ return
}
- continue
}
+ var e *event
+ var buf []byte
+ var err error
+ logger := logger
- buf, err := sess.EventMessage(e)
- if err != nil {
- sess.debugLogf("EventMessage %d: err %s", err)
- stop <- err
- break
- } else if len(buf) == 0 {
- sess.debugLogf("EventMessage %d: skip", e.Serial)
+ switch data := data.(type) {
+ case []byte:
+ buf = data
+ case *event:
+ e = data
+ logger = logger.WithField("serial", e.Serial)
+ buf, err = sess.EventMessage(e)
+ if err != nil {
+ logger.WithError(err).Error("EventMessage failed")
+ return
+ } else if len(buf) == 0 {
+ logger.Debug("skip")
+ continue
+ }
+ default:
+ logger.WithField("data", data).Error("bad object in client queue")
continue
}
- sess.debugLogf("handler: send event %d: %q", e.Serial, buf)
+ logger.WithField("frame", string(buf)).Debug("send event")
ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
+ t0 := time.Now()
_, err = ws.Write(buf)
if err != nil {
- sess.debugLogf("handler: write: %s", err)
- stop <- err
- break
+ if ctx.Err() == nil {
+ logger.WithError(err).Error("write failed")
+ }
+ return
}
- sess.debugLogf("handler: sent event %d", e.Serial)
- }
- for _ = range queue {
+ logger.Debug("sent")
+
+ if e != nil {
+ hStats.QueueDelayNs += t0.Sub(e.Ready)
+ h.mtx.Lock()
+ h.lastDelay[queue] = stats.Duration(time.Since(e.Ready))
+ h.mtx.Unlock()
+ }
+ hStats.WriteDelayNs += time.Since(t0)
+ hStats.EventBytes += uint64(len(buf))
+ hStats.EventCount++
}
}()
// Filter incoming events against the current subscription
// list, and forward matching events to the outgoing message
- // queue. Close the queue and return when the "stopped"
- // channel closes or the incoming event stream ends. Shut down
- // the handler if the outgoing queue fills up.
+ // queue. Close the queue and return when the request context
+ // is done/cancelled or the incoming event stream ends. Shut
+ // down the handler if the outgoing queue fills up.
go func() {
- send := func(e *event) {
- select {
- case queue <- e:
- default:
- stop <- errQueueFull
- }
- }
-
+ defer cancel()
ticker := time.NewTicker(h.PingTimeout)
defer ticker.Stop()
for {
- var e *event
- var ok bool
select {
- case <-stopped:
- close(queue)
+ case <-ctx.Done():
return
case <-ticker.C:
// If the outgoing queue is empty,
// socket, and prevent an idle socket
// from being closed.
if len(queue) == 0 {
- queue <- nil
+ select {
+ case queue <- []byte(`{}`):
+ default:
+ }
}
- continue
- case e, ok = <-events:
+ case e, ok := <-incoming.Channel():
if !ok {
- close(queue)
return
}
- }
- if sess.Filter(e) {
- send(e)
+ if !sess.Filter(e) {
+ continue
+ }
+ select {
+ case queue <- e:
+ default:
+ logger.WithError(errQueueFull).Error("terminate")
+ return
+ }
}
}
}()
- <-stop
- close(stopped)
+ <-ctx.Done()
+ return
+}
+
+func (h *handler) DebugStatus() interface{} {
+ h.mtx.Lock()
+ defer h.mtx.Unlock()
+
+ var s struct {
+ QueueCount int
+ QueueMin int
+ QueueMax int
+ QueueTotal uint64
+ QueueDelayMin stats.Duration
+ QueueDelayMax stats.Duration
+ }
+ for q, lastDelay := range h.lastDelay {
+ s.QueueCount++
+ n := len(q)
+ s.QueueTotal += uint64(n)
+ if s.QueueMax < n {
+ s.QueueMax = n
+ }
+ if s.QueueMin > n || s.QueueCount == 1 {
+ s.QueueMin = n
+ }
+ if (s.QueueDelayMin > lastDelay || s.QueueDelayMin == 0) && lastDelay > 0 {
+ s.QueueDelayMin = lastDelay
+ }
+ if s.QueueDelayMax < lastDelay {
+ s.QueueDelayMax = lastDelay
+ }
+ }
+ return &s
+}
+
+func (h *handler) setup() {
+ h.lastDelay = make(map[chan interface{}]stats.Duration)
}