From a80839a8ac9d295980f3284c056b3c7b48431148 Mon Sep 17 00:00:00 2001 From: Eric Allam Date: Fri, 7 Feb 2025 13:20:02 +0000 Subject: [PATCH] Rate limit alerts by channel for task run alerts using generic cell rate algo --- apps/webapp/app/env.server.ts | 33 +++ apps/webapp/app/v3/GCRARateLimiter.server.ts | 171 ++++++++++++++ .../webapp/app/v3/alertsRateLimiter.server.ts | 30 +++ .../v3/services/alerts/deliverAlert.server.ts | 70 +++++- .../alerts/performDeploymentAlerts.server.ts | 15 +- .../alerts/performTaskRunAlerts.server.ts | 13 +- apps/webapp/test/GCRARateLimiter.test.ts | 217 ++++++++++++++++++ 7 files changed, 528 insertions(+), 21 deletions(-) create mode 100644 apps/webapp/app/v3/GCRARateLimiter.server.ts create mode 100644 apps/webapp/app/v3/alertsRateLimiter.server.ts create mode 100644 apps/webapp/test/GCRARateLimiter.test.ts diff --git a/apps/webapp/app/env.server.ts b/apps/webapp/app/env.server.ts index 4facbfffb2..49726cfbdc 100644 --- a/apps/webapp/app/env.server.ts +++ b/apps/webapp/app/env.server.ts @@ -308,6 +308,39 @@ const EnvironmentSchema = z.object({ ALERT_SMTP_SECURE: z.coerce.boolean().optional(), ALERT_SMTP_USER: z.string().optional(), ALERT_SMTP_PASSWORD: z.string().optional(), + ALERT_RATE_LIMITER_EMISSION_INTERVAL: z.coerce.number().int().default(2_500), + ALERT_RATE_LIMITER_BURST_TOLERANCE: z.coerce.number().int().default(10_000), + ALERT_RATE_LIMITER_REDIS_HOST: z + .string() + .optional() + .transform((v) => v ?? process.env.REDIS_HOST), + ALERT_RATE_LIMITER_REDIS_READER_HOST: z + .string() + .optional() + .transform((v) => v ?? process.env.REDIS_READER_HOST), + ALERT_RATE_LIMITER_REDIS_READER_PORT: z.coerce + .number() + .optional() + .transform( + (v) => + v ?? (process.env.REDIS_READER_PORT ? parseInt(process.env.REDIS_READER_PORT) : undefined) + ), + ALERT_RATE_LIMITER_REDIS_PORT: z.coerce + .number() + .optional() + .transform((v) => v ?? (process.env.REDIS_PORT ? parseInt(process.env.REDIS_PORT) : undefined)), + ALERT_RATE_LIMITER_REDIS_USERNAME: z + .string() + .optional() + .transform((v) => v ?? process.env.REDIS_USERNAME), + ALERT_RATE_LIMITER_REDIS_PASSWORD: z + .string() + .optional() + .transform((v) => v ?? process.env.REDIS_PASSWORD), + ALERT_RATE_LIMITER_REDIS_TLS_DISABLED: z + .string() + .default(process.env.REDIS_TLS_DISABLED ?? "false"), + ALERT_RATE_LIMITER_REDIS_CLUSTER_MODE_ENABLED: z.string().default("0"), MAX_SEQUENTIAL_INDEX_FAILURE_COUNT: z.coerce.number().default(96), diff --git a/apps/webapp/app/v3/GCRARateLimiter.server.ts b/apps/webapp/app/v3/GCRARateLimiter.server.ts new file mode 100644 index 0000000000..9e705ebbed --- /dev/null +++ b/apps/webapp/app/v3/GCRARateLimiter.server.ts @@ -0,0 +1,171 @@ +import Redis, { Cluster } from "ioredis"; + +/** + * Options for configuring the RateLimiter. + */ +export interface GCRARateLimiterOptions { + /** An instance of ioredis. */ + redis: Redis | Cluster; + /** + * A string prefix to namespace keys in Redis. + * Defaults to "ratelimit:". + */ + keyPrefix?: string; + /** + * The minimum interval between requests (the emission interval) in milliseconds. + * For example, 1000 ms for one request per second. + */ + emissionInterval: number; + /** + * The burst tolerance in milliseconds. This represents how much “credit” can be + * accumulated to allow short bursts beyond the average rate. + * For example, if you want to allow 3 requests in a burst with an emission interval of 1000 ms, + * you might set this to 3000. + */ + burstTolerance: number; + /** + * Expiration for the Redis key in milliseconds. + * Defaults to the larger of 60 seconds or (emissionInterval + burstTolerance). + */ + keyExpiration?: number; +} + +/** + * The result of a rate limit check. + */ +export interface RateLimitResult { + /** Whether the request is allowed. */ + allowed: boolean; + /** + * If not allowed, this is the number of milliseconds the caller should wait + * before retrying. + */ + retryAfter?: number; +} + +/** + * A rate limiter using Redis and the Generic Cell Rate Algorithm (GCRA). + * + * The GCRA is implemented using a Lua script that runs atomically in Redis. + * + * When a request comes in, the algorithm: + * - Retrieves the current "Theoretical Arrival Time" (TAT) from Redis (or initializes it if missing). + * - If the current time is greater than or equal to the TAT, the request is allowed and the TAT is updated to now + emissionInterval. + * - Otherwise, if the current time plus the burst tolerance is at least the TAT, the request is allowed and the TAT is incremented. + * - If neither condition is met, the request is rejected and a Retry-After value is returned. + */ +export class GCRARateLimiter { + private redis: Redis | Cluster; + private keyPrefix: string; + private emissionInterval: number; + private burstTolerance: number; + private keyExpiration: number; + + constructor(options: GCRARateLimiterOptions) { + this.redis = options.redis; + this.keyPrefix = options.keyPrefix || "gcra:ratelimit:"; + this.emissionInterval = options.emissionInterval; + this.burstTolerance = options.burstTolerance; + // Default expiration: at least 60 seconds or the sum of emissionInterval and burstTolerance + this.keyExpiration = + options.keyExpiration || Math.max(60_000, this.emissionInterval + this.burstTolerance); + + // Define a custom Redis command 'gcra' that implements the GCRA algorithm. + // Using defineCommand ensures the Lua script is loaded once and run atomically. + this.redis.defineCommand("gcra", { + numberOfKeys: 1, + lua: ` +--[[ + GCRA Lua script + KEYS[1] - The rate limit key (e.g. "ratelimit:") + ARGV[1] - Current time in ms (number) + ARGV[2] - Emission interval in ms (number) + ARGV[3] - Burst tolerance in ms (number) + ARGV[4] - Key expiration in ms (number) + + Returns: { allowedFlag, value } + allowedFlag: 1 if allowed, 0 if rate-limited. + value: 0 when allowed; if not allowed, the number of ms to wait. +]]-- + +local key = KEYS[1] +local now = tonumber(ARGV[1]) +local emission_interval = tonumber(ARGV[2]) +local burst_tolerance = tonumber(ARGV[3]) +local expire = tonumber(ARGV[4]) + +-- Get the stored Theoretical Arrival Time (TAT) or default to 0. +local tat = tonumber(redis.call("GET", key) or 0) +if tat == 0 then + tat = now +end + +local allowed, new_tat, retry_after + +if now >= tat then + -- No delay: request is on schedule. + new_tat = now + emission_interval + allowed = true +elseif (now + burst_tolerance) >= tat then + -- Within burst capacity: allow request. + new_tat = tat + emission_interval + allowed = true +else + -- Request exceeds the allowed burst; calculate wait time. + allowed = false + retry_after = tat - (now + burst_tolerance) +end + +if allowed then + redis.call("SET", key, new_tat, "PX", expire) + return {1, 0} +else + return {0, retry_after} +end +`, + }); + } + + /** + * Checks whether a request associated with the given identifier is allowed. + * + * @param identifier A unique string identifying the subject of rate limiting (e.g. user ID, IP address, or domain). + * @returns A promise that resolves to a RateLimitResult. + * + * @example + * const result = await rateLimiter.check('user:12345'); + * if (!result.allowed) { + * // Tell the client to retry after result.retryAfter milliseconds. + * } + */ + async check(identifier: string): Promise { + const key = `${this.keyPrefix}${identifier}`; + const now = Date.now(); + + try { + // Call the custom 'gcra' command. + // The script returns an array: [allowedFlag, value] + // - allowedFlag: 1 if allowed; 0 if rejected. + // - value: 0 when allowed; if rejected, the number of ms to wait before retrying. + // @ts-expect-error: The custom command is defined via defineCommand. + const result: [number, number] = await this.redis.gcra( + key, + now, + this.emissionInterval, + this.burstTolerance, + this.keyExpiration + ); + const allowed = result[0] === 1; + if (allowed) { + return { allowed: true }; + } else { + return { allowed: false, retryAfter: result[1] }; + } + } catch (error) { + // In a production system you might log the error and either + // allow the request (fail open) or deny it (fail closed). + // Here we choose to propagate the error. + throw error; + } + } +} diff --git a/apps/webapp/app/v3/alertsRateLimiter.server.ts b/apps/webapp/app/v3/alertsRateLimiter.server.ts new file mode 100644 index 0000000000..b0bedb0f13 --- /dev/null +++ b/apps/webapp/app/v3/alertsRateLimiter.server.ts @@ -0,0 +1,30 @@ +import { env } from "~/env.server"; +import { createRedisClient } from "~/redis.server"; +import { GCRARateLimiter } from "./GCRARateLimiter.server"; +import { singleton } from "~/utils/singleton"; +import { logger } from "~/services/logger.server"; + +export const alertsRateLimiter = singleton("alertsRateLimiter", initializeAlertsRateLimiter); + +function initializeAlertsRateLimiter() { + const redis = createRedisClient("alerts:ratelimiter", { + keyPrefix: "alerts:ratelimiter:", + host: env.ALERT_RATE_LIMITER_REDIS_HOST, + port: env.ALERT_RATE_LIMITER_REDIS_PORT, + username: env.ALERT_RATE_LIMITER_REDIS_USERNAME, + password: env.ALERT_RATE_LIMITER_REDIS_PASSWORD, + tlsDisabled: env.ALERT_RATE_LIMITER_REDIS_TLS_DISABLED === "true", + clusterMode: env.ALERT_RATE_LIMITER_REDIS_CLUSTER_MODE_ENABLED === "1", + }); + + logger.debug(`🚦 Initializing alerts rate limiter at host ${env.ALERT_RATE_LIMITER_REDIS_HOST}`, { + emissionInterval: env.ALERT_RATE_LIMITER_EMISSION_INTERVAL, + burstTolerance: env.ALERT_RATE_LIMITER_BURST_TOLERANCE, + }); + + return new GCRARateLimiter({ + redis, + emissionInterval: env.ALERT_RATE_LIMITER_EMISSION_INTERVAL, + burstTolerance: env.ALERT_RATE_LIMITER_BURST_TOLERANCE, + }); +} diff --git a/apps/webapp/app/v3/services/alerts/deliverAlert.server.ts b/apps/webapp/app/v3/services/alerts/deliverAlert.server.ts index 282a839bff..51832bcd4d 100644 --- a/apps/webapp/app/v3/services/alerts/deliverAlert.server.ts +++ b/apps/webapp/app/v3/services/alerts/deliverAlert.server.ts @@ -9,7 +9,7 @@ import { import { TaskRunError, createJsonErrorObject } from "@trigger.dev/core/v3"; import assertNever from "assert-never"; import { subtle } from "crypto"; -import { Prisma, PrismaClientOrTransaction, prisma } from "~/db.server"; +import { Prisma, prisma, PrismaClientOrTransaction } from "~/db.server"; import { env } from "~/env.server"; import { OrgIntegrationRepository, @@ -25,10 +25,12 @@ import { DeploymentPresenter } from "~/presenters/v3/DeploymentPresenter.server" import { sendAlertEmail } from "~/services/email.server"; import { logger } from "~/services/logger.server"; import { decryptSecret } from "~/services/secrets/secretStore.server"; -import { workerQueue } from "~/services/worker.server"; -import { BaseService } from "../baseService.server"; -import { FINAL_ATTEMPT_STATUSES } from "~/v3/taskStatus"; import { commonWorker } from "~/v3/commonWorker.server"; +import { FINAL_ATTEMPT_STATUSES } from "~/v3/taskStatus"; +import { BaseService } from "../baseService.server"; +import { generateFriendlyId } from "~/v3/friendlyIdentifiers"; +import { ProjectAlertType } from "@trigger.dev/database"; +import { alertsRateLimiter } from "~/v3/alertsRateLimiter.server"; type FoundAlert = Prisma.Result< typeof prisma.projectAlert, @@ -1101,6 +1103,66 @@ export class DeliverAlertService extends BaseService { availableAt: runAt, }); } + + static async createAndSendAlert( + { + channelId, + projectId, + environmentId, + alertType, + deploymentId, + taskRunId, + }: { + channelId: string; + projectId: string; + environmentId: string; + alertType: ProjectAlertType; + deploymentId?: string; + taskRunId?: string; + }, + db: PrismaClientOrTransaction + ) { + if (taskRunId) { + try { + const result = await alertsRateLimiter.check(channelId); + + if (!result.allowed) { + logger.warn("[DeliverAlert] Rate limited", { + taskRunId, + environmentId, + alertType, + channelId, + result, + }); + + return; + } + } catch (error) { + logger.error("[DeliverAlert] Rate limiter error", { + taskRunId, + environmentId, + alertType, + channelId, + error, + }); + } + } + + const alert = await db.projectAlert.create({ + data: { + friendlyId: generateFriendlyId("alert"), + channelId, + projectId, + environmentId, + status: "PENDING", + type: alertType, + workerDeploymentId: deploymentId, + taskRunId, + }, + }); + + await DeliverAlertService.enqueue(alert.id); + } } function isWebAPIPlatformError(error: unknown): error is WebAPIPlatformError { diff --git a/apps/webapp/app/v3/services/alerts/performDeploymentAlerts.server.ts b/apps/webapp/app/v3/services/alerts/performDeploymentAlerts.server.ts index 4bbe7b50cf..7d8a71c586 100644 --- a/apps/webapp/app/v3/services/alerts/performDeploymentAlerts.server.ts +++ b/apps/webapp/app/v3/services/alerts/performDeploymentAlerts.server.ts @@ -46,19 +46,16 @@ export class PerformDeploymentAlertsService extends BaseService { deployment: WorkerDeployment, alertType: ProjectAlertType ) { - const alert = await this._prisma.projectAlert.create({ - data: { - friendlyId: generateFriendlyId("alert"), + await DeliverAlertService.createAndSendAlert( + { channelId: alertChannel.id, projectId: deployment.projectId, environmentId: deployment.environmentId, - status: "PENDING", - type: alertType, - workerDeploymentId: deployment.id, + alertType, + deploymentId: deployment.id, }, - }); - - await DeliverAlertService.enqueue(alert.id); + this._prisma + ); } static async enqueue(deploymentId: string, runAt?: Date) { diff --git a/apps/webapp/app/v3/services/alerts/performTaskRunAlerts.server.ts b/apps/webapp/app/v3/services/alerts/performTaskRunAlerts.server.ts index 0706bd0192..8b88a3f9db 100644 --- a/apps/webapp/app/v3/services/alerts/performTaskRunAlerts.server.ts +++ b/apps/webapp/app/v3/services/alerts/performTaskRunAlerts.server.ts @@ -46,19 +46,16 @@ export class PerformTaskRunAlertsService extends BaseService { } async #createAndSendAlert(alertChannel: ProjectAlertChannel, run: FoundRun) { - const alert = await this._prisma.projectAlert.create({ - data: { - friendlyId: generateFriendlyId("alert"), + await DeliverAlertService.createAndSendAlert( + { channelId: alertChannel.id, projectId: run.projectId, environmentId: run.runtimeEnvironmentId, - status: "PENDING", - type: "TASK_RUN", + alertType: "TASK_RUN", taskRunId: run.id, }, - }); - - await DeliverAlertService.enqueue(alert.id); + this._prisma + ); } static async enqueue(runId: string, runAt?: Date) { diff --git a/apps/webapp/test/GCRARateLimiter.test.ts b/apps/webapp/test/GCRARateLimiter.test.ts new file mode 100644 index 0000000000..9c645310c0 --- /dev/null +++ b/apps/webapp/test/GCRARateLimiter.test.ts @@ -0,0 +1,217 @@ +// GCRARateLimiter.test.ts +import { redisTest } from "@internal/testcontainers"; +import { describe, expect, vi } from "vitest"; +import { GCRARateLimiter } from "../app/v3/GCRARateLimiter.server.js"; // adjust the import as needed + +// Extend the timeout to 30 seconds (as in your redis tests) +vi.setConfig({ testTimeout: 30_000 }); + +describe("GCRARateLimiter", () => { + redisTest("should allow a single request when under the rate limit", async ({ redis }) => { + const limiter = new GCRARateLimiter({ + redis, + emissionInterval: 1000, // 1 request per second on average + burstTolerance: 3000, // Allows a burst of 4 requests (3 * 1000 + 1) + keyPrefix: "test:ratelimit:", + }); + + const result = await limiter.check("user:1"); + expect(result.allowed).toBe(true); + }); + + redisTest( + "should allow bursts up to the configured limit and then reject further requests", + async ({ redis }) => { + const limiter = new GCRARateLimiter({ + redis, + emissionInterval: 1000, + burstTolerance: 3000, // With an emission interval of 1000ms, burstTolerance of 3000ms allows 4 rapid requests. + keyPrefix: "test:ratelimit:", + }); + + // Call 4 times in rapid succession (all should be allowed) + const results = await Promise.all([ + limiter.check("user:burst"), + limiter.check("user:burst"), + limiter.check("user:burst"), + limiter.check("user:burst"), + ]); + results.forEach((result) => expect(result.allowed).toBe(true)); + + // The 5th call should be rejected. + const fifthResult = await limiter.check("user:burst"); + expect(fifthResult.allowed).toBe(false); + expect(fifthResult.retryAfter).toBeGreaterThan(0); + } + ); + + redisTest("should allow a request after the required waiting period", async ({ redis }) => { + const limiter = new GCRARateLimiter({ + redis, + emissionInterval: 1000, + burstTolerance: 3000, + keyPrefix: "test:ratelimit:", + }); + + // Exhaust burst capacity with 4 rapid calls. + await limiter.check("user:wait"); + await limiter.check("user:wait"); + await limiter.check("user:wait"); + await limiter.check("user:wait"); + + // The 5th call should be rejected. + const rejection = await limiter.check("user:wait"); + expect(rejection.allowed).toBe(false); + expect(rejection.retryAfter).toBeGreaterThan(0); + + // Wait for the period specified in retryAfter (plus a small buffer) + await new Promise((resolve) => setTimeout(resolve, rejection.retryAfter! + 50)); + + // Now the next call should be allowed. + const allowedAfterWait = await limiter.check("user:wait"); + expect(allowedAfterWait.allowed).toBe(true); + }); + + redisTest("should rate limit independently for different identifiers", async ({ redis }) => { + const limiter = new GCRARateLimiter({ + redis, + emissionInterval: 1000, + burstTolerance: 3000, + keyPrefix: "test:ratelimit:", + }); + + // For "user:independent", exhaust burst capacity. + await limiter.check("user:independent"); + await limiter.check("user:independent"); + await limiter.check("user:independent"); + await limiter.check("user:independent"); + const rejected = await limiter.check("user:independent"); + expect(rejected.allowed).toBe(false); + + // A different identifier should start fresh. + const fresh = await limiter.check("user:different"); + expect(fresh.allowed).toBe(true); + }); + + redisTest("should gradually reduce retryAfter with time", async ({ redis }) => { + const limiter = new GCRARateLimiter({ + redis, + emissionInterval: 1000, + burstTolerance: 3000, + keyPrefix: "test:ratelimit:", + }); + + // Exhaust the burst capacity. + await limiter.check("user:gradual"); + await limiter.check("user:gradual"); + await limiter.check("user:gradual"); + await limiter.check("user:gradual"); + + const firstRejection = await limiter.check("user:gradual"); + expect(firstRejection.allowed).toBe(false); + const firstRetry = firstRejection.retryAfter!; + + // Wait 500ms, then perform another check. + await new Promise((resolve) => setTimeout(resolve, 500)); + const secondRejection = await limiter.check("user:gradual"); + // It should still be rejected but with a smaller wait time. + expect(secondRejection.allowed).toBe(false); + const secondRetry = secondRejection.retryAfter!; + expect(secondRetry).toBeLessThan(firstRetry); + }); + + redisTest("should expire the key after the TTL", async ({ redis }) => { + // For this test, override keyExpiration to a short value. + const keyExpiration = 1500; // 1.5 seconds TTL + const limiter = new GCRARateLimiter({ + redis, + emissionInterval: 100, + burstTolerance: 300, // These values are arbitrary for this test. + keyPrefix: "test:expire:", + keyExpiration, + }); + const identifier = "user:expire"; + + // Make a call to set the key. + const result = await limiter.check(identifier); + expect(result.allowed).toBe(true); + + // Immediately verify the key exists. + const key = `test:expire:${identifier}`; + let stored = await redis.get(key); + expect(stored).not.toBeNull(); + + // Wait for longer than keyExpiration. + await new Promise((resolve) => setTimeout(resolve, keyExpiration + 200)); + stored = await redis.get(key); + expect(stored).toBeNull(); + }); + + redisTest("should not share state across different key prefixes", async ({ redis }) => { + const limiter1 = new GCRARateLimiter({ + redis, + emissionInterval: 1000, + burstTolerance: 3000, + keyPrefix: "test:ratelimit1:", + }); + const limiter2 = new GCRARateLimiter({ + redis, + emissionInterval: 1000, + burstTolerance: 3000, + keyPrefix: "test:ratelimit2:", + }); + + // Exhaust the burst capacity for a given identifier in limiter1. + await limiter1.check("user:shared"); + await limiter1.check("user:shared"); + await limiter1.check("user:shared"); + await limiter1.check("user:shared"); + const rejection1 = await limiter1.check("user:shared"); + expect(rejection1.allowed).toBe(false); + + // With a different key prefix, the same identifier should be fresh. + const result2 = await limiter2.check("user:shared"); + expect(result2.allowed).toBe(true); + }); + + redisTest("should increment TAT correctly on sequential allowed requests", async ({ redis }) => { + const limiter = new GCRARateLimiter({ + redis, + emissionInterval: 1000, + burstTolerance: 3000, + keyPrefix: "test:ratelimit:", + }); + + // The first request should be allowed. + const r1 = await limiter.check("user:sequential"); + expect(r1.allowed).toBe(true); + + // Wait a bit longer than the emission interval. + await new Promise((resolve) => setTimeout(resolve, 1100)); + const r2 = await limiter.check("user:sequential"); + expect(r2.allowed).toBe(true); + }); + + redisTest("should throw an error if redis command fails", async ({ redis }) => { + const limiter = new GCRARateLimiter({ + redis, + emissionInterval: 1000, + burstTolerance: 3000, + keyPrefix: "test:ratelimit:", + }); + + // Stub redis.gcra to simulate a failure. + // @ts-expect-error + const originalGcra = redis.gcra; + // @ts-ignore + redis.gcra = vi.fn(() => { + throw new Error("Simulated Redis error"); + }); + + await expect(limiter.check("user:error")).rejects.toThrow("Simulated Redis error"); + + // Restore the original command. + // @ts-expect-error + redis.gcra = originalGcra; + }); +});