Skip to content

Commit e987d55

Browse files
committed
Refactor rate limiting to use separate read/write Redis operations
1 parent 1bfcd37 commit e987d55

File tree

2 files changed

+39
-20
lines changed

2 files changed

+39
-20
lines changed

packages/service-utils/src/core/rateLimit/index.ts

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@ const RATE_LIMIT_WINDOW_SECONDS = 10;
55

66
// Redis interface compatible with ioredis (Node) and upstash (Cloudflare Workers).
77
type IRedis = {
8-
incr: (key: string) => Promise<number>;
9-
expire: (key: string, ttlSeconds: number) => Promise<0 | 1>;
8+
get: (key: string) => Promise<string | null>;
9+
set(
10+
key: string,
11+
value: string | number,
12+
secondsToken?: "EX" | "XX",
13+
seconds?: number | string,
14+
): Promise<"OK">;
1015
};
1116

1217
export async function rateLimit(args: {
@@ -20,6 +25,7 @@ export async function rateLimit(args: {
2025
* @default 1.0
2126
*/
2227
sampleRate?: number;
28+
logger?: typeof console;
2329
}): Promise<RateLimitResult> {
2430
const { team, limitPerSecond, serviceConfig, redis, sampleRate = 1.0 } = args;
2531

@@ -49,12 +55,8 @@ export async function rateLimit(args: {
4955
RATE_LIMIT_WINDOW_SECONDS;
5056
const key = `rate-limit:${serviceScope}:${team.id}:${timestampWindow}`;
5157

52-
// Increment and get the current request count in this window.
53-
const requestCount = await redis.incr(key);
54-
if (requestCount === 1) {
55-
// For the first increment, set an expiration to clean up this key.
56-
await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS);
57-
}
58+
// first read the request count from redis
59+
const requestCount = Number((await redis.get(key)) || "0");
5860

5961
// Get the limit for this window accounting for the sample rate.
6062
const limitPerWindow =
@@ -71,6 +73,23 @@ export async function rateLimit(args: {
7173
};
7274
}
7375

76+
// only need to increment the request count if it's not already set
77+
// we are setting the request count, however we are not waiting on this to be complete
78+
if (requestCount === 0) {
79+
// For the first increment, set an expiration to clean up this key (EX).
80+
redis
81+
.set(key, requestCount + 1, "EX", RATE_LIMIT_WINDOW_SECONDS)
82+
?.catch((err) => {
83+
console.error("Failed to increment request count", err);
84+
});
85+
} else {
86+
// For all other increments, just increment the request count.
87+
// only set it if it already exists (XX)
88+
redis.set(key, requestCount + 1, "XX")?.catch((err) => {
89+
console.error("Failed to increment request count", err);
90+
});
91+
}
92+
7493
return {
7594
rateLimited: false,
7695
requestCount,

packages/service-utils/src/core/rateLimit/rateLimit.test.ts

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ import { validServiceConfig, validTeamResponse } from "../../mocks.js";
33
import { rateLimit } from "./index.js";
44

55
const mockRedis = {
6-
incr: vi.fn(),
7-
expire: vi.fn(),
6+
set: vi.fn(),
7+
get: vi.fn(),
88
};
99

1010
describe("rateLimit", () => {
1111
beforeEach(() => {
1212
// Clear mock function calls and reset any necessary state.
1313
vi.clearAllMocks();
14-
mockRedis.incr.mockReset();
15-
mockRedis.expire.mockReset();
14+
mockRedis.set.mockReset();
15+
mockRedis.get.mockReset();
1616
});
1717

1818
afterEach(() => {
@@ -35,7 +35,7 @@ describe("rateLimit", () => {
3535
});
3636

3737
it("should not rate limit if within limit", async () => {
38-
mockRedis.incr.mockResolvedValue(50); // Current count is 50 requests in 10 seconds.
38+
mockRedis.get.mockResolvedValue("50"); // Current count is 50 requests in 10 seconds.
3939

4040
const result = await rateLimit({
4141
team: validTeamResponse,
@@ -50,11 +50,11 @@ describe("rateLimit", () => {
5050
rateLimit: 50,
5151
});
5252

53-
expect(mockRedis.expire).not.toHaveBeenCalled();
53+
expect(mockRedis.set).toHaveBeenCalledTimes(1);
5454
});
5555

5656
it("should rate limit if exceeded hard limit", async () => {
57-
mockRedis.incr.mockResolvedValue(51);
57+
mockRedis.get.mockResolvedValue(51);
5858

5959
const result = await rateLimit({
6060
team: validTeamResponse,
@@ -72,11 +72,11 @@ describe("rateLimit", () => {
7272
errorCode: "RATE_LIMIT_EXCEEDED",
7373
});
7474

75-
expect(mockRedis.expire).not.toHaveBeenCalled();
75+
expect(mockRedis.set).not.toHaveBeenCalled();
7676
});
7777

7878
it("expires on the first incr request only", async () => {
79-
mockRedis.incr.mockResolvedValue(1);
79+
mockRedis.get.mockResolvedValue("1");
8080

8181
const result = await rateLimit({
8282
team: validTeamResponse,
@@ -90,11 +90,11 @@ describe("rateLimit", () => {
9090
requestCount: 1,
9191
rateLimit: 50,
9292
});
93-
expect(mockRedis.expire).toHaveBeenCalled();
93+
expect(mockRedis.set).toHaveBeenCalled();
9494
});
9595

9696
it("enforces rate limit if sampled (hit)", async () => {
97-
mockRedis.incr.mockResolvedValue(10);
97+
mockRedis.get.mockResolvedValue("10");
9898
vi.spyOn(global.Math, "random").mockReturnValue(0.08);
9999

100100
const result = await rateLimit({
@@ -117,7 +117,7 @@ describe("rateLimit", () => {
117117
});
118118

119119
it("does not enforce rate limit if sampled (miss)", async () => {
120-
mockRedis.incr.mockResolvedValue(10);
120+
mockRedis.get.mockResolvedValue(10);
121121
vi.spyOn(global.Math, "random").mockReturnValue(0.15);
122122

123123
const result = await rateLimit({

0 commit comments

Comments
 (0)