18870: Need to declare NODES as array
[arvados.git] / services / ws / handler.go
index 72291900fac52c4192f4db971e29bd736be5c189..912643ad97c6374006b3fd4b00f90d340157d687 100644 (file)
@@ -1,4 +1,8 @@
-package main
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package ws
 
 import (
        "context"
@@ -6,8 +10,9 @@ import (
        "sync"
        "time"
 
-       "git.curoverse.com/arvados.git/sdk/go/arvados"
-       "git.curoverse.com/arvados.git/sdk/go/stats"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/stats"
+       "github.com/sirupsen/logrus"
 )
 
 type handler struct {
@@ -27,12 +32,11 @@ type handlerStats struct {
        EventCount   uint64
 }
 
-func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
+func (h *handler) Handle(ws wsConn, logger logrus.FieldLogger, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
        h.setupOnce.Do(h.setup)
 
        ctx, cancel := context.WithCancel(ws.Request().Context())
        defer cancel()
-       log := logger(ctx)
 
        incoming := eventSource.NewSink()
        defer incoming.Stop()
@@ -49,13 +53,14 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
 
        sess, err := newSession(ws, queue)
        if err != nil {
-               log.WithError(err).Error("newSession failed")
+               logger.WithError(err).Error("newSession failed")
                return
        }
 
        // Receive websocket frames from the client and pass them to
        // sess.Receive().
        go func() {
+               defer cancel()
                buf := make([]byte, 2<<20)
                for {
                        select {
@@ -66,21 +71,19 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                        ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour))
                        n, err := ws.Read(buf)
                        buf := buf[:n]
-                       log.WithField("frame", string(buf[:n])).Debug("received frame")
+                       logger.WithField("frame", string(buf[:n])).Debug("received frame")
                        if err == nil && n == cap(buf) {
                                err = errFrameTooBig
                        }
                        if err != nil {
-                               if err != io.EOF {
-                                       log.WithError(err).Info("read error")
+                               if err != io.EOF && ctx.Err() == nil {
+                                       logger.WithError(err).Info("read error")
                                }
-                               cancel()
                                return
                        }
                        err = sess.Receive(buf)
                        if err != nil {
-                               log.WithError(err).Error("sess.Receive() failed")
-                               cancel()
+                               logger.WithError(err).Error("sess.Receive() failed")
                                return
                        }
                }
@@ -90,6 +93,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
        // sess.EventMessage() as needed, and send them to the client
        // as websocket frames.
        go func() {
+               defer cancel()
                for {
                        var ok bool
                        var data interface{}
@@ -104,38 +108,38 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                        var e *event
                        var buf []byte
                        var err error
-                       log := log
+                       logger := logger
 
                        switch data := data.(type) {
                        case []byte:
                                buf = data
                        case *event:
                                e = data
-                               log = log.WithField("serial", e.Serial)
+                               logger = logger.WithField("serial", e.Serial)
                                buf, err = sess.EventMessage(e)
                                if err != nil {
-                                       log.WithError(err).Error("EventMessage failed")
-                                       cancel()
-                                       break
+                                       logger.WithError(err).Error("EventMessage failed")
+                                       return
                                } else if len(buf) == 0 {
-                                       log.Debug("skip")
+                                       logger.Debug("skip")
                                        continue
                                }
                        default:
-                               log.WithField("data", data).Error("bad object in client queue")
+                               logger.WithField("data", data).Error("bad object in client queue")
                                continue
                        }
 
-                       log.WithField("frame", string(buf)).Debug("send event")
+                       logger.WithField("frame", string(buf)).Debug("send event")
                        ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
                        t0 := time.Now()
                        _, err = ws.Write(buf)
                        if err != nil {
-                               log.WithError(err).Error("write failed")
-                               cancel()
-                               break
+                               if ctx.Err() == nil {
+                                       logger.WithError(err).Error("write failed")
+                               }
+                               return
                        }
-                       log.Debug("sent")
+                       logger.Debug("sent")
 
                        if e != nil {
                                hStats.QueueDelayNs += t0.Sub(e.Ready)
@@ -155,6 +159,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
        // is done/cancelled or the incoming event stream ends. Shut
        // down the handler if the outgoing queue fills up.
        go func() {
+               defer cancel()
                ticker := time.NewTicker(h.PingTimeout)
                defer ticker.Stop()
 
@@ -174,10 +179,8 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                                        default:
                                        }
                                }
-                               continue
                        case e, ok := <-incoming.Channel():
                                if !ok {
-                                       cancel()
                                        return
                                }
                                if !sess.Filter(e) {
@@ -186,8 +189,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                                select {
                                case queue <- e:
                                default:
-                                       log.WithError(errQueueFull).Error("terminate")
-                                       cancel()
+                                       logger.WithError(errQueueFull).Error("terminate")
                                        return
                                }
                        }