Merge branch 'master' of git.curoverse.com:arvados into 11876-r-sdk
[arvados.git] / services / ws / router.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package main
6
7 import (
8         "encoding/json"
9         "io"
10         "net/http"
11         "strconv"
12         "sync"
13         "sync/atomic"
14         "time"
15
16         "git.curoverse.com/arvados.git/sdk/go/ctxlog"
17         "git.curoverse.com/arvados.git/sdk/go/health"
18         "github.com/Sirupsen/logrus"
19         "golang.org/x/net/websocket"
20 )
21
22 type wsConn interface {
23         io.ReadWriter
24         Request() *http.Request
25         SetReadDeadline(time.Time) error
26         SetWriteDeadline(time.Time) error
27 }
28
29 type router struct {
30         Config         *wsConfig
31         eventSource    eventSource
32         newPermChecker func() permChecker
33
34         handler   *handler
35         mux       *http.ServeMux
36         setupOnce sync.Once
37
38         lastReqID  int64
39         lastReqMtx sync.Mutex
40
41         status routerDebugStatus
42 }
43
44 type routerDebugStatus struct {
45         ReqsReceived int64
46         ReqsActive   int64
47 }
48
49 type debugStatuser interface {
50         DebugStatus() interface{}
51 }
52
53 func (rtr *router) setup() {
54         rtr.handler = &handler{
55                 PingTimeout: rtr.Config.PingTimeout.Duration(),
56                 QueueSize:   rtr.Config.ClientEventQueue,
57         }
58         rtr.mux = http.NewServeMux()
59         rtr.mux.Handle("/websocket", rtr.makeServer(newSessionV0))
60         rtr.mux.Handle("/arvados/v1/events.ws", rtr.makeServer(newSessionV1))
61         rtr.mux.Handle("/debug.json", rtr.jsonHandler(rtr.DebugStatus))
62         rtr.mux.Handle("/status.json", rtr.jsonHandler(rtr.Status))
63
64         rtr.mux.Handle("/_health/", &health.Handler{
65                 Token:  rtr.Config.ManagementToken,
66                 Prefix: "/_health/",
67                 Routes: health.Routes{
68                         "db": rtr.eventSource.DBHealth,
69                 },
70                 Log: func(r *http.Request, err error) {
71                         if err != nil {
72                                 logger(r.Context()).WithError(err).Error("error")
73                         }
74                 },
75         })
76 }
77
78 func (rtr *router) makeServer(newSession sessionFactory) *websocket.Server {
79         return &websocket.Server{
80                 Handshake: func(c *websocket.Config, r *http.Request) error {
81                         return nil
82                 },
83                 Handler: websocket.Handler(func(ws *websocket.Conn) {
84                         t0 := time.Now()
85                         log := logger(ws.Request().Context())
86                         log.Info("connected")
87
88                         stats := rtr.handler.Handle(ws, rtr.eventSource,
89                                 func(ws wsConn, sendq chan<- interface{}) (session, error) {
90                                         return newSession(ws, sendq, rtr.eventSource.DB(), rtr.newPermChecker(), &rtr.Config.Client)
91                                 })
92
93                         log.WithFields(logrus.Fields{
94                                 "elapsed": time.Now().Sub(t0).Seconds(),
95                                 "stats":   stats,
96                         }).Info("disconnect")
97                         ws.Close()
98                 }),
99         }
100 }
101
102 func (rtr *router) newReqID() string {
103         rtr.lastReqMtx.Lock()
104         defer rtr.lastReqMtx.Unlock()
105         id := time.Now().UnixNano()
106         if id <= rtr.lastReqID {
107                 id = rtr.lastReqID + 1
108         }
109         return strconv.FormatInt(id, 36)
110 }
111
112 func (rtr *router) DebugStatus() interface{} {
113         s := map[string]interface{}{
114                 "HTTP":     rtr.status,
115                 "Outgoing": rtr.handler.DebugStatus(),
116         }
117         if es, ok := rtr.eventSource.(debugStatuser); ok {
118                 s["EventSource"] = es.DebugStatus()
119         }
120         return s
121 }
122
123 func (rtr *router) Status() interface{} {
124         return map[string]interface{}{
125                 "Clients": atomic.LoadInt64(&rtr.status.ReqsActive),
126                 "Version": version,
127         }
128 }
129
130 func (rtr *router) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
131         rtr.setupOnce.Do(rtr.setup)
132         atomic.AddInt64(&rtr.status.ReqsReceived, 1)
133         atomic.AddInt64(&rtr.status.ReqsActive, 1)
134         defer atomic.AddInt64(&rtr.status.ReqsActive, -1)
135
136         logger := logger(req.Context()).
137                 WithField("RequestID", rtr.newReqID())
138         ctx := ctxlog.Context(req.Context(), logger)
139         req = req.WithContext(ctx)
140         logger.WithFields(logrus.Fields{
141                 "remoteAddr":      req.RemoteAddr,
142                 "reqForwardedFor": req.Header.Get("X-Forwarded-For"),
143         }).Info("accept request")
144         rtr.mux.ServeHTTP(resp, req)
145 }
146
147 func (rtr *router) jsonHandler(fn func() interface{}) http.Handler {
148         return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
149                 logger := logger(r.Context())
150                 w.Header().Set("Content-Type", "application/json")
151                 enc := json.NewEncoder(w)
152                 err := enc.Encode(fn())
153                 if err != nil {
154                         msg := "encode failed"
155                         logger.WithError(err).Error(msg)
156                         http.Error(w, msg, http.StatusInternalServerError)
157                 }
158         })
159 }