8460: Check permissions.
authorTom Clegg <tom@curoverse.com>
Mon, 14 Nov 2016 15:38:14 +0000 (10:38 -0500)
committerTom Clegg <tom@curoverse.com>
Tue, 15 Nov 2016 07:45:55 +0000 (02:45 -0500)
sdk/go/arvados/client.go
services/ws/config.go
services/ws/event.go
services/ws/handler.go
services/ws/handler_v0.go
services/ws/handler_v1.go
services/ws/main.go
services/ws/pg.go
services/ws/proxy_client.go [new file with mode: 0644]
services/ws/router.go

index 36f4eb52ae298982dfa09ddf82b0cea08c2604f7..0c18d38974f8be6ce99460ef2713e220a2cff403 100644 (file)
@@ -41,6 +41,8 @@ type Client struct {
        // callers who use a Client to initialize an
        // arvadosclient.ArvadosClient.)
        KeepServiceURIs []string `json:",omitempty"`
+
+       dd *DiscoveryDocument
 }
 
 // The default http.Client used by a Client with Insecure==true and
@@ -198,14 +200,83 @@ func (c *Client) apiURL(path string) string {
 
 // DiscoveryDocument is the Arvados server's description of itself.
 type DiscoveryDocument struct {
-       DefaultCollectionReplication int   `json:"defaultCollectionReplication"`
-       BlobSignatureTTL             int64 `json:"blobSignatureTtl"`
+       BasePath                     string              `json:"basePath"`
+       DefaultCollectionReplication int                 `json:"defaultCollectionReplication"`
+       BlobSignatureTTL             int64               `json:"blobSignatureTtl"`
+       Schemas                      map[string]Schema   `json:"schemas"`
+       Resources                    map[string]Resource `json:"resources"`
+}
+
+type Resource struct {
+       Methods map[string]ResourceMethod `json:"methods"`
+}
+
+type ResourceMethod struct {
+       HTTPMethod string         `json:"httpMethod"`
+       Path       string         `json:"path"`
+       Response   MethodResponse `json:"response"`
+}
+
+type MethodResponse struct {
+       Ref string `json:"$ref"`
+}
+
+type Schema struct {
+       UUIDPrefix string `json:"uuidPrefix"`
 }
 
 // DiscoveryDocument returns a *DiscoveryDocument. The returned object
 // should not be modified: the same object may be returned by
 // subsequent calls.
 func (c *Client) DiscoveryDocument() (*DiscoveryDocument, error) {
+       if c.dd != nil {
+               return c.dd, nil
+       }
        var dd DiscoveryDocument
-       return &dd, c.RequestAndDecode(&dd, "GET", "discovery/v1/apis/arvados/v1/rest", nil, nil)
+       err := c.RequestAndDecode(&dd, "GET", "discovery/v1/apis/arvados/v1/rest", nil, nil)
+       if err != nil {
+               return nil, err
+       }
+       c.dd = &dd
+       return c.dd, nil
+}
+
+func (c *Client) PathForUUID(method, uuid string) (string, error) {
+       if len(uuid) != 27 {
+               return "", fmt.Errorf("invalid UUID: %q", uuid)
+       }
+       dd, err := c.DiscoveryDocument()
+       if err != nil {
+               return "", err
+       }
+       infix := uuid[6:11]
+       var model string
+       for m, s := range dd.Schemas {
+               if s.UUIDPrefix == infix {
+                       model = m
+                       break
+               }
+       }
+       if model == "" {
+               return "", fmt.Errorf("unrecognized UUID infix: %q", infix)
+       }
+       var resource string
+       for r, rsc := range dd.Resources {
+               if rsc.Methods["get"].Response.Ref == model {
+                       resource = r
+                       break
+               }
+       }
+       if resource == "" {
+               return "", fmt.Errorf("no resource for model: %q", model)
+       }
+       m, ok := dd.Resources[resource].Methods[method]
+       if !ok {
+               return "", fmt.Errorf("no method %q for resource %q", method, resource)
+       }
+       path := dd.BasePath + strings.Replace(m.Path, "{uuid}", uuid, -1)
+       if path[0] == '/' {
+               path = path[1:]
+       }
+       return path, nil
 }
index 3e3d91f292ea7969b3a844883996a88b8981213a..9c2e80a1728afba9998d8a1b40d58b1ecd311b8f 100644 (file)
@@ -23,12 +23,12 @@ func DefaultConfig() Config {
                        APIHost: "localhost:443",
                },
                Postgres: pgConfig{
-                       "dbname":          "arvados_test",
+                       "dbname":          "arvados_production",
                        "user":            "arvados",
                        "password":        "xyzzy",
                        "host":            "localhost",
                        "connect_timeout": "30",
-                       "sslmode":         "disable",
+                       "sslmode":         "require",
                },
                PingTimeout:      arvados.Duration(time.Minute),
                ClientEventQueue: 64,
index b6dda4968b83c3bd363d53a9bd18bbf80ed5c979..e34b6b4b58a6b54ddae945c84102578e309d369f 100644 (file)
@@ -7,6 +7,7 @@ import (
        "time"
 
        "git.curoverse.com/arvados.git/sdk/go/arvados"
+       "github.com/ghodss/yaml"
 )
 
 type eventSink interface {
@@ -15,11 +16,11 @@ type eventSink interface {
 }
 
 type eventSource interface {
-       NewSink(chan *event) eventSink
+       NewSink() eventSink
 }
 
 type event struct {
-       LogUUID  string
+       LogID    uint64
        Received time.Time
        Serial   uint64
 
@@ -39,18 +40,24 @@ func (e *event) Detail() *arvados.Log {
                return e.logRow
        }
        var logRow arvados.Log
-       var oldAttrs, newAttrs []byte
-       e.err = e.db.QueryRow(`SELECT id, uuid, object_uuid, object_owner_uuid, event_type, created_at, old_attributes, new_attributes FROM logs WHERE uuid = ?`, e.LogUUID).Scan(
+       var propYAML []byte
+       e.err = e.db.QueryRow(`SELECT id, uuid, object_uuid, object_owner_uuid, event_type, created_at, properties FROM logs WHERE id = $1`, e.LogID).Scan(
                &logRow.ID,
                &logRow.UUID,
                &logRow.ObjectUUID,
                &logRow.ObjectOwnerUUID,
                &logRow.EventType,
                &logRow.CreatedAt,
-               &oldAttrs,
-               &newAttrs)
+               &propYAML)
        if e.err != nil {
-               log.Printf("retrieving log row %s: %s", e.LogUUID, e.err)
+               log.Printf("retrieving log row %d: %s", e.LogID, e.err)
+               return nil
        }
+       e.err = yaml.Unmarshal(propYAML, &logRow.Properties)
+       if e.err != nil {
+               log.Printf("decoding yaml for log row %d: %s", e.LogID, e.err)
+               return nil
+       }
+       e.logRow = &logRow
        return e.logRow
 }
index fe47a62ccf3092b47924c3a903d04a4ac2534aab..ba8f945dfce27a9a341bf1b6ff35905a9a438f19 100644 (file)
@@ -14,6 +14,7 @@ type wsConn interface {
        io.ReadWriter
        Request() *http.Request
        SetReadDeadline(time.Time) error
+       SetWriteDeadline(time.Time) error
 }
 
 type timeouter interface {
index c728d121f3e21c010ba61fefc3344f3d3d65bf14..eb076b5bb7cffd77b420b81b9c7d5fefdb0d54bc 100644 (file)
@@ -7,6 +7,8 @@ import (
        "log"
        "sync"
        "time"
+
+       "git.curoverse.com/arvados.git/sdk/go/arvados"
 )
 
 var (
@@ -15,6 +17,7 @@ var (
 )
 
 type handlerV0 struct {
+       Client      arvados.Client
        PingTimeout time.Duration
        QueueSize   int
 }
@@ -29,6 +32,18 @@ func (h *handlerV0) Handle(ws wsConn, events <-chan *event) {
        mtx := sync.Mutex{}
        subscribed := make(map[string]bool)
 
+       proxyClient := NewProxyClient(h.Client)
+       {
+               err := ws.Request().ParseForm()
+               if err != nil {
+                       log.Printf("%s ParseForm: %s", ws.Request().RemoteAddr, err)
+                       return
+               }
+               token := ws.Request().Form.Get("api_token")
+               h.debugLogf(ws, "handlerV0: token = %+q", token)
+               proxyClient.SetToken(token)
+       }
+
        stopped := make(chan struct{})
        stop := make(chan error, 5)
 
@@ -40,21 +55,13 @@ func (h *handlerV0) Handle(ws wsConn, events <-chan *event) {
                                return
                        default:
                        }
-                       ws.SetReadDeadline(time.Now().Add(h.PingTimeout))
+                       ws.SetReadDeadline(time.Now().Add(24 * 365 * time.Hour))
                        n, err := ws.Read(buf)
                        h.debugLogf(ws, "received frame: %q", buf[:n])
                        if err == nil && n == len(buf) {
                                err = errFrameTooBig
                        }
                        if err, ok := err.(timeouter); ok && err.Timeout() {
-                               // If the outgoing queue is empty,
-                               // send an empty message. This can
-                               // help detect a disconnected network
-                               // socket, and prevent an idle socket
-                               // from being closed.
-                               if len(queue) == 0 {
-                                       queue <- nil
-                               }
                                continue
                        }
                        if err != nil {
@@ -80,6 +87,7 @@ func (h *handlerV0) Handle(ws wsConn, events <-chan *event) {
        go func() {
                for e := range queue {
                        if e == nil {
+                               ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
                                _, err := ws.Write([]byte("{}\n"))
                                if err != nil {
                                        h.debugLogf(ws, "handlerV0: write: %s", err)
@@ -92,7 +100,18 @@ func (h *handlerV0) Handle(ws wsConn, events <-chan *event) {
                        if detail == nil {
                                continue
                        }
-                       // FIXME: check permission
+
+                       ok, err := proxyClient.CheckReadPermission(detail.UUID)
+                       if err != nil {
+                               log.Printf("CheckReadPermission: %s", err)
+                               stop <- err
+                               break
+                       }
+                       if !ok {
+                               h.debugLogf(ws, "handlerV0: skip event %d", e.Serial)
+                               continue
+                       }
+
                        buf, err := json.Marshal(map[string]interface{}{
                                "msgID":             e.Serial,
                                "id":                detail.ID,
@@ -105,14 +124,18 @@ func (h *handlerV0) Handle(ws wsConn, events <-chan *event) {
                                log.Printf("error encoding: ", err)
                                continue
                        }
+                       h.debugLogf(ws, "handlerV0: send event %d: %q", e.Serial, buf)
+                       ws.SetWriteDeadline(time.Now().Add(h.PingTimeout))
                        _, err = ws.Write(append(buf, byte('\n')))
                        if err != nil {
                                h.debugLogf(ws, "handlerV0: write: %s", err)
                                stop <- err
                                break
                        }
+                       h.debugLogf(ws, "handlerV0: sent event %d", e.Serial)
+               }
+               for _ = range queue {
                }
-               for _ = range queue {}
        }()
 
        // Filter incoming events against the current subscription
@@ -129,6 +152,9 @@ func (h *handlerV0) Handle(ws wsConn, events <-chan *event) {
                        }
                }
 
+               ticker := time.NewTicker(h.PingTimeout)
+               defer ticker.Stop()
+
                for {
                        var e *event
                        var ok bool
@@ -136,6 +162,16 @@ func (h *handlerV0) Handle(ws wsConn, events <-chan *event) {
                        case <-stopped:
                                close(queue)
                                return
+                       case <-ticker.C:
+                               // If the outgoing queue is empty,
+                               // send an empty message. This can
+                               // help detect a disconnected network
+                               // socket, and prevent an idle socket
+                               // from being closed.
+                               if len(queue) == 0 {
+                                       queue <- nil
+                               }
+                               continue
                        case e, ok = <-events:
                                if !ok {
                                        close(queue)
index 4160d8696d48197c5f3d6439ad44ce167e53060b..1b8549e26f40b5cdfd95873acf328734b7dae661 100644 (file)
@@ -2,9 +2,12 @@ package main
 
 import (
        "time"
+
+       "git.curoverse.com/arvados.git/sdk/go/arvados"
 )
 
 type handlerV1 struct {
+       Client      arvados.Client
        PingTimeout time.Duration
        QueueSize   int
 }
index 0f978231b9658068bd8351a582c4328d3383e1e0..28662440d07162d200a862672778dd62cb590902 100644 (file)
@@ -35,18 +35,20 @@ func main() {
                return
        }
 
+       eventSource := &pgEventSource{
+               PgConfig:  cfg.Postgres,
+               QueueSize: cfg.ServerEventQueue,
+       }
        srv := &http.Server{
                Addr:           cfg.Listen,
                ReadTimeout:    time.Minute,
                WriteTimeout:   time.Minute,
                MaxHeaderBytes: 1 << 20,
                Handler: &router{
-                       Config: &cfg,
-                       eventSource: &pgEventSource{
-                               PgConfig:  cfg.Postgres,
-                               QueueSize: cfg.ServerEventQueue,
-                       },
+                       Config:      &cfg,
+                       eventSource: eventSource,
                },
        }
+       eventSource.NewSink().Stop()
        log.Fatal(srv.ListenAndServe())
 }
index 51bc92ca6706bb6c323d6039283ca7a8924aa83b..5e8e63e01fc3ec27ccffbb20875d1c5df79639db 100644 (file)
@@ -3,6 +3,7 @@ package main
 import (
        "database/sql"
        "log"
+       "strconv"
        "strings"
        "sync"
        "time"
@@ -52,15 +53,17 @@ func (ps *pgEventSource) run() {
                        // on missed events, we cannot recover from a
                        // dropped connection without breaking our
                        // promises to clients.
-                       log.Fatal(err)
+                       log.Fatalf("pgEventSource listener problem: %s", err)
                }
        })
        err = listener.Listen("logs")
        if err != nil {
                log.Fatal(err)
        }
+       debugLogf("pgEventSource listening")
        go func() {
                for _ = range time.NewTicker(time.Minute).C {
+                       debugLogf("pgEventSource listener ping")
                        listener.Ping()
                }
        }()
@@ -74,7 +77,7 @@ func (ps *pgEventSource) run() {
                        // concurrent queries would be bounded by
                        // client_count X client_queue_size.
                        e.Detail()
-                       debugLogf("%+v", e)
+                       debugLogf("event %d detail %+v", e.Serial, e.Detail())
                        ps.mtx.Lock()
                        for sink := range ps.sinks {
                                sink.channel <- e
@@ -88,33 +91,35 @@ func (ps *pgEventSource) run() {
                if pqEvent.Channel != "logs" {
                        continue
                }
+               logID, err := strconv.ParseUint(pqEvent.Extra, 10, 64)
+               if err != nil {
+                       log.Printf("bad notify payload: %+v", pqEvent)
+                       continue
+               }
                serial++
                e := &event{
-                       LogUUID:  pqEvent.Extra,
+                       LogID:    logID,
                        Received: time.Now(),
                        Serial:   serial,
                        db:       db,
                }
-               debugLogf("%+v", e)
+               debugLogf("event %d %+v", e.Serial, e)
                eventQueue <- e
                go e.Detail()
        }
 }
 
-// NewSink subscribes to the event source. If c is not nil, it will be
-// used as the event channel. Otherwise, a new channel will be
-// created. Either way, the sink channel will be returned by the
-// Channel() method of the returned eventSink. All subsequent events
-// will be sent to the sink channel. The caller must ensure events are
-// received from the sink channel as quickly as possible: when one
-// sink blocks, all other sinks also block.
-func (ps *pgEventSource) NewSink(c chan *event) eventSink {
+// NewSink subscribes to the event source. NewSink returns an
+// eventSink, whose Channel() method returns a channel: a pointer to
+// each subsequent event will be sent to that channel.
+//
+// The caller must ensure events are received from the sink channel as
+// quickly as possible because when one sink stops being ready, all
+// other sinks block.
+func (ps *pgEventSource) NewSink() eventSink {
        ps.setupOnce.Do(ps.setup)
-       if c == nil {
-               c = make(chan *event, 1)
-       }
        sink := &pgEventSink{
-               channel: c,
+               channel: make(chan *event, 1),
                source:  ps,
        }
        ps.mtx.Lock()
diff --git a/services/ws/proxy_client.go b/services/ws/proxy_client.go
new file mode 100644 (file)
index 0000000..28be2e2
--- /dev/null
@@ -0,0 +1,41 @@
+package main
+
+import (
+       "net/http"
+       "net/url"
+
+       "git.curoverse.com/arvados.git/sdk/go/arvados"
+)
+
+type proxyClient struct {
+       *arvados.Client
+}
+
+func NewProxyClient(ac arvados.Client) *proxyClient {
+       ac.AuthToken = ""
+       return &proxyClient{
+               Client: &ac,
+       }
+}
+
+func (pc *proxyClient) SetToken(token string) {
+       pc.Client.AuthToken = token
+}
+
+func (pc *proxyClient) CheckReadPermission(uuid string) (bool, error) {
+       var buf map[string]interface{}
+       path, err := pc.PathForUUID("get", uuid)
+       if err != nil {
+               return false, err
+       }
+       err = pc.RequestAndDecode(&buf, "GET", path, nil, url.Values{
+               "select": {`["uuid"]`},
+       })
+       if err, ok := err.(arvados.TransactionError); ok && err.StatusCode == http.StatusNotFound {
+               return false, nil
+       }
+       if err != nil {
+               return false, err
+       }
+       return true, nil
+}
index 685b6132abd32bb939f35240386c2aa080114004..30f93eab3aeb301f23668b93a3cc8994bcd992c8 100644 (file)
@@ -21,10 +21,12 @@ type router struct {
 func (rtr *router) setup() {
        rtr.mux = http.NewServeMux()
        rtr.mux.Handle("/websocket", rtr.makeServer(&handlerV0{
+               Client:      rtr.Config.Client,
                PingTimeout: rtr.Config.PingTimeout.Duration(),
                QueueSize:   rtr.Config.ClientEventQueue,
        }))
        rtr.mux.Handle("/arvados/v1/events.ws", rtr.makeServer(&handlerV1{
+               Client:      rtr.Config.Client,
                PingTimeout: rtr.Config.PingTimeout.Duration(),
                QueueSize:   rtr.Config.ClientEventQueue,
        }))
@@ -37,7 +39,7 @@ func (rtr *router) makeServer(handler handler) *websocket.Server {
                },
                Handler: websocket.Handler(func(ws *websocket.Conn) {
                        log.Printf("%v accepted", ws.Request().RemoteAddr)
-                       sink := rtr.eventSource.NewSink(nil)
+                       sink := rtr.eventSource.NewSink()
                        handler.Handle(ws, sink.Channel())
                        sink.Stop()
                        ws.Close()