Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/clean-beans-compete.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@trigger.dev/sdk": patch
---

New internal idempotency implementation for trigger and batch trigger to prevent request retries from duplicating work
38 changes: 38 additions & 0 deletions apps/webapp/app/env.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,44 @@ const EnvironmentSchema = z.object({
TRIGGER_CLI_TAG: z.string().default("latest"),

HEALTHCHECK_DATABASE_DISABLED: z.string().default("0"),

REQUEST_IDEMPOTENCY_REDIS_HOST: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_HOST),
REQUEST_IDEMPOTENCY_REDIS_READER_HOST: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_READER_HOST),
REQUEST_IDEMPOTENCY_REDIS_READER_PORT: z.coerce
.number()
.optional()
.transform(
(v) =>
v ?? (process.env.REDIS_READER_PORT ? parseInt(process.env.REDIS_READER_PORT) : undefined)
),
REQUEST_IDEMPOTENCY_REDIS_PORT: z.coerce
.number()
.optional()
.transform((v) => v ?? (process.env.REDIS_PORT ? parseInt(process.env.REDIS_PORT) : undefined)),
REQUEST_IDEMPOTENCY_REDIS_USERNAME: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_USERNAME),
REQUEST_IDEMPOTENCY_REDIS_PASSWORD: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_PASSWORD),
REQUEST_IDEMPOTENCY_REDIS_TLS_DISABLED: z
.string()
.default(process.env.REDIS_TLS_DISABLED ?? "false"),

REQUEST_IDEMPOTENCY_LOG_LEVEL: z.enum(["log", "error", "warn", "info", "debug"]).default("info"),

REQUEST_IDEMPOTENCY_TTL_IN_MS: z.coerce
.number()
.int()
.default(60_000 * 60 * 24),
});

export type Environment = z.infer<typeof EnvironmentSchema>;
Expand Down
43 changes: 37 additions & 6 deletions apps/webapp/app/routes/api.v1.tasks.$taskId.trigger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ import {
} from "@trigger.dev/core/v3";
import { TaskRun } from "@trigger.dev/database";
import { z } from "zod";
import { prisma } from "~/db.server";
import { env } from "~/env.server";
import { EngineServiceValidationError } from "~/runEngine/concerns/errors";
import {
ApiAuthenticationResultSuccess,
AuthenticatedEnvironment,
getOneTimeUseToken,
} from "~/services/apiAuth.server";
import { ApiAuthenticationResultSuccess, getOneTimeUseToken } from "~/services/apiAuth.server";
import { logger } from "~/services/logger.server";
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
import { resolveIdempotencyKeyTTL } from "~/utils/idempotencyKeys.server";
import {
handleRequestIdempotency,
saveRequestIdempotency,
} from "~/utils/requestIdempotency.server";
import { ServiceValidationError } from "~/v3/services/baseService.server";
import { OutOfEntitlementError, TriggerTaskService } from "~/v3/services/triggerTask.server";

Expand All @@ -31,6 +32,7 @@ export const HeadersSchema = z.object({
"x-trigger-worker": z.string().nullish(),
"x-trigger-client": z.string().nullish(),
"x-trigger-engine-version": RunEngineVersionSchema.nullish(),
"x-trigger-request-idempotency-key": z.string().nullish(),
traceparent: z.string().optional(),
tracestate: z.string().optional(),
});
Expand Down Expand Up @@ -60,8 +62,34 @@ const { action, loader } = createActionApiRoute(
"x-trigger-worker": isFromWorker,
"x-trigger-client": triggerClient,
"x-trigger-engine-version": engineVersion,
"x-trigger-request-idempotency-key": requestIdempotencyKey,
} = headers;

const cachedResponse = await handleRequestIdempotency(requestIdempotencyKey, {
requestType: "trigger",
findCachedEntity: async (cachedRequestId) => {
return await prisma.taskRun.findFirst({
where: {
id: cachedRequestId,
},
select: {
friendlyId: true,
},
});
},
buildResponse: (cachedRun) => ({
id: cachedRun.friendlyId,
isCached: false,
}),
buildResponseHeaders: async (responseBody, cachedEntity) => {
return await responseHeaders(cachedEntity, authentication, triggerClient);
},
});

if (cachedResponse) {
return cachedResponse;
}

const service = new TriggerTaskService();

