8460: Support created_at filters.
[arvados.git] / services / ws / session_v0.go
index 122767b1e5c6bd46d34c658accd6a91060b9334e..2035acbd5d573efa33a7ec75895a39379809a400 100644 (file)
@@ -5,6 +5,7 @@ import (
        "errors"
        "log"
        "sync"
+       "time"
 
        "git.curoverse.com/arvados.git/sdk/go/arvados"
 )
@@ -12,30 +13,22 @@ import (
 var (
        errQueueFull   = errors.New("client queue full")
        errFrameTooBig = errors.New("frame too big")
+
+       sendObjectAttributes = []string{"state", "name"}
 )
 
-type sessionV0 struct {
-       ws          wsConn
-       permChecker permChecker
-       subscribed  map[string]bool
-       eventTypes  map[string]bool
-       mtx         sync.Mutex
-       setupOnce   sync.Once
+type v0session struct {
+       ws            wsConn
+       permChecker   permChecker
+       subscriptions []v0subscribe
+       mtx           sync.Mutex
+       setupOnce     sync.Once
 }
 
-type v0subscribe struct {
-       Method  string
-       Filters []v0filter
-}
-
-type v0filter []interface{}
-
 func NewSessionV0(ws wsConn, ac arvados.Client) (session, error) {
-       sess := &sessionV0{
+       sess := &v0session{
                ws:          ws,
                permChecker: NewPermChecker(ac),
-               subscribed:  make(map[string]bool),
-               eventTypes:  make(map[string]bool),
        }
 
        err := ws.Request().ParseForm()
@@ -50,72 +43,30 @@ func NewSessionV0(ws wsConn, ac arvados.Client) (session, error) {
        return sess, nil
 }
 
-func (sess *sessionV0) debugLogf(s string, args ...interface{}) {
+func (sess *v0session) debugLogf(s string, args ...interface{}) {
        args = append([]interface{}{sess.ws.Request().RemoteAddr}, args...)
        debugLogf("%s "+s, args...)
 }
 
-// If every client subscription message includes filters consisting
-// only of [["event_type","in",...]] then send only the requested
-// event types. Otherwise, clear sess.eventTypes and send all event
-// types from now on.
-func (sess *sessionV0) checkFilters(filters []v0filter) {
-       if sess.eventTypes == nil {
-               // Already received a subscription request without
-               // event_type filters.
-               return
-       }
-       eventTypes := sess.eventTypes
-       sess.eventTypes = nil
-       if len(filters) == 0 {
-               return
-       }
-       useFilters := false
-       for _, f := range filters {
-               col, ok := f[0].(string)
-               if !ok || col != "event_type" {
-                       continue
-               }
-               op, ok := f[1].(string)
-               if !ok || op != "in" {
-                       return
-               }
-               arr, ok := f[2].([]interface{})
-               if !ok {
-                       return
-               }
-               useFilters = true
-               for _, s := range arr {
-                       if s, ok := s.(string); ok {
-                               eventTypes[s] = true
-                       } else {
-                               return
-                       }
-               }
-       }
-       if useFilters {
-               sess.debugLogf("eventTypes %+v", eventTypes)
-               sess.eventTypes = eventTypes
-       }
-}
-
-func (sess *sessionV0) Receive(msg map[string]interface{}, buf []byte) {
+func (sess *v0session) Receive(msg map[string]interface{}, buf []byte) []byte {
        sess.debugLogf("received message: %+v", msg)
        var sub v0subscribe
        if err := json.Unmarshal(buf, &sub); err != nil {
                sess.debugLogf("ignored unrecognized request: %s", err)
-               return
+               return nil
        }
        if sub.Method == "subscribe" {
-               sess.debugLogf("subscribing to *")
+               sub.prepare()
+               sess.debugLogf("subscription: %v", sub)
                sess.mtx.Lock()
-               sess.checkFilters(sub.Filters)
-               sess.subscribed["*"] = true
+               sess.subscriptions = append(sess.subscriptions, sub)
                sess.mtx.Unlock()
+               return []byte(`{"status":200}`)
        }
+       return []byte(`{"status":400}`)
 }
 
-func (sess *sessionV0) EventMessage(e *event) ([]byte, error) {
+func (sess *v0session) EventMessage(e *event) ([]byte, error) {
        detail := e.Detail()
        if detail == nil {
                return nil, nil
@@ -136,26 +87,130 @@ func (sess *sessionV0) EventMessage(e *event) ([]byte, error) {
        }
        if detail.Properties != nil && detail.Properties["text"] != nil {
                msg["properties"] = detail.Properties
+       } else {
+               msgProps := map[string]map[string]interface{}{}
+               for _, ak := range []string{"old_attributes", "new_attributes"} {
+                       eventAttrs, ok := detail.Properties[ak].(map[string]interface{})
+                       if !ok {
+                               continue
+                       }
+                       msgAttrs := map[string]interface{}{}
+                       for _, k := range sendObjectAttributes {
+                               if v, ok := eventAttrs[k]; ok {
+                                       msgAttrs[k] = v
+                               }
+                       }
+                       msgProps[ak] = msgAttrs
+               }
+               msg["properties"] = msgProps
        }
        return json.Marshal(msg)
 }
 
-func (sess *sessionV0) Filter(e *event) bool {
-       detail := e.Detail()
+func (sess *v0session) Filter(e *event) bool {
        sess.mtx.Lock()
        defer sess.mtx.Unlock()
-       switch {
-       case sess.eventTypes != nil && !sess.eventTypes[detail.EventType]:
-               return false
-       case sess.subscribed["*"]:
-               return true
-       case detail == nil:
-               return false
-       case sess.subscribed[detail.ObjectUUID]:
-               return true
-       case sess.subscribed[detail.ObjectOwnerUUID]:
-               return true
-       default:
+       for _, sub := range sess.subscriptions {
+               if sub.match(e) {
+                       return true
+               }
+       }
+       return false
+}
+
+type v0subscribe struct {
+       Method  string
+       Filters []v0filter
+       funcs   []func(*event) bool
+}
+
+type v0filter [3]interface{}
+
+func (sub *v0subscribe) match(e *event) bool {
+       detail := e.Detail()
+       if detail == nil {
                return false
        }
+       debugLogf("sub.match: len(funcs)==%d", len(sub.funcs))
+       for i, f := range sub.funcs {
+               if !f(e) {
+                       debugLogf("sub.match: failed on func %d", i)
+                       return false
+               }
+       }
+       return true
+}
+
+func (sub *v0subscribe) prepare() {
+       for _, f := range sub.Filters {
+               if len(f) != 3 {
+                       continue
+               }
+               if col, ok := f[0].(string); ok && col == "event_type" {
+                       op, ok := f[1].(string)
+                       if !ok || op != "in" {
+                               continue
+                       }
+                       arr, ok := f[2].([]interface{})
+                       if !ok {
+                               continue
+                       }
+                       var strs []string
+                       for _, s := range arr {
+                               if s, ok := s.(string); ok {
+                                       strs = append(strs, s)
+                               }
+                       }
+                       sub.funcs = append(sub.funcs, func(e *event) bool {
+                               debugLogf("event_type func: %v in %v", e.Detail().EventType, strs)
+                               for _, s := range strs {
+                                       if s == e.Detail().EventType {
+                                               return true
+                                       }
+                               }
+                               return false
+                       })
+               } else if ok && col == "created_at" {
+                       op, ok := f[1].(string)
+                       if !ok {
+                               continue
+                       }
+                       tstr, ok := f[2].(string)
+                       if !ok {
+                               continue
+                       }
+                       t, err := time.Parse(time.RFC3339Nano, tstr)
+                       if err != nil {
+                               debugLogf("time.Parse(%q): %s", tstr, err)
+                               continue
+                       }
+                       switch op {
+                       case ">=":
+                               sub.funcs = append(sub.funcs, func(e *event) bool {
+                                       debugLogf("created_at func: %v >= %v", e.Detail().CreatedAt, t)
+                                       return !e.Detail().CreatedAt.Before(t)
+                               })
+                       case "<=":
+                               sub.funcs = append(sub.funcs, func(e *event) bool {
+                                       debugLogf("created_at func: %v <= %v", e.Detail().CreatedAt, t)
+                                       return !e.Detail().CreatedAt.After(t)
+                               })
+                       case ">":
+                               sub.funcs = append(sub.funcs, func(e *event) bool {
+                                       debugLogf("created_at func: %v > %v", e.Detail().CreatedAt, t)
+                                       return e.Detail().CreatedAt.After(t)
+                               })
+                       case "<":
+                               sub.funcs = append(sub.funcs, func(e *event) bool {
+                                       debugLogf("created_at func: %v < %v", e.Detail().CreatedAt, t)
+                                       return e.Detail().CreatedAt.Before(t)
+                               })
+                       case "=":
+                               sub.funcs = append(sub.funcs, func(e *event) bool {
+                                       debugLogf("created_at func: %v = %v", e.Detail().CreatedAt, t)
+                                       return e.Detail().CreatedAt.Equal(t)
+                               })
+                       }
+               }
+       }
 }