Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/brown-moles-peel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@thirdweb-dev/service-utils": minor
---

update rateLimit function
41 changes: 31 additions & 10 deletions packages/service-utils/src/core/rateLimit/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ const RATE_LIMIT_WINDOW_SECONDS = 10;

// Redis interface compatible with ioredis (Node) and upstash (Cloudflare Workers).
type IRedis = {
incr: (key: string) => Promise<number>;
expire: (key: string, ttlSeconds: number) => Promise<0 | 1>;
get: (key: string) => Promise<string | null>;
expire(key: string, seconds: number): Promise<number>;
incrBy(key: string, value: number): Promise<number>;
};

export async function rateLimit(args: {
Expand All @@ -20,8 +21,20 @@ export async function rateLimit(args: {
* @default 1.0
*/
sampleRate?: number;
/**
* The number of requests to increment by.
* @default 1
*/
increment?: number;
}): Promise<RateLimitResult> {
const { team, limitPerSecond, serviceConfig, redis, sampleRate = 1.0 } = args;
const {
team,
limitPerSecond,
serviceConfig,
redis,
sampleRate = 1.0,
increment = 1,
} = args;

const shouldSampleRequest = Math.random() < sampleRate;
if (!shouldSampleRequest) {
Expand Down Expand Up @@ -49,12 +62,8 @@ export async function rateLimit(args: {
RATE_LIMIT_WINDOW_SECONDS;
const key = `rate-limit:${serviceScope}:${team.id}:${timestampWindow}`;

// Increment and get the current request count in this window.
const requestCount = await redis.incr(key);
if (requestCount === 1) {
// For the first increment, set an expiration to clean up this key.
await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS);
}
// first read the request count from redis
const requestCount = Number((await redis.get(key).catch(() => "0")) || "0");

// Get the limit for this window accounting for the sample rate.
const limitPerWindow =
Expand All @@ -71,9 +80,21 @@ export async function rateLimit(args: {
};
}

// do not await this, it just needs to execute at all
(async () =>
// always incrementBy the amount specified for the key
await redis.incrBy(key, increment).then(async () => {
// if the initial request count was 0, set the key to expire in the future
if (requestCount === 0) {
await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS);
}
}))().catch(() => {
console.error("Error incrementing rate limit key", key);
});

return {
rateLimited: false,
requestCount,
requestCount: requestCount + increment,
rateLimit: limitPerWindow,
};
}
174 changes: 162 additions & 12 deletions packages/service-utils/src/core/rateLimit/rateLimit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@ import { validServiceConfig, validTeamResponse } from "../../mocks.js";
import { rateLimit } from "./index.js";

const mockRedis = {
incr: vi.fn(),
get: vi.fn(),
expire: vi.fn(),
incrBy: vi.fn(),
};

describe("rateLimit", () => {
beforeEach(() => {
// Clear mock function calls and reset any necessary state.
vi.clearAllMocks();
mockRedis.incr.mockReset();
mockRedis.get.mockReset();
mockRedis.expire.mockReset();
mockRedis.incrBy.mockReset();
});

afterEach(() => {
Expand All @@ -35,7 +37,7 @@ describe("rateLimit", () => {
});

it("should not rate limit if within limit", async () => {
mockRedis.incr.mockResolvedValue(50); // Current count is 50 requests in 10 seconds.
mockRedis.get.mockResolvedValue("50"); // Current count is 50 requests in 10 seconds.

const result = await rateLimit({
team: validTeamResponse,
Expand All @@ -46,15 +48,15 @@ describe("rateLimit", () => {

expect(result).toEqual({
rateLimited: false,
requestCount: 50,
requestCount: 51,
rateLimit: 50,
});

expect(mockRedis.expire).not.toHaveBeenCalled();
expect(mockRedis.incrBy).toHaveBeenCalledTimes(1);
});

it("should rate limit if exceeded hard limit", async () => {
mockRedis.incr.mockResolvedValue(51);
mockRedis.get.mockResolvedValue(51);

const result = await rateLimit({
team: validTeamResponse,
Expand All @@ -72,11 +74,11 @@ describe("rateLimit", () => {
errorCode: "RATE_LIMIT_EXCEEDED",
});

expect(mockRedis.expire).not.toHaveBeenCalled();
expect(mockRedis.incrBy).not.toHaveBeenCalled();
});

it("expires on the first incr request only", async () => {
mockRedis.incr.mockResolvedValue(1);
mockRedis.get.mockResolvedValue("1");

const result = await rateLimit({
team: validTeamResponse,
Expand All @@ -87,14 +89,14 @@ describe("rateLimit", () => {

expect(result).toEqual({
rateLimited: false,
requestCount: 1,
requestCount: 2,
rateLimit: 50,
});
expect(mockRedis.expire).toHaveBeenCalled();
expect(mockRedis.incrBy).toHaveBeenCalled();
});

it("enforces rate limit if sampled (hit)", async () => {
mockRedis.incr.mockResolvedValue(10);
mockRedis.get.mockResolvedValue("10");
vi.spyOn(global.Math, "random").mockReturnValue(0.08);

const result = await rateLimit({
Expand All @@ -117,7 +119,7 @@ describe("rateLimit", () => {
});

it("does not enforce rate limit if sampled (miss)", async () => {
mockRedis.incr.mockResolvedValue(10);
mockRedis.get.mockResolvedValue(10);
vi.spyOn(global.Math, "random").mockReturnValue(0.15);

const result = await rateLimit({
Expand All @@ -134,4 +136,152 @@ describe("rateLimit", () => {
rateLimit: 0,
});
});

it("should handle redis get failure gracefully", async () => {
mockRedis.get.mockRejectedValue(new Error("Redis connection error"));

const result = await rateLimit({
team: validTeamResponse,
limitPerSecond: 5,
serviceConfig: validServiceConfig,
redis: mockRedis,
});

expect(result).toEqual({
rateLimited: false,
requestCount: 1,
rateLimit: 50,
});
});

it("should handle zero requests correctly", async () => {
mockRedis.get.mockResolvedValue("0");

const result = await rateLimit({
team: validTeamResponse,
limitPerSecond: 5,
serviceConfig: validServiceConfig,
redis: mockRedis,
});

expect(result).toEqual({
rateLimited: false,
requestCount: 1,
rateLimit: 50,
});
expect(mockRedis.incrBy).toHaveBeenCalledWith(expect.any(String), 1);
});

it("should handle null response from redis", async () => {
mockRedis.get.mockResolvedValue(null);

const result = await rateLimit({
team: validTeamResponse,
limitPerSecond: 5,
serviceConfig: validServiceConfig,
redis: mockRedis,
});

expect(result).toEqual({
rateLimited: false,
requestCount: 1,
rateLimit: 50,
});
});

it("should handle very low sample rates", async () => {
mockRedis.get.mockResolvedValue("100");
vi.spyOn(global.Math, "random").mockReturnValue(0.001);

const result = await rateLimit({
team: validTeamResponse,
limitPerSecond: 5,
serviceConfig: validServiceConfig,
redis: mockRedis,
sampleRate: 0.01,
});

expect(result).toEqual({
rateLimited: true,
requestCount: 100,
rateLimit: 0.5,
status: 429,
errorMessage: expect.any(String),
errorCode: "RATE_LIMIT_EXCEEDED",
});
});

it("should handle multiple concurrent requests with redis lag", async () => {
// Mock initial state
mockRedis.get.mockResolvedValue("0");

// Mock redis.set to have 100ms delay
mockRedis.incrBy.mockImplementation(
() =>
new Promise((resolve) => {
setTimeout(() => resolve(1), 100);
}),
);

// Make 3 concurrent requests
const requests = Promise.all([
rateLimit({
team: validTeamResponse,
limitPerSecond: 5,
serviceConfig: validServiceConfig,
redis: mockRedis,
}),
rateLimit({
team: validTeamResponse,
limitPerSecond: 5,
serviceConfig: validServiceConfig,
redis: mockRedis,
}),
rateLimit({
team: validTeamResponse,
limitPerSecond: 5,
serviceConfig: validServiceConfig,
redis: mockRedis,
}),
]);

const results = await requests;
// All requests should succeed since they all see initial count of 0
for (const result of results) {
expect(result).toEqual({
rateLimited: false,
requestCount: 1,
rateLimit: 50,
});
}

// Redis set should be called 3 times
expect(mockRedis.incrBy).toHaveBeenCalledTimes(3);
});

it("should handle custom increment values", async () => {
// Mock initial state
mockRedis.get.mockResolvedValue("5");
mockRedis.incrBy.mockResolvedValue(10);

const result = await rateLimit({
team: validTeamResponse,
limitPerSecond: 20,
serviceConfig: validServiceConfig,
redis: mockRedis,
increment: 5,
});

expect(result).toEqual({
rateLimited: false,
requestCount: 10,
rateLimit: 200,
});

// Verify redis was called with correct increment
expect(mockRedis.incrBy).toHaveBeenCalledWith(
expect.stringContaining("rate-limit"),
5,
);
});
});
Loading