try {
Expand Down Expand Up @@ -104,6 +132,8 @@ const { action, loader } = createActionApiRoute(
return json({ error: "Task not found" }, { status: 404 });
}

await saveRequestIdempotency(requestIdempotencyKey, "trigger", result.run.id);

const $responseHeaders = await responseHeaders(result.run, authentication, triggerClient);

return json(
Expand All @@ -113,6 +143,7 @@ const { action, loader } = createActionApiRoute(
},
{
headers: $responseHeaders,
status: 200,
}
);
} catch (error) {
Expand All @@ -132,7 +163,7 @@ const { action, loader } = createActionApiRoute(
);

async function responseHeaders(
run: TaskRun,
run: Pick<TaskRun, "friendlyId">,
authentication: ApiAuthenticationResultSuccess,
triggerClient?: string | null
): Promise<Record<string, string>> {
Expand Down
45 changes: 43 additions & 2 deletions apps/webapp/app/routes/api.v2.tasks.batch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@ import {
BatchTriggerTaskV3Response,
generateJWT,
} from "@trigger.dev/core/v3";
import { prisma } from "~/db.server";
import { env } from "~/env.server";
import { RunEngineBatchTriggerService } from "~/runEngine/services/batchTrigger.server";
import { AuthenticatedEnvironment, getOneTimeUseToken } from "~/services/apiAuth.server";
import { logger } from "~/services/logger.server";
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
import {
handleRequestIdempotency,
saveRequestIdempotency,
} from "~/utils/requestIdempotency.server";
import { ServiceValidationError } from "~/v3/services/baseService.server";
import { BatchProcessingStrategy } from "~/v3/services/batchTriggerV3.server";
import { OutOfEntitlementError } from "~/v3/services/triggerTask.server";
import { HeadersSchema } from "./api.v1.tasks.$taskId.trigger";
import { RunEngineBatchTriggerService } from "~/runEngine/services/batchTrigger.server";

const { action, loader } = createActionApiRoute(
{
Expand Down Expand Up @@ -53,6 +58,7 @@ const { action, loader } = createActionApiRoute(
"x-trigger-client": triggerClient,
"x-trigger-engine-version": engineVersion,
"batch-processing-strategy": batchProcessingStrategy,
"x-trigger-request-idempotency-key": requestIdempotencyKey,
traceparent,
tracestate,
} = headers;
Expand All @@ -67,15 +73,47 @@ const { action, loader } = createActionApiRoute(
traceparent,
tracestate,
batchProcessingStrategy,
requestIdempotencyKey,
});

const cachedResponse = await handleRequestIdempotency(requestIdempotencyKey, {
requestType: "batch-trigger",
findCachedEntity: async (cachedRequestId) => {
return await prisma.batchTaskRun.findFirst({
where: {
id: cachedRequestId,
runtimeEnvironmentId: authentication.environment.id,
},
select: {
friendlyId: true,
runCount: true,
},
});
},
buildResponse: (cachedBatch) => ({
id: cachedBatch.friendlyId,
runCount: cachedBatch.runCount,
}),
buildResponseHeaders: async (responseBody, cachedEntity) => {
return await responseHeaders(responseBody, authentication.environment, triggerClient);
},
});

if (cachedResponse) {
return cachedResponse;
}

const traceContext =
traceparent && isFromWorker // If the request is from a worker, we should pass the trace context
? { traceparent, tracestate }
: undefined;

const service = new RunEngineBatchTriggerService(batchProcessingStrategy ?? undefined);

service.onBatchTaskRunCreated.attachOnce(async (batch) => {
await saveRequestIdempotency(requestIdempotencyKey, "batch-trigger", batch.id);
});

try {
const batch = await service.call(authentication.environment, body, {
triggerVersion: triggerVersion ?? undefined,
Expand All @@ -90,7 +128,10 @@ const { action, loader } = createActionApiRoute(
triggerClient
);

return json(batch, { status: 202, headers: $responseHeaders });
return json(batch, {
status: 202,
headers: $responseHeaders,
});
} catch (error) {
logger.error("Batch trigger error", {
error: {
Expand Down
8 changes: 7 additions & 1 deletion apps/webapp/app/runEngine/services/batchTrigger.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import {
} from "@trigger.dev/core/v3";
import { BatchId, RunId } from "@trigger.dev/core/v3/isomorphic";
import { BatchTaskRun, Prisma } from "@trigger.dev/database";
import { Evt } from "evt";
import { z } from "zod";
import { $transaction, prisma, PrismaClientOrTransaction } from "~/db.server";
import { prisma, PrismaClientOrTransaction } from "~/db.server";
import { env } from "~/env.server";
import { AuthenticatedEnvironment } from "~/services/apiAuth.server";
import { logger } from "~/services/logger.server";
Expand Down Expand Up @@ -51,6 +52,7 @@ export type BatchTriggerTaskServiceOptions = {
*/
export class RunEngineBatchTriggerService extends WithRunEngine {
private _batchProcessingStrategy: BatchProcessingStrategy;
public onBatchTaskRunCreated: Evt<BatchTaskRun> = new Evt();

constructor(
batchProcessingStrategy?: BatchProcessingStrategy,
Expand Down Expand Up @@ -168,6 +170,8 @@ export class RunEngineBatchTriggerService extends WithRunEngine {
},
});

this.onBatchTaskRunCreated.post(batch);

if (body.parentRunId && body.resumeParentOnCompletion) {
await this._engine.blockRunWithCreatedBatch({
runId: RunId.fromFriendlyId(body.parentRunId),
Expand Down Expand Up @@ -259,6 +263,8 @@ export class RunEngineBatchTriggerService extends WithRunEngine {
},
});

this.onBatchTaskRunCreated.post(batch);

if (body.parentRunId && body.resumeParentOnCompletion) {
await this._engine.blockRunWithCreatedBatch({
runId: RunId.fromFriendlyId(body.parentRunId),
Expand Down
124 changes: 124 additions & 0 deletions apps/webapp/app/services/requestIdempotency.server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import { Logger, LogLevel } from "@trigger.dev/core/logger";
import { createCache, DefaultStatefulContext, Namespace, Cache as UnkeyCache } from "@unkey/cache";
import { MemoryStore } from "@unkey/cache/stores";
import { RedisCacheStore } from "./unkey/redisCacheStore.server";
import { RedisWithClusterOptions } from "~/redis.server";
import { validate as uuidValidate, version as uuidVersion } from "uuid";
import { startActiveSpan } from "~/v3/tracer.server";

export type RequestIdempotencyServiceOptions<TTypes extends string> = {
types: TTypes[];
redis: RedisWithClusterOptions;
logger?: Logger;
logLevel?: LogLevel;
ttlInMs?: number;
};

const DEFAULT_TTL_IN_MS = 60_000 * 60 * 24;

type RequestIdempotencyCacheEntry = {
id: string;
};

export class RequestIdempotencyService<TTypes extends string> {
private readonly logger: Logger;
private readonly cache: UnkeyCache<{ requests: RequestIdempotencyCacheEntry }>;

constructor(private readonly options: RequestIdempotencyServiceOptions<TTypes>) {
this.logger =
options.logger ?? new Logger("RequestIdempotencyService", options.logLevel ?? "info");

const keyPrefix = options.redis.keyPrefix
? `request-idempotency:${options.redis.keyPrefix}`
: "request-idempotency:";

const ctx = new DefaultStatefulContext();
const memory = new MemoryStore({ persistentMap: new Map() });
const redisCacheStore = new RedisCacheStore({
name: "request-idempotency",
connection: {
keyPrefix: keyPrefix,
...options.redis,
},
});

// This cache holds the rate limit configuration for each org, so we don't have to fetch it every request
const cache = createCache({
requests: new Namespace<RequestIdempotencyCacheEntry>(ctx, {
stores: [memory, redisCacheStore],
fresh: options.ttlInMs ?? DEFAULT_TTL_IN_MS,
stale: options.ttlInMs ?? DEFAULT_TTL_IN_MS,
}),
});

this.cache = cache;
}

async checkRequest(type: TTypes, requestIdempotencyKey: string) {
if (!this.#validateRequestId(requestIdempotencyKey)) {
this.logger.warn("RequestIdempotency: invalid requestIdempotencyKey", {
requestIdempotencyKey,
});

return undefined;
}

return startActiveSpan("RequestIdempotency.checkRequest()", async (span) => {
span.setAttribute("request_id", requestIdempotencyKey);
span.setAttribute("type", type);

const key = `${type}:${requestIdempotencyKey}`;
const result = await this.cache.requests.get(key);

this.logger.debug("RequestIdempotency: checking request", {
type,
requestIdempotencyKey,
key,
result,
});

return result.val ? result.val : undefined;
});
}

async saveRequest(
type: TTypes,
requestIdempotencyKey: string,
value: RequestIdempotencyCacheEntry
) {
if (!this.#validateRequestId(requestIdempotencyKey)) {
this.logger.warn("RequestIdempotency: invalid requestIdempotencyKey", {
requestIdempotencyKey,
});
return undefined;
}

const key = `${type}:${requestIdempotencyKey}`;
const result = await this.cache.requests.set(key, value);

if (result.err) {
this.logger.error("RequestIdempotency: error saving request", {
key,
error: result.err,
});
} else {
this.logger.debug("RequestIdempotency: saved request", {
type,
requestIdempotencyKey,
key,
value,
});
}

return result;
}

// The requestIdempotencyKey should be a valid UUID
#validateRequestId(requestIdempotencyKey: string): boolean {
return isValidV4UUID(requestIdempotencyKey);
}
}

function isValidV4UUID(uuid: string): boolean {
return uuidValidate(uuid) && uuidVersion(uuid) === 4;
}
Loading