20264: Reject redirect target with userinfo.
[arvados.git] / services / ws / session_v0.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ws
6
7 import (
8         "database/sql"
9         "encoding/json"
10         "errors"
11         "reflect"
12         "sync"
13         "sync/atomic"
14         "time"
15
16         "git.arvados.org/arvados.git/sdk/go/arvados"
17         "git.arvados.org/arvados.git/sdk/go/ctxlog"
18         "github.com/sirupsen/logrus"
19 )
20
21 var (
22         errQueueFull   = errors.New("client queue full")
23         errFrameTooBig = errors.New("frame too big")
24
25         // Send clients only these keys from the
26         // log.properties.old_attributes and
27         // log.properties.new_attributes hashes.
28         sendObjectAttributes = []string{
29                 "is_trashed",
30                 "name",
31                 "owner_uuid",
32                 "portable_data_hash",
33                 "state",
34         }
35
36         v0subscribeOK   = []byte(`{"status":200}`)
37         v0subscribeFail = []byte(`{"status":400}`)
38 )
39
40 type v0session struct {
41         ac            *arvados.Client
42         ws            wsConn
43         sendq         chan<- interface{}
44         db            *sql.DB
45         permChecker   permChecker
46         subscriptions []v0subscribe
47         lastMsgID     uint64
48         log           logrus.FieldLogger
49         mtx           sync.Mutex
50         setupOnce     sync.Once
51 }
52
53 // newSessionV0 returns a v0 session: a partial port of the Rails/puma
54 // implementation, with just enough functionality to support Workbench
55 // and arv-mount.
56 func newSessionV0(ws wsConn, sendq chan<- interface{}, db *sql.DB, pc permChecker, ac *arvados.Client) (session, error) {
57         sess := &v0session{
58                 sendq:       sendq,
59                 ws:          ws,
60                 db:          db,
61                 ac:          ac,
62                 permChecker: pc,
63                 log:         ctxlog.FromContext(ws.Request().Context()),
64         }
65
66         err := ws.Request().ParseForm()
67         if err != nil {
68                 sess.log.WithError(err).Error("ParseForm failed")
69                 return nil, err
70         }
71         token := ws.Request().Form.Get("api_token")
72         sess.permChecker.SetToken(token)
73         sess.log.WithField("token", token).Debug("set token")
74
75         return sess, nil
76 }
77
78 func (sess *v0session) Receive(buf []byte) error {
79         var sub v0subscribe
80         if err := json.Unmarshal(buf, &sub); err != nil {
81                 sess.log.WithError(err).Info("invalid message from client")
82         } else if sub.Method == "subscribe" {
83                 sub.prepare(sess)
84                 sess.log.WithField("sub", sub).Debug("sub prepared")
85                 sess.sendq <- v0subscribeOK
86                 sess.mtx.Lock()
87                 sess.subscriptions = append(sess.subscriptions, sub)
88                 sess.mtx.Unlock()
89                 sub.sendOldEvents(sess)
90                 return nil
91         } else if sub.Method == "unsubscribe" {
92                 sess.mtx.Lock()
93                 found := false
94                 for i, s := range sess.subscriptions {
95                         if !reflect.DeepEqual(s.Filters, sub.Filters) {
96                                 continue
97                         }
98                         copy(sess.subscriptions[i:], sess.subscriptions[i+1:])
99                         sess.subscriptions = sess.subscriptions[:len(sess.subscriptions)-1]
100                         found = true
101                         break
102                 }
103                 sess.mtx.Unlock()
104                 sess.log.WithField("sub", sub).WithField("found", found).Debug("unsubscribe")
105                 if found {
106                         sess.sendq <- v0subscribeOK
107                         return nil
108                 }
109         } else {
110                 sess.log.WithField("Method", sub.Method).Info("unknown method")
111         }
112         sess.sendq <- v0subscribeFail
113         return nil
114 }
115
116 func (sess *v0session) EventMessage(e *event) ([]byte, error) {
117         detail := e.Detail()
118         if detail == nil {
119                 return nil, nil
120         }
121
122         var permTarget string
123         if detail.EventType == "delete" {
124                 // It's pointless to check permission by reading
125                 // ObjectUUID if it has just been deleted, but if the
126                 // client has permission on the parent project then
127                 // it's OK to send the event.
128                 permTarget = detail.ObjectOwnerUUID
129         } else {
130                 permTarget = detail.ObjectUUID
131         }
132         ok, err := sess.permChecker.Check(sess.ws.Request().Context(), permTarget)
133         if err != nil || !ok {
134                 return nil, err
135         }
136
137         kind, _ := sess.ac.KindForUUID(detail.ObjectUUID)
138         msg := map[string]interface{}{
139                 "msgID":             atomic.AddUint64(&sess.lastMsgID, 1),
140                 "id":                detail.ID,
141                 "uuid":              detail.UUID,
142                 "object_uuid":       detail.ObjectUUID,
143                 "object_owner_uuid": detail.ObjectOwnerUUID,
144                 "object_kind":       kind,
145                 "event_type":        detail.EventType,
146                 "event_at":          detail.EventAt,
147         }
148         if detail.Properties != nil && detail.Properties["text"] != nil {
149                 msg["properties"] = detail.Properties
150         } else {
151                 msgProps := map[string]map[string]interface{}{}
152                 for _, ak := range []string{"old_attributes", "new_attributes"} {
153                         eventAttrs, ok := detail.Properties[ak].(map[string]interface{})
154                         if !ok {
155                                 continue
156                         }
157                         msgAttrs := map[string]interface{}{}
158                         for _, k := range sendObjectAttributes {
159                                 if v, ok := eventAttrs[k]; ok {
160                                         msgAttrs[k] = v
161                                 }
162                         }
163                         msgProps[ak] = msgAttrs
164                 }
165                 msg["properties"] = msgProps
166         }
167         return json.Marshal(msg)
168 }
169
170 func (sess *v0session) Filter(e *event) bool {
171         sess.mtx.Lock()
172         defer sess.mtx.Unlock()
173         for _, sub := range sess.subscriptions {
174                 if sub.match(sess, e) {
175                         return true
176                 }
177         }
178         return false
179 }
180
181 func (sub *v0subscribe) sendOldEvents(sess *v0session) {
182         if sub.LastLogID == 0 {
183                 return
184         }
185         sess.log.WithField("LastLogID", sub.LastLogID).Debug("sendOldEvents")
186         // Here we do a "select id" query and queue an event for every
187         // log since the given ID, then use (*event)Detail() to
188         // retrieve the whole row and decide whether to send it. This
189         // approach is very inefficient if the subscriber asks for
190         // last_log_id==1, even if the filters end up matching very
191         // few events.
192         //
193         // To mitigate this, filter on "created > 10 minutes ago" when
194         // retrieving the list of old event IDs to consider.
195         rows, err := sess.db.Query(
196                 `SELECT id FROM logs WHERE id > $1 AND created_at > $2 ORDER BY id`,
197                 sub.LastLogID,
198                 time.Now().UTC().Add(-10*time.Minute).Format(time.RFC3339Nano))
199         if err != nil {
200                 sess.log.WithError(err).Error("sendOldEvents db.Query failed")
201                 return
202         }
203
204         var ids []uint64
205         for rows.Next() {
206                 var id uint64
207                 err := rows.Scan(&id)
208                 if err != nil {
209                         sess.log.WithError(err).Error("sendOldEvents row Scan failed")
210                         continue
211                 }
212                 ids = append(ids, id)
213         }
214         if err := rows.Err(); err != nil {
215                 sess.log.WithError(err).Error("sendOldEvents db.Query failed")
216         }
217         rows.Close()
218
219         for _, id := range ids {
220                 for len(sess.sendq)*2 > cap(sess.sendq) {
221                         // Ugly... but if we fill up the whole client
222                         // queue with a backlog of old events, a
223                         // single new event will overflow it and
224                         // terminate the connection, and then the
225                         // client will probably reconnect and do the
226                         // same thing all over again.
227                         time.Sleep(100 * time.Millisecond)
228                         if sess.ws.Request().Context().Err() != nil {
229                                 // Session terminated while we were sleeping
230                                 return
231                         }
232                 }
233                 now := time.Now()
234                 e := &event{
235                         LogID:    id,
236                         Received: now,
237                         Ready:    now,
238                         db:       sess.db,
239                 }
240                 if sub.match(sess, e) {
241                         select {
242                         case sess.sendq <- e:
243                         case <-sess.ws.Request().Context().Done():
244                                 return
245                         }
246                 }
247         }
248 }
249
250 type v0subscribe struct {
251         Method    string
252         Filters   []v0filter
253         LastLogID int64 `json:"last_log_id"`
254
255         funcs []func(*event) bool
256 }
257
258 type v0filter [3]interface{}
259
260 func (sub *v0subscribe) match(sess *v0session, e *event) bool {
261         log := sess.log.WithField("LogID", e.LogID)
262         detail := e.Detail()
263         if detail == nil {
264                 log.Error("match failed, no detail")
265                 return false
266         }
267         log = log.WithField("funcs", len(sub.funcs))
268         for i, f := range sub.funcs {
269                 if !f(e) {
270                         log.WithField("func", i).Debug("match failed")
271                         return false
272                 }
273         }
274         log.Debug("match passed")
275         return true
276 }
277
278 func (sub *v0subscribe) prepare(sess *v0session) {
279         for _, f := range sub.Filters {
280                 if len(f) != 3 {
281                         continue
282                 }
283                 if col, ok := f[0].(string); ok && col == "event_type" {
284                         op, ok := f[1].(string)
285                         if !ok || op != "in" {
286                                 continue
287                         }
288                         arr, ok := f[2].([]interface{})
289                         if !ok {
290                                 continue
291                         }
292                         var strs []string
293                         for _, s := range arr {
294                                 if s, ok := s.(string); ok {
295                                         strs = append(strs, s)
296                                 }
297                         }
298                         sub.funcs = append(sub.funcs, func(e *event) bool {
299                                 for _, s := range strs {
300                                         if s == e.Detail().EventType {
301                                                 return true
302                                         }
303                                 }
304                                 return false
305                         })
306                 } else if ok && col == "created_at" {
307                         op, ok := f[1].(string)
308                         if !ok {
309                                 continue
310                         }
311                         tstr, ok := f[2].(string)
312                         if !ok {
313                                 continue
314                         }
315                         t, err := time.Parse(time.RFC3339Nano, tstr)
316                         if err != nil {
317                                 sess.log.WithField("data", tstr).WithError(err).Info("time.Parse failed")
318                                 continue
319                         }
320                         var fn func(*event) bool
321                         switch op {
322                         case ">=":
323                                 fn = func(e *event) bool {
324                                         return !e.Detail().CreatedAt.Before(t)
325                                 }
326                         case "<=":
327                                 fn = func(e *event) bool {
328                                         return !e.Detail().CreatedAt.After(t)
329                                 }
330                         case ">":
331                                 fn = func(e *event) bool {
332                                         return e.Detail().CreatedAt.After(t)
333                                 }
334                         case "<":
335                                 fn = func(e *event) bool {
336                                         return e.Detail().CreatedAt.Before(t)
337                                 }
338                         case "=":
339                                 fn = func(e *event) bool {
340                                         return e.Detail().CreatedAt.Equal(t)
341                                 }
342                         default:
343                                 sess.log.WithField("operator", op).Info("bogus operator")
344                                 continue
345                         }
346                         sub.funcs = append(sub.funcs, fn)
347                 }
348         }
349 }