Merge branch '16265-security-updates' into dependabot/bundler/services/api/nokogiri...
[arvados.git] / services / ws / handler.go
index ca9231c986de0f75e0655e506b8590c1b17d3d84..913b1ee8000cbd274039483df70bad7896d52df5 100644 (file)
@@ -1,3 +1,7 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
@@ -6,8 +10,8 @@ import (
        "sync"
        "time"
 
-       "git.curoverse.com/arvados.git/sdk/go/arvados"
-       "git.curoverse.com/arvados.git/sdk/go/stats"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/stats"
 )
 
 type handler struct {
@@ -31,6 +35,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
        h.setupOnce.Do(h.setup)
 
        ctx, cancel := context.WithCancel(ws.Request().Context())
+       defer cancel()
        log := logger(ctx)
 
        incoming := eventSource.NewSink()
@@ -52,7 +57,10 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                return
        }
 
+       // Receive websocket frames from the client and pass them to
+       // sess.Receive().
        go func() {
+               defer cancel()
                buf := make([]byte, 2<<20)
                for {
                        select {
@@ -68,22 +76,24 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                                err = errFrameTooBig
                        }
                        if err != nil {
-                               if err != io.EOF {
+                               if err != io.EOF && ctx.Err() == nil {
                                        log.WithError(err).Info("read error")
                                }
-                               cancel()
                                return
                        }
                        err = sess.Receive(buf)
                        if err != nil {
                                log.WithError(err).Error("sess.Receive() failed")
-                               cancel()
                                return
                        }
                }
        }()
 
+       // Take items from the outgoing queue, serialize them using
+       // sess.EventMessage() as needed, and send them to the client
+       // as websocket frames.
        go func() {
+               defer cancel()
                for {
                        var ok bool
                        var data interface{}
@@ -109,8 +119,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                                buf, err = sess.EventMessage(e)
                                if err != nil {
                                        log.WithError(err).Error("EventMessage failed")
-                                       cancel()
-                                       break
+                                       return
                                } else if len(buf) == 0 {
                                        log.Debug("skip")
                                        continue
@@ -125,9 +134,10 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                        t0 := time.Now()
                        _, err = ws.Write(buf)
                        if err != nil {
-                               log.WithError(err).Error("write failed")
-                               cancel()
-                               break
+                               if ctx.Err() == nil {
+                                       log.WithError(err).Error("write failed")
+                               }
+                               return
                        }
                        log.Debug("sent")
 
@@ -149,6 +159,7 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
        // is done/cancelled or the incoming event stream ends. Shut
        // down the handler if the outgoing queue fills up.
        go func() {
+               defer cancel()
                ticker := time.NewTicker(h.PingTimeout)
                defer ticker.Stop()
 
@@ -168,10 +179,8 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                                        default:
                                        }
                                }
-                               continue
                        case e, ok := <-incoming.Channel():
                                if !ok {
-                                       cancel()
                                        return
                                }
                                if !sess.Filter(e) {
@@ -181,7 +190,6 @@ func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsC
                                case queue <- e:
                                default:
                                        log.WithError(errQueueFull).Error("terminate")
-                                       cancel()
                                        return
                                }
                        }