8460: Fix "send to closed channel" race by using context lib to release goroutines.
authorTom Clegg <tom@curoverse.com>
Fri, 18 Nov 2016 23:07:15 +0000 (18:07 -0500)
committerTom Clegg <tom@curoverse.com>
Fri, 18 Nov 2016 23:07:15 +0000 (18:07 -0500)
services/ws/handler.go
services/ws/session_v0.go

index ab25805c6c00e0490085f79b6be1e928b35bd6b6..91a77022d6030aef29bef64b15ad6951674188ba 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "context"
        "io"
        "time"
 
@@ -22,7 +23,8 @@ type handlerStats struct {
 }
 
 func (h *handler) Handle(ws wsConn, incoming <-chan *event) (stats handlerStats) {
-       log := logger(ws.Request().Context())
+       ctx, cancel := context.WithCancel(ws.Request().Context())
+       log := logger(ctx)
        queue := make(chan interface{}, h.QueueSize)
        sess, err := h.NewSession(ws, queue)
        if err != nil {
@@ -30,14 +32,11 @@ func (h *handler) Handle(ws wsConn, incoming <-chan *event) (stats handlerStats)
                return
        }
 
-       stopped := make(chan struct{})
-       stop := make(chan error, 5)
-
        go func() {
                buf := make([]byte, 2<<20)
                for {
                        select {
-                       case <-stopped:
+                       case <-ctx.Done():
                                return
                        default:
                        }
@@ -52,19 +51,30 @@ func (h *handler) Handle(ws wsConn, incoming <-chan *event) (stats handlerStats)
                                if err != io.EOF {
                                        log.WithError(err).Info("read error")
                                }
-                               stop <- err
+                               cancel()
                                return
                        }
                        err = sess.Receive(buf)
                        if err != nil {
-                               stop <- err
+                               log.WithError(err).Error("sess.Receive() failed")
+                               cancel()
                                return
                        }
                }
        }()
 
        go func() {
-               for data := range queue {
+               for {
+                       var ok bool
+                       var data interface{}
+                       select {
+                       case <-ctx.Done():
+                               return
+                       case data, ok = <-queue:
+                               if !ok {
+                                       return
+                               }
+                       }
                        var e *event
                        var buf []byte
                        var err error
@@ -79,7 +89,7 @@ func (h *handler) Handle(ws wsConn, incoming <-chan *event) (stats handlerStats)
                                buf, err = sess.EventMessage(e)
                                if err != nil {
                                        log.WithError(err).Error("EventMessage failed")
-                                       stop <- err
+                                       cancel()
                                        break
                                } else if len(buf) == 0 {
                                        log.Debug("skip")
@@ -96,7 +106,7 @@ func (h *handler) Handle(ws wsConn, incoming <-chan *event) (stats handlerStats)
                        _, err = ws.Write(buf)
                        if err != nil {
                                log.WithError(err).Error("write failed")
-                               stop <- err
+                               cancel()
                                break
                        }
                        log.Debug("sent")
@@ -108,25 +118,20 @@ func (h *handler) Handle(ws wsConn, incoming <-chan *event) (stats handlerStats)
                        stats.EventBytes += uint64(len(buf))
                        stats.EventCount++
                }
-               for _ = range queue {
-                       // Ensure queue can't fill up and block other
-                       // goroutines after we hit a write error.
-               }
        }()
 
        // Filter incoming events against the current subscription
        // list, and forward matching events to the outgoing message
-       // queue. Close the queue and return when the "stopped"
-       // channel closes or the incoming event stream ends. Shut down
-       // the handler if the outgoing queue fills up.
+       // queue. Close the queue and return when the request context
+       // is done/cancelled or the incoming event stream ends. Shut
+       // down the handler if the outgoing queue fills up.
        go func() {
                ticker := time.NewTicker(h.PingTimeout)
                defer ticker.Stop()
 
                for {
                        select {
-                       case <-stopped:
-                               close(queue)
+                       case <-ctx.Done():
                                return
                        case <-ticker.C:
                                // If the outgoing queue is empty,
@@ -135,12 +140,15 @@ func (h *handler) Handle(ws wsConn, incoming <-chan *event) (stats handlerStats)
                                // socket, and prevent an idle socket
                                // from being closed.
                                if len(queue) == 0 {
-                                       queue <- []byte(`{}`)
+                                       select {
+                                       case queue <- []byte(`{}`):
+                                       default:
+                                       }
                                }
                                continue
                        case e, ok := <-incoming:
                                if !ok {
-                                       close(queue)
+                                       cancel()
                                        return
                                }
                                if !sess.Filter(e) {
@@ -149,14 +157,14 @@ func (h *handler) Handle(ws wsConn, incoming <-chan *event) (stats handlerStats)
                                select {
                                case queue <- e:
                                default:
-                                       stop <- errQueueFull
+                                       log.WithError(errQueueFull).Error("terminate")
+                                       cancel()
+                                       return
                                }
                        }
                }
        }()
 
-       <-stop
-       close(stopped)
-
+       <-ctx.Done()
        return
 }
index 2bcce6073e5ffc575934d7e5ec33c55bc6616c25..29a7adec82a7fcd62905e08d9018bb983b174cd7 100644 (file)
@@ -171,7 +171,11 @@ func (sub *v0subscribe) sendOldEvents(sess *v0session) {
                        db:       sess.db,
                }
                if sub.match(sess, e) {
-                       sess.sendq <- e
+                       select {
+                       case sess.sendq <- e:
+                       case <-sess.ws.Request().Context().Done():
+                               return
+                       }
                }
        }
        if err := rows.Err(); err != nil {