21285: Use separate request limit/queue for gateway tunnel requests.
authorTom Clegg <tom@curii.com>
Wed, 27 Dec 2023 22:50:45 +0000 (17:50 -0500)
committerTom Clegg <tom@curii.com>
Wed, 27 Dec 2023 22:50:45 +0000 (17:50 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/config/config.default.yml
lib/config/export.go
lib/service/cmd.go
lib/service/cmd_test.go
sdk/go/arvados/config.go
sdk/go/httpserver/request_limiter.go
sdk/go/httpserver/request_limiter_test.go

index 05bc1309cdde1ef269306fb634dbe4e6595914ee..3924090ca9e0c47c44179cf4f82d141c337e5df6 100644 (file)
@@ -231,6 +231,10 @@ Clusters:
       # also effectively limited by MaxConcurrentRailsRequests (see
       # below) because most controller requests proxy through to the
       # RailsAPI service.
+      #
+      # HTTP proxies and load balancers downstream of arvados services
+      # should be configured to allow at least {MaxConcurrentRequest +
+      # MaxQueuedRequests + MaxGatewayTunnels} concurrent requests.
       MaxConcurrentRequests: 64
 
       # Maximum number of concurrent requests to process concurrently
@@ -250,6 +254,12 @@ Clusters:
       # the incoming request queue before returning 503.
       MaxQueueTimeForLockRequests: 2s
 
+      # Maximum number of active gateway tunnel connections. A slot is
+      # consumed by each running container, and by each incoming
+      # "container shell" connection. These do not count toward
+      # MaxConcurrentRequests.
+      MaxGatewayTunnels: 1000
+
       # Fraction of MaxConcurrentRequests that can be "log create"
       # messages at any given time.  This is to prevent logging
       # updates from crowding out more important requests.
index e51e6fc32cdeb03909f84eb3f24cbfcb8351b31a..4b6c142ff2e29f41bcf2b843ac6479b54dd436aa 100644 (file)
@@ -70,6 +70,7 @@ var whitelist = map[string]bool{
        "API.LogCreateRequestFraction":             false,
        "API.MaxConcurrentRailsRequests":           false,
        "API.MaxConcurrentRequests":                false,
+       "API.MaxGatewayTunnels":                    false,
        "API.MaxIndexDatabaseRead":                 false,
        "API.MaxItemsPerResponse":                  true,
        "API.MaxKeepBlobBuffers":                   false,
index 725f86f3bda5c2a82476615ba9ecd6e7a9b7a4fa..e40b47acbbb6d4bb1b96a9de1bf2d625e28b2430 100644 (file)
@@ -148,32 +148,13 @@ func (c *command) RunCommand(prog string, args []string, stdin io.Reader, stdout
                return 1
        }
 
-       maxReqs := cluster.API.MaxConcurrentRequests
-       if maxRails := cluster.API.MaxConcurrentRailsRequests; maxRails > 0 &&
-               (maxRails < maxReqs || maxReqs == 0) &&
-               strings.HasSuffix(prog, "controller") {
-               // Ideally, we would accept up to
-               // MaxConcurrentRequests, and apply the
-               // MaxConcurrentRailsRequests limit only for requests
-               // that require calling upstream to RailsAPI. But for
-               // now we make the simplifying assumption that every
-               // controller request causes an upstream RailsAPI
-               // request.
-               maxReqs = maxRails
-       }
        instrumented := httpserver.Instrument(reg, log,
                httpserver.HandlerWithDeadline(cluster.API.RequestTimeout.Duration(),
                        httpserver.AddRequestIDs(
                                httpserver.Inspect(reg, cluster.ManagementToken,
                                        httpserver.LogRequests(
                                                interceptHealthReqs(cluster.ManagementToken, handler.CheckHealth,
-                                                       &httpserver.RequestLimiter{
-                                                               Handler:                    handler,
-                                                               MaxConcurrent:              maxReqs,
-                                                               MaxQueue:                   cluster.API.MaxQueuedRequests,
-                                                               MaxQueueTimeForMinPriority: cluster.API.MaxQueueTimeForLockRequests.Duration(),
-                                                               Priority:                   c.requestPriority,
-                                                               Registry:                   reg}))))))
+                                                       c.requestLimiter(handler, cluster, reg)))))))
        srv := &httpserver.Server{
                Server: http.Server{
                        Handler:     ifCollectionInHost(instrumented, instrumented.ServeAPI(cluster.ManagementToken, instrumented)),
@@ -212,7 +193,7 @@ func (c *command) RunCommand(prog string, args []string, stdin io.Reader, stdout
                <-handler.Done()
                srv.Close()
        }()
-       go c.requestQueueDumpCheck(cluster, maxReqs, prog, reg, &srv.Server, logger)
+       go c.requestQueueDumpCheck(cluster, prog, reg, &srv.Server, logger)
        err = srv.Wait()
        if err != nil {
                return 1
@@ -221,12 +202,13 @@ func (c *command) RunCommand(prog string, args []string, stdin io.Reader, stdout
 }
 
 // If SystemLogs.RequestQueueDumpDirectory is set, monitor the
-// server's incoming HTTP request queue size. When it exceeds 90% of
-// API.MaxConcurrentRequests, write the /_inspect/requests data to a
-// JSON file in the specified directory.
-func (c *command) requestQueueDumpCheck(cluster *arvados.Cluster, maxReqs int, prog string, reg *prometheus.Registry, srv *http.Server, logger logrus.FieldLogger) {
+// server's incoming HTTP request limiters. When the number of
+// concurrent requests in any queue ("api" or "tunnel") exceeds 90% of
+// its maximum slots, write the /_inspect/requests data to a JSON file
+// in the specified directory.
+func (c *command) requestQueueDumpCheck(cluster *arvados.Cluster, prog string, reg *prometheus.Registry, srv *http.Server, logger logrus.FieldLogger) {
        outdir := cluster.SystemLogs.RequestQueueDumpDirectory
-       if outdir == "" || cluster.ManagementToken == "" || maxReqs < 1 {
+       if outdir == "" || cluster.ManagementToken == "" {
                return
        }
        logger = logger.WithField("worker", "RequestQueueDump")
@@ -237,16 +219,29 @@ func (c *command) requestQueueDumpCheck(cluster *arvados.Cluster, maxReqs int, p
                        logger.WithError(err).Warn("error getting metrics")
                        continue
                }
-               dump := false
+               cur := map[string]int{} // queue label => current
+               max := map[string]int{} // queue label => max
                for _, mf := range mfs {
-                       if mf.Name != nil && *mf.Name == "arvados_concurrent_requests" && len(mf.Metric) == 1 {
-                               n := int(mf.Metric[0].GetGauge().GetValue())
-                               if n > 0 && n >= maxReqs*9/10 {
-                                       dump = true
-                                       break
+                       for _, m := range mf.GetMetric() {
+                               for _, ml := range m.GetLabel() {
+                                       if ml.GetName() == "queue" {
+                                               n := int(m.GetGauge().GetValue())
+                                               if name := mf.GetName(); name == "arvados_concurrent_requests" {
+                                                       cur[*ml.Value] = n
+                                               } else if name == "arvados_max_concurrent_requests" {
+                                                       max[*ml.Value] = n
+                                               }
+                                       }
                                }
                        }
                }
+               dump := false
+               for queue, n := range cur {
+                       if n > 0 && max[queue] > 0 && n >= max[queue]*9/10 {
+                               dump = true
+                               break
+                       }
+               }
                if dump {
                        req, err := http.NewRequest("GET", "/_inspect/requests", nil)
                        if err != nil {
@@ -269,6 +264,48 @@ func (c *command) requestQueueDumpCheck(cluster *arvados.Cluster, maxReqs int, p
        }
 }
 
+// Set up a httpserver.RequestLimiter with separate queues/streams for
+// API requests (obeying MaxConcurrentRequests etc) and gateway tunnel
+// requests (obeying MaxGatewayTunnels).
+func (c *command) requestLimiter(handler http.Handler, cluster *arvados.Cluster, reg *prometheus.Registry) http.Handler {
+       maxReqs := cluster.API.MaxConcurrentRequests
+       if maxRails := cluster.API.MaxConcurrentRailsRequests; maxRails > 0 &&
+               (maxRails < maxReqs || maxReqs == 0) &&
+               c.svcName == arvados.ServiceNameController {
+               // Ideally, we would accept up to
+               // MaxConcurrentRequests, and apply the
+               // MaxConcurrentRailsRequests limit only for requests
+               // that require calling upstream to RailsAPI. But for
+               // now we make the simplifying assumption that every
+               // controller request causes an upstream RailsAPI
+               // request.
+               maxReqs = maxRails
+       }
+       rqAPI := &httpserver.RequestQueue{
+               Label:                      "api",
+               MaxConcurrent:              maxReqs,
+               MaxQueue:                   cluster.API.MaxQueuedRequests,
+               MaxQueueTimeForMinPriority: cluster.API.MaxQueueTimeForLockRequests.Duration(),
+       }
+       rqTunnel := &httpserver.RequestQueue{
+               Label:         "tunnel",
+               MaxConcurrent: cluster.API.MaxGatewayTunnels,
+               MaxQueue:      0,
+       }
+       return &httpserver.RequestLimiter{
+               Handler:  handler,
+               Priority: c.requestPriority,
+               Registry: reg,
+               Queue: func(req *http.Request) *httpserver.RequestQueue {
+                       if strings.HasPrefix(req.URL.Path, "/arvados/v1/connect/") {
+                               return rqTunnel
+                       } else {
+                               return rqAPI
+                       }
+               },
+       }
+}
+
 func (c *command) requestPriority(req *http.Request, queued time.Time) int64 {
        switch {
        case req.Method == http.MethodPost && strings.HasPrefix(req.URL.Path, "/arvados/v1/containers/") && strings.HasSuffix(req.URL.Path, "/lock"):
index 08b3a239dc2583c5da4271465cc550521dc54a79..0266752f383ef861802fe0ad718f1e35d1e0ba9d 100644 (file)
@@ -17,6 +17,8 @@ import (
        "net/url"
        "os"
        "strings"
+       "sync"
+       "sync/atomic"
        "testing"
        "time"
 
@@ -198,15 +200,15 @@ func (*Suite) TestCommand(c *check.C) {
        c.Check(stderr.String(), check.Matches, `(?ms).*"msg":"CheckHealth called".*`)
 }
 
-func (s *Suite) TestDumpRequestsKeepweb(c *check.C) {
-       s.testDumpRequests(c, arvados.ServiceNameKeepweb, "MaxConcurrentRequests")
+func (s *Suite) TestRequestLimitsAndDumpRequests_Keepweb(c *check.C) {
+       s.testRequestLimitAndDumpRequests(c, arvados.ServiceNameKeepweb, "MaxConcurrentRequests")
 }
 
-func (s *Suite) TestDumpRequestsController(c *check.C) {
-       s.testDumpRequests(c, arvados.ServiceNameController, "MaxConcurrentRailsRequests")
+func (s *Suite) TestRequestLimitsAndDumpRequests_Controller(c *check.C) {
+       s.testRequestLimitAndDumpRequests(c, arvados.ServiceNameController, "MaxConcurrentRailsRequests")
 }
 
-func (*Suite) testDumpRequests(c *check.C, serviceName arvados.ServiceName, maxReqsConfigKey string) {
+func (*Suite) testRequestLimitAndDumpRequests(c *check.C, serviceName arvados.ServiceName, maxReqsConfigKey string) {
        defer func(orig time.Duration) { requestQueueDumpCheckInterval = orig }(requestQueueDumpCheckInterval)
        requestQueueDumpCheckInterval = time.Second / 10
 
@@ -218,6 +220,7 @@ func (*Suite) testDumpRequests(c *check.C, serviceName arvados.ServiceName, maxR
        defer cf.Close()
 
        max := 24
+       maxTunnels := 30
        fmt.Fprintf(cf, `
 Clusters:
  zzzzz:
@@ -225,7 +228,8 @@ Clusters:
   ManagementToken: bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb
   API:
    `+maxReqsConfigKey+`: %d
-   MaxQueuedRequests: 0
+   MaxQueuedRequests: 1
+   MaxGatewayTunnels: %d
   SystemLogs: {RequestQueueDumpDirectory: %q}
   Services:
    Controller:
@@ -234,14 +238,18 @@ Clusters:
    WebDAV:
     ExternalURL: "http://localhost:`+port+`"
     InternalURLs: {"http://localhost:`+port+`": {}}
-`, max, tmpdir)
+`, max, maxTunnels, tmpdir)
        cf.Close()
 
        started := make(chan bool, max+1)
        hold := make(chan bool)
        handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               started <- true
-               <-hold
+               if strings.HasPrefix(r.URL.Path, "/arvados/v1/connect") {
+                       <-hold
+               } else {
+                       started <- true
+                       <-hold
+               }
        })
        healthCheck := make(chan bool, 1)
        ctx, cancel := context.WithCancel(context.Background())
@@ -267,15 +275,50 @@ Clusters:
        }
        client := http.Client{}
        deadline := time.Now().Add(time.Second * 2)
+       var activeReqs sync.WaitGroup
+
+       // Start some API reqs
+       var apiResp200, apiResp503 int64
        for i := 0; i < max+1; i++ {
+               activeReqs.Add(1)
                go func() {
+                       defer activeReqs.Done()
                        resp, err := client.Get("http://localhost:" + port + "/testpath")
                        for err != nil && strings.Contains(err.Error(), "dial tcp") && deadline.After(time.Now()) {
                                time.Sleep(time.Second / 100)
                                resp, err = client.Get("http://localhost:" + port + "/testpath")
                        }
                        if c.Check(err, check.IsNil) {
-                               c.Logf("resp StatusCode %d", resp.StatusCode)
+                               if resp.StatusCode == http.StatusOK {
+                                       atomic.AddInt64(&apiResp200, 1)
+                               } else if resp.StatusCode == http.StatusServiceUnavailable {
+                                       atomic.AddInt64(&apiResp503, 1)
+                               }
+                       }
+               }()
+       }
+
+       // Start some gateway tunnel reqs that don't count toward our
+       // API req limit
+       extraTunnelReqs := 20
+       var tunnelResp200, tunnelResp503 int64
+       for i := 0; i < maxTunnels+extraTunnelReqs; i++ {
+               activeReqs.Add(1)
+               go func() {
+                       defer activeReqs.Done()
+                       resp, err := client.Get("http://localhost:" + port + "/arvados/v1/connect/...")
+                       for err != nil && strings.Contains(err.Error(), "dial tcp") && deadline.After(time.Now()) {
+                               time.Sleep(time.Second / 100)
+                               resp, err = client.Get("http://localhost:" + port + "/arvados/v1/connect/...")
+                       }
+                       if c.Check(err, check.IsNil) {
+                               if resp.StatusCode == http.StatusOK {
+                                       atomic.AddInt64(&tunnelResp200, 1)
+                               } else if resp.StatusCode == http.StatusServiceUnavailable {
+                                       atomic.AddInt64(&tunnelResp503, 1)
+                               } else {
+                                       c.Errorf("tunnel response code %d", resp.StatusCode)
+                               }
                        }
                }()
        }
@@ -300,6 +343,20 @@ Clusters:
                var loaded []struct{ URL string }
                err = json.Unmarshal(j, &loaded)
                c.Check(err, check.IsNil)
+
+               for i := 0; i < len(loaded); i++ {
+                       if strings.HasPrefix(loaded[i].URL, "/arvados/v1/connect/") {
+                               // Filter out a gateway tunnel req
+                               // that doesn't count toward our API
+                               // req limit
+                               if i < len(loaded)-1 {
+                                       copy(loaded[i:], loaded[i+1:])
+                                       i--
+                               }
+                               loaded = loaded[:len(loaded)-1]
+                       }
+               }
+
                if len(loaded) < max {
                        // Dumped when #requests was >90% but <100% of
                        // limit. If we stop now, we won't be able to
@@ -309,7 +366,7 @@ Clusters:
                        c.Logf("loaded dumped requests, but len %d < max %d -- still waiting", len(loaded), max)
                        continue
                }
-               c.Check(loaded, check.HasLen, max)
+               c.Check(loaded, check.HasLen, max+1)
                c.Check(loaded[0].URL, check.Equals, "/testpath")
                break
        }
@@ -328,7 +385,8 @@ Clusters:
                c.Check(err, check.IsNil)
                switch path {
                case "/metrics":
-                       c.Check(string(buf), check.Matches, `(?ms).*arvados_concurrent_requests `+fmt.Sprintf("%d", max)+`\n.*`)
+                       c.Check(string(buf), check.Matches, `(?ms).*arvados_concurrent_requests{queue="api"} `+fmt.Sprintf("%d", max)+`\n.*`)
+                       c.Check(string(buf), check.Matches, `(?ms).*arvados_queued_requests{priority="normal",queue="api"} 1\n.*`)
                case "/_inspect/requests":
                        c.Check(string(buf), check.Matches, `(?ms).*"URL":"/testpath".*`)
                default:
@@ -336,6 +394,11 @@ Clusters:
                }
        }
        close(hold)
+       activeReqs.Wait()
+       c.Check(int(apiResp200), check.Equals, max+1)
+       c.Check(int(apiResp503), check.Equals, 0)
+       c.Check(int(tunnelResp200), check.Equals, maxTunnels)
+       c.Check(int(tunnelResp503), check.Equals, extraTunnelReqs)
        cancel()
 }
 
index 6301ed047a1dbfca82b3c717926a2f05415aa291..16d789daf5163f49bc6fe4770565fcde9325fd35 100644 (file)
@@ -102,6 +102,7 @@ type Cluster struct {
                MaxConcurrentRailsRequests       int
                MaxConcurrentRequests            int
                MaxQueuedRequests                int
+               MaxGatewayTunnels                int
                MaxQueueTimeForLockRequests      Duration
                LogCreateRequestFraction         float64
                MaxKeepBlobBuffers               int
index 9d501ab0ebfa7db908a2886d4b208973c8606863..1e3316ed487d17ca2eade2655ad3bfb04c8c6851 100644 (file)
@@ -34,13 +34,8 @@ const metricsUpdateInterval = time.Second
 type RequestLimiter struct {
        Handler http.Handler
 
-       // Maximum number of requests being handled at once. Beyond
-       // this limit, requests will be queued.
-       MaxConcurrent int
-
-       // Maximum number of requests in the queue. Beyond this limit,
-       // the lowest priority requests will return 503.
-       MaxQueue int
+       // Queue determines which queue a request is assigned to.
+       Queue func(req *http.Request) *RequestQueue
 
        // Priority determines queue ordering. Requests with higher
        // priority are handled first. Requests with equal priority
@@ -48,11 +43,6 @@ type RequestLimiter struct {
        // handled FIFO.
        Priority func(req *http.Request, queued time.Time) int64
 
-       // Return 503 for any request for which Priority() returns
-       // MinPriority if it spends longer than this in the queue
-       // before starting processing.
-       MaxQueueTimeForMinPriority time.Duration
-
        // "concurrent_requests", "max_concurrent_requests",
        // "queued_requests", and "max_queued_requests" metrics are
        // registered with Registry, if it is not nil.
@@ -63,11 +53,32 @@ type RequestLimiter struct {
        mQueueTimeout *prometheus.SummaryVec
        mQueueUsage   *prometheus.GaugeVec
        mtx           sync.Mutex
-       handling      int
-       queue         queue
+       rqs           map[*RequestQueue]bool // all RequestQueues in use
+}
+
+type RequestQueue struct {
+       // Label for metrics. No two queues should have the same label.
+       Label string
+
+       // Maximum number of requests being handled at once. Beyond
+       // this limit, requests will be queued.
+       MaxConcurrent int
+
+       // Maximum number of requests in the queue. Beyond this limit,
+       // the lowest priority requests will return 503.
+       MaxQueue int
+
+       // Return 503 for any request for which Priority() returns
+       // MinPriority if it spends longer than this in the queue
+       // before starting processing.
+       MaxQueueTimeForMinPriority time.Duration
+
+       queue    queue
+       handling int
 }
 
 type qent struct {
+       rq       *RequestQueue
        queued   time.Time
        priority int64
        heappos  int
@@ -121,101 +132,96 @@ func (h *queue) remove(i int) {
 
 func (rl *RequestLimiter) setup() {
        if rl.Registry != nil {
-               rl.Registry.MustRegister(prometheus.NewGaugeFunc(
-                       prometheus.GaugeOpts{
-                               Namespace: "arvados",
-                               Name:      "concurrent_requests",
-                               Help:      "Number of requests in progress",
-                       },
-                       func() float64 {
-                               rl.mtx.Lock()
-                               defer rl.mtx.Unlock()
-                               return float64(rl.handling)
-                       },
-               ))
-               rl.Registry.MustRegister(prometheus.NewGaugeFunc(
-                       prometheus.GaugeOpts{
-                               Namespace: "arvados",
-                               Name:      "max_concurrent_requests",
-                               Help:      "Maximum number of concurrent requests",
-                       },
-                       func() float64 { return float64(rl.MaxConcurrent) },
-               ))
+               mCurrentReqs := prometheus.NewGaugeVec(prometheus.GaugeOpts{
+                       Namespace: "arvados",
+                       Name:      "concurrent_requests",
+                       Help:      "Number of requests in progress",
+               }, []string{"queue"})
+               rl.Registry.MustRegister(mCurrentReqs)
+               mMaxReqs := prometheus.NewGaugeVec(prometheus.GaugeOpts{
+                       Namespace: "arvados",
+                       Name:      "max_concurrent_requests",
+                       Help:      "Maximum number of concurrent requests",
+               }, []string{"queue"})
+               rl.Registry.MustRegister(mMaxReqs)
+               mMaxQueue := prometheus.NewGaugeVec(prometheus.GaugeOpts{
+                       Namespace: "arvados",
+                       Name:      "max_queued_requests",
+                       Help:      "Maximum number of queued requests",
+               }, []string{"queue"})
+               rl.Registry.MustRegister(mMaxQueue)
                rl.mQueueUsage = prometheus.NewGaugeVec(prometheus.GaugeOpts{
                        Namespace: "arvados",
                        Name:      "queued_requests",
                        Help:      "Number of requests in queue",
-               }, []string{"priority"})
+               }, []string{"queue", "priority"})
                rl.Registry.MustRegister(rl.mQueueUsage)
-               rl.Registry.MustRegister(prometheus.NewGaugeFunc(
-                       prometheus.GaugeOpts{
-                               Namespace: "arvados",
-                               Name:      "max_queued_requests",
-                               Help:      "Maximum number of queued requests",
-                       },
-                       func() float64 { return float64(rl.MaxQueue) },
-               ))
                rl.mQueueDelay = prometheus.NewSummaryVec(prometheus.SummaryOpts{
                        Namespace:  "arvados",
                        Name:       "queue_delay_seconds",
                        Help:       "Time spent in the incoming request queue before start of processing",
                        Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.95: 0.005, 0.99: 0.001},
-               }, []string{"priority"})
+               }, []string{"queue", "priority"})
                rl.Registry.MustRegister(rl.mQueueDelay)
                rl.mQueueTimeout = prometheus.NewSummaryVec(prometheus.SummaryOpts{
                        Namespace:  "arvados",
                        Name:       "queue_timeout_seconds",
                        Help:       "Time spent in the incoming request queue before client timed out or disconnected",
                        Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.95: 0.005, 0.99: 0.001},
-               }, []string{"priority"})
+               }, []string{"queue", "priority"})
                rl.Registry.MustRegister(rl.mQueueTimeout)
                go func() {
                        for range time.NewTicker(metricsUpdateInterval).C {
-                               var low, normal, high int
                                rl.mtx.Lock()
-                               for _, ent := range rl.queue {
-                                       switch {
-                                       case ent.priority < 0:
-                                               low++
-                                       case ent.priority > 0:
-                                               high++
-                                       default:
-                                               normal++
+                               for rq := range rl.rqs {
+                                       var low, normal, high int
+                                       for _, ent := range rq.queue {
+                                               switch {
+                                               case ent.priority < 0:
+                                                       low++
+                                               case ent.priority > 0:
+                                                       high++
+                                               default:
+                                                       normal++
+                                               }
                                        }
+                                       mCurrentReqs.WithLabelValues(rq.Label).Set(float64(rq.handling))
+                                       mMaxReqs.WithLabelValues(rq.Label).Set(float64(rq.MaxConcurrent))
+                                       mMaxQueue.WithLabelValues(rq.Label).Set(float64(rq.MaxQueue))
+                                       rl.mQueueUsage.WithLabelValues(rq.Label, "low").Set(float64(low))
+                                       rl.mQueueUsage.WithLabelValues(rq.Label, "normal").Set(float64(normal))
+                                       rl.mQueueUsage.WithLabelValues(rq.Label, "high").Set(float64(high))
                                }
                                rl.mtx.Unlock()
-                               rl.mQueueUsage.WithLabelValues("low").Set(float64(low))
-                               rl.mQueueUsage.WithLabelValues("normal").Set(float64(normal))
-                               rl.mQueueUsage.WithLabelValues("high").Set(float64(high))
                        }
                }()
        }
 }
 
 // caller must have lock
-func (rl *RequestLimiter) runqueue() {
+func (rq *RequestQueue) runqueue() {
        // Handle entries from the queue as capacity permits
-       for len(rl.queue) > 0 && (rl.MaxConcurrent == 0 || rl.handling < rl.MaxConcurrent) {
-               rl.handling++
-               ent := rl.queue.removeMax()
+       for len(rq.queue) > 0 && (rq.MaxConcurrent == 0 || rq.handling < rq.MaxConcurrent) {
+               rq.handling++
+               ent := rq.queue.removeMax()
                ent.ready <- true
        }
 }
 
 // If the queue is too full, fail and remove the lowest-priority
 // entry. Caller must have lock. Queue must not be empty.
-func (rl *RequestLimiter) trimqueue() {
-       if len(rl.queue) <= rl.MaxQueue {
+func (rq *RequestQueue) trimqueue() {
+       if len(rq.queue) <= rq.MaxQueue {
                return
        }
        min := 0
-       for i := range rl.queue {
-               if i == 0 || rl.queue.Less(min, i) {
+       for i := range rq.queue {
+               if i == 0 || rq.queue.Less(min, i) {
                        min = i
                }
        }
-       rl.queue[min].ready <- false
-       rl.queue.remove(min)
+       rq.queue[min].ready <- false
+       rq.queue.remove(min)
 }
 
 func (rl *RequestLimiter) enqueue(req *http.Request) *qent {
@@ -227,19 +233,24 @@ func (rl *RequestLimiter) enqueue(req *http.Request) *qent {
                priority = rl.Priority(req, qtime)
        }
        ent := &qent{
+               rq:       rl.Queue(req),
                queued:   qtime,
                priority: priority,
                ready:    make(chan bool, 1),
                heappos:  -1,
        }
-       if rl.MaxConcurrent == 0 || rl.MaxConcurrent > rl.handling {
+       if rl.rqs == nil {
+               rl.rqs = map[*RequestQueue]bool{}
+       }
+       rl.rqs[ent.rq] = true
+       if ent.rq.MaxConcurrent == 0 || ent.rq.MaxConcurrent > ent.rq.handling {
                // fast path, skip the queue
-               rl.handling++
+               ent.rq.handling++
                ent.ready <- true
                return ent
        }
-       rl.queue.add(ent)
-       rl.trimqueue()
+       ent.rq.queue.add(ent)
+       ent.rq.trimqueue()
        return ent
 }
 
@@ -247,7 +258,7 @@ func (rl *RequestLimiter) remove(ent *qent) {
        rl.mtx.Lock()
        defer rl.mtx.Unlock()
        if ent.heappos >= 0 {
-               rl.queue.remove(ent.heappos)
+               ent.rq.queue.remove(ent.heappos)
                ent.ready <- false
        }
 }
@@ -255,14 +266,14 @@ func (rl *RequestLimiter) remove(ent *qent) {
 func (rl *RequestLimiter) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
        rl.setupOnce.Do(rl.setup)
        ent := rl.enqueue(req)
-       SetResponseLogFields(req.Context(), logrus.Fields{"priority": ent.priority})
+       SetResponseLogFields(req.Context(), logrus.Fields{"priority": ent.priority, "queue": ent.rq.Label})
        if ent.priority == MinPriority {
                // Note that MaxQueueTime==0 does not cancel a req
                // that skips the queue, because in that case
                // rl.enqueue() has already fired ready<-true and
                // rl.remove() is a no-op.
                go func() {
-                       time.Sleep(rl.MaxQueueTimeForMinPriority)
+                       time.Sleep(ent.rq.MaxQueueTimeForMinPriority)
                        rl.remove(ent)
                }()
        }
@@ -273,7 +284,7 @@ func (rl *RequestLimiter) ServeHTTP(resp http.ResponseWriter, req *http.Request)
                // we still need to wait for ent.ready, because
                // sometimes runqueue() will have already decided to
                // send true before our rl.remove() call, and in that
-               // case we'll need to decrement rl.handling below.
+               // case we'll need to decrement ent.rq.handling below.
                ok = <-ent.ready
        case ok = <-ent.ready:
        }
@@ -298,7 +309,7 @@ func (rl *RequestLimiter) ServeHTTP(resp http.ResponseWriter, req *http.Request)
                default:
                        qlabel = "normal"
                }
-               series.WithLabelValues(qlabel).Observe(time.Now().Sub(ent.queued).Seconds())
+               series.WithLabelValues(ent.rq.Label, qlabel).Observe(time.Now().Sub(ent.queued).Seconds())
        }
 
        if !ok {
@@ -308,9 +319,9 @@ func (rl *RequestLimiter) ServeHTTP(resp http.ResponseWriter, req *http.Request)
        defer func() {
                rl.mtx.Lock()
                defer rl.mtx.Unlock()
-               rl.handling--
+               ent.rq.handling--
                // unblock the next waiting request
-               rl.runqueue()
+               ent.rq.runqueue()
        }()
        rl.Handler.ServeHTTP(resp, req)
 }
index 55f13b4625fdf1c637dfba721b4ee8f00af2ecc3..7366e1426ba5831b1ebdc551cda7c332bdf0446e 100644 (file)
@@ -34,7 +34,11 @@ func newTestHandler() *testHandler {
 
 func (s *Suite) TestRequestLimiter1(c *check.C) {
        h := newTestHandler()
-       l := RequestLimiter{MaxConcurrent: 1, Handler: h}
+       rq := &RequestQueue{
+               MaxConcurrent: 1}
+       l := RequestLimiter{
+               Queue:   func(*http.Request) *RequestQueue { return rq },
+               Handler: h}
        var wg sync.WaitGroup
        resps := make([]*httptest.ResponseRecorder, 10)
        for i := 0; i < 10; i++ {
@@ -94,7 +98,11 @@ func (s *Suite) TestRequestLimiter1(c *check.C) {
 
 func (*Suite) TestRequestLimiter10(c *check.C) {
        h := newTestHandler()
-       l := RequestLimiter{MaxConcurrent: 10, Handler: h}
+       rq := &RequestQueue{
+               MaxConcurrent: 10}
+       l := RequestLimiter{
+               Queue:   func(*http.Request) *RequestQueue { return rq },
+               Handler: h}
        var wg sync.WaitGroup
        for i := 0; i < 10; i++ {
                wg.Add(1)
@@ -114,29 +122,32 @@ func (*Suite) TestRequestLimiter10(c *check.C) {
 
 func (*Suite) TestRequestLimiterQueuePriority(c *check.C) {
        h := newTestHandler()
-       rl := RequestLimiter{
+       rq := &RequestQueue{
                MaxConcurrent: 1000,
                MaxQueue:      200,
-               Handler:       h,
+       }
+       rl := RequestLimiter{
+               Handler: h,
+               Queue:   func(*http.Request) *RequestQueue { return rq },
                Priority: func(r *http.Request, _ time.Time) int64 {
                        p, _ := strconv.ParseInt(r.Header.Get("Priority"), 10, 64)
                        return p
                }}
 
        c.Logf("starting initial requests")
-       for i := 0; i < rl.MaxConcurrent; i++ {
+       for i := 0; i < rq.MaxConcurrent; i++ {
                go func() {
                        rl.ServeHTTP(httptest.NewRecorder(), &http.Request{Header: http.Header{"No-Priority": {"x"}}})
                }()
        }
        c.Logf("waiting for initial requests to consume all MaxConcurrent slots")
-       for i := 0; i < rl.MaxConcurrent; i++ {
+       for i := 0; i < rq.MaxConcurrent; i++ {
                <-h.inHandler
        }
 
-       c.Logf("starting %d priority=MinPriority requests (should respond 503 immediately)", rl.MaxQueue)
+       c.Logf("starting %d priority=MinPriority requests (should respond 503 immediately)", rq.MaxQueue)
        var wgX sync.WaitGroup
-       for i := 0; i < rl.MaxQueue; i++ {
+       for i := 0; i < rq.MaxQueue; i++ {
                wgX.Add(1)
                go func() {
                        defer wgX.Done()
@@ -147,13 +158,13 @@ func (*Suite) TestRequestLimiterQueuePriority(c *check.C) {
        }
        wgX.Wait()
 
-       c.Logf("starting %d priority=MinPriority requests (should respond 503 after 100 ms)", rl.MaxQueue)
+       c.Logf("starting %d priority=MinPriority requests (should respond 503 after 100 ms)", rq.MaxQueue)
        // Usage docs say the caller isn't allowed to change fields
        // after first use, but we secretly know it's OK to change
        // this field on the fly as long as no requests are arriving
        // concurrently.
-       rl.MaxQueueTimeForMinPriority = time.Millisecond * 100
-       for i := 0; i < rl.MaxQueue; i++ {
+       rq.MaxQueueTimeForMinPriority = time.Millisecond * 100
+       for i := 0; i < rq.MaxQueue; i++ {
                wgX.Add(1)
                go func() {
                        defer wgX.Done()
@@ -162,17 +173,17 @@ func (*Suite) TestRequestLimiterQueuePriority(c *check.C) {
                        rl.ServeHTTP(resp, &http.Request{Header: http.Header{"Priority": {fmt.Sprintf("%d", MinPriority)}}})
                        c.Check(resp.Code, check.Equals, http.StatusServiceUnavailable)
                        elapsed := time.Since(t0)
-                       c.Check(elapsed > rl.MaxQueueTimeForMinPriority, check.Equals, true)
-                       c.Check(elapsed < rl.MaxQueueTimeForMinPriority*10, check.Equals, true)
+                       c.Check(elapsed > rq.MaxQueueTimeForMinPriority, check.Equals, true)
+                       c.Check(elapsed < rq.MaxQueueTimeForMinPriority*10, check.Equals, true)
                }()
        }
        wgX.Wait()
 
-       c.Logf("starting %d priority=1 and %d priority=1 requests", rl.MaxQueue, rl.MaxQueue)
+       c.Logf("starting %d priority=1 and %d priority=1 requests", rq.MaxQueue, rq.MaxQueue)
        var wg1, wg2 sync.WaitGroup
-       wg1.Add(rl.MaxQueue)
-       wg2.Add(rl.MaxQueue)
-       for i := 0; i < rl.MaxQueue*2; i++ {
+       wg1.Add(rq.MaxQueue)
+       wg2.Add(rq.MaxQueue)
+       for i := 0; i < rq.MaxQueue*2; i++ {
                i := i
                go func() {
                        pri := (i & 1) + 1
@@ -192,12 +203,12 @@ func (*Suite) TestRequestLimiterQueuePriority(c *check.C) {
        wg1.Wait()
 
        c.Logf("allowing initial requests to proceed")
-       for i := 0; i < rl.MaxConcurrent; i++ {
+       for i := 0; i < rq.MaxConcurrent; i++ {
                h.okToProceed <- struct{}{}
        }
 
        c.Logf("allowing queued priority=2 requests to proceed")
-       for i := 0; i < rl.MaxQueue; i++ {
+       for i := 0; i < rq.MaxQueue; i++ {
                <-h.inHandler
                h.okToProceed <- struct{}{}
        }