11870: minor update
[arvados.git] / services / ws / router.go
1 package main
2
3 import (
4         "encoding/json"
5         "io"
6         "net/http"
7         "strconv"
8         "sync"
9         "sync/atomic"
10         "time"
11
12         "git.curoverse.com/arvados.git/sdk/go/ctxlog"
13         "github.com/Sirupsen/logrus"
14         "golang.org/x/net/websocket"
15 )
16
17 type wsConn interface {
18         io.ReadWriter
19         Request() *http.Request
20         SetReadDeadline(time.Time) error
21         SetWriteDeadline(time.Time) error
22 }
23
24 type router struct {
25         Config         *wsConfig
26         eventSource    eventSource
27         newPermChecker func() permChecker
28
29         handler   *handler
30         mux       *http.ServeMux
31         setupOnce sync.Once
32
33         lastReqID  int64
34         lastReqMtx sync.Mutex
35
36         status routerDebugStatus
37 }
38
39 type routerDebugStatus struct {
40         ReqsReceived int64
41         ReqsActive   int64
42 }
43
44 type debugStatuser interface {
45         DebugStatus() interface{}
46 }
47
48 func (rtr *router) setup() {
49         rtr.handler = &handler{
50                 PingTimeout: rtr.Config.PingTimeout.Duration(),
51                 QueueSize:   rtr.Config.ClientEventQueue,
52         }
53         rtr.mux = http.NewServeMux()
54         rtr.mux.Handle("/websocket", rtr.makeServer(newSessionV0))
55         rtr.mux.Handle("/arvados/v1/events.ws", rtr.makeServer(newSessionV1))
56         rtr.mux.Handle("/debug.json", rtr.jsonHandler(rtr.DebugStatus))
57         rtr.mux.Handle("/status.json", rtr.jsonHandler(rtr.Status))
58
59         health := http.NewServeMux()
60         rtr.mux.Handle("/_health/", rtr.mgmtAuth(health))
61         health.Handle("/_health/ping", rtr.jsonHandler(rtr.HealthFunc(func() error { return nil })))
62         health.Handle("/_health/db", rtr.jsonHandler(rtr.HealthFunc(rtr.eventSource.DBHealth)))
63 }
64
65 func (rtr *router) makeServer(newSession sessionFactory) *websocket.Server {
66         return &websocket.Server{
67                 Handshake: func(c *websocket.Config, r *http.Request) error {
68                         return nil
69                 },
70                 Handler: websocket.Handler(func(ws *websocket.Conn) {
71                         t0 := time.Now()
72                         log := logger(ws.Request().Context())
73                         log.Info("connected")
74
75                         stats := rtr.handler.Handle(ws, rtr.eventSource,
76                                 func(ws wsConn, sendq chan<- interface{}) (session, error) {
77                                         return newSession(ws, sendq, rtr.eventSource.DB(), rtr.newPermChecker(), &rtr.Config.Client)
78                                 })
79
80                         log.WithFields(logrus.Fields{
81                                 "elapsed": time.Now().Sub(t0).Seconds(),
82                                 "stats":   stats,
83                         }).Info("disconnect")
84                         ws.Close()
85                 }),
86         }
87 }
88
89 func (rtr *router) newReqID() string {
90         rtr.lastReqMtx.Lock()
91         defer rtr.lastReqMtx.Unlock()
92         id := time.Now().UnixNano()
93         if id <= rtr.lastReqID {
94                 id = rtr.lastReqID + 1
95         }
96         return strconv.FormatInt(id, 36)
97 }
98
99 func (rtr *router) DebugStatus() interface{} {
100         s := map[string]interface{}{
101                 "HTTP":     rtr.status,
102                 "Outgoing": rtr.handler.DebugStatus(),
103         }
104         if es, ok := rtr.eventSource.(debugStatuser); ok {
105                 s["EventSource"] = es.DebugStatus()
106         }
107         return s
108 }
109
110 var pingResponseOK = map[string]string{"health": "OK"}
111
112 func (rtr *router) HealthFunc(f func() error) func() interface{} {
113         return func() interface{} {
114                 err := f()
115                 if err == nil {
116                         return pingResponseOK
117                 }
118                 return map[string]string{
119                         "health": "ERROR",
120                         "error":  err.Error(),
121                 }
122         }
123 }
124
125 func (rtr *router) Status() interface{} {
126         return map[string]interface{}{
127                 "Clients": atomic.LoadInt64(&rtr.status.ReqsActive),
128         }
129 }
130
131 func (rtr *router) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
132         rtr.setupOnce.Do(rtr.setup)
133         atomic.AddInt64(&rtr.status.ReqsReceived, 1)
134         atomic.AddInt64(&rtr.status.ReqsActive, 1)
135         defer atomic.AddInt64(&rtr.status.ReqsActive, -1)
136
137         logger := logger(req.Context()).
138                 WithField("RequestID", rtr.newReqID())
139         ctx := ctxlog.Context(req.Context(), logger)
140         req = req.WithContext(ctx)
141         logger.WithFields(logrus.Fields{
142                 "remoteAddr":      req.RemoteAddr,
143                 "reqForwardedFor": req.Header.Get("X-Forwarded-For"),
144         }).Info("accept request")
145         rtr.mux.ServeHTTP(resp, req)
146 }
147
148 func (rtr *router) mgmtAuth(h http.Handler) http.Handler {
149         return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
150                 if rtr.Config.ManagementToken == "" {
151                         http.Error(w, "disabled", http.StatusNotFound)
152                 } else if ah := r.Header.Get("Authorization"); ah == "" {
153                         http.Error(w, "authorization required", http.StatusUnauthorized)
154                 } else if ah != "Bearer "+rtr.Config.ManagementToken {
155                         http.Error(w, "authorization error", http.StatusForbidden)
156                 } else {
157                         h.ServeHTTP(w, r)
158                 }
159         })
160 }
161
162 func (rtr *router) jsonHandler(fn func() interface{}) http.Handler {
163         return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
164                 logger := logger(r.Context())
165                 w.Header().Set("Content-Type", "application/json")
166                 enc := json.NewEncoder(w)
167                 err := enc.Encode(fn())
168                 if err != nil {
169                         msg := "encode failed"
170                         logger.WithError(err).Error(msg)
171                         http.Error(w, msg, http.StatusInternalServerError)
172                 }
173         })
174 }