8460: Reply to unparsable messages with status:400.
[arvados.git] / services / ws / router.go
1 package main
2
3 import (
4         "database/sql"
5         "io"
6         "net/http"
7         "strconv"
8         "sync"
9         "time"
10
11         "git.curoverse.com/arvados.git/sdk/go/arvados"
12         log "github.com/Sirupsen/logrus"
13         "golang.org/x/net/websocket"
14 )
15
16 type wsConn interface {
17         io.ReadWriter
18         Request() *http.Request
19         SetReadDeadline(time.Time) error
20         SetWriteDeadline(time.Time) error
21 }
22
23 type router struct {
24         Config *Config
25
26         eventSource eventSource
27         mux         *http.ServeMux
28         setupOnce   sync.Once
29
30         lastReqID  int64
31         lastReqMtx sync.Mutex
32 }
33
34 type sessionFactory func(wsConn, chan<- interface{}, arvados.Client, *sql.DB) (session, error)
35
36 func (rtr *router) setup() {
37         rtr.mux = http.NewServeMux()
38         rtr.mux.Handle("/websocket", rtr.makeServer(NewSessionV0))
39         rtr.mux.Handle("/arvados/v1/events.ws", rtr.makeServer(NewSessionV1))
40 }
41
42 func (rtr *router) makeServer(newSession sessionFactory) *websocket.Server {
43         handler := &handler{
44                 PingTimeout: rtr.Config.PingTimeout.Duration(),
45                 QueueSize:   rtr.Config.ClientEventQueue,
46                 NewSession: func(ws wsConn, sendq chan<- interface{}) (session, error) {
47                         return newSession(ws, sendq, rtr.Config.Client, rtr.eventSource.DB())
48                 },
49         }
50         return &websocket.Server{
51                 Handshake: func(c *websocket.Config, r *http.Request) error {
52                         return nil
53                 },
54                 Handler: websocket.Handler(func(ws *websocket.Conn) {
55                         t0 := time.Now()
56                         sink := rtr.eventSource.NewSink()
57                         logger(ws.Request().Context()).Info("connected")
58
59                         stats := handler.Handle(ws, sink.Channel())
60
61                         logger(ws.Request().Context()).WithFields(log.Fields{
62                                 "Elapsed": time.Now().Sub(t0).Seconds(),
63                                 "Stats":   stats,
64                         }).Info("disconnect")
65
66                         sink.Stop()
67                         ws.Close()
68                 }),
69         }
70 }
71
72 func (rtr *router) newReqID() string {
73         rtr.lastReqMtx.Lock()
74         defer rtr.lastReqMtx.Unlock()
75         id := time.Now().UnixNano()
76         if id <= rtr.lastReqID {
77                 id = rtr.lastReqID + 1
78         }
79         return strconv.FormatInt(id, 36)
80 }
81
82 func (rtr *router) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
83         rtr.setupOnce.Do(rtr.setup)
84         logger := logger(req.Context()).
85                 WithField("RequestID", rtr.newReqID())
86         ctx := contextWithLogger(req.Context(), logger)
87         req = req.WithContext(ctx)
88         logger.WithFields(log.Fields{
89                 "RemoteAddr":      req.RemoteAddr,
90                 "X-Forwarded-For": req.Header.Get("X-Forwarded-For"),
91         }).Info("accept request")
92         rtr.mux.ServeHTTP(resp, req)
93 }