Merge branch '19973-dispatch-throttle' into main
[arvados.git] / services / ws / handler.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ws
6
7 import (
8         "context"
9         "io"
10         "sync"
11         "time"
12
13         "git.arvados.org/arvados.git/sdk/go/arvados"
14         "git.arvados.org/arvados.git/sdk/go/stats"
15         "github.com/sirupsen/logrus"
16 )
17
18 type handler struct {
19         Client      arvados.Client
20         PingTimeout time.Duration
21         QueueSize   int
22
23         mtx       sync.Mutex
24         lastDelay map[chan interface{}]stats.Duration
25         setupOnce sync.Once
26 }
27
28 type handlerStats struct {
29         QueueDelayNs time.Duration
30         WriteDelayNs time.Duration
31         EventBytes   uint64
32         EventCount   uint64
33 }
34
35 func (h *handler) Handle(ws wsConn, logger logrus.FieldLogger, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
36         h.setupOnce.Do(h.setup)
37
38         ctx, cancel := context.WithCancel(ws.Request().Context())
39         defer cancel()
40
41         incoming := eventSource.NewSink()
42         defer incoming.Stop()
43
44         queue := make(chan interface{}, h.QueueSize)
45         h.mtx.Lock()
46         h.lastDelay[queue] = 0
47         h.mtx.Unlock()
48         defer func() {
49                 h.mtx.Lock()
50                 delete(h.lastDelay, queue)
51                 h.mtx.Unlock()
52         }()
53
54         sess, err := newSession(ws, queue)
55         if err != nil {
56                 logger.WithError(err).Error("newSession failed")
57                 return
58         }
59
60         // Receive websocket frames from the client and pass them to
61         // sess.Receive().
62         go func() {
63                 defer cancel()
64                 buf := make([]byte, 2<<20)
65                 for {
66                         select {
67                         case <-ctx.Done():
68                                 return
69                         default:
70                         }
71                         ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour))
72                         n, err := ws.Read(buf)
73                         buf := buf[:n]
74                         logger.WithField("frame", string(buf[:n])).Debug("received frame")
75                         if err == nil && n == cap(buf) {
76                                 err = errFrameTooBig
77                         }
78                         if err != nil {
79                                 if err != io.EOF && ctx.Err() == nil {
80                                         logger.WithError(err).Info("read error")
81                                 }
82                                 return
83                         }
84                         err = sess.Receive(buf)
85                         if err != nil {
86                                 logger.WithError(err).Error("sess.Receive() failed")
87                                 return
88                         }
89                 }
90         }()
91
92         // Take items from the outgoing queue, serialize them using
93         // sess.EventMessage() as needed, and send them to the client
94         // as websocket frames.
95         go func() {
96                 defer cancel()
97                 for {
98                         var ok bool
99                         var data interface{}
100                         select {
101                         case <-ctx.Done():
102                                 return
103                         case data, ok = <-queue:
104                                 if !ok {
105                                         return
106                                 }
107                         }
108                         var e *event
109                         var buf []byte
110                         var err error
111                         logger := logger
112
113                         switch data := data.(type) {
114                         case []byte:
115                                 buf = data
116                         case *event:
117                                 e = data
118                                 logger = logger.WithField("serial", e.Serial)
119                                 buf, err = sess.EventMessage(e)
120                                 if err != nil {
121                                         logger.WithError(err).Error("EventMessage failed")
122                                         return
123                                 } else if len(buf) == 0 {
124                                         logger.Debug("skip")
125                                         continue
126                                 }
127                         default:
128                                 logger.WithField("data", data).Error("bad object in client queue")
129                                 continue
130                         }
131
132                         logger.WithField("frame", string(buf)).Debug("send event")
133                         ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
134                         t0 := time.Now()
135                         _, err = ws.Write(buf)
136                         if err != nil {
137                                 if ctx.Err() == nil {
138                                         logger.WithError(err).Error("write failed")
139                                 }
140                                 return
141                         }
142                         logger.Debug("sent")
143
144                         if e != nil {
145                                 hStats.QueueDelayNs += t0.Sub(e.Ready)
146                                 h.mtx.Lock()
147                                 h.lastDelay[queue] = stats.Duration(time.Since(e.Ready))
148                                 h.mtx.Unlock()
149                         }
150                         hStats.WriteDelayNs += time.Since(t0)
151                         hStats.EventBytes += uint64(len(buf))
152                         hStats.EventCount++
153                 }
154         }()
155
156         // Filter incoming events against the current subscription
157         // list, and forward matching events to the outgoing message
158         // queue. Close the queue and return when the request context
159         // is done/cancelled or the incoming event stream ends. Shut
160         // down the handler if the outgoing queue fills up.
161         go func() {
162                 defer cancel()
163                 ticker := time.NewTicker(h.PingTimeout)
164                 defer ticker.Stop()
165
166                 for {
167                         select {
168                         case <-ctx.Done():
169                                 return
170                         case <-ticker.C:
171                                 // If the outgoing queue is empty,
172                                 // send an empty message. This can
173                                 // help detect a disconnected network
174                                 // socket, and prevent an idle socket
175                                 // from being closed.
176                                 if len(queue) == 0 {
177                                         select {
178                                         case queue <- []byte(`{}`):
179                                         default:
180                                         }
181                                 }
182                         case e, ok := <-incoming.Channel():
183                                 if !ok {
184                                         return
185                                 }
186                                 if !sess.Filter(e) {
187                                         continue
188                                 }
189                                 select {
190                                 case queue <- e:
191                                 default:
192                                         logger.WithError(errQueueFull).Error("terminate")
193                                         return
194                                 }
195                         }
196                 }
197         }()
198
199         <-ctx.Done()
200         return
201 }
202
203 func (h *handler) DebugStatus() interface{} {
204         h.mtx.Lock()
205         defer h.mtx.Unlock()
206
207         var s struct {
208                 QueueCount    int
209                 QueueMin      int
210                 QueueMax      int
211                 QueueTotal    uint64
212                 QueueDelayMin stats.Duration
213                 QueueDelayMax stats.Duration
214         }
215         for q, lastDelay := range h.lastDelay {
216                 s.QueueCount++
217                 n := len(q)
218                 s.QueueTotal += uint64(n)
219                 if s.QueueMax < n {
220                         s.QueueMax = n
221                 }
222                 if s.QueueMin > n || s.QueueCount == 1 {
223                         s.QueueMin = n
224                 }
225                 if (s.QueueDelayMin > lastDelay || s.QueueDelayMin == 0) && lastDelay > 0 {
226                         s.QueueDelayMin = lastDelay
227                 }
228                 if s.QueueDelayMax < lastDelay {
229                         s.QueueDelayMax = lastDelay
230                 }
231         }
232         return &s
233 }
234
235 func (h *handler) setup() {
236         h.lastDelay = make(map[chan interface{}]stats.Duration)
237 }