+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
package httpserver
import (
"net/http"
+ "sync/atomic"
+
+ "github.com/prometheus/client_golang/prometheus"
)
+// RequestCounter is an http.Handler that tracks the number of
+// requests in progress.
+type RequestCounter interface {
+ http.Handler
+
+ // Current() returns the number of requests in progress.
+ Current() int
+
+ // Max() returns the maximum number of concurrent requests
+ // that will be accepted.
+ Max() int
+}
+
type limiterHandler struct {
requests chan struct{}
handler http.Handler
+ count int64 // only used if cap(requests)==0
}
-func NewRequestLimiter(maxRequests int, handler http.Handler) http.Handler {
- return &limiterHandler{
+// NewRequestLimiter returns a RequestCounter that delegates up to
+// maxRequests at a time to the given handler, and responds 503 to all
+// incoming requests beyond that limit.
+//
+// "concurrent_requests" and "max_concurrent_requests" metrics are
+// registered with the given reg, if reg is not nil.
+func NewRequestLimiter(maxRequests int, handler http.Handler, reg *prometheus.Registry) RequestCounter {
+ h := &limiterHandler{
requests: make(chan struct{}, maxRequests),
handler: handler,
}
+ if reg != nil {
+ reg.MustRegister(prometheus.NewGaugeFunc(
+ prometheus.GaugeOpts{
+ Namespace: "arvados",
+ Name: "concurrent_requests",
+ Help: "Number of requests in progress",
+ },
+ func() float64 { return float64(h.Current()) },
+ ))
+ reg.MustRegister(prometheus.NewGaugeFunc(
+ prometheus.GaugeOpts{
+ Namespace: "arvados",
+ Name: "max_concurrent_requests",
+ Help: "Maximum number of concurrent requests",
+ },
+ func() float64 { return float64(h.Max()) },
+ ))
+ }
+ return h
+}
+
+func (h *limiterHandler) Current() int {
+ if cap(h.requests) == 0 {
+ return int(atomic.LoadInt64(&h.count))
+ }
+ return len(h.requests)
+}
+
+func (h *limiterHandler) Max() int {
+ return cap(h.requests)
}
func (h *limiterHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+ if cap(h.requests) == 0 {
+ atomic.AddInt64(&h.count, 1)
+ defer atomic.AddInt64(&h.count, -1)
+ h.handler.ServeHTTP(resp, req)
+ return
+ }
select {
case h.requests <- struct{}{}:
default: