diff --git a/.changeset/open-ends-melt.md b/.changeset/open-ends-melt.md new file mode 100644 index 00000000000..73b0be5dde8 --- /dev/null +++ b/.changeset/open-ends-melt.md @@ -0,0 +1,5 @@ +--- +"@thirdweb-dev/service-utils": patch +--- + +Update rate limit to sliding window diff --git a/packages/service-utils/src/core/rateLimit/index.ts b/packages/service-utils/src/core/rateLimit/index.ts index e52cfc7a86b..7ca2ae43745 100644 --- a/packages/service-utils/src/core/rateLimit/index.ts +++ b/packages/service-utils/src/core/rateLimit/index.ts @@ -1,49 +1,34 @@ import type { CoreServiceConfig, TeamResponse } from "../api.js"; import type { RateLimitResult } from "./types.js"; -const RATE_LIMIT_WINDOW_SECONDS = 10; +const SLIDING_WINDOW_SECONDS = 10; // Redis interface compatible with ioredis (Node) and upstash (Cloudflare Workers). type IRedis = { - get: (key: string) => Promise; - expire(key: string, seconds: number): Promise; incrby(key: string, value: number): Promise; + mget(keys: string[]): Promise<(string | null)[]>; + expire(key: string, seconds: number): Promise; }; +/** + * Increments the request count for this team and returns whether the team has hit their rate limit. + * Uses a sliding 10 second window. + * @param args + * @returns + */ export async function rateLimit(args: { team: TeamResponse; limitPerSecond: number; serviceConfig: CoreServiceConfig; redis: IRedis; - /** - * Sample requests to reduce load on Redis. - * This scales down the request count and the rate limit threshold. - * @default 1.0 - */ - sampleRate?: number; /** * The number of requests to increment by. * @default 1 */ increment?: number; }): Promise { - const { - team, - limitPerSecond, - serviceConfig, - redis, - sampleRate = 1.0, - increment = 1, - } = args; - - const shouldSampleRequest = Math.random() < sampleRate; - if (!shouldSampleRequest) { - return { - rateLimited: false, - requestCount: 0, - rateLimit: 0, - }; - } + const { team, limitPerSecond, serviceConfig, redis, increment = 1 } = args; + const { serviceScope } = serviceConfig; if (limitPerSecond === 0) { // No rate limit is provided. Assume the request is not rate limited. @@ -54,47 +39,59 @@ export async function rateLimit(args: { }; } - const serviceScope = serviceConfig.serviceScope; - - // Gets the 10-second window for the current timestamp. - const timestampWindow = - Math.floor(Date.now() / (1000 * RATE_LIMIT_WINDOW_SECONDS)) * - RATE_LIMIT_WINDOW_SECONDS; - const key = `rate-limit:${serviceScope}:${team.id}:${timestampWindow}`; + // Enforce rate limit: sum the total requests in the last `SLIDING_WINDOW_SECONDS` seconds. + const currentSecond = Math.floor(Date.now() / 1000); + const keys = Array.from({ length: SLIDING_WINDOW_SECONDS }, (_, i) => + getRequestCountAtSecondCacheKey(serviceScope, team.id, currentSecond - i), + ); + const counts = await redis.mget(keys); + const totalCount = counts.reduce( + (sum, count) => sum + (count ? Number.parseInt(count) : 0), + 0, + ); - // first read the request count from redis - const requestCount = Number((await redis.get(key).catch(() => "0")) || "0"); + const limitPerWindow = limitPerSecond * SLIDING_WINDOW_SECONDS; - // Get the limit for this window accounting for the sample rate. - const limitPerWindow = - limitPerSecond * sampleRate * RATE_LIMIT_WINDOW_SECONDS; - - if (requestCount > limitPerWindow) { + if (totalCount > limitPerWindow) { return { rateLimited: true, - requestCount, + requestCount: totalCount, rateLimit: limitPerWindow, status: 429, - errorMessage: `You've exceeded your ${serviceScope} rate limit at ${limitPerSecond} reqs/sec. To get higher rate limits, contact us at https://thirdweb.com/contact-us.`, + errorMessage: `You've exceeded your ${serviceScope} rate limit at ${limitPerSecond} reqs/sec. Please upgrade your plan to get higher rate limits.`, errorCode: "RATE_LIMIT_EXCEEDED", }; } - // do not await this, it just needs to execute at all - (async () => - // always incrementBy the amount specified for the key - await redis.incrby(key, increment).then(async () => { - // if the initial request count was 0, set the key to expire in the future - if (requestCount === 0) { - await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS); + // Non-blocking: increment the request count for the current second. + (async () => { + try { + const key = getRequestCountAtSecondCacheKey( + serviceScope, + team.id, + currentSecond, + ); + await redis.incrby(key, increment); + // If this is the first time setting this key, expire it after the sliding window is past. + if (counts[0] === null) { + await redis.expire(key, SLIDING_WINDOW_SECONDS + 1); } - }))().catch(() => { - console.error("Error incrementing rate limit key", key); - }); + } catch (error) { + console.error("Error updating rate limit key:", error); + } + })(); return { rateLimited: false, - requestCount: requestCount + increment, + requestCount: totalCount + increment, rateLimit: limitPerWindow, }; } + +function getRequestCountAtSecondCacheKey( + serviceScope: CoreServiceConfig["serviceScope"], + teamId: string, + second: number, +) { + return `rate-limit:${serviceScope}:${teamId}:${second}`; +} diff --git a/packages/service-utils/src/core/rateLimit/rateLimit.test.ts b/packages/service-utils/src/core/rateLimit/rateLimit.test.ts index f4f324831b1..3161bd796d4 100644 --- a/packages/service-utils/src/core/rateLimit/rateLimit.test.ts +++ b/packages/service-utils/src/core/rateLimit/rateLimit.test.ts @@ -2,26 +2,26 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { validServiceConfig, validTeamResponse } from "../../mocks.js"; import { rateLimit } from "./index.js"; +const SLIDING_WINDOW_SECONDS = 10; + const mockRedis = { - get: vi.fn(), + mget: vi.fn(), expire: vi.fn(), incrby: vi.fn(), }; describe("rateLimit", () => { beforeEach(() => { - // Clear mock function calls and reset any necessary state. vi.clearAllMocks(); - mockRedis.get.mockReset(); - mockRedis.expire.mockReset(); - mockRedis.incrby.mockReset(); + // Mock current time to a fixed value + vi.setSystemTime(new Date("2024-01-01T00:00:00Z")); }); afterEach(() => { - vi.spyOn(global.Math, "random").mockRestore(); + vi.useRealTimers(); }); - it("should not rate limit if service scope is not in rate limits", async () => { + it("should not rate limit if limitPerSecond is 0", async () => { const result = await rateLimit({ team: validTeamResponse, limitPerSecond: 0, @@ -34,254 +34,97 @@ describe("rateLimit", () => { requestCount: 0, rateLimit: 0, }); + expect(mockRedis.mget).not.toHaveBeenCalled(); }); - it("should not rate limit if within limit", async () => { - mockRedis.get.mockResolvedValue("50"); // Current count is 50 requests in 10 seconds. - - const result = await rateLimit({ - team: validTeamResponse, - limitPerSecond: 5, - serviceConfig: validServiceConfig, - redis: mockRedis, - }); - - expect(result).toEqual({ - rateLimited: false, - requestCount: 51, - rateLimit: 50, - }); - - expect(mockRedis.incrby).toHaveBeenCalledTimes(1); - }); - - it("should rate limit if exceeded hard limit", async () => { - mockRedis.get.mockResolvedValue(51); + it("should check last 10 seconds of requests", async () => { + const currentSecond = Math.floor(Date.now() / 1000); + mockRedis.mget.mockResolvedValue([ + null, // current second + "5", + null, + "3", + "1", + null, + "17", + null, + "5", + null, + ]); const result = await rateLimit({ team: validTeamResponse, - limitPerSecond: 5, + limitPerSecond: 10, serviceConfig: validServiceConfig, redis: mockRedis, }); - expect(result).toEqual({ - rateLimited: true, - requestCount: 51, - rateLimit: 50, - status: 429, - errorMessage: `You've exceeded your storage rate limit at 5 reqs/sec. To get higher rate limits, contact us at https://thirdweb.com/contact-us.`, - errorCode: "RATE_LIMIT_EXCEEDED", - }); - - expect(mockRedis.incrby).not.toHaveBeenCalled(); - }); - - it("expires on the first incr request only", async () => { - mockRedis.get.mockResolvedValue("1"); - - const result = await rateLimit({ - team: validTeamResponse, - limitPerSecond: 5, - serviceConfig: validServiceConfig, - redis: mockRedis, - }); + // Verify correct keys are checked + const expectedKeys = Array.from( + { length: SLIDING_WINDOW_SECONDS }, + (_, i) => `rate-limit:storage:1:${currentSecond - i}`, + ); + expect(mockRedis.mget).toHaveBeenCalledWith(expectedKeys); - expect(result).toEqual({ - rateLimited: false, - requestCount: 2, - rateLimit: 50, - }); - expect(mockRedis.incrby).toHaveBeenCalled(); + expect(result.requestCount).toBe(32); + expect(result.rateLimit).toBe(100); + expect(result.rateLimited).toBe(false); }); - it("enforces rate limit if sampled (hit)", async () => { - mockRedis.get.mockResolvedValue("10"); - vi.spyOn(global.Math, "random").mockReturnValue(0.08); + it("should rate limit when total count exceeds limit", async () => { + // 101 total requests + mockRedis.mget.mockResolvedValue(["50", "51"]); const result = await rateLimit({ team: validTeamResponse, - limitPerSecond: 5, + limitPerSecond: 10, serviceConfig: validServiceConfig, redis: mockRedis, - sampleRate: 0.1, }); - expect(result).toEqual({ + expect(result).toMatchObject({ rateLimited: true, - requestCount: 10, - rateLimit: 5, + requestCount: 101, + rateLimit: 100, status: 429, - errorMessage: - "You've exceeded your storage rate limit at 5 reqs/sec. To get higher rate limits, contact us at https://thirdweb.com/contact-us.", errorCode: "RATE_LIMIT_EXCEEDED", }); }); - it("does not enforce rate limit if sampled (miss)", async () => { - mockRedis.get.mockResolvedValue(10); - vi.spyOn(global.Math, "random").mockReturnValue(0.15); - - const result = await rateLimit({ + it("should set expiry only when current second count is 0", async () => { + // First case: current second has no requests + mockRedis.mget.mockResolvedValueOnce([null, ...Array(9).fill("5")]); + await rateLimit({ team: validTeamResponse, - limitPerSecond: 5, + limitPerSecond: 10, serviceConfig: validServiceConfig, redis: mockRedis, - sampleRate: 0.1, }); + expect(mockRedis.expire).toHaveBeenCalled(); - expect(result).toEqual({ - rateLimited: false, - requestCount: 0, - rateLimit: 0, - }); - }); + mockRedis.expire.mockClear(); - it("should handle redis get failure gracefully", async () => { - mockRedis.get.mockRejectedValue(new Error("Redis connection error")); - - const result = await rateLimit({ + // Second case: current second already has requests + mockRedis.mget.mockResolvedValueOnce(["5", ...Array(9).fill("5")]); + await rateLimit({ team: validTeamResponse, - limitPerSecond: 5, + limitPerSecond: 10, serviceConfig: validServiceConfig, redis: mockRedis, }); - - expect(result).toEqual({ - rateLimited: false, - requestCount: 1, - rateLimit: 50, - }); + expect(mockRedis.expire).not.toHaveBeenCalled(); }); - it("should handle zero requests correctly", async () => { - mockRedis.get.mockResolvedValue("0"); - + it("should increment by the amount provided", async () => { + mockRedis.mget.mockResolvedValueOnce(["5"]); const result = await rateLimit({ team: validTeamResponse, - limitPerSecond: 5, + limitPerSecond: 10, serviceConfig: validServiceConfig, redis: mockRedis, + increment: 3, }); - - expect(result).toEqual({ - rateLimited: false, - requestCount: 1, - rateLimit: 50, - }); - expect(mockRedis.incrby).toHaveBeenCalledWith(expect.any(String), 1); - }); - - it("should handle null response from redis", async () => { - mockRedis.get.mockResolvedValue(null); - - const result = await rateLimit({ - team: validTeamResponse, - limitPerSecond: 5, - serviceConfig: validServiceConfig, - redis: mockRedis, - }); - - expect(result).toEqual({ - rateLimited: false, - requestCount: 1, - rateLimit: 50, - }); - }); - - it("should handle very low sample rates", async () => { - mockRedis.get.mockResolvedValue("100"); - vi.spyOn(global.Math, "random").mockReturnValue(0.001); - - const result = await rateLimit({ - team: validTeamResponse, - limitPerSecond: 5, - serviceConfig: validServiceConfig, - redis: mockRedis, - sampleRate: 0.01, - }); - - expect(result).toEqual({ - rateLimited: true, - requestCount: 100, - rateLimit: 0.5, - status: 429, - errorMessage: expect.any(String), - errorCode: "RATE_LIMIT_EXCEEDED", - }); - }); - - it("should handle multiple concurrent requests with redis lag", async () => { - // Mock initial state - mockRedis.get.mockResolvedValue("0"); - - // Mock redis.set to have 100ms delay - mockRedis.incrby.mockImplementation( - () => - new Promise((resolve) => { - setTimeout(() => resolve(1), 100); - }), - ); - - // Make 3 concurrent requests - const requests = Promise.all([ - rateLimit({ - team: validTeamResponse, - limitPerSecond: 5, - serviceConfig: validServiceConfig, - redis: mockRedis, - }), - rateLimit({ - team: validTeamResponse, - limitPerSecond: 5, - serviceConfig: validServiceConfig, - redis: mockRedis, - }), - rateLimit({ - team: validTeamResponse, - limitPerSecond: 5, - serviceConfig: validServiceConfig, - redis: mockRedis, - }), - ]); - - const results = await requests; - // All requests should succeed since they all see initial count of 0 - for (const result of results) { - expect(result).toEqual({ - rateLimited: false, - requestCount: 1, - rateLimit: 50, - }); - } - - // Redis set should be called 3 times - expect(mockRedis.incrby).toHaveBeenCalledTimes(3); - }); - - it("should handle custom increment values", async () => { - // Mock initial state - mockRedis.get.mockResolvedValue("5"); - mockRedis.incrby.mockResolvedValue(10); - - const result = await rateLimit({ - team: validTeamResponse, - limitPerSecond: 20, - serviceConfig: validServiceConfig, - redis: mockRedis, - increment: 5, - }); - - expect(result).toEqual({ - rateLimited: false, - requestCount: 10, - rateLimit: 200, - }); - - // Verify redis was called with correct increment - expect(mockRedis.incrby).toHaveBeenCalledWith( - expect.stringContaining("rate-limit"), - 5, - ); + expect(mockRedis.incrby).toHaveBeenCalledWith(expect.anything(), 3); + expect(result.requestCount).toBe(8); }); });