Merge branch 'main' into 21386-project-loading-view
[arvados.git] / services / ws / handler.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ws
6
7 import (
8         "context"
9         "io"
10         "sync"
11         "time"
12
13         "git.arvados.org/arvados.git/sdk/go/arvados"
14         "git.arvados.org/arvados.git/sdk/go/stats"
15         "github.com/sirupsen/logrus"
16 )
17
18 type handler struct {
19         Client      arvados.Client
20         PingTimeout time.Duration
21         QueueSize   int
22
23         mtx       sync.Mutex
24         lastDelay map[chan interface{}]stats.Duration
25         setupOnce sync.Once
26 }
27
28 type handlerStats struct {
29         QueueDelayNs time.Duration
30         WriteDelayNs time.Duration
31         EventBytes   uint64
32         EventCount   uint64
33 }
34
35 func (h *handler) Handle(ws wsConn, logger logrus.FieldLogger, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
36         h.setupOnce.Do(h.setup)
37
38         ctx, cancel := context.WithCancel(ws.Request().Context())
39         defer cancel()
40
41         queue := make(chan interface{}, h.QueueSize)
42         h.mtx.Lock()
43         h.lastDelay[queue] = 0
44         h.mtx.Unlock()
45         defer func() {
46                 h.mtx.Lock()
47                 delete(h.lastDelay, queue)
48                 h.mtx.Unlock()
49         }()
50
51         sess, err := newSession(ws, queue)
52         if err != nil {
53                 logger.WithError(err).Error("newSession failed")
54                 return
55         }
56
57         // Receive websocket frames from the client and pass them to
58         // sess.Receive().
59         go func() {
60                 defer cancel()
61                 buf := make([]byte, 2<<20)
62                 for {
63                         select {
64                         case <-ctx.Done():
65                                 return
66                         default:
67                         }
68                         ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour))
69                         n, err := ws.Read(buf)
70                         buf := buf[:n]
71                         logger.WithField("frame", string(buf[:n])).Debug("received frame")
72                         if err == nil && n == cap(buf) {
73                                 err = errFrameTooBig
74                         }
75                         if err != nil {
76                                 if err != io.EOF && ctx.Err() == nil {
77                                         logger.WithError(err).Info("read error")
78                                 }
79                                 return
80                         }
81                         err = sess.Receive(buf)
82                         if err != nil {
83                                 logger.WithError(err).Error("sess.Receive() failed")
84                                 return
85                         }
86                 }
87         }()
88
89         // Take items from the outgoing queue, serialize them using
90         // sess.EventMessage() as needed, and send them to the client
91         // as websocket frames.
92         go func() {
93                 defer cancel()
94                 for {
95                         var ok bool
96                         var data interface{}
97                         select {
98                         case <-ctx.Done():
99                                 return
100                         case data, ok = <-queue:
101                                 if !ok {
102                                         return
103                                 }
104                         }
105                         var e *event
106                         var buf []byte
107                         var err error
108                         logger := logger
109
110                         switch data := data.(type) {
111                         case []byte:
112                                 buf = data
113                         case *event:
114                                 e = data
115                                 logger = logger.WithField("serial", e.Serial)
116                                 buf, err = sess.EventMessage(e)
117                                 if err != nil {
118                                         logger.WithError(err).Error("EventMessage failed")
119                                         return
120                                 } else if len(buf) == 0 {
121                                         logger.Debug("skip")
122                                         continue
123                                 }
124                         default:
125                                 logger.WithField("data", data).Error("bad object in client queue")
126                                 continue
127                         }
128
129                         logger.WithField("frame", string(buf)).Debug("send event")
130                         ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
131                         t0 := time.Now()
132                         _, err = ws.Write(buf)
133                         if err != nil {
134                                 if ctx.Err() == nil {
135                                         logger.WithError(err).Error("write failed")
136                                 }
137                                 return
138                         }
139                         logger.Debug("sent")
140
141                         if e != nil {
142                                 hStats.QueueDelayNs += t0.Sub(e.Ready)
143                                 h.mtx.Lock()
144                                 h.lastDelay[queue] = stats.Duration(time.Since(e.Ready))
145                                 h.mtx.Unlock()
146                         }
147                         hStats.WriteDelayNs += time.Since(t0)
148                         hStats.EventBytes += uint64(len(buf))
149                         hStats.EventCount++
150                 }
151         }()
152
153         // Filter incoming events against the current subscription
154         // list, and forward matching events to the outgoing message
155         // queue. Close the queue and return when the request context
156         // is done/cancelled or the incoming event stream ends. Shut
157         // down the handler if the outgoing queue fills up.
158         go func() {
159                 defer cancel()
160                 ticker := time.NewTicker(h.PingTimeout)
161                 defer ticker.Stop()
162
163                 incoming := eventSource.NewSink()
164                 defer incoming.Stop()
165
166                 for {
167                         select {
168                         case <-ctx.Done():
169                                 return
170                         case <-ticker.C:
171                                 // If the outgoing queue is empty,
172                                 // send an empty message. This can
173                                 // help detect a disconnected network
174                                 // socket, and prevent an idle socket
175                                 // from being closed.
176                                 if len(queue) == 0 {
177                                         select {
178                                         case queue <- []byte(`{}`):
179                                         default:
180                                         }
181                                 }
182                         case e, ok := <-incoming.Channel():
183                                 if !ok {
184                                         return
185                                 }
186                                 if !sess.Filter(e) {
187                                         continue
188                                 }
189                                 select {
190                                 case queue <- e:
191                                 default:
192                                         logger.WithError(errQueueFull).Error("terminate")
193                                         return
194                                 }
195                         }
196                 }
197         }()
198
199         <-ctx.Done()
200         return
201 }
202
203 func (h *handler) DebugStatus() interface{} {
204         h.mtx.Lock()
205         defer h.mtx.Unlock()
206
207         var s struct {
208                 QueueCount    int
209                 QueueMin      int
210                 QueueMax      int
211                 QueueTotal    uint64
212                 QueueDelayMin stats.Duration
213                 QueueDelayMax stats.Duration
214         }
215         for q, lastDelay := range h.lastDelay {
216                 s.QueueCount++
217                 n := len(q)
218                 s.QueueTotal += uint64(n)
219                 if s.QueueMax < n {
220                         s.QueueMax = n
221                 }
222                 if s.QueueMin > n || s.QueueCount == 1 {
223                         s.QueueMin = n
224                 }
225                 if (s.QueueDelayMin > lastDelay || s.QueueDelayMin == 0) && lastDelay > 0 {
226                         s.QueueDelayMin = lastDelay
227                 }
228                 if s.QueueDelayMax < lastDelay {
229                         s.QueueDelayMax = lastDelay
230                 }
231         }
232         return &s
233 }
234
235 func (h *handler) setup() {
236         h.lastDelay = make(map[chan interface{}]stats.Duration)
237 }