Merge branch '9066-max-requests'
[arvados.git] / sdk / go / httpserver / request_limiter_test.go
1 package httpserver
2
3 import (
4         "net/http"
5         "net/http/httptest"
6         "sync"
7         "testing"
8         "time"
9 )
10
11 type testHandler struct {
12         inHandler   chan struct{}
13         okToProceed chan struct{}
14 }
15
16 func (h *testHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
17         h.inHandler <- struct{}{}
18         <-h.okToProceed
19 }
20
21 func newTestHandler(maxReqs int) *testHandler {
22         return &testHandler{
23                 inHandler:   make(chan struct{}),
24                 okToProceed: make(chan struct{}),
25         }
26 }
27
28 func TestRequestLimiter1(t *testing.T) {
29         h := newTestHandler(10)
30         l := NewRequestLimiter(1, h)
31         var wg sync.WaitGroup
32         resps := make([]*httptest.ResponseRecorder, 10)
33         for i := 0; i < 10; i++ {
34                 wg.Add(1)
35                 resps[i] = httptest.NewRecorder()
36                 go func(i int) {
37                         l.ServeHTTP(resps[i], &http.Request{})
38                         wg.Done()
39                 }(i)
40         }
41         done := make(chan struct{})
42         go func() {
43                 // Make sure one request has entered the handler
44                 <-h.inHandler
45                 // Make sure all unsuccessful requests finish (but don't wait
46                 // for the one that's still waiting for okToProceed)
47                 wg.Add(-1)
48                 wg.Wait()
49                 // Wait for the last goroutine
50                 wg.Add(1)
51                 h.okToProceed <- struct{}{}
52                 wg.Wait()
53                 done <- struct{}{}
54         }()
55         select {
56         case <-done:
57         case <-time.After(10 * time.Second):
58                 t.Fatal("test timed out, probably deadlocked")
59         }
60         n200 := 0
61         n503 := 0
62         for i := 0; i < 10; i++ {
63                 switch resps[i].Code {
64                 case 200:
65                         n200++
66                 case 503:
67                         n503++
68                 default:
69                         t.Fatalf("Unexpected response code %d", resps[i].Code)
70                 }
71         }
72         if n200 != 1 || n503 != 9 {
73                 t.Fatalf("Got %d 200 responses, %d 503 responses (expected 1, 9)", n200, n503)
74         }
75         // Now that all 10 are finished, an 11th request should
76         // succeed.
77         go func() {
78                 <-h.inHandler
79                 h.okToProceed <- struct{}{}
80         }()
81         resp := httptest.NewRecorder()
82         l.ServeHTTP(resp, &http.Request{})
83         if resp.Code != 200 {
84                 t.Errorf("Got status %d on 11th request, want 200", resp.Code)
85         }
86 }
87
88 func TestRequestLimiter10(t *testing.T) {
89         h := newTestHandler(10)
90         l := NewRequestLimiter(10, h)
91         var wg sync.WaitGroup
92         for i := 0; i < 10; i++ {
93                 wg.Add(1)
94                 go func() {
95                         l.ServeHTTP(httptest.NewRecorder(), &http.Request{})
96                         wg.Done()
97                 }()
98                 // Make sure the handler starts before we initiate the
99                 // next request, but don't let it finish yet.
100                 <-h.inHandler
101         }
102         for i := 0; i < 10; i++ {
103                 h.okToProceed <- struct{}{}
104         }
105         wg.Wait()
106 }