diff --git a/.changeset/brown-moles-peel.md b/.changeset/brown-moles-peel.md new file mode 100644 index 00000000000..49013d4478d --- /dev/null +++ b/.changeset/brown-moles-peel.md @@ -0,0 +1,5 @@ +--- +"@thirdweb-dev/service-utils": minor +--- + +update rateLimit function diff --git a/packages/service-utils/src/core/rateLimit/index.ts b/packages/service-utils/src/core/rateLimit/index.ts index 89584379f67..578e37ca35c 100644 --- a/packages/service-utils/src/core/rateLimit/index.ts +++ b/packages/service-utils/src/core/rateLimit/index.ts @@ -5,8 +5,9 @@ const RATE_LIMIT_WINDOW_SECONDS = 10; // Redis interface compatible with ioredis (Node) and upstash (Cloudflare Workers). type IRedis = { - incr: (key: string) => Promise; - expire: (key: string, ttlSeconds: number) => Promise<0 | 1>; + get: (key: string) => Promise; + expire(key: string, seconds: number): Promise; + incrBy(key: string, value: number): Promise; }; export async function rateLimit(args: { @@ -20,8 +21,20 @@ export async function rateLimit(args: { * @default 1.0 */ sampleRate?: number; + /** + * The number of requests to increment by. + * @default 1 + */ + increment?: number; }): Promise { - const { team, limitPerSecond, serviceConfig, redis, sampleRate = 1.0 } = args; + const { + team, + limitPerSecond, + serviceConfig, + redis, + sampleRate = 1.0, + increment = 1, + } = args; const shouldSampleRequest = Math.random() < sampleRate; if (!shouldSampleRequest) { @@ -49,12 +62,8 @@ export async function rateLimit(args: { RATE_LIMIT_WINDOW_SECONDS; const key = `rate-limit:${serviceScope}:${team.id}:${timestampWindow}`; - // Increment and get the current request count in this window. - const requestCount = await redis.incr(key); - if (requestCount === 1) { - // For the first increment, set an expiration to clean up this key. - await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS); - } + // first read the request count from redis + const requestCount = Number((await redis.get(key).catch(() => "0")) || "0"); // Get the limit for this window accounting for the sample rate. const limitPerWindow = @@ -71,9 +80,21 @@ export async function rateLimit(args: { }; } + // do not await this, it just needs to execute at all + (async () => + // always incrementBy the amount specified for the key + await redis.incrBy(key, increment).then(async () => { + // if the initial request count was 0, set the key to expire in the future + if (requestCount === 0) { + await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS); + } + }))().catch(() => { + console.error("Error incrementing rate limit key", key); + }); + return { rateLimited: false, - requestCount, + requestCount: requestCount + increment, rateLimit: limitPerWindow, }; } diff --git a/packages/service-utils/src/core/rateLimit/rateLimit.test.ts b/packages/service-utils/src/core/rateLimit/rateLimit.test.ts index 04832eccb4c..1b2153dcf90 100644 --- a/packages/service-utils/src/core/rateLimit/rateLimit.test.ts +++ b/packages/service-utils/src/core/rateLimit/rateLimit.test.ts @@ -3,16 +3,18 @@ import { validServiceConfig, validTeamResponse } from "../../mocks.js"; import { rateLimit } from "./index.js"; const mockRedis = { - incr: vi.fn(), + get: vi.fn(), expire: vi.fn(), + incrBy: vi.fn(), }; describe("rateLimit", () => { beforeEach(() => { // Clear mock function calls and reset any necessary state. vi.clearAllMocks(); - mockRedis.incr.mockReset(); + mockRedis.get.mockReset(); mockRedis.expire.mockReset(); + mockRedis.incrBy.mockReset(); }); afterEach(() => { @@ -35,7 +37,7 @@ describe("rateLimit", () => { }); it("should not rate limit if within limit", async () => { - mockRedis.incr.mockResolvedValue(50); // Current count is 50 requests in 10 seconds. + mockRedis.get.mockResolvedValue("50"); // Current count is 50 requests in 10 seconds. const result = await rateLimit({ team: validTeamResponse, @@ -46,15 +48,15 @@ describe("rateLimit", () => { expect(result).toEqual({ rateLimited: false, - requestCount: 50, + requestCount: 51, rateLimit: 50, }); - expect(mockRedis.expire).not.toHaveBeenCalled(); + expect(mockRedis.incrBy).toHaveBeenCalledTimes(1); }); it("should rate limit if exceeded hard limit", async () => { - mockRedis.incr.mockResolvedValue(51); + mockRedis.get.mockResolvedValue(51); const result = await rateLimit({ team: validTeamResponse, @@ -72,11 +74,11 @@ describe("rateLimit", () => { errorCode: "RATE_LIMIT_EXCEEDED", }); - expect(mockRedis.expire).not.toHaveBeenCalled(); + expect(mockRedis.incrBy).not.toHaveBeenCalled(); }); it("expires on the first incr request only", async () => { - mockRedis.incr.mockResolvedValue(1); + mockRedis.get.mockResolvedValue("1"); const result = await rateLimit({ team: validTeamResponse, @@ -87,14 +89,14 @@ describe("rateLimit", () => { expect(result).toEqual({ rateLimited: false, - requestCount: 1, + requestCount: 2, rateLimit: 50, }); - expect(mockRedis.expire).toHaveBeenCalled(); + expect(mockRedis.incrBy).toHaveBeenCalled(); }); it("enforces rate limit if sampled (hit)", async () => { - mockRedis.incr.mockResolvedValue(10); + mockRedis.get.mockResolvedValue("10"); vi.spyOn(global.Math, "random").mockReturnValue(0.08); const result = await rateLimit({ @@ -117,7 +119,7 @@ describe("rateLimit", () => { }); it("does not enforce rate limit if sampled (miss)", async () => { - mockRedis.incr.mockResolvedValue(10); + mockRedis.get.mockResolvedValue(10); vi.spyOn(global.Math, "random").mockReturnValue(0.15); const result = await rateLimit({ @@ -134,4 +136,152 @@ describe("rateLimit", () => { rateLimit: 0, }); }); + + it("should handle redis get failure gracefully", async () => { + mockRedis.get.mockRejectedValue(new Error("Redis connection error")); + + const result = await rateLimit({ + team: validTeamResponse, + limitPerSecond: 5, + serviceConfig: validServiceConfig, + redis: mockRedis, + }); + + expect(result).toEqual({ + rateLimited: false, + requestCount: 1, + rateLimit: 50, + }); + }); + + it("should handle zero requests correctly", async () => { + mockRedis.get.mockResolvedValue("0"); + + const result = await rateLimit({ + team: validTeamResponse, + limitPerSecond: 5, + serviceConfig: validServiceConfig, + redis: mockRedis, + }); + + expect(result).toEqual({ + rateLimited: false, + requestCount: 1, + rateLimit: 50, + }); + expect(mockRedis.incrBy).toHaveBeenCalledWith(expect.any(String), 1); + }); + + it("should handle null response from redis", async () => { + mockRedis.get.mockResolvedValue(null); + + const result = await rateLimit({ + team: validTeamResponse, + limitPerSecond: 5, + serviceConfig: validServiceConfig, + redis: mockRedis, + }); + + expect(result).toEqual({ + rateLimited: false, + requestCount: 1, + rateLimit: 50, + }); + }); + + it("should handle very low sample rates", async () => { + mockRedis.get.mockResolvedValue("100"); + vi.spyOn(global.Math, "random").mockReturnValue(0.001); + + const result = await rateLimit({ + team: validTeamResponse, + limitPerSecond: 5, + serviceConfig: validServiceConfig, + redis: mockRedis, + sampleRate: 0.01, + }); + + expect(result).toEqual({ + rateLimited: true, + requestCount: 100, + rateLimit: 0.5, + status: 429, + errorMessage: expect.any(String), + errorCode: "RATE_LIMIT_EXCEEDED", + }); + }); + + it("should handle multiple concurrent requests with redis lag", async () => { + // Mock initial state + mockRedis.get.mockResolvedValue("0"); + + // Mock redis.set to have 100ms delay + mockRedis.incrBy.mockImplementation( + () => + new Promise((resolve) => { + setTimeout(() => resolve(1), 100); + }), + ); + + // Make 3 concurrent requests + const requests = Promise.all([ + rateLimit({ + team: validTeamResponse, + limitPerSecond: 5, + serviceConfig: validServiceConfig, + redis: mockRedis, + }), + rateLimit({ + team: validTeamResponse, + limitPerSecond: 5, + serviceConfig: validServiceConfig, + redis: mockRedis, + }), + rateLimit({ + team: validTeamResponse, + limitPerSecond: 5, + serviceConfig: validServiceConfig, + redis: mockRedis, + }), + ]); + + const results = await requests; + // All requests should succeed since they all see initial count of 0 + for (const result of results) { + expect(result).toEqual({ + rateLimited: false, + requestCount: 1, + rateLimit: 50, + }); + } + + // Redis set should be called 3 times + expect(mockRedis.incrBy).toHaveBeenCalledTimes(3); + }); + + it("should handle custom increment values", async () => { + // Mock initial state + mockRedis.get.mockResolvedValue("5"); + mockRedis.incrBy.mockResolvedValue(10); + + const result = await rateLimit({ + team: validTeamResponse, + limitPerSecond: 20, + serviceConfig: validServiceConfig, + redis: mockRedis, + increment: 5, + }); + + expect(result).toEqual({ + rateLimited: false, + requestCount: 10, + rateLimit: 200, + }); + + // Verify redis was called with correct increment + expect(mockRedis.incrBy).toHaveBeenCalledWith( + expect.stringContaining("rate-limit"), + 5, + ); + }); });