12167: Pass request ID via keepclient instead of custom code.
[arvados.git] / services / ws / session_v0.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package main
6
7 import (
8         "database/sql"
9         "encoding/json"
10         "errors"
11         "reflect"
12         "sync"
13         "sync/atomic"
14         "time"
15
16         "git.curoverse.com/arvados.git/sdk/go/arvados"
17         "github.com/Sirupsen/logrus"
18 )
19
20 var (
21         errQueueFull   = errors.New("client queue full")
22         errFrameTooBig = errors.New("frame too big")
23
24         // Send clients only these keys from the
25         // log.properties.old_attributes and
26         // log.properties.new_attributes hashes.
27         sendObjectAttributes = []string{
28                 "is_trashed",
29                 "name",
30                 "owner_uuid",
31                 "portable_data_hash",
32                 "state",
33         }
34
35         v0subscribeOK   = []byte(`{"status":200}`)
36         v0subscribeFail = []byte(`{"status":400}`)
37 )
38
39 type v0session struct {
40         ac            *arvados.Client
41         ws            wsConn
42         sendq         chan<- interface{}
43         db            *sql.DB
44         permChecker   permChecker
45         subscriptions []v0subscribe
46         lastMsgID     uint64
47         log           *logrus.Entry
48         mtx           sync.Mutex
49         setupOnce     sync.Once
50 }
51
52 // newSessionV0 returns a v0 session: a partial port of the Rails/puma
53 // implementation, with just enough functionality to support Workbench
54 // and arv-mount.
55 func newSessionV0(ws wsConn, sendq chan<- interface{}, db *sql.DB, pc permChecker, ac *arvados.Client) (session, error) {
56         sess := &v0session{
57                 sendq:       sendq,
58                 ws:          ws,
59                 db:          db,
60                 ac:          ac,
61                 permChecker: pc,
62                 log:         logger(ws.Request().Context()),
63         }
64
65         err := ws.Request().ParseForm()
66         if err != nil {
67                 sess.log.WithError(err).Error("ParseForm failed")
68                 return nil, err
69         }
70         token := ws.Request().Form.Get("api_token")
71         sess.permChecker.SetToken(token)
72         sess.log.WithField("token", token).Debug("set token")
73
74         return sess, nil
75 }
76
77 func (sess *v0session) Receive(buf []byte) error {
78         var sub v0subscribe
79         if err := json.Unmarshal(buf, &sub); err != nil {
80                 sess.log.WithError(err).Info("invalid message from client")
81         } else if sub.Method == "subscribe" {
82                 sub.prepare(sess)
83                 sess.log.WithField("sub", sub).Debug("sub prepared")
84                 sess.sendq <- v0subscribeOK
85                 sess.mtx.Lock()
86                 sess.subscriptions = append(sess.subscriptions, sub)
87                 sess.mtx.Unlock()
88                 sub.sendOldEvents(sess)
89                 return nil
90         } else if sub.Method == "unsubscribe" {
91                 sess.mtx.Lock()
92                 found := false
93                 for i, s := range sess.subscriptions {
94                         if !reflect.DeepEqual(s.Filters, sub.Filters) {
95                                 continue
96                         }
97                         copy(sess.subscriptions[i:], sess.subscriptions[i+1:])
98                         sess.subscriptions = sess.subscriptions[:len(sess.subscriptions)-1]
99                         found = true
100                         break
101                 }
102                 sess.mtx.Unlock()
103                 sess.log.WithField("sub", sub).WithField("found", found).Debug("unsubscribe")
104                 if found {
105                         sess.sendq <- v0subscribeOK
106                         return nil
107                 }
108         } else {
109                 sess.log.WithField("Method", sub.Method).Info("unknown method")
110         }
111         sess.sendq <- v0subscribeFail
112         return nil
113 }
114
115 func (sess *v0session) EventMessage(e *event) ([]byte, error) {
116         detail := e.Detail()
117         if detail == nil {
118                 return nil, nil
119         }
120
121         var permTarget string
122         if detail.EventType == "delete" {
123                 // It's pointless to check permission by reading
124                 // ObjectUUID if it has just been deleted, but if the
125                 // client has permission on the parent project then
126                 // it's OK to send the event.
127                 permTarget = detail.ObjectOwnerUUID
128         } else {
129                 permTarget = detail.ObjectUUID
130         }
131         ok, err := sess.permChecker.Check(permTarget)
132         if err != nil || !ok {
133                 return nil, err
134         }
135
136         kind, _ := sess.ac.KindForUUID(detail.ObjectUUID)
137         msg := map[string]interface{}{
138                 "msgID":             atomic.AddUint64(&sess.lastMsgID, 1),
139                 "id":                detail.ID,
140                 "uuid":              detail.UUID,
141                 "object_uuid":       detail.ObjectUUID,
142                 "object_owner_uuid": detail.ObjectOwnerUUID,
143                 "object_kind":       kind,
144                 "event_type":        detail.EventType,
145                 "event_at":          detail.EventAt,
146         }
147         if detail.Properties != nil && detail.Properties["text"] != nil {
148                 msg["properties"] = detail.Properties
149         } else {
150                 msgProps := map[string]map[string]interface{}{}
151                 for _, ak := range []string{"old_attributes", "new_attributes"} {
152                         eventAttrs, ok := detail.Properties[ak].(map[string]interface{})
153                         if !ok {
154                                 continue
155                         }
156                         msgAttrs := map[string]interface{}{}
157                         for _, k := range sendObjectAttributes {
158                                 if v, ok := eventAttrs[k]; ok {
159                                         msgAttrs[k] = v
160                                 }
161                         }
162                         msgProps[ak] = msgAttrs
163                 }
164                 msg["properties"] = msgProps
165         }
166         return json.Marshal(msg)
167 }
168
169 func (sess *v0session) Filter(e *event) bool {
170         sess.mtx.Lock()
171         defer sess.mtx.Unlock()
172         for _, sub := range sess.subscriptions {
173                 if sub.match(sess, e) {
174                         return true
175                 }
176         }
177         return false
178 }
179
180 func (sub *v0subscribe) sendOldEvents(sess *v0session) {
181         if sub.LastLogID == 0 {
182                 return
183         }
184         sess.log.WithField("LastLogID", sub.LastLogID).Debug("sendOldEvents")
185         // Here we do a "select id" query and queue an event for every
186         // log since the given ID, then use (*event)Detail() to
187         // retrieve the whole row and decide whether to send it. This
188         // approach is very inefficient if the subscriber asks for
189         // last_log_id==1, even if the filters end up matching very
190         // few events.
191         //
192         // To mitigate this, filter on "created > 10 minutes ago" when
193         // retrieving the list of old event IDs to consider.
194         rows, err := sess.db.Query(
195                 `SELECT id FROM logs WHERE id > $1 AND created_at > $2 ORDER BY id`,
196                 sub.LastLogID,
197                 time.Now().UTC().Add(-10*time.Minute).Format(time.RFC3339Nano))
198         if err != nil {
199                 sess.log.WithError(err).Error("sendOldEvents db.Query failed")
200                 return
201         }
202
203         var ids []uint64
204         for rows.Next() {
205                 var id uint64
206                 err := rows.Scan(&id)
207                 if err != nil {
208                         sess.log.WithError(err).Error("sendOldEvents row Scan failed")
209                         continue
210                 }
211                 ids = append(ids, id)
212         }
213         if err := rows.Err(); err != nil {
214                 sess.log.WithError(err).Error("sendOldEvents db.Query failed")
215         }
216         rows.Close()
217
218         for _, id := range ids {
219                 for len(sess.sendq)*2 > cap(sess.sendq) {
220                         // Ugly... but if we fill up the whole client
221                         // queue with a backlog of old events, a
222                         // single new event will overflow it and
223                         // terminate the connection, and then the
224                         // client will probably reconnect and do the
225                         // same thing all over again.
226                         time.Sleep(100 * time.Millisecond)
227                         if sess.ws.Request().Context().Err() != nil {
228                                 // Session terminated while we were sleeping
229                                 return
230                         }
231                 }
232                 now := time.Now()
233                 e := &event{
234                         LogID:    id,
235                         Received: now,
236                         Ready:    now,
237                         db:       sess.db,
238                 }
239                 if sub.match(sess, e) {
240                         select {
241                         case sess.sendq <- e:
242                         case <-sess.ws.Request().Context().Done():
243                                 return
244                         }
245                 }
246         }
247 }
248
249 type v0subscribe struct {
250         Method    string
251         Filters   []v0filter
252         LastLogID int64 `json:"last_log_id"`
253
254         funcs []func(*event) bool
255 }
256
257 type v0filter [3]interface{}
258
259 func (sub *v0subscribe) match(sess *v0session, e *event) bool {
260         log := sess.log.WithField("LogID", e.LogID)
261         detail := e.Detail()
262         if detail == nil {
263                 log.Error("match failed, no detail")
264                 return false
265         }
266         log = log.WithField("funcs", len(sub.funcs))
267         for i, f := range sub.funcs {
268                 if !f(e) {
269                         log.WithField("func", i).Debug("match failed")
270                         return false
271                 }
272         }
273         log.Debug("match passed")
274         return true
275 }
276
277 func (sub *v0subscribe) prepare(sess *v0session) {
278         for _, f := range sub.Filters {
279                 if len(f) != 3 {
280                         continue
281                 }
282                 if col, ok := f[0].(string); ok && col == "event_type" {
283                         op, ok := f[1].(string)
284                         if !ok || op != "in" {
285                                 continue
286                         }
287                         arr, ok := f[2].([]interface{})
288                         if !ok {
289                                 continue
290                         }
291                         var strs []string
292                         for _, s := range arr {
293                                 if s, ok := s.(string); ok {
294                                         strs = append(strs, s)
295                                 }
296                         }
297                         sub.funcs = append(sub.funcs, func(e *event) bool {
298                                 for _, s := range strs {
299                                         if s == e.Detail().EventType {
300                                                 return true
301                                         }
302                                 }
303                                 return false
304                         })
305                 } else if ok && col == "created_at" {
306                         op, ok := f[1].(string)
307                         if !ok {
308                                 continue
309                         }
310                         tstr, ok := f[2].(string)
311                         if !ok {
312                                 continue
313                         }
314                         t, err := time.Parse(time.RFC3339Nano, tstr)
315                         if err != nil {
316                                 sess.log.WithField("data", tstr).WithError(err).Info("time.Parse failed")
317                                 continue
318                         }
319                         var fn func(*event) bool
320                         switch op {
321                         case ">=":
322                                 fn = func(e *event) bool {
323                                         return !e.Detail().CreatedAt.Before(t)
324                                 }
325                         case "<=":
326                                 fn = func(e *event) bool {
327                                         return !e.Detail().CreatedAt.After(t)
328                                 }
329                         case ">":
330                                 fn = func(e *event) bool {
331                                         return e.Detail().CreatedAt.After(t)
332                                 }
333                         case "<":
334                                 fn = func(e *event) bool {
335                                         return e.Detail().CreatedAt.Before(t)
336                                 }
337                         case "=":
338                                 fn = func(e *event) bool {
339                                         return e.Detail().CreatedAt.Equal(t)
340                                 }
341                         default:
342                                 sess.log.WithField("operator", op).Info("bogus operator")
343                                 continue
344                         }
345                         sub.funcs = append(sub.funcs, fn)
346                 }
347         }
348 }