10903: support need_transaction for job and pipeline_instance cancel methods.
[arvados.git] / services / ws / handler.go
1 package main
2
3 import (
4         "context"
5         "io"
6         "sync"
7         "time"
8
9         "git.curoverse.com/arvados.git/sdk/go/arvados"
10         "git.curoverse.com/arvados.git/sdk/go/stats"
11 )
12
13 type handler struct {
14         Client      arvados.Client
15         PingTimeout time.Duration
16         QueueSize   int
17
18         mtx       sync.Mutex
19         lastDelay map[chan interface{}]stats.Duration
20         setupOnce sync.Once
21 }
22
23 type handlerStats struct {
24         QueueDelayNs time.Duration
25         WriteDelayNs time.Duration
26         EventBytes   uint64
27         EventCount   uint64
28 }
29
30 func (h *handler) Handle(ws wsConn, eventSource eventSource, newSession func(wsConn, chan<- interface{}) (session, error)) (hStats handlerStats) {
31         h.setupOnce.Do(h.setup)
32
33         ctx, cancel := context.WithCancel(ws.Request().Context())
34         defer cancel()
35         log := logger(ctx)
36
37         incoming := eventSource.NewSink()
38         defer incoming.Stop()
39
40         queue := make(chan interface{}, h.QueueSize)
41         h.mtx.Lock()
42         h.lastDelay[queue] = 0
43         h.mtx.Unlock()
44         defer func() {
45                 h.mtx.Lock()
46                 delete(h.lastDelay, queue)
47                 h.mtx.Unlock()
48         }()
49
50         sess, err := newSession(ws, queue)
51         if err != nil {
52                 log.WithError(err).Error("newSession failed")
53                 return
54         }
55
56         // Receive websocket frames from the client and pass them to
57         // sess.Receive().
58         go func() {
59                 buf := make([]byte, 2<<20)
60                 for {
61                         select {
62                         case <-ctx.Done():
63                                 return
64                         default:
65                         }
66                         ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour))
67                         n, err := ws.Read(buf)
68                         buf := buf[:n]
69                         log.WithField("frame", string(buf[:n])).Debug("received frame")
70                         if err == nil && n == cap(buf) {
71                                 err = errFrameTooBig
72                         }
73                         if err != nil {
74                                 if err != io.EOF {
75                                         log.WithError(err).Info("read error")
76                                 }
77                                 cancel()
78                                 return
79                         }
80                         err = sess.Receive(buf)
81                         if err != nil {
82                                 log.WithError(err).Error("sess.Receive() failed")
83                                 cancel()
84                                 return
85                         }
86                 }
87         }()
88
89         // Take items from the outgoing queue, serialize them using
90         // sess.EventMessage() as needed, and send them to the client
91         // as websocket frames.
92         go func() {
93                 for {
94                         var ok bool
95                         var data interface{}
96                         select {
97                         case <-ctx.Done():
98                                 return
99                         case data, ok = <-queue:
100                                 if !ok {
101                                         return
102                                 }
103                         }
104                         var e *event
105                         var buf []byte
106                         var err error
107                         log := log
108
109                         switch data := data.(type) {
110                         case []byte:
111                                 buf = data
112                         case *event:
113                                 e = data
114                                 log = log.WithField("serial", e.Serial)
115                                 buf, err = sess.EventMessage(e)
116                                 if err != nil {
117                                         log.WithError(err).Error("EventMessage failed")
118                                         cancel()
119                                         break
120                                 } else if len(buf) == 0 {
121                                         log.Debug("skip")
122                                         continue
123                                 }
124                         default:
125                                 log.WithField("data", data).Error("bad object in client queue")
126                                 continue
127                         }
128
129                         log.WithField("frame", string(buf)).Debug("send event")
130                         ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
131                         t0 := time.Now()
132                         _, err = ws.Write(buf)
133                         if err != nil {
134                                 log.WithError(err).Error("write failed")
135                                 cancel()
136                                 break
137                         }
138                         log.Debug("sent")
139
140                         if e != nil {
141                                 hStats.QueueDelayNs += t0.Sub(e.Ready)
142                                 h.mtx.Lock()
143                                 h.lastDelay[queue] = stats.Duration(time.Since(e.Ready))
144                                 h.mtx.Unlock()
145                         }
146                         hStats.WriteDelayNs += time.Since(t0)
147                         hStats.EventBytes += uint64(len(buf))
148                         hStats.EventCount++
149                 }
150         }()
151
152         // Filter incoming events against the current subscription
153         // list, and forward matching events to the outgoing message
154         // queue. Close the queue and return when the request context
155         // is done/cancelled or the incoming event stream ends. Shut
156         // down the handler if the outgoing queue fills up.
157         go func() {
158                 ticker := time.NewTicker(h.PingTimeout)
159                 defer ticker.Stop()
160
161                 for {
162                         select {
163                         case <-ctx.Done():
164                                 return
165                         case <-ticker.C:
166                                 // If the outgoing queue is empty,
167                                 // send an empty message. This can
168                                 // help detect a disconnected network
169                                 // socket, and prevent an idle socket
170                                 // from being closed.
171                                 if len(queue) == 0 {
172                                         select {
173                                         case queue <- []byte(`{}`):
174                                         default:
175                                         }
176                                 }
177                                 continue
178                         case e, ok := <-incoming.Channel():
179                                 if !ok {
180                                         cancel()
181                                         return
182                                 }
183                                 if !sess.Filter(e) {
184                                         continue
185                                 }
186                                 select {
187                                 case queue <- e:
188                                 default:
189                                         log.WithError(errQueueFull).Error("terminate")
190                                         cancel()
191                                         return
192                                 }
193                         }
194                 }
195         }()
196
197         <-ctx.Done()
198         return
199 }
200
201 func (h *handler) DebugStatus() interface{} {
202         h.mtx.Lock()
203         defer h.mtx.Unlock()
204
205         var s struct {
206                 QueueCount    int
207                 QueueMin      int
208                 QueueMax      int
209                 QueueTotal    uint64
210                 QueueDelayMin stats.Duration
211                 QueueDelayMax stats.Duration
212         }
213         for q, lastDelay := range h.lastDelay {
214                 s.QueueCount++
215                 n := len(q)
216                 s.QueueTotal += uint64(n)
217                 if s.QueueMax < n {
218                         s.QueueMax = n
219                 }
220                 if s.QueueMin > n || s.QueueCount == 1 {
221                         s.QueueMin = n
222                 }
223                 if (s.QueueDelayMin > lastDelay || s.QueueDelayMin == 0) && lastDelay > 0 {
224                         s.QueueDelayMin = lastDelay
225                 }
226                 if s.QueueDelayMax < lastDelay {
227                         s.QueueDelayMax = lastDelay
228                 }
229         }
230         return &s
231 }
232
233 func (h *handler) setup() {
234         h.lastDelay = make(map[chan interface{}]stats.Duration)
235 }