Merge branch '16265-security-updates' into dependabot/bundler/apps/workbench/loofah...
[arvados.git] / services / ws / handler.go
index 1c9d5ba61de8e636bb1ddd855b9d37be4aeeba1c..913b1ee8000cbd274039483df70bad7896d52df5 100644 (file)
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
-       "encoding/json"
+       "context"
        "io"
-       "log"
-       "net/http"
+       "sync"
        "time"
 
-       "git.curoverse.com/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/stats"
 )
 
-type wsConn interface {
-       io.ReadWriter
-       Request() *http.Request
-       SetReadDeadline(time.Time) error
-       SetWriteDeadline(time.Time) error
-}
-
 type handler struct {
        Client      arvados.Client
        PingTimeout time.Duration
        QueueSize   int
-       NewSession  func(wsConn, arvados.Client) (session, error)
+
+       mtx       sync.Mutex
+       lastDelay map[chan interface{}]stats.Duration
+       setupOnce sync.Once
 }
 
-func (h *handler) Handle(ws wsConn, events <-chan *event) {
-       sess, err := h.NewSession(ws, h.Client)
+type handlerStats struct {
+       QueueDelayNs time.Duration
+       WriteDelayNs time.Duration
+       EventBytes   uint64
+       EventCount   uint64
+}
+
+func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
+       h.setupOnce.Do(h.setup)
+
+       ctx, cancel := context.WithCancel(ws.Request().Context())
+       defer cancel()
+       log := logger(ctx)
+
+       incoming := eventSource.NewSink()
+       defer incoming.Stop()
+
+       queue := make(chan interface{}, h.QueueSize)
+       h.mtx.Lock()
+       h.lastDelay[queue] = 0
+       h.mtx.Unlock()
+       defer func() {
+               h.mtx.Lock()
+               delete(h.lastDelay, queue)
+               h.mtx.Unlock()
+       }()
+
+       sess, err := newSession(ws, queue)
        if err != nil {
-               log.Printf("%s NewSession: %s", ws.Request().RemoteAddr, err)
+               log.WithError(err).Error("newSession failed")
                return
        }
 
-       queue := make(chan *event, h.QueueSize)
-
-       stopped := make(chan struct{})
-       stop := make(chan error, 5)
-
+       // Receive websocket frames from the client and pass them to
+       // sess.Receive().
        go func() {
+               defer cancel()
                buf := make([]byte, 2<<20)
                for {
                        select {
-                       case <-stopped:
+                       case <-ctx.Done():
                                return
                        default:
                        }
                        ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour))
                        n, err := ws.Read(buf)
-                       sess.debugLogf("received frame: %q", buf[:n])
-                       if err == nil && n == len(buf) {
+                       buf := buf[:n]
+                       log.WithField("frame", string(buf[:n])).Debug("received frame")
+                       if err == nil && n == cap(buf) {
                                err = errFrameTooBig
                        }
                        if err != nil {
-                               if err != io.EOF {
-                                       sess.debugLogf("handler: read: %s", err)
+                               if err != io.EOF && ctx.Err() == nil {
+                                       log.WithError(err).Info("read error")
                                }
-                               stop <- err
                                return
                        }
-                       msg := make(map[string]interface{})
-                       err = json.Unmarshal(buf[:n], &msg)
+                       err = sess.Receive(buf)
                        if err != nil {
-                               sess.debugLogf("handler: unmarshal: %s", err)
-                               stop <- err
+                               log.WithError(err).Error("sess.Receive() failed")
                                return
                        }
-                       sess.Receive(msg, buf[:n])
                }
        }()
 
+       // Take items from the outgoing queue, serialize them using
+       // sess.EventMessage() as needed, and send them to the client
+       // as websocket frames.
        go func() {
-               for e := range queue {
-                       if e == nil {
-                               ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
-                               _, err := ws.Write([]byte("{}"))
-                               if err != nil {
-                                       sess.debugLogf("handler: write {}: %s", err)
-                                       stop <- err
-                                       break
+               defer cancel()
+               for {
+                       var ok bool
+                       var data interface{}
+                       select {
+                       case <-ctx.Done():
+                               return
+                       case data, ok = <-queue:
+                               if !ok {
+                                       return
                                }
-                               continue
                        }
+                       var e *event
+                       var buf []byte
+                       var err error
+                       log := log
 
-                       buf, err := sess.EventMessage(e)
-                       if err != nil {
-                               sess.debugLogf("EventMessage %d: err %s", err)
-                               stop <- err
-                               break
-                       } else if len(buf) == 0 {
-                               sess.debugLogf("EventMessage %d: skip", e.Serial)
+                       switch data := data.(type) {
+                       case []byte:
+                               buf = data
+                       case *event:
+                               e = data
+                               log = log.WithField("serial", e.Serial)
+                               buf, err = sess.EventMessage(e)
+                               if err != nil {
+                                       log.WithError(err).Error("EventMessage failed")
+                                       return
+                               } else if len(buf) == 0 {
+                                       log.Debug("skip")
+                                       continue
+                               }
+                       default:
+                               log.WithField("data", data).Error("bad object in client queue")
                                continue
                        }
 
-                       sess.debugLogf("handler: send event %d: %q", e.Serial, buf)
+                       log.WithField("frame", string(buf)).Debug("send event")
                        ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
+                       t0 := time.Now()
                        _, err = ws.Write(buf)
                        if err != nil {
-                               sess.debugLogf("handler: write: %s", err)
-                               stop <- err
-                               break
+                               if ctx.Err() == nil {
+                                       log.WithError(err).Error("write failed")
+                               }
+                               return
                        }
-                       sess.debugLogf("handler: sent event %d", e.Serial)
-               }
-               for _ = range queue {
+                       log.Debug("sent")
+
+                       if e != nil {
+                               hStats.QueueDelayNs += t0.Sub(e.Ready)
+                               h.mtx.Lock()
+                               h.lastDelay[queue] = stats.Duration(time.Since(e.Ready))
+                               h.mtx.Unlock()
+                       }
+                       hStats.WriteDelayNs += time.Since(t0)
+                       hStats.EventBytes += uint64(len(buf))
+                       hStats.EventCount++
                }
        }()
 
        // Filter incoming events against the current subscription
        // list, and forward matching events to the outgoing message
-       // queue. Close the queue and return when the "stopped"
-       // channel closes or the incoming event stream ends. Shut down
-       // the handler if the outgoing queue fills up.
+       // queue. Close the queue and return when the request context
+       // is done/cancelled or the incoming event stream ends. Shut
+       // down the handler if the outgoing queue fills up.
        go func() {
-               send := func(e *event) {
-                       select {
-                       case queue <- e:
-                       default:
-                               stop <- errQueueFull
-                       }
-               }
-
+               defer cancel()
                ticker := time.NewTicker(h.PingTimeout)
                defer ticker.Stop()
 
                for {
-                       var e *event
-                       var ok bool
                        select {
-                       case <-stopped:
-                               close(queue)
+                       case <-ctx.Done():
                                return
                        case <-ticker.C:
                                // If the outgoing queue is empty,
@@ -136,21 +174,64 @@ func (h *handler) Handle(ws wsConn, events <-chan *event) {
                                // socket, and prevent an idle socket
                                // from being closed.
                                if len(queue) == 0 {
-                                       queue <- nil
+                                       select {
+                                       case queue <- []byte(`{}`):
+                                       default:
+                                       }
                                }
-                               continue
-                       case e, ok = <-events:
+                       case e, ok := <-incoming.Channel():
                                if !ok {
-                                       close(queue)
                                        return
                                }
-                       }
-                       if sess.Filter(e) {
-                               send(e)
+                               if !sess.Filter(e) {
+                                       continue
+                               }
+                               select {
+                               case queue <- e:
+                               default:
+                                       log.WithError(errQueueFull).Error("terminate")
+                                       return
+                               }
                        }
                }
        }()
 
-       <-stop
-       close(stopped)
+       <-ctx.Done()
+       return
+}
+
+func (h *handler) DebugStatus() interface{} {
+       h.mtx.Lock()
+       defer h.mtx.Unlock()
+
+       var s struct {
+               QueueCount    int
+               QueueMin      int
+               QueueMax      int
+               QueueTotal    uint64
+               QueueDelayMin stats.Duration
+               QueueDelayMax stats.Duration
+       }
+       for q, lastDelay := range h.lastDelay {
+               s.QueueCount++
+               n := len(q)
+               s.QueueTotal += uint64(n)
+               if s.QueueMax < n {
+                       s.QueueMax = n
+               }
+               if s.QueueMin > n || s.QueueCount == 1 {
+                       s.QueueMin = n
+               }
+               if (s.QueueDelayMin > lastDelay || s.QueueDelayMin == 0) && lastDelay > 0 {
+                       s.QueueDelayMin = lastDelay
+               }
+               if s.QueueDelayMax < lastDelay {
+                       s.QueueDelayMax = lastDelay
+               }
+       }
+       return &s
+}
+
+func (h *handler) setup() {
+       h.lastDelay = make(map[chan interface{}]stats.Duration)
 }