@@ -3,16 +3,18 @@ import { validServiceConfig, validTeamResponse } from "../../mocks.js";
33import { rateLimit } from "./index.js" ;
44
55const mockRedis = {
6- incr : vi . fn ( ) ,
6+ get : vi . fn ( ) ,
77 expire : vi . fn ( ) ,
8+ incrBy : vi . fn ( ) ,
89} ;
910
1011describe ( "rateLimit" , ( ) => {
1112 beforeEach ( ( ) => {
1213 // Clear mock function calls and reset any necessary state.
1314 vi . clearAllMocks ( ) ;
14- mockRedis . incr . mockReset ( ) ;
15+ mockRedis . get . mockReset ( ) ;
1516 mockRedis . expire . mockReset ( ) ;
17+ mockRedis . incrBy . mockReset ( ) ;
1618 } ) ;
1719
1820 afterEach ( ( ) => {
@@ -35,7 +37,7 @@ describe("rateLimit", () => {
3537 } ) ;
3638
3739 it ( "should not rate limit if within limit" , async ( ) => {
38- mockRedis . incr . mockResolvedValue ( 50 ) ; // Current count is 50 requests in 10 seconds.
40+ mockRedis . get . mockResolvedValue ( "50" ) ; // Current count is 50 requests in 10 seconds.
3941
4042 const result = await rateLimit ( {
4143 team : validTeamResponse ,
@@ -46,15 +48,15 @@ describe("rateLimit", () => {
4648
4749 expect ( result ) . toEqual ( {
4850 rateLimited : false ,
49- requestCount : 50 ,
51+ requestCount : 51 ,
5052 rateLimit : 50 ,
5153 } ) ;
5254
53- expect ( mockRedis . expire ) . not . toHaveBeenCalled ( ) ;
55+ expect ( mockRedis . incrBy ) . toHaveBeenCalledTimes ( 1 ) ;
5456 } ) ;
5557
5658 it ( "should rate limit if exceeded hard limit" , async ( ) => {
57- mockRedis . incr . mockResolvedValue ( 51 ) ;
59+ mockRedis . get . mockResolvedValue ( 51 ) ;
5860
5961 const result = await rateLimit ( {
6062 team : validTeamResponse ,
@@ -72,11 +74,11 @@ describe("rateLimit", () => {
7274 errorCode : "RATE_LIMIT_EXCEEDED" ,
7375 } ) ;
7476
75- expect ( mockRedis . expire ) . not . toHaveBeenCalled ( ) ;
77+ expect ( mockRedis . incrBy ) . not . toHaveBeenCalled ( ) ;
7678 } ) ;
7779
7880 it ( "expires on the first incr request only" , async ( ) => {
79- mockRedis . incr . mockResolvedValue ( 1 ) ;
81+ mockRedis . get . mockResolvedValue ( "1" ) ;
8082
8183 const result = await rateLimit ( {
8284 team : validTeamResponse ,
@@ -87,14 +89,14 @@ describe("rateLimit", () => {
8789
8890 expect ( result ) . toEqual ( {
8991 rateLimited : false ,
90- requestCount : 1 ,
92+ requestCount : 2 ,
9193 rateLimit : 50 ,
9294 } ) ;
93- expect ( mockRedis . expire ) . toHaveBeenCalled ( ) ;
95+ expect ( mockRedis . incrBy ) . toHaveBeenCalled ( ) ;
9496 } ) ;
9597
9698 it ( "enforces rate limit if sampled (hit)" , async ( ) => {
97- mockRedis . incr . mockResolvedValue ( 10 ) ;
99+ mockRedis . get . mockResolvedValue ( "10" ) ;
98100 vi . spyOn ( global . Math , "random" ) . mockReturnValue ( 0.08 ) ;
99101
100102 const result = await rateLimit ( {
@@ -117,7 +119,7 @@ describe("rateLimit", () => {
117119 } ) ;
118120
119121 it ( "does not enforce rate limit if sampled (miss)" , async ( ) => {
120- mockRedis . incr . mockResolvedValue ( 10 ) ;
122+ mockRedis . get . mockResolvedValue ( 10 ) ;
121123 vi . spyOn ( global . Math , "random" ) . mockReturnValue ( 0.15 ) ;
122124
123125 const result = await rateLimit ( {
@@ -134,4 +136,152 @@ describe("rateLimit", () => {
134136 rateLimit : 0 ,
135137 } ) ;
136138 } ) ;
139+
140+ it ( "should handle redis get failure gracefully" , async ( ) => {
141+ mockRedis . get . mockRejectedValue ( new Error ( "Redis connection error" ) ) ;
142+
143+ const result = await rateLimit ( {
144+ team : validTeamResponse ,
145+ limitPerSecond : 5 ,
146+ serviceConfig : validServiceConfig ,
147+ redis : mockRedis ,
148+ } ) ;
149+
150+ expect ( result ) . toEqual ( {
151+ rateLimited : false ,
152+ requestCount : 1 ,
153+ rateLimit : 50 ,
154+ } ) ;
155+ } ) ;
156+
157+ it ( "should handle zero requests correctly" , async ( ) => {
158+ mockRedis . get . mockResolvedValue ( "0" ) ;
159+
160+ const result = await rateLimit ( {
161+ team : validTeamResponse ,
162+ limitPerSecond : 5 ,
163+ serviceConfig : validServiceConfig ,
164+ redis : mockRedis ,
165+ } ) ;
166+
167+ expect ( result ) . toEqual ( {
168+ rateLimited : false ,
169+ requestCount : 1 ,
170+ rateLimit : 50 ,
171+ } ) ;
172+ expect ( mockRedis . incrBy ) . toHaveBeenCalledWith ( expect . any ( String ) , 1 ) ;
173+ } ) ;
174+
175+ it ( "should handle null response from redis" , async ( ) => {
176+ mockRedis . get . mockResolvedValue ( null ) ;
177+
178+ const result = await rateLimit ( {
179+ team : validTeamResponse ,
180+ limitPerSecond : 5 ,
181+ serviceConfig : validServiceConfig ,
182+ redis : mockRedis ,
183+ } ) ;
184+
185+ expect ( result ) . toEqual ( {
186+ rateLimited : false ,
187+ requestCount : 1 ,
188+ rateLimit : 50 ,
189+ } ) ;
190+ } ) ;
191+
192+ it ( "should handle very low sample rates" , async ( ) => {
193+ mockRedis . get . mockResolvedValue ( "100" ) ;
194+ vi . spyOn ( global . Math , "random" ) . mockReturnValue ( 0.001 ) ;
195+
196+ const result = await rateLimit ( {
197+ team : validTeamResponse ,
198+ limitPerSecond : 5 ,
199+ serviceConfig : validServiceConfig ,
200+ redis : mockRedis ,
201+ sampleRate : 0.01 ,
202+ } ) ;
203+
204+ expect ( result ) . toEqual ( {
205+ rateLimited : true ,
206+ requestCount : 100 ,
207+ rateLimit : 0.5 ,
208+ status : 429 ,
209+ errorMessage : expect . any ( String ) ,
210+ errorCode : "RATE_LIMIT_EXCEEDED" ,
211+ } ) ;
212+ } ) ;
213+
214+ it ( "should handle multiple concurrent requests with redis lag" , async ( ) => {
215+ // Mock initial state
216+ mockRedis . get . mockResolvedValue ( "0" ) ;
217+
218+ // Mock redis.set to have 100ms delay
219+ mockRedis . incrBy . mockImplementation (
220+ ( ) =>
221+ new Promise ( ( resolve ) => {
222+ setTimeout ( ( ) => resolve ( 1 ) , 100 ) ;
223+ } ) ,
224+ ) ;
225+
226+ // Make 3 concurrent requests
227+ const requests = Promise . all ( [
228+ rateLimit ( {
229+ team : validTeamResponse ,
230+ limitPerSecond : 5 ,
231+ serviceConfig : validServiceConfig ,
232+ redis : mockRedis ,
233+ } ) ,
234+ rateLimit ( {
235+ team : validTeamResponse ,
236+ limitPerSecond : 5 ,
237+ serviceConfig : validServiceConfig ,
238+ redis : mockRedis ,
239+ } ) ,
240+ rateLimit ( {
241+ team : validTeamResponse ,
242+ limitPerSecond : 5 ,
243+ serviceConfig : validServiceConfig ,
244+ redis : mockRedis ,
245+ } ) ,
246+ ] ) ;
247+
248+ const results = await requests ;
249+ // All requests should succeed since they all see initial count of 0
250+ for ( const result of results ) {
251+ expect ( result ) . toEqual ( {
252+ rateLimited : false ,
253+ requestCount : 1 ,
254+ rateLimit : 50 ,
255+ } ) ;
256+ }
257+
258+ // Redis set should be called 3 times
259+ expect ( mockRedis . incrBy ) . toHaveBeenCalledTimes ( 3 ) ;
260+ } ) ;
261+
262+ it ( "should handle custom increment values" , async ( ) => {
263+ // Mock initial state
264+ mockRedis . get . mockResolvedValue ( "5" ) ;
265+ mockRedis . incrBy . mockResolvedValue ( 10 ) ;
266+
267+ const result = await rateLimit ( {
268+ team : validTeamResponse ,
269+ limitPerSecond : 20 ,
270+ serviceConfig : validServiceConfig ,
271+ redis : mockRedis ,
272+ increment : 5 ,
273+ } ) ;
274+
275+ expect ( result ) . toEqual ( {
276+ rateLimited : false ,
277+ requestCount : 10 ,
278+ rateLimit : 200 ,
279+ } ) ;
280+
281+ // Verify redis was called with correct increment
282+ expect ( mockRedis . incrBy ) . toHaveBeenCalledWith (
283+ expect . stringContaining ( "rate-limit" ) ,
284+ 5 ,
285+ ) ;
286+ } ) ;
137287} ) ;
0 commit comments