]> git.arvados.org - arvados.git/blob - services/ws/router.go
22845: added test
[arvados.git] / services / ws / router.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ws
6
7 import (
8         "io"
9         "net/http"
10         "sync"
11         "sync/atomic"
12         "time"
13
14         "git.arvados.org/arvados.git/sdk/go/arvados"
15         "git.arvados.org/arvados.git/sdk/go/ctxlog"
16         "git.arvados.org/arvados.git/sdk/go/health"
17         "git.arvados.org/arvados.git/sdk/go/httpserver"
18         "github.com/prometheus/client_golang/prometheus"
19         "github.com/sirupsen/logrus"
20         "golang.org/x/net/websocket"
21 )
22
23 type wsConn interface {
24         io.ReadWriter
25         Request() *http.Request
26         SetReadDeadline(time.Time) error
27         SetWriteDeadline(time.Time) error
28 }
29
30 type router struct {
31         client         *arvados.Client
32         cluster        *arvados.Cluster
33         eventSource    eventSource
34         newPermChecker func() permChecker
35
36         handler   *handler
37         mux       *http.ServeMux
38         setupOnce sync.Once
39         done      chan struct{}
40         reg       *prometheus.Registry
41 }
42
43 func (rtr *router) setup() {
44         mSockets := prometheus.NewGaugeVec(prometheus.GaugeOpts{
45                 Namespace: "arvados",
46                 Subsystem: "ws",
47                 Name:      "sockets",
48                 Help:      "Number of connected sockets",
49         }, []string{"version"})
50         rtr.reg.MustRegister(mSockets)
51
52         rtr.handler = &handler{
53                 PingTimeout: time.Duration(rtr.cluster.API.SendTimeout),
54                 QueueSize:   rtr.cluster.API.WebsocketClientEventQueue,
55         }
56         rtr.mux = http.NewServeMux()
57         rtr.mux.Handle("/websocket", rtr.makeServer(newSessionV0, mSockets.WithLabelValues("0")))
58         rtr.mux.Handle("/arvados/v1/events.ws", rtr.makeServer(newSessionV1, mSockets.WithLabelValues("1")))
59         rtr.mux.Handle("/_health/", &health.Handler{
60                 Token:  rtr.cluster.ManagementToken,
61                 Prefix: "/_health/",
62                 Routes: health.Routes{
63                         "db": rtr.eventSource.DBHealth,
64                 },
65                 Log: func(r *http.Request, err error) {
66                         if err != nil {
67                                 ctxlog.FromContext(r.Context()).WithError(err).Error("error")
68                         }
69                 },
70         })
71 }
72
73 func exemptFromDeadline(h http.Handler) http.Handler {
74         return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
75                 httpserver.ExemptFromDeadline(req)
76                 h.ServeHTTP(w, req)
77         })
78 }
79
80 func (rtr *router) makeServer(newSession sessionFactory, gauge prometheus.Gauge) http.Handler {
81         var connected int64
82         return exemptFromDeadline(&websocket.Server{
83                 Handshake: func(c *websocket.Config, r *http.Request) error {
84                         return nil
85                 },
86                 Handler: websocket.Handler(func(ws *websocket.Conn) {
87                         t0 := time.Now()
88                         logger := ctxlog.FromContext(ws.Request().Context())
89                         atomic.AddInt64(&connected, 1)
90                         gauge.Set(float64(atomic.LoadInt64(&connected)))
91
92                         stats := rtr.handler.Handle(ws, logger, rtr.eventSource,
93                                 func(ws wsConn, sendq chan<- interface{}) (session, error) {
94                                         return newSession(ws, sendq, rtr.eventSource.DB(), rtr.newPermChecker(), rtr.client)
95                                 })
96
97                         logger.WithFields(logrus.Fields{
98                                 "elapsed": time.Now().Sub(t0).Seconds(),
99                                 "stats":   stats,
100                         }).Info("client disconnected")
101                         ws.Close()
102                         atomic.AddInt64(&connected, -1)
103                         gauge.Set(float64(atomic.LoadInt64(&connected)))
104                 }),
105         })
106 }
107
108 func (rtr *router) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
109         rtr.setupOnce.Do(rtr.setup)
110         rtr.mux.ServeHTTP(httpserver.ResponseControllerShim{ResponseWriter: resp}, req)
111 }
112
113 func (rtr *router) CheckHealth() error {
114         rtr.setupOnce.Do(rtr.setup)
115         return rtr.eventSource.DBHealth()
116 }
117
118 func (rtr *router) Done() <-chan struct{} {
119         return rtr.done
120 }