-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrate_limiter_test.go
More file actions
91 lines (71 loc) · 2.34 KB
/
rate_limiter_test.go
File metadata and controls
91 lines (71 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package transport_api_client
import (
"context"
"errors"
"github.com/stretchr/testify/require"
"net/http"
"testing"
"time"
)
type fakeLimiter struct {
err error
}
func (f fakeLimiter) Wait(ctx context.Context) error {
return f.err
}
func TestLimiterMiddleware(t *testing.T) {
t.Run("request passes when limiter allows", func(t *testing.T) {
lim := fakeLimiter{err: nil}
var called bool
next := DoerFunc(func(req *http.Request) (*http.Response, error) {
called = true
return &http.Response{StatusCode: 200, Body: http.NoBody}, nil
})
doer := Limiter(lim)(next)
req, _ := http.NewRequest("GET", "http://example.com", nil)
resp, err := doer.Do(req)
require.NoError(t, err)
require.True(t, called)
require.Equal(t, 200, resp.StatusCode)
})
t.Run("request fails when limiter blocks", func(t *testing.T) {
lim := fakeLimiter{err: errors.New("rate limit exceeded")}
var called bool
next := DoerFunc(func(req *http.Request) (*http.Response, error) {
called = true
return &http.Response{StatusCode: 200, Body: http.NoBody}, nil
})
doer := Limiter(lim)(next)
req, _ := http.NewRequest("GET", "http://example.com", nil)
resp, err := doer.Do(req)
require.Error(t, err)
require.Nil(t, resp)
require.False(t, called, "next.Do() should not be called if limiter blocks")
})
t.Run("NewDefaultLimiter allows only burst requests", func(t *testing.T) {
lim := NewDefaultLimiter(2, 2)
next := DoerFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: 200, Body: http.NoBody}, nil
})
doer := Limiter(lim)(next)
req1, _ := http.NewRequest("GET", "http://example.com/1", nil)
start1 := time.Now()
resp1, err1 := doer.Do(req1)
require.NoError(t, err1)
require.Equal(t, 200, resp1.StatusCode)
require.Less(t, time.Since(start1), 50*time.Millisecond)
req2, _ := http.NewRequest("GET", "http://example.com/2", nil)
start2 := time.Now()
resp2, err2 := doer.Do(req2)
require.NoError(t, err2)
require.Equal(t, 200, resp2.StatusCode)
require.Less(t, time.Since(start2), 50*time.Millisecond)
req3, _ := http.NewRequest("GET", "http://example.com/3", nil)
start3 := time.Now()
resp3, err3 := doer.Do(req3)
require.NoError(t, err3)
require.Equal(t, 200, resp3.StatusCode)
delay := time.Since(start3)
require.GreaterOrEqual(t, delay, 400*time.Millisecond)
})
}