diff --git a/apps/webapp/test/engine/triggerTask.test.ts b/apps/webapp/test/engine/triggerTask.test.ts index ad16be44e7..df6f68184e 100644 --- a/apps/webapp/test/engine/triggerTask.test.ts +++ b/apps/webapp/test/engine/triggerTask.test.ts @@ -474,4 +474,91 @@ describe("RunEngineTriggerTaskService", () => { await engine.quit(); } ); + + containerTest("should pass concurrencyKey through to run ctx", async ({ prisma, redisOptions }) => { + const engine = new RunEngine({ + prisma, + worker: { + redis: redisOptions, + workers: 1, + tasksPerWorker: 10, + pollIntervalMs: 100, + }, + queue: { + redis: redisOptions, + }, + runLock: { + redis: redisOptions, + }, + machines: { + defaultMachine: "small-1x", + machines: { + "small-1x": { + name: "small-1x" as const, + cpu: 0.5, + memory: 0.5, + centsPerMs: 0.0001, + }, + }, + baseCostInCents: 0.0005, + }, + tracer: trace.getTracer("test", "0.0.0"), + }); + + const authenticatedEnvironment = await setupAuthenticatedEnvironment(prisma, "PRODUCTION"); + const taskIdentifier = "test-task"; + await setupBackgroundWorker(engine, authenticatedEnvironment, taskIdentifier); + const queuesManager = new DefaultQueueManager(prisma, engine); + const idempotencyKeyConcern = new IdempotencyKeyConcern( + prisma, + engine, + new MockTraceEventConcern() + ); + const triggerTaskService = new RunEngineTriggerTaskService({ + engine, + prisma, + runNumberIncrementer: new MockRunNumberIncrementer(), + payloadProcessor: new MockPayloadProcessor(), + queueConcern: queuesManager, + idempotencyKeyConcern, + validator: new MockTriggerTaskValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024 * 1, // 1MB + }); + + const concurrencyKey = "user-123"; + const result = await triggerTaskService.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { + payload: { test: "test" }, + options: { + concurrencyKey, + }, + }, + }); + + expect(result).toBeDefined(); + expect(result?.run.friendlyId).toBeDefined(); + expect(result?.run.status).toBe("PENDING"); + expect(result?.isCached).toBe(false); + + const run = await prisma.taskRun.findUnique({ + where: { + id: result?.run.id, + }, + }); + + expect(run).toBeDefined(); + expect(run?.concurrencyKey).toBe(concurrencyKey); + + // Optionally, fetch the run context and check ctx.run.concurrencyKey + if (run) { + const ctx = await engine["runAttemptSystem"].resolveTaskRunContext(run.id); + expect(ctx.run.concurrencyKey).toBe(concurrencyKey); + } + + await engine.quit(); + }); }); diff --git a/internal-packages/run-engine/src/engine/systems/runAttemptSystem.ts b/internal-packages/run-engine/src/engine/systems/runAttemptSystem.ts index ce0f8abe4d..aa19b84996 100644 --- a/internal-packages/run-engine/src/engine/systems/runAttemptSystem.ts +++ b/internal-packages/run-engine/src/engine/systems/runAttemptSystem.ts @@ -197,6 +197,7 @@ export class RunAttemptSystem { traceContext: true, priorityMs: true, taskIdentifier: true, + concurrencyKey: true, runtimeEnvironment: { select: { id: true, @@ -261,6 +262,7 @@ export class RunAttemptSystem { priority: run.priorityMs === 0 ? undefined : run.priorityMs / 1_000, parentTaskRunId: run.parentTaskRunId ? RunId.toFriendlyId(run.parentTaskRunId) : undefined, rootTaskRunId: run.rootTaskRunId ? RunId.toFriendlyId(run.rootTaskRunId) : undefined, + concurrencyKey: run.concurrencyKey ?? undefined, }, attempt: { number: run.attemptNumber ?? 1, diff --git a/packages/core/src/v3/schemas/common.ts b/packages/core/src/v3/schemas/common.ts index 2928995606..f244ff2b43 100644 --- a/packages/core/src/v3/schemas/common.ts +++ b/packages/core/src/v3/schemas/common.ts @@ -222,10 +222,9 @@ export const TaskRun = z.object({ /** The priority of the run. Wih a value of 10 it will be dequeued before runs that were triggered 9 seconds before it (assuming they had no priority set). */ priority: z.number().optional(), baseCostInCents: z.number().optional(), - parentTaskRunId: z.string().optional(), rootTaskRunId: z.string().optional(), - + concurrencyKey: z.string().optional(), // These are only used during execution, not in run.ctx durationMs: z.number().optional(), costInCents: z.number().optional(),