8460: Merge branch 'master' into 8460-websocket-go
[arvados.git] / services / ws / session_v0.go
1 package main
2
3 import (
4         "database/sql"
5         "encoding/json"
6         "errors"
7         "log"
8         "sync"
9         "time"
10
11         "git.curoverse.com/arvados.git/sdk/go/arvados"
12 )
13
14 var (
15         errQueueFull   = errors.New("client queue full")
16         errFrameTooBig = errors.New("frame too big")
17
18         sendObjectAttributes = []string{"state", "name"}
19
20         v0subscribeOK   = []byte(`{"status":200}`)
21         v0subscribeFail = []byte(`{"status":400}`)
22 )
23
24 type v0session struct {
25         ws            wsConn
26         db            *sql.DB
27         permChecker   permChecker
28         subscriptions []v0subscribe
29         mtx           sync.Mutex
30         setupOnce     sync.Once
31 }
32
33 func NewSessionV0(ws wsConn, ac arvados.Client, db *sql.DB) (session, error) {
34         sess := &v0session{
35                 ws:          ws,
36                 db:          db,
37                 permChecker: NewPermChecker(ac),
38         }
39
40         err := ws.Request().ParseForm()
41         if err != nil {
42                 log.Printf("%s ParseForm: %s", ws.Request().RemoteAddr, err)
43                 return nil, err
44         }
45         token := ws.Request().Form.Get("api_token")
46         sess.permChecker.SetToken(token)
47         sess.debugLogf("token = %+q", token)
48
49         return sess, nil
50 }
51
52 func (sess *v0session) debugLogf(s string, args ...interface{}) {
53         args = append([]interface{}{sess.ws.Request().RemoteAddr}, args...)
54         debugLogf("%s "+s, args...)
55 }
56
57 func (sess *v0session) Receive(msg map[string]interface{}, buf []byte) [][]byte {
58         sess.debugLogf("received message: %+v", msg)
59         var sub v0subscribe
60         if err := json.Unmarshal(buf, &sub); err != nil {
61                 sess.debugLogf("ignored unrecognized request: %s", err)
62                 return nil
63         }
64         if sub.Method == "subscribe" {
65                 sub.prepare()
66                 sess.debugLogf("subscription: %v", sub)
67                 sess.mtx.Lock()
68                 sess.subscriptions = append(sess.subscriptions, sub)
69                 sess.mtx.Unlock()
70
71                 return append([][]byte{v0subscribeOK}, sub.getOldEvents(sess)...)
72         }
73         return [][]byte{v0subscribeFail}
74 }
75
76 func (sess *v0session) EventMessage(e *event) ([]byte, error) {
77         detail := e.Detail()
78         if detail == nil {
79                 return nil, nil
80         }
81
82         ok, err := sess.permChecker.Check(detail.ObjectUUID)
83         if err != nil || !ok {
84                 return nil, err
85         }
86
87         msg := map[string]interface{}{
88                 "msgID":             e.Serial,
89                 "id":                detail.ID,
90                 "uuid":              detail.UUID,
91                 "object_uuid":       detail.ObjectUUID,
92                 "object_owner_uuid": detail.ObjectOwnerUUID,
93                 "event_type":        detail.EventType,
94         }
95         if detail.Properties != nil && detail.Properties["text"] != nil {
96                 msg["properties"] = detail.Properties
97         } else {
98                 msgProps := map[string]map[string]interface{}{}
99                 for _, ak := range []string{"old_attributes", "new_attributes"} {
100                         eventAttrs, ok := detail.Properties[ak].(map[string]interface{})
101                         if !ok {
102                                 continue
103                         }
104                         msgAttrs := map[string]interface{}{}
105                         for _, k := range sendObjectAttributes {
106                                 if v, ok := eventAttrs[k]; ok {
107                                         msgAttrs[k] = v
108                                 }
109                         }
110                         msgProps[ak] = msgAttrs
111                 }
112                 msg["properties"] = msgProps
113         }
114         return json.Marshal(msg)
115 }
116
117 func (sess *v0session) Filter(e *event) bool {
118         sess.mtx.Lock()
119         defer sess.mtx.Unlock()
120         for _, sub := range sess.subscriptions {
121                 if sub.match(e) {
122                         return true
123                 }
124         }
125         return false
126 }
127
128 func (sub *v0subscribe) getOldEvents(sess *v0session) (msgs [][]byte) {
129         if sub.LastLogID == 0 {
130                 return
131         }
132         debugLogf("getOldEvents(%d)", sub.LastLogID)
133         // Here we do a "select id" query and queue an event for every
134         // log since the given ID, then use (*event)Detail() to
135         // retrieve the whole row and decide whether to send it. This
136         // approach is very inefficient if the subscriber asks for
137         // last_log_id==1, even if the filters end up matching very
138         // few events.
139         //
140         // To mitigate this, filter on "created > 10 minutes ago" when
141         // retrieving the list of old event IDs to consider.
142         rows, err := sess.db.Query(
143                 `SELECT id FROM logs WHERE id > $1 AND created_at > $2 ORDER BY id`,
144                 sub.LastLogID,
145                 time.Now().UTC().Add(-10*time.Minute).Format(time.RFC3339Nano))
146         if err != nil {
147                 errorLogf("db.Query: %s", err)
148                 return
149         }
150         for rows.Next() {
151                 var id uint64
152                 err := rows.Scan(&id)
153                 if err != nil {
154                         errorLogf("Scan: %s", err)
155                         continue
156                 }
157                 e := &event{
158                         LogID:    id,
159                         Received: time.Now(),
160                         db:       sess.db,
161                 }
162                 if !sub.match(e) {
163                         debugLogf("skip old event %+v", e)
164                         continue
165                 }
166                 msg, err := sess.EventMessage(e)
167                 if err != nil {
168                         debugLogf("event marshal: %s", err)
169                         continue
170                 }
171                 debugLogf("old event: %s", string(msg))
172                 msgs = append(msgs, msg)
173         }
174         if err := rows.Err(); err != nil {
175                 errorLogf("db.Query: %s", err)
176         }
177         return
178 }
179
180 type v0subscribe struct {
181         Method    string
182         Filters   []v0filter
183         LastLogID int64 `json:"last_log_id"`
184
185         funcs []func(*event) bool
186 }
187
188 type v0filter [3]interface{}
189
190 func (sub *v0subscribe) match(e *event) bool {
191         detail := e.Detail()
192         if detail == nil {
193                 debugLogf("match(%d): failed on no detail", e.LogID)
194                 return false
195         }
196         for i, f := range sub.funcs {
197                 if !f(e) {
198                         debugLogf("match(%d): failed on func %d", e.LogID, i)
199                         return false
200                 }
201         }
202         debugLogf("match(%d): passed %d funcs", e.LogID, len(sub.funcs))
203         return true
204 }
205
206 func (sub *v0subscribe) prepare() {
207         for _, f := range sub.Filters {
208                 if len(f) != 3 {
209                         continue
210                 }
211                 if col, ok := f[0].(string); ok && col == "event_type" {
212                         op, ok := f[1].(string)
213                         if !ok || op != "in" {
214                                 continue
215                         }
216                         arr, ok := f[2].([]interface{})
217                         if !ok {
218                                 continue
219                         }
220                         var strs []string
221                         for _, s := range arr {
222                                 if s, ok := s.(string); ok {
223                                         strs = append(strs, s)
224                                 }
225                         }
226                         sub.funcs = append(sub.funcs, func(e *event) bool {
227                                 debugLogf("event_type func: %v in %v", e.Detail().EventType, strs)
228                                 for _, s := range strs {
229                                         if s == e.Detail().EventType {
230                                                 return true
231                                         }
232                                 }
233                                 return false
234                         })
235                 } else if ok && col == "created_at" {
236                         op, ok := f[1].(string)
237                         if !ok {
238                                 continue
239                         }
240                         tstr, ok := f[2].(string)
241                         if !ok {
242                                 continue
243                         }
244                         t, err := time.Parse(time.RFC3339Nano, tstr)
245                         if err != nil {
246                                 debugLogf("time.Parse(%q): %s", tstr, err)
247                                 continue
248                         }
249                         switch op {
250                         case ">=":
251                                 sub.funcs = append(sub.funcs, func(e *event) bool {
252                                         debugLogf("created_at func: %v >= %v", e.Detail().CreatedAt, t)
253                                         return !e.Detail().CreatedAt.Before(t)
254                                 })
255                         case "<=":
256                                 sub.funcs = append(sub.funcs, func(e *event) bool {
257                                         debugLogf("created_at func: %v <= %v", e.Detail().CreatedAt, t)
258                                         return !e.Detail().CreatedAt.After(t)
259                                 })
260                         case ">":
261                                 sub.funcs = append(sub.funcs, func(e *event) bool {
262                                         debugLogf("created_at func: %v > %v", e.Detail().CreatedAt, t)
263                                         return e.Detail().CreatedAt.After(t)
264                                 })
265                         case "<":
266                                 sub.funcs = append(sub.funcs, func(e *event) bool {
267                                         debugLogf("created_at func: %v < %v", e.Detail().CreatedAt, t)
268                                         return e.Detail().CreatedAt.Before(t)
269                                 })
270                         case "=":
271                                 sub.funcs = append(sub.funcs, func(e *event) bool {
272                                         debugLogf("created_at func: %v = %v", e.Detail().CreatedAt, t)
273                                         return e.Detail().CreatedAt.Equal(t)
274                                 })
275                         }
276                 }
277         }
278 }