diff --git a/apps/webapp/app/env.server.ts b/apps/webapp/app/env.server.ts index 9136e44a1a..5bd44f35a3 100644 --- a/apps/webapp/app/env.server.ts +++ b/apps/webapp/app/env.server.ts @@ -428,6 +428,10 @@ const EnvironmentSchema = z.object({ RUN_ENGINE_PROCESS_WORKER_QUEUE_DEBOUNCE_MS: z.coerce.number().int().default(200), RUN_ENGINE_DEQUEUE_BLOCKING_TIMEOUT_SECONDS: z.coerce.number().int().default(10), RUN_ENGINE_MASTER_QUEUE_CONSUMERS_INTERVAL_MS: z.coerce.number().int().default(500), + RUN_ENGINE_CONCURRENCY_SWEEPER_SCAN_SCHEDULE: z.string().optional(), + RUN_ENGINE_CONCURRENCY_SWEEPER_PROCESS_MARKED_SCHEDULE: z.string().optional(), + RUN_ENGINE_CONCURRENCY_SWEEPER_SCAN_JITTER_IN_MS: z.coerce.number().int().optional(), + RUN_ENGINE_CONCURRENCY_SWEEPER_PROCESS_MARKED_JITTER_IN_MS: z.coerce.number().int().optional(), RUN_ENGINE_RUN_LOCK_DURATION: z.coerce.number().int().default(5000), RUN_ENGINE_RUN_LOCK_AUTOMATIC_EXTENSION_THRESHOLD: z.coerce.number().int().default(1000), @@ -593,6 +597,7 @@ const EnvironmentSchema = z.object({ RUN_ENGINE_WORKER_ENABLED: z.string().default("1"), RUN_ENGINE_WORKER_LOG_LEVEL: z.enum(["log", "error", "warn", "info", "debug"]).default("info"), + RUN_ENGINE_RUN_QUEUE_LOG_LEVEL: z.enum(["log", "error", "warn", "info", "debug"]).default("info"), /** How long should the presence ttl last */ DEV_PRESENCE_SSE_TIMEOUT: z.coerce.number().int().default(30_000), diff --git a/apps/webapp/app/v3/runEngine.server.ts b/apps/webapp/app/v3/runEngine.server.ts index 559c363988..6c9a11c2a8 100644 --- a/apps/webapp/app/v3/runEngine.server.ts +++ b/apps/webapp/app/v3/runEngine.server.ts @@ -1,10 +1,10 @@ import { RunEngine } from "@internal/run-engine"; -import { defaultMachine } from "~/services/platform.v3.server"; -import { prisma } from "~/db.server"; +import { $replica, prisma } from "~/db.server"; import { env } from "~/env.server"; +import { defaultMachine } from "~/services/platform.v3.server"; import { singleton } from "~/utils/singleton"; import { allMachines } from "./machinePresets.server"; -import { tracer, meter } from "./tracer.server"; +import { meter, tracer } from "./tracer.server"; export const engine = singleton("RunEngine", createRunEngine); @@ -13,6 +13,7 @@ export type { RunEngine }; function createRunEngine() { const engine = new RunEngine({ prisma, + readOnlyPrisma: $replica, logLevel: env.RUN_ENGINE_WORKER_LOG_LEVEL, worker: { disabled: env.RUN_ENGINE_WORKER_ENABLED === "0", @@ -39,6 +40,7 @@ function createRunEngine() { }, queue: { defaultEnvConcurrency: env.DEFAULT_ENV_EXECUTION_CONCURRENCY_LIMIT, + logLevel: env.RUN_ENGINE_RUN_QUEUE_LOG_LEVEL, redis: { keyPrefix: "engine:", port: env.RUN_ENGINE_RUN_QUEUE_REDIS_PORT ?? undefined, @@ -64,6 +66,12 @@ function createRunEngine() { dequeueBlockingTimeoutSeconds: env.RUN_ENGINE_DEQUEUE_BLOCKING_TIMEOUT_SECONDS, masterQueueConsumersIntervalMs: env.RUN_ENGINE_MASTER_QUEUE_CONSUMERS_INTERVAL_MS, masterQueueConsumersDisabled: env.RUN_ENGINE_WORKER_ENABLED === "0", + concurrencySweeper: { + scanSchedule: env.RUN_ENGINE_CONCURRENCY_SWEEPER_SCAN_SCHEDULE, + processMarkedSchedule: env.RUN_ENGINE_CONCURRENCY_SWEEPER_PROCESS_MARKED_SCHEDULE, + scanJitterInMs: env.RUN_ENGINE_CONCURRENCY_SWEEPER_SCAN_JITTER_IN_MS, + processMarkedJitterInMs: env.RUN_ENGINE_CONCURRENCY_SWEEPER_PROCESS_MARKED_JITTER_IN_MS, + }, }, runLock: { redis: { diff --git a/apps/webapp/test/engine/triggerTask.test.ts b/apps/webapp/test/engine/triggerTask.test.ts index 8aa0fddb79..d90ae99b84 100644 --- a/apps/webapp/test/engine/triggerTask.test.ts +++ b/apps/webapp/test/engine/triggerTask.test.ts @@ -3,6 +3,7 @@ import { describe, expect, vi } from "vitest"; // Mock the db prisma client vi.mock("~/db.server", () => ({ prisma: {}, + $replica: {}, })); vi.mock("~/services/platform.v3.server", async (importOriginal) => { diff --git a/internal-packages/run-engine/src/engine/index.ts b/internal-packages/run-engine/src/engine/index.ts index aca188cd2d..8ed9febb31 100644 --- a/internal-packages/run-engine/src/engine/index.ts +++ b/internal-packages/run-engine/src/engine/index.ts @@ -16,6 +16,7 @@ import { Prisma, PrismaClient, PrismaClientOrTransaction, + PrismaReplicaClient, TaskRun, TaskRunExecutionSnapshot, Waitpoint, @@ -50,6 +51,7 @@ import { TtlSystem } from "./systems/ttlSystem.js"; import { WaitpointSystem } from "./systems/waitpointSystem.js"; import { EngineWorker, HeartbeatTimeouts, RunEngineOptions, TriggerParams } from "./types.js"; import { workerCatalog } from "./workerCatalog.js"; +import { getFinalRunStatuses, isFinalRunStatus } from "./statuses.js"; export class RunEngine { private runLockRedis: Redis; @@ -61,6 +63,7 @@ export class RunEngine { private heartbeatTimeouts: HeartbeatTimeouts; prisma: PrismaClient; + readOnlyPrisma: PrismaReplicaClient; runQueue: RunQueue; eventBus: EventBus = new EventEmitter(); executionSnapshotSystem: ExecutionSnapshotSystem; @@ -79,6 +82,7 @@ export class RunEngine { constructor(private readonly options: RunEngineOptions) { this.logger = options.logger ?? new Logger("RunEngine", this.options.logLevel ?? "info"); this.prisma = options.prisma; + this.readOnlyPrisma = options.readOnlyPrisma ?? this.prisma; this.runLockRedis = createRedisClient( { ...options.runLock.redis, @@ -123,7 +127,7 @@ export class RunEngine { defaultEnvConcurrencyLimit: options.queue?.defaultEnvConcurrency ?? 10, }), defaultEnvConcurrency: options.queue?.defaultEnvConcurrency ?? 10, - logger: new Logger("RunQueue", this.options.logLevel ?? "info"), + logger: new Logger("RunQueue", options.queue?.logLevel ?? "info"), redis: { ...options.queue.redis, keyPrefix: `${options.queue.redis.keyPrefix}runqueue:` }, retryOptions: options.queue?.retryOptions, workerOptions: { @@ -133,6 +137,13 @@ export class RunEngine { immediatePollIntervalMs: options.worker.immediatePollIntervalMs, shutdownTimeoutMs: options.worker.shutdownTimeoutMs, }, + concurrencySweeper: { + scanSchedule: options.queue?.concurrencySweeper?.scanSchedule, + processMarkedSchedule: options.queue?.concurrencySweeper?.processMarkedSchedule, + scanJitterInMs: options.queue?.concurrencySweeper?.scanJitterInMs, + processMarkedJitterInMs: options.queue?.concurrencySweeper?.processMarkedJitterInMs, + callback: this.#concurrencySweeperCallback.bind(this), + }, shardCount: options.queue?.shardCount, masterQueueConsumersDisabled: options.queue?.masterQueueConsumersDisabled, masterQueueConsumersIntervalMs: options.queue?.masterQueueConsumersIntervalMs, @@ -1329,4 +1340,44 @@ export class RunEngine { } }); } + + async #concurrencySweeperCallback( + runIds: string[] + ): Promise> { + const runs = await this.readOnlyPrisma.taskRun.findMany({ + where: { + id: { in: runIds }, + completedAt: { + lte: new Date(Date.now() - 1000 * 60 * 10), // This only finds runs that were completed more than 10 minutes ago + }, + organizationId: { + not: null, + }, + status: { + in: getFinalRunStatuses(), + }, + }, + select: { + id: true, + status: true, + organizationId: true, + }, + }); + + // Log the finished runs + for (const run of runs) { + this.logger.info("Concurrency sweeper callback found finished run", { + runId: run.id, + orgId: run.organizationId, + status: run.status, + }); + } + + return runs + .filter((run) => !!run.organizationId) + .map((run) => ({ + id: run.id, + orgId: run.organizationId!, + })); + } } diff --git a/internal-packages/run-engine/src/engine/statuses.ts b/internal-packages/run-engine/src/engine/statuses.ts index f8a66240d3..36f0825013 100644 --- a/internal-packages/run-engine/src/engine/statuses.ts +++ b/internal-packages/run-engine/src/engine/statuses.ts @@ -41,21 +41,25 @@ export function isInitialState(status: TaskRunExecutionStatus): boolean { return startedStatuses.includes(status); } -export function isFinalRunStatus(status: TaskRunStatus): boolean { - const finalStatuses: TaskRunStatus[] = [ - "CANCELED", - "INTERRUPTED", - "COMPLETED_SUCCESSFULLY", - "COMPLETED_WITH_ERRORS", - "SYSTEM_FAILURE", - "CRASHED", - "EXPIRED", - "TIMED_OUT", - ]; +const finalStatuses: TaskRunStatus[] = [ + "CANCELED", + "INTERRUPTED", + "COMPLETED_SUCCESSFULLY", + "COMPLETED_WITH_ERRORS", + "SYSTEM_FAILURE", + "CRASHED", + "EXPIRED", + "TIMED_OUT", +]; +export function isFinalRunStatus(status: TaskRunStatus): boolean { return finalStatuses.includes(status); } +export function getFinalRunStatuses(): TaskRunStatus[] { + return finalStatuses; +} + export function canReleaseConcurrency(status: TaskRunExecutionStatus): boolean { const releaseableStatuses: TaskRunExecutionStatus[] = ["SUSPENDED", "EXECUTING_WITH_WAITPOINTS"]; return releaseableStatuses.includes(status); diff --git a/internal-packages/run-engine/src/engine/types.ts b/internal-packages/run-engine/src/engine/types.ts index 286545c85e..87612b2bfa 100644 --- a/internal-packages/run-engine/src/engine/types.ts +++ b/internal-packages/run-engine/src/engine/types.ts @@ -8,7 +8,7 @@ import { RetryOptions, RunChainState, } from "@trigger.dev/core/v3"; -import { PrismaClient } from "@trigger.dev/database"; +import { PrismaClient, PrismaReplicaClient } from "@trigger.dev/database"; import { FairQueueSelectionStrategyOptions } from "../run-queue/fairQueueSelectionStrategy.js"; import { MinimalAuthenticatedEnvironment } from "../shared/index.js"; import { workerCatalog } from "./workerCatalog.js"; @@ -17,6 +17,7 @@ import { LockRetryConfig } from "./locking.js"; export type RunEngineOptions = { prisma: PrismaClient; + readOnlyPrisma?: PrismaReplicaClient; worker: { disabled?: boolean; redis: RedisOptions; @@ -38,11 +39,18 @@ export type RunEngineOptions = { workerOptions?: WorkerConcurrencyOptions; retryOptions?: RetryOptions; defaultEnvConcurrency?: number; + logLevel?: LogLevel; queueSelectionStrategyOptions?: Pick< FairQueueSelectionStrategyOptions, "parentQueueLimit" | "tracer" | "biases" | "reuseSnapshotCount" | "maximumEnvCount" >; dequeueBlockingTimeoutSeconds?: number; + concurrencySweeper?: { + scanSchedule?: string; + processMarkedSchedule?: string; + scanJitterInMs?: number; + processMarkedJitterInMs?: number; + }; }; runLock: { redis: RedisOptions; diff --git a/internal-packages/run-engine/src/run-queue/index.ts b/internal-packages/run-engine/src/run-queue/index.ts index f5772a9d6c..0b6d5c1d1d 100644 --- a/internal-packages/run-engine/src/run-queue/index.ts +++ b/internal-packages/run-engine/src/run-queue/index.ts @@ -36,11 +36,13 @@ import { type Result, } from "@internal/redis"; import { MessageNotFoundError } from "./errors.js"; -import { tryCatch } from "@trigger.dev/core"; +import { promiseWithResolvers, tryCatch } from "@trigger.dev/core"; import { setInterval } from "node:timers/promises"; import { nanoid } from "nanoid"; -import { Worker, type WorkerConcurrencyOptions } from "@trigger.dev/redis-worker"; +import { CronSchema, Worker, type WorkerConcurrencyOptions } from "@trigger.dev/redis-worker"; import { z } from "zod"; +import { Readable } from "node:stream"; +import { setTimeout } from "node:timers/promises"; const SemanticAttributes = { QUEUE: "runqueue.queue", @@ -78,14 +80,31 @@ export type RunQueueOptions = { }; meter?: Meter; dequeueBlockingTimeoutSeconds?: number; + concurrencySweeper?: { + scanSchedule?: string; + scanJitterInMs?: number; + processMarkedSchedule?: string; + processMarkedJitterInMs?: number; + callback: ConcurrencySweeperCallback; + }; }; +export interface ConcurrencySweeperCallback { + (runIds: string[]): Promise>; +} + type DequeuedMessage = { messageId: string; messageScore: string; message: OutputPayload; }; +type MarkedRun = { + orgId: string; + messageId: string; + score: number; +}; + const defaultRetrySettings = { maxAttempts: 12, factor: 2, @@ -102,6 +121,24 @@ const workerCatalog = { }), visibilityTimeoutMs: 30_000, }, + scanConcurrencySets: { + schema: CronSchema, + visibilityTimeoutMs: 60_000 * 5, + cron: "*/10 * * * *", + jitterInMs: 60_000, + retry: { + maxAttempts: 1, + }, + }, + processMarkedRuns: { + schema: CronSchema, + visibilityTimeoutMs: 60_000 * 5, + cron: "*/5 * * * *", + jitterInMs: 30_000, + retry: { + maxAttempts: 1, + }, + }, }; /** @@ -112,7 +149,7 @@ export class RunQueue { private subscriber: Redis; private luaDebugSubscriber: Redis; private logger: Logger; - private redis: Redis; + public redis: Redis; public keys: RunQueueKeyProducer; private queueSelectionStrategy: RunQueueSelectionStrategy; private shardCount: number; @@ -121,7 +158,7 @@ export class RunQueue { private _observableWorkerQueues: Set = new Set(); private _meter: Meter; - constructor(private readonly options: RunQueueOptions) { + constructor(public readonly options: RunQueueOptions) { this.shardCount = options.shardCount ?? 2; this.retryOptions = options.retryOptions ?? defaultRetrySettings; this.redis = createRedisClient(options.redis, { @@ -170,16 +207,40 @@ export class RunQueue { ...options.redis, keyPrefix: `${options.redis.keyPrefix}:worker`, }, - catalog: workerCatalog, + catalog: { + ...workerCatalog, + scanConcurrencySets: { + ...workerCatalog.scanConcurrencySets, + cron: options.concurrencySweeper?.scanSchedule ?? workerCatalog.scanConcurrencySets.cron, + jitter: + options.concurrencySweeper?.scanJitterInMs ?? + workerCatalog.scanConcurrencySets.jitterInMs, + }, + processMarkedRuns: { + ...workerCatalog.processMarkedRuns, + cron: + options.concurrencySweeper?.processMarkedSchedule ?? + workerCatalog.processMarkedRuns.cron, + jitterInMs: + options.concurrencySweeper?.processMarkedJitterInMs ?? + workerCatalog.processMarkedRuns.jitterInMs, + }, + }, concurrency: options.workerOptions?.concurrency, pollIntervalMs: options.workerOptions?.pollIntervalMs ?? 1000, immediatePollIntervalMs: options.workerOptions?.immediatePollIntervalMs ?? 100, shutdownTimeoutMs: options.workerOptions?.shutdownTimeoutMs ?? 10_000, - logger: new Logger("RunQueueWorker", options.logLevel ?? "log"), + logger: new Logger("RunQueueWorker", options.logLevel ?? "info"), jobs: { processQueueForWorkerQueue: async (job) => { await this.#processQueueForWorkerQueue(job.payload.queueKey, job.payload.environmentId); }, + scanConcurrencySets: async (job) => { + await this.scanConcurrencySets(); + }, + processMarkedRuns: async (job) => { + await this.processMarkedRuns(); + }, }, }); @@ -404,38 +465,7 @@ export class RunQueue { } public async readMessage(orgId: string, messageId: string) { - return this.#trace( - "readMessage", - async (span) => { - const rawMessage = await this.redis.get(this.keys.messageKey(orgId, messageId)); - - if (!rawMessage) { - return; - } - - const message = OutputPayload.safeParse(JSON.parse(rawMessage)); - - if (!message.success) { - this.logger.error(`[${this.name}] Failed to parse message`, { - messageId, - error: message.error, - service: this.name, - }); - - return; - } - - return message.data; - }, - { - attributes: { - [SEMATTRS_MESSAGING_OPERATION]: "receive", - [SEMATTRS_MESSAGE_ID]: messageId, - [SEMATTRS_MESSAGING_SYSTEM]: "marqs", - [SemanticAttributes.RUN_ID]: messageId, - }, - } - ); + return this.readMessageFromKey(this.keys.messageKey(orgId, messageId)); } public async readMessageFromKey(messageKey: string) { @@ -448,24 +478,34 @@ export class RunQueue { return; } - const message = OutputPayload.safeParse(JSON.parse(rawMessage)); + const deserializedMessage = safeJsonParse(rawMessage); + + const message = OutputPayload.safeParse(deserializedMessage); if (!message.success) { this.logger.error(`[${this.name}] Failed to parse message`, { messageKey, error: message.error, service: this.name, + deserializedMessage, }); - return; + return deserializedMessage as OutputPayload; } + span.setAttributes({ + [SemanticAttributes.QUEUE]: message.data.queue, + [SemanticAttributes.RUN_ID]: message.data.runId, + [SemanticAttributes.CONCURRENCY_KEY]: message.data.concurrencyKey, + [SemanticAttributes.WORKER_QUEUE]: this.#getWorkerQueueFromMessage(message.data), + }); + return message.data; }, { attributes: { [SEMATTRS_MESSAGING_OPERATION]: "receive", - [SEMATTRS_MESSAGING_SYSTEM]: "marqs", + [SEMATTRS_MESSAGING_SYSTEM]: "runqueue", }, } ); @@ -896,6 +936,35 @@ export class RunQueue { return await this.redis.lrange(workerQueueKey, 0, -1); } + /** + * Create a scan stream for queue current concurrency keys + */ + public currentConcurrencyScanStream( + count: number = 10, + onEnd?: () => void, + onError?: (error: Error) => void + ): { stream: Readable; redis: Redis } { + const pattern = this.keys.currentConcurrencySetKeyScanPattern(); + const stream = this.redis.scanStream({ + match: pattern, + count, + type: "set", + }); + + if (onEnd) { + stream.on("end", onEnd); + } + + if (onError) { + stream.on("error", onError); + } + + return { + stream, + redis: this.redis, + }; + } + private async handleRedriveMessage(channel: string, message: string) { try { const { runId, envId, projectId, orgId } = JSON.parse(message) as any; @@ -1651,6 +1720,249 @@ export class RunQueue { return blockingClient; } + // Call this every 10 minutes + private async scanConcurrencySets() { + if (this.abortController.signal.aborted) { + return; + } + + this.logger.debug("Scanning concurrency sets for completed runs"); + + const stats = { + streamCallbacks: 0, + processedKeys: 0, + }; + + const { promise, resolve, reject } = promiseWithResolvers(); + + const { stream, redis } = this.currentConcurrencyScanStream( + 10, + () => { + this.logger.debug("Concurrency scan stream closed", { stats }); + + resolve(stats); + }, + (error) => { + this.logger.error("Concurrency scan stream error", { + stats, + error: { + name: error.name, + message: error.message, + stack: error.stack, + }, + }); + + reject(error); + } + ); + + stream.on("data", async (keys: string[]) => { + if (!keys || keys.length === 0) { + return; + } + + stream.pause(); + + if (this.abortController.signal.aborted) { + stream.destroy(); + return; + } + + stats.streamCallbacks++; + + const uniqueKeys = Array.from(new Set(keys)).map((key) => + key.replace(redis.options.keyPrefix ?? "", "") + ); + + if (uniqueKeys.length === 0) { + stream.resume(); + return; + } + + this.logger.debug("Processing concurrency keys from stream", { + keys: uniqueKeys, + }); + + stats.processedKeys += uniqueKeys.length; + + await Promise.allSettled(uniqueKeys.map((key) => this.processConcurrencySet(key))).finally( + () => { + stream.resume(); + } + ); + }); + + return promise; + } + + private async processConcurrencySet(concurrencyKey: string) { + const stream = this.redis.sscanStream(concurrencyKey, { + count: 100, + }); + + const { promise, resolve, reject } = promiseWithResolvers(); + + stream.on("end", () => { + resolve(); + }); + + stream.on("error", (error) => { + this.logger.error("Error in sscanStream for concurrency set", { + concurrencyKey, + error, + }); + + reject(error); + }); + + stream.on("data", async (runIds: string[]) => { + stream.pause(); + + if (this.abortController.signal.aborted) { + stream.destroy(); + return; + } + + if (!runIds || runIds.length === 0) { + stream.resume(); + return; + } + + const deduplicatedRunIds = Array.from(new Set(runIds)); + + const [processError] = await tryCatch( + this.processCurrentConcurrencyRunIds(concurrencyKey, deduplicatedRunIds) + ); + + if (processError) { + this.logger.error("Error processing concurrency set", { + concurrencyKey, + runIds, + error: processError, + }); + } + + stream.resume(); + }); + + return promise; + } + + private async processCurrentConcurrencyRunIds(concurrencyKey: string, runIds: string[]) { + this.logger.debug(`Processing concurrency set with ${runIds.length} runs`, { + concurrencyKey, + runIds: runIds.slice(0, 5), // Log first 5 for debugging + }); + + // Call the callback to determine which runs are completed + const completedRuns = await this.options.concurrencySweeper?.callback(runIds); + + if (!completedRuns) { + this.logger.debug("No completed runs found in concurrency set", { concurrencyKey }); + return; + } + + if (completedRuns.length === 0) { + this.logger.debug("No completed runs found in concurrency set", { concurrencyKey }); + return; + } + + this.logger.debug(`Found ${completedRuns.length} completed runs to mark for ack`, { + concurrencyKey, + completedRunIds: completedRuns.map((r) => r.id).slice(0, 5), + }); + + // Mark the completed runs for acknowledgment + await this.markRunsForAck(completedRuns); + } + + private async markRunsForAck(completedRuns: Array<{ id: string; orgId: string }>) { + const markedForAckKey = this.keys.markedForAckKey(); + + // Prepare arguments: alternating orgId, messageId pairs + const args: Array = []; + for (const run of completedRuns) { + this.logger.info("Marking run for acknowledgment", { + orgId: run.orgId, + runId: run.id, + }); + + args.push(Date.now()); + args.push(`${run.orgId}:${run.id}`); + } + + const count = await this.redis.zadd(markedForAckKey, ...args); + + this.logger.debug(`Marked ${count} runs for acknowledgment`, { + markedForAckKey, + count, + }); + } + + // Call this every 5 minutes + private async processMarkedRuns() { + if (this.abortController.signal.aborted) { + return; + } + + try { + const markedForAckKey = this.keys.markedForAckKey(); + const results = await this.redis.getMarkedRunsForAck(markedForAckKey, "100"); + + if (results.length === 0) { + return; + } + + const markedRuns: MarkedRun[] = []; + + // Parse results: [orgId1, messageId1, score1, orgId2, messageId2, score2, ...] + for (let i = 0; i < results.length; i += 3) { + markedRuns.push({ + orgId: results[i], + messageId: results[i + 1], + score: Number(results[i + 2]), + }); + } + + this.logger.debug(`Processing ${markedRuns.length} marked runs for acknowledgment`, { + markedRuns: markedRuns, // Log first 3 for debugging + }); + + for (const run of markedRuns) { + const [processError] = await tryCatch(this.processMarkedRun(run)); + + if (processError) { + this.logger.error("Error processing marked run", { + error: processError, + orgId: run.orgId, + messageId: run.messageId, + }); + } + } + + const shouldProcessMoreRuns = (await this.redis.zcard(markedForAckKey)) > 0; + + if (shouldProcessMoreRuns) { + await setTimeout(1000); + await this.processMarkedRuns(); + } + } catch (error) { + this.logger.error("Error processing marked runs", { error }); + } + } + + async processMarkedRun(run: MarkedRun) { + this.logger.info("Acknowledging marked run", { + orgId: run.orgId, + messageId: run.messageId, + }); + + await this.acknowledgeMessage(run.orgId, run.messageId, { + skipDequeueProcessing: true, + removeFromWorkerQueue: false, + }); + } + #registerCommands() { this.redis.defineCommand("migrateLegacyMasterQueues", { numberOfKeys: 1, @@ -2020,6 +2332,77 @@ local envConcurrencyLimit = ARGV[1] redis.call('SET', envConcurrencyLimitKey, envConcurrencyLimit) `, }); + + this.redis.defineCommand("markCompletedRunsForAck", { + numberOfKeys: 1, + lua: ` +-- Keys: +local markedForAckKey = KEYS[1] + +-- Args: alternating orgId, messageId pairs +local currentTime = tonumber(redis.call('TIME')[1]) * 1000 + +for i = 1, #ARGV, 2 do + local orgId = ARGV[i] + local messageId = ARGV[i + 1] + local markedValue = orgId .. ':' .. messageId + + redis.call('ZADD', markedForAckKey, currentTime, markedValue) +end + +return #ARGV / 2 + `, + }); + + this.redis.defineCommand("getMarkedRunsForAck", { + numberOfKeys: 1, + lua: ` +-- Keys: +local markedForAckKey = KEYS[1] + +-- Args: +local maxCount = tonumber(ARGV[1] or '10') + +-- Get the oldest marked runs +local markedRuns = redis.call('ZRANGE', markedForAckKey, 0, maxCount - 1, 'WITHSCORES') + +local results = {} +for i = 1, #markedRuns, 2 do + local markedValue = markedRuns[i] + local score = markedRuns[i + 1] + + -- Parse orgId:messageId + local colonIndex = string.find(markedValue, ':') + if colonIndex then + local orgId = string.sub(markedValue, 1, colonIndex - 1) + local messageId = string.sub(markedValue, colonIndex + 1) + + table.insert(results, orgId) + table.insert(results, messageId) + table.insert(results, score) + end +end + +-- Remove the processed items +if #results > 0 then + local itemsToRemove = {} + for i = 1, #markedRuns, 2 do + table.insert(itemsToRemove, markedRuns[i]) + end + redis.call('ZREM', markedForAckKey, unpack(itemsToRemove)) +end + +return results + `, + }); + } +} + +function safeJsonParse(rawMessage: string): unknown { + try { + return JSON.parse(rawMessage); + } catch (e) { + return undefined; } } @@ -2145,5 +2528,11 @@ declare module "@internal/redis" { keyPrefix: string, ...queueNames: string[] ): Result; + + getMarkedRunsForAck( + markedForAckKey: string, + maxCount: string, + callback?: Callback + ): Result; } } diff --git a/internal-packages/run-engine/src/run-queue/keyProducer.ts b/internal-packages/run-engine/src/run-queue/keyProducer.ts index 49e165ad90..d2ac500f7a 100644 --- a/internal-packages/run-engine/src/run-queue/keyProducer.ts +++ b/internal-packages/run-engine/src/run-queue/keyProducer.ts @@ -219,9 +219,16 @@ export class RunQueueFullKeyProducer implements RunQueueKeyProducer { } } deadLetterQueueKeyFromQueue(queue: string): string { - const descriptor = this.descriptorFromQueue(queue); + const { orgId, projectId, envId } = this.descriptorFromQueue(queue); + return this.deadLetterQueueKey({ orgId, projectId, envId }); + } + + markedForAckKey(): string { + return "markedForAck"; + } - return this.deadLetterQueueKey(descriptor); + currentConcurrencySetKeyScanPattern(): string { + return `*:${constants.ENV_PART}:*:queue:*:${constants.CURRENT_CONCURRENCY_PART}`; } descriptorFromQueue(queue: string): QueueDescriptor { diff --git a/internal-packages/run-engine/src/run-queue/tests/concurrencySweeper.test.ts b/internal-packages/run-engine/src/run-queue/tests/concurrencySweeper.test.ts new file mode 100644 index 0000000000..97fef5303d --- /dev/null +++ b/internal-packages/run-engine/src/run-queue/tests/concurrencySweeper.test.ts @@ -0,0 +1,171 @@ +import { redisTest } from "@internal/testcontainers"; +import { trace } from "@internal/tracing"; +import { Logger } from "@trigger.dev/core/logger"; +import { describe } from "node:test"; +import { setTimeout } from "node:timers/promises"; +import { FairQueueSelectionStrategy } from "../fairQueueSelectionStrategy.js"; +import { RunQueue } from "../index.js"; +import { RunQueueFullKeyProducer } from "../keyProducer.js"; +import { InputPayload } from "../types.js"; + +const testOptions = { + name: "rq", + tracer: trace.getTracer("rq"), + workers: 1, + defaultEnvConcurrency: 25, + logger: new Logger("RunQueue", "warn"), + retryOptions: { + maxAttempts: 5, + factor: 1.1, + minTimeoutInMs: 100, + maxTimeoutInMs: 1_000, + randomize: true, + }, + keys: new RunQueueFullKeyProducer(), +}; + +const authenticatedEnvDev = { + id: "e1234", + type: "DEVELOPMENT" as const, + maximumConcurrencyLimit: 10, + project: { id: "p1234" }, + organization: { id: "o1234" }, +}; + +const messageDev: InputPayload = { + runId: "r4321", + taskIdentifier: "task/my-task", + orgId: "o1234", + projectId: "p1234", + environmentId: "e4321", + environmentType: "DEVELOPMENT", + queue: "task/my-task", + timestamp: Date.now(), + attempt: 0, +}; + +const messageDev2: InputPayload = { + ...messageDev, + runId: "r4322", +}; + +vi.setConfig({ testTimeout: 60_000 }); + +describe("RunQueue Concurrency Sweeper", () => { + redisTest( + "should process queue current concurrency sets and mark runs for ack if they are completed", + async ({ redisContainer }) => { + let enableConcurrencySweeper = false; + + const queue = new RunQueue({ + ...testOptions, + queueSelectionStrategy: new FairQueueSelectionStrategy({ + redis: { + keyPrefix: "runqueue:test:", + host: redisContainer.getHost(), + port: redisContainer.getPort(), + }, + keys: testOptions.keys, + }), + redis: { + keyPrefix: "runqueue:test:", + host: redisContainer.getHost(), + port: redisContainer.getPort(), + }, + concurrencySweeper: { + scanSchedule: "* * * * * *", // Every second + scanJitter: 5, + processMarkedSchedule: "* * * * * *", // Every second + processMarkedJitter: 5, + callback: async (runIds) => { + if (!enableConcurrencySweeper) { + return []; + } + + return [{ id: messageDev.runId, orgId: "o1234" }]; + }, + }, + }); + + try { + //enqueue message + const enqueueResult = await queue.enqueueMessage({ + env: authenticatedEnvDev, + message: messageDev, + workerQueue: authenticatedEnvDev.id, + }); + + expect(enqueueResult).toBe(undefined); + + const enqueueResult2 = await queue.enqueueMessage({ + env: authenticatedEnvDev, + message: messageDev2, + workerQueue: authenticatedEnvDev.id, + }); + + expect(enqueueResult2).toBe(undefined); + + //queue length + const result2 = await queue.lengthOfQueue(authenticatedEnvDev, messageDev.queue); + expect(result2).toBe(2); + + const envQueueLength2 = await queue.lengthOfEnvQueue(authenticatedEnvDev); + expect(envQueueLength2).toBe(2); + + //concurrencies + const queueConcurrency = await queue.currentConcurrencyOfQueue( + authenticatedEnvDev, + messageDev.queue + ); + expect(queueConcurrency).toBe(0); + + const envConcurrency = await queue.currentConcurrencyOfEnvironment(authenticatedEnvDev); + expect(envConcurrency).toBe(0); + + //dequeue message + const dequeued = await queue.dequeueMessageFromWorkerQueue( + "test_12345", + authenticatedEnvDev.id + ); + + expect(dequeued).toBeDefined(); + expect(dequeued?.messageId).toEqual(messageDev.runId); + + const dequeued2 = await queue.dequeueMessageFromWorkerQueue( + "test_12345", + authenticatedEnvDev.id + ); + expect(dequeued2).toBeDefined(); + expect(dequeued2?.messageId).toEqual(messageDev2.runId); + + // queue concurrency should be 2 + const queueConcurrency2 = await queue.currentConcurrencyOfQueue( + authenticatedEnvDev, + messageDev.queue + ); + expect(queueConcurrency2).toBe(2); + + // env concurrency should be 2 + const envConcurrency2 = await queue.currentConcurrencyOfEnvironment(authenticatedEnvDev); + expect(envConcurrency2).toBe(2); + + enableConcurrencySweeper = true; + + await setTimeout(3_000); // Now a run is "completed" and should be removed from the concurrency set + + // queue concurrency should be 0 + const queueConcurrency3 = await queue.currentConcurrencyOfQueue( + authenticatedEnvDev, + messageDev.queue + ); + expect(queueConcurrency3).toBe(1); + + // env concurrency should be 1 + const envConcurrency3 = await queue.currentConcurrencyOfEnvironment(authenticatedEnvDev); + expect(envConcurrency3).toBe(1); + } finally { + await queue.quit(); + } + } + ); +}); diff --git a/internal-packages/run-engine/src/run-queue/types.ts b/internal-packages/run-engine/src/run-queue/types.ts index 68431f4ede..38420ce897 100644 --- a/internal-packages/run-engine/src/run-queue/types.ts +++ b/internal-packages/run-engine/src/run-queue/types.ts @@ -96,6 +96,10 @@ export interface RunQueueKeyProducer { deadLetterQueueKey(env: MinimalAuthenticatedEnvironment): string; deadLetterQueueKey(env: EnvDescriptor): string; deadLetterQueueKeyFromQueue(queue: string): string; + + // Concurrency sweeper methods + markedForAckKey(): string; + currentConcurrencySetKeyScanPattern(): string; } export type EnvQueues = { diff --git a/packages/redis-worker/package.json b/packages/redis-worker/package.json index e572b095f4..2f64432e13 100644 --- a/packages/redis-worker/package.json +++ b/packages/redis-worker/package.json @@ -27,7 +27,8 @@ "lodash.omit": "^4.5.0", "nanoid": "^5.0.7", "p-limit": "^6.2.0", - "zod": "3.23.8" + "zod": "3.23.8", + "cron-parser": "^4.9.0" }, "devDependencies": { "@internal/redis": "workspace:*", diff --git a/packages/redis-worker/src/cron.test.ts b/packages/redis-worker/src/cron.test.ts new file mode 100644 index 0000000000..9a607234a4 --- /dev/null +++ b/packages/redis-worker/src/cron.test.ts @@ -0,0 +1,130 @@ +import { redisTest } from "@internal/testcontainers"; +import { Logger } from "@trigger.dev/core/logger"; +import { describe } from "node:test"; +import { expect } from "vitest"; +import { Worker, CronSchema } from "./worker.js"; +import { setTimeout } from "node:timers/promises"; + +describe("Worker with cron", () => { + redisTest( + "process items on the cron schedule", + { timeout: 180_000 }, + async ({ redisContainer }) => { + const processedItems: CronSchema[] = []; + const worker = new Worker({ + name: "test-worker", + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + catalog: { + cronJob: { + cron: "*/5 * * * * *", // Every 5 seconds + schema: CronSchema, + visibilityTimeoutMs: 5000, + retry: { maxAttempts: 3 }, + jitterInMs: 100, + }, + }, + jobs: { + cronJob: async ({ payload }) => { + await setTimeout(30); // Simulate work + processedItems.push(payload); + }, + }, + concurrency: { + workers: 2, + tasksPerWorker: 3, + }, + logger: new Logger("test", "debug"), + }).start(); + + await setTimeout(6_000); + + expect(processedItems.length).toBe(1); + + const firstItem = processedItems[0]; + + expect(firstItem?.timestamp).toBeGreaterThan(0); + expect(firstItem?.lastTimestamp).toBeUndefined(); + expect(firstItem?.cron).toBe("*/5 * * * * *"); + + await setTimeout(6_000); + + expect(processedItems.length).toBeGreaterThanOrEqual(2); + + const secondItem = processedItems[1]; + expect(secondItem?.timestamp).toBeGreaterThan(firstItem!.timestamp); + expect(secondItem?.lastTimestamp).toBe(firstItem?.timestamp); + expect(secondItem?.cron).toBe("*/5 * * * * *"); + + await worker.stop(); + } + ); + + redisTest( + "continues processing cron items even when job handler throws errors", + { timeout: 180_000 }, + async ({ redisContainer }) => { + const processedItems: CronSchema[] = []; + let executionCount = 0; + + const worker = new Worker({ + name: "test-worker-error", + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + catalog: { + cronJob: { + cron: "*/3 * * * * *", // Every 3 seconds + schema: CronSchema, + visibilityTimeoutMs: 5000, + retry: { maxAttempts: 1 }, // Only try once to fail faster + jitterInMs: 100, + }, + }, + jobs: { + cronJob: async ({ payload }) => { + executionCount++; + await setTimeout(30); // Simulate work + + // Throw error on first and third execution + if (executionCount === 1 || executionCount === 3) { + throw new Error(`Simulated error on execution ${executionCount}`); + } + + processedItems.push(payload); + }, + }, + concurrency: { + workers: 2, + tasksPerWorker: 3, + }, + logger: new Logger("test", "debug"), + }).start(); + + // Wait long enough for 4 executions (12 seconds + buffer) + await setTimeout(14_000); + + // Should have at least 4 executions total + expect(executionCount).toBeGreaterThanOrEqual(4); + + // Should have 2 successful items (executions 2 and 4) + expect(processedItems.length).toBeGreaterThanOrEqual(2); + + // Verify that some executions failed (execution count > successful count) + // This proves that errors occurred but cron scheduling continued + expect(executionCount).toBeGreaterThan(processedItems.length); + + // Verify that successful executions still have correct structure + const firstSuccessful = processedItems[0]; + expect(firstSuccessful?.timestamp).toBeGreaterThan(0); + expect(firstSuccessful?.cron).toBe("*/3 * * * * *"); + + await worker.stop(); + } + ); +}); diff --git a/packages/redis-worker/src/worker.ts b/packages/redis-worker/src/worker.ts index 69a1bbf8b8..3542f305fe 100644 --- a/packages/redis-worker/src/worker.ts +++ b/packages/redis-worker/src/worker.ts @@ -19,15 +19,28 @@ import { nanoid } from "nanoid"; import pLimit from "p-limit"; import { z } from "zod"; import { AnyQueueItem, SimpleQueue } from "./queue.js"; +import { parseExpression } from "cron-parser"; + +export const CronSchema = z.object({ + cron: z.string(), + lastTimestamp: z.number().optional(), + timestamp: z.number(), +}); + +export type CronSchema = z.infer; export type WorkerCatalog = { [key: string]: { schema: z.ZodFirstPartySchemaTypes | z.ZodDiscriminatedUnion; visibilityTimeoutMs: number; retry?: RetryOptions; + cron?: string; + jitterInMs?: number; }; }; +type WorkerCatalogItem = WorkerCatalog[keyof WorkerCatalog]; + type QueueCatalogFromWorkerCatalog = { [K in keyof Catalog]: Catalog[K]["schema"]; }; @@ -204,6 +217,12 @@ class Worker { public start() { const { workers, tasksPerWorker } = this.concurrency; + this.logger.info("Starting worker", { + workers, + tasksPerWorker, + concurrency: this.concurrency, + }); + // Launch a number of "worker loops" on the main thread. for (let i = 0; i < workers; i++) { this.workerLoops.push(this.runWorkerLoop(`worker-${nanoid(12)}`, tasksPerWorker, i, workers)); @@ -219,7 +238,9 @@ class Worker { }); }, }); + this.setupSubscriber(); + this.setupCron(); return this; } @@ -496,6 +517,11 @@ class Worker { return; } + if (!catalogItem) { + this.logger.error(`No catalog item found for job type: ${job}`); + return; + } + await startSpan( this.tracer, "processItem", @@ -513,6 +539,10 @@ class Worker { // On success, acknowledge the item. await this.queue.ack(id, deduplicationKey); + + if (catalogItem.cron) { + await this.rescheduleCronJob(job, catalogItem, item); + } }, { kind: SpanKind.CONSUMER, @@ -560,7 +590,13 @@ class Worker { attempt: newAttempt, errorMessage, }); + await this.queue.moveToDeadLetterQueue(id, errorMessage); + + if (catalogItem.cron) { + await this.rescheduleCronJob(job, catalogItem, item); + } + return; } @@ -622,6 +658,113 @@ class Worker { return new Promise((resolve) => setTimeout(resolve, ms)); } + private setupCron() { + const cronJobs = Object.entries(this.options.catalog).filter(([_, value]) => value.cron); + + if (cronJobs.length === 0) { + return; + } + + this.logger.info("Setting up cron jobs", { + cronJobs: cronJobs.map(([job, value]) => ({ + job, + cron: value.cron, + jitterInMs: value.jitterInMs, + })), + }); + + // For each cron job, we need to try and enqueue a job with the next timestamp of the cron job. + const enqueuePromises = cronJobs.map(([job, value]) => + this.enqueueCronJob(value.cron!, job, value.jitterInMs) + ); + + Promise.allSettled(enqueuePromises).then((results) => { + results.forEach((result) => { + if (result.status === "fulfilled") { + this.logger.info("Enqueued cron job", { result: result.value }); + } else { + this.logger.error("Failed to enqueue cron job", { reason: result.reason }); + } + }); + }); + } + + private async enqueueCronJob(cron: string, job: string, jitter?: number, lastTimestamp?: Date) { + const scheduledAt = this.calculateNextScheduledAt(cron, lastTimestamp); + const identifier = [job, this.timestampIdentifier(scheduledAt)].join(":"); + // Calculate the availableAt date by calculating a random number between -jitter/2 and jitter/2 and adding it to the scheduledAt + const availableAt = jitter + ? new Date(scheduledAt.getTime() + Math.random() * jitter - jitter / 2) + : scheduledAt; + + const enqueued = await this.enqueueOnce({ + id: identifier, + job, + payload: { + timestamp: scheduledAt.getTime(), + lastTimestamp: lastTimestamp?.getTime(), + cron, + }, + availableAt, + }); + + this.logger.info("Enqueued cron job", { + identifier, + cron, + job, + scheduledAt, + enqueued, + availableAt, + }); + + return { + identifier, + cron, + job, + scheduledAt, + enqueued, + }; + } + + private async rescheduleCronJob(job: string, catalogItem: WorkerCatalogItem, item: CronSchema) { + if (!catalogItem.cron) { + return; + } + + return this.enqueueCronJob( + catalogItem.cron, + job, + catalogItem.jitterInMs, + new Date(item.timestamp) + ); + } + + private calculateNextScheduledAt(cron: string, lastTimestamp?: Date): Date { + const scheduledAt = parseExpression(cron, { + currentDate: lastTimestamp, + }) + .next() + .toDate(); + + // If scheduledAt is in the past, we should just calculate the next one based on the current time + if (scheduledAt < new Date()) { + return this.calculateNextScheduledAt(cron); + } + + return scheduledAt; + } + + private timestampIdentifier(timestamp: Date) { + const year = timestamp.getUTCFullYear(); + const month = timestamp.getUTCMonth(); + const day = timestamp.getUTCDate(); + const hour = timestamp.getUTCHours(); + const minute = timestamp.getUTCMinutes(); + const second = timestamp.getUTCSeconds(); + + return `${year}-${month}-${day}-${hour}-${minute}-${second}`; + } + private setupSubscriber() { const channel = `${this.options.name}:redrive`; this.subscriber?.subscribe(channel, (err) => { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index be39db8422..0963682054 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1660,6 +1660,9 @@ importers: '@trigger.dev/core': specifier: workspace:4.0.0-v4-beta.21 version: link:../core + cron-parser: + specifier: ^4.9.0 + version: 4.9.0 lodash.omit: specifier: ^4.5.0 version: 4.5.0 diff --git a/references/hello-world/trigger.config.ts b/references/hello-world/trigger.config.ts index fde795ccce..c3c6aea9e4 100644 --- a/references/hello-world/trigger.config.ts +++ b/references/hello-world/trigger.config.ts @@ -5,7 +5,7 @@ export default defineConfig({ compatibilityFlags: ["run_engine_v2"], project: "proj_rrkpdguyagvsoktglnod", logLevel: "log", - maxDuration: 60, + maxDuration: 3600, retries: { enabledInDev: true, default: {