Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/clean-beans-compete.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@trigger.dev/sdk": patch
---

New internal idempotency implementation for trigger and batch trigger to prevent request retries from duplicating work
38 changes: 38 additions & 0 deletions apps/webapp/app/env.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,44 @@ const EnvironmentSchema = z.object({
TRIGGER_CLI_TAG: z.string().default("latest"),

HEALTHCHECK_DATABASE_DISABLED: z.string().default("0"),

REQUEST_IDEMPOTENCY_REDIS_HOST: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_HOST),
REQUEST_IDEMPOTENCY_REDIS_READER_HOST: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_READER_HOST),
REQUEST_IDEMPOTENCY_REDIS_READER_PORT: z.coerce
.number()
.optional()
.transform(
(v) =>
v ?? (process.env.REDIS_READER_PORT ? parseInt(process.env.REDIS_READER_PORT) : undefined)
),
REQUEST_IDEMPOTENCY_REDIS_PORT: z.coerce
.number()
.optional()
.transform((v) => v ?? (process.env.REDIS_PORT ? parseInt(process.env.REDIS_PORT) : undefined)),
REQUEST_IDEMPOTENCY_REDIS_USERNAME: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_USERNAME),
REQUEST_IDEMPOTENCY_REDIS_PASSWORD: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_PASSWORD),
REQUEST_IDEMPOTENCY_REDIS_TLS_DISABLED: z
.string()
.default(process.env.REDIS_TLS_DISABLED ?? "false"),

REQUEST_IDEMPOTENCY_LOG_LEVEL: z.enum(["log", "error", "warn", "info", "debug"]).default("info"),

REQUEST_IDEMPOTENCY_TTL_IN_MS: z.coerce
.number()
.int()
.default(60_000 * 60 * 24),
});

export type Environment = z.infer<typeof EnvironmentSchema>;
Expand Down
43 changes: 37 additions & 6 deletions apps/webapp/app/routes/api.v1.tasks.$taskId.trigger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ import {
} from "@trigger.dev/core/v3";
import { TaskRun } from "@trigger.dev/database";
import { z } from "zod";
import { prisma } from "~/db.server";
import { env } from "~/env.server";
import { EngineServiceValidationError } from "~/runEngine/concerns/errors";
import {
ApiAuthenticationResultSuccess,
AuthenticatedEnvironment,
getOneTimeUseToken,
} from "~/services/apiAuth.server";
import { ApiAuthenticationResultSuccess, getOneTimeUseToken } from "~/services/apiAuth.server";
import { logger } from "~/services/logger.server";
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
import { resolveIdempotencyKeyTTL } from "~/utils/idempotencyKeys.server";
import {
handleRequestIdempotency,
saveRequestIdempotency,
} from "~/utils/requestIdempotency.server";
import { ServiceValidationError } from "~/v3/services/baseService.server";
import { OutOfEntitlementError, TriggerTaskService } from "~/v3/services/triggerTask.server";

Expand All @@ -31,6 +32,7 @@ export const HeadersSchema = z.object({
"x-trigger-worker": z.string().nullish(),
"x-trigger-client": z.string().nullish(),
"x-trigger-engine-version": RunEngineVersionSchema.nullish(),
"x-trigger-request-idempotency-key": z.string().nullish(),
traceparent: z.string().optional(),
tracestate: z.string().optional(),
});
Expand Down Expand Up @@ -60,8 +62,34 @@ const { action, loader } = createActionApiRoute(
"x-trigger-worker": isFromWorker,
"x-trigger-client": triggerClient,
"x-trigger-engine-version": engineVersion,
"x-trigger-request-idempotency-key": requestIdempotencyKey,
} = headers;

const cachedResponse = await handleRequestIdempotency(requestIdempotencyKey, {
requestType: "trigger",
findCachedEntity: async (cachedRequestId) => {
return await prisma.taskRun.findFirst({
where: {
id: cachedRequestId,
},
select: {
friendlyId: true,
},
});
},
buildResponse: (cachedRun) => ({
id: cachedRun.friendlyId,
isCached: false,
}),
buildResponseHeaders: async (responseBody, cachedEntity) => {
return await responseHeaders(cachedEntity, authentication, triggerClient);
},
});

if (cachedResponse) {
return cachedResponse;
}

const service = new TriggerTaskService();

try {
Expand Down Expand Up @@ -104,6 +132,8 @@ const { action, loader } = createActionApiRoute(
return json({ error: "Task not found" }, { status: 404 });
}

await saveRequestIdempotency(requestIdempotencyKey, "trigger", result.run.id);

const $responseHeaders = await responseHeaders(result.run, authentication, triggerClient);

return json(
Expand All @@ -113,6 +143,7 @@ const { action, loader } = createActionApiRoute(
},
{
headers: $responseHeaders,
status: 200,
}
);
} catch (error) {
Expand All @@ -132,7 +163,7 @@ const { action, loader } = createActionApiRoute(
);

async function responseHeaders(
run: TaskRun,
run: Pick<TaskRun, "friendlyId">,
authentication: ApiAuthenticationResultSuccess,
triggerClient?: string | null
): Promise<Record<string, string>> {
Expand Down
45 changes: 43 additions & 2 deletions apps/webapp/app/routes/api.v2.tasks.batch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@ import {
BatchTriggerTaskV3Response,
generateJWT,
} from "@trigger.dev/core/v3";
import { prisma } from "~/db.server";
import { env } from "~/env.server";
import { RunEngineBatchTriggerService } from "~/runEngine/services/batchTrigger.server";
import { AuthenticatedEnvironment, getOneTimeUseToken } from "~/services/apiAuth.server";
import { logger } from "~/services/logger.server";
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
import {
handleRequestIdempotency,
saveRequestIdempotency,
} from "~/utils/requestIdempotency.server";
import { ServiceValidationError } from "~/v3/services/baseService.server";
import { BatchProcessingStrategy } from "~/v3/services/batchTriggerV3.server";
import { OutOfEntitlementError } from "~/v3/services/triggerTask.server";
import { HeadersSchema } from "./api.v1.tasks.$taskId.trigger";
import { RunEngineBatchTriggerService } from "~/runEngine/services/batchTrigger.server";

const { action, loader } = createActionApiRoute(
{
Expand Down Expand Up @@ -53,6 +58,7 @@ const { action, loader } = createActionApiRoute(
"x-trigger-client": triggerClient,
"x-trigger-engine-version": engineVersion,
"batch-processing-strategy": batchProcessingStrategy,
"x-trigger-request-idempotency-key": requestIdempotencyKey,
traceparent,
tracestate,
} = headers;
Expand All @@ -67,15 +73,47 @@ const { action, loader } = createActionApiRoute(
traceparent,
tracestate,
batchProcessingStrategy,
requestIdempotencyKey,
});

const cachedResponse = await handleRequestIdempotency(requestIdempotencyKey, {
requestType: "batch-trigger",
findCachedEntity: async (cachedRequestId) => {
return await prisma.batchTaskRun.findFirst({
where: {
id: cachedRequestId,
runtimeEnvironmentId: authentication.environment.id,
},
select: {
friendlyId: true,
runCount: true,
},
});
},
buildResponse: (cachedBatch) => ({
id: cachedBatch.friendlyId,
runCount: cachedBatch.runCount,
}),
buildResponseHeaders: async (responseBody, cachedEntity) => {
return await responseHeaders(responseBody, authentication.environment, triggerClient);
},
});

if (cachedResponse) {
return cachedResponse;
}

const traceContext =
traceparent && isFromWorker // If the request is from a worker, we should pass the trace context
? { traceparent, tracestate }
: undefined;

const service = new RunEngineBatchTriggerService(batchProcessingStrategy ?? undefined);

service.onBatchTaskRunCreated.attachOnce(async (batch) => {
await saveRequestIdempotency(requestIdempotencyKey, "batch-trigger", batch.id);
});

try {
const batch = await service.call(authentication.environment, body, {
triggerVersion: triggerVersion ?? undefined,
Expand All @@ -90,7 +128,10 @@ const { action, loader } = createActionApiRoute(
triggerClient
);

return json(batch, { status: 202, headers: $responseHeaders });
return json(batch, {
status: 202,
headers: $responseHeaders,
});
} catch (error) {
logger.error("Batch trigger error", {
error: {
Expand Down
8 changes: 7 additions & 1 deletion apps/webapp/app/runEngine/services/batchTrigger.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import {
} from "@trigger.dev/core/v3";
import { BatchId, RunId } from "@trigger.dev/core/v3/isomorphic";
import { BatchTaskRun, Prisma } from "@trigger.dev/database";
import { Evt } from "evt";
import { z } from "zod";
import { $transaction, prisma, PrismaClientOrTransaction } from "~/db.server";
import { prisma, PrismaClientOrTransaction } from "~/db.server";
import { env } from "~/env.server";
import { AuthenticatedEnvironment } from "~/services/apiAuth.server";
import { logger } from "~/services/logger.server";
Expand Down Expand Up @@ -51,6 +52,7 @@ export type BatchTriggerTaskServiceOptions = {
*/
export class RunEngineBatchTriggerService extends WithRunEngine {
private _batchProcessingStrategy: BatchProcessingStrategy;
public onBatchTaskRunCreated: Evt<BatchTaskRun> = new Evt();

constructor(
batchProcessingStrategy?: BatchProcessingStrategy,
Expand Down Expand Up @@ -168,6 +170,8 @@ export class RunEngineBatchTriggerService extends WithRunEngine {
},
});

this.onBatchTaskRunCreated.post(batch);

if (body.parentRunId && body.resumeParentOnCompletion) {
await this._engine.blockRunWithCreatedBatch({
runId: RunId.fromFriendlyId(body.parentRunId),
Expand Down Expand Up @@ -259,6 +263,8 @@ export class RunEngineBatchTriggerService extends WithRunEngine {
},
});

this.onBatchTaskRunCreated.post(batch);

if (body.parentRunId && body.resumeParentOnCompletion) {
await this._engine.blockRunWithCreatedBatch({
runId: RunId.fromFriendlyId(body.parentRunId),
Expand Down
124 changes: 124 additions & 0 deletions apps/webapp/app/services/requestIdempotency.server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import { Logger, LogLevel } from "@trigger.dev/core/logger";
import { createCache, DefaultStatefulContext, Namespace, Cache as UnkeyCache } from "@unkey/cache";
import { MemoryStore } from "@unkey/cache/stores";
import { RedisCacheStore } from "./unkey/redisCacheStore.server";
import { RedisWithClusterOptions } from "~/redis.server";
import { validate as uuidValidate, version as uuidVersion } from "uuid";
import { startActiveSpan } from "~/v3/tracer.server";

export type RequestIdempotencyServiceOptions<TTypes extends string> = {
types: TTypes[];
redis: RedisWithClusterOptions;
logger?: Logger;
logLevel?: LogLevel;
ttlInMs?: number;
};

const DEFAULT_TTL_IN_MS = 60_000 * 60 * 24;

type RequestIdempotencyCacheEntry = {
id: string;
};

export class RequestIdempotencyService<TTypes extends string> {
private readonly logger: Logger;
private readonly cache: UnkeyCache<{ requests: RequestIdempotencyCacheEntry }>;

constructor(private readonly options: RequestIdempotencyServiceOptions<TTypes>) {
this.logger =
options.logger ?? new Logger("RequestIdempotencyService", options.logLevel ?? "info");

const keyPrefix = options.redis.keyPrefix
? `request-idempotency:${options.redis.keyPrefix}`
: "request-idempotency:";

const ctx = new DefaultStatefulContext();
const memory = new MemoryStore({ persistentMap: new Map() });
const redisCacheStore = new RedisCacheStore({
name: "request-idempotency",
connection: {
keyPrefix: keyPrefix,
...options.redis,
},
});

// This cache holds the rate limit configuration for each org, so we don't have to fetch it every request
const cache = createCache({
requests: new Namespace<RequestIdempotencyCacheEntry>(ctx, {
stores: [memory, redisCacheStore],
fresh: options.ttlInMs ?? DEFAULT_TTL_IN_MS,
stale: options.ttlInMs ?? DEFAULT_TTL_IN_MS,
}),
});

this.cache = cache;
}

async checkRequest(type: TTypes, requestIdempotencyKey: string) {
if (!this.#validateRequestId(requestIdempotencyKey)) {
this.logger.warn("RequestIdempotency: invalid requestIdempotencyKey", {
requestIdempotencyKey,
});

return undefined;
}

return startActiveSpan("RequestIdempotency.checkRequest()", async (span) => {
span.setAttribute("request_id", requestIdempotencyKey);
span.setAttribute("type", type);

const key = `${type}:${requestIdempotencyKey}`;
const result = await this.cache.requests.get(key);

this.logger.debug("RequestIdempotency: checking request", {
type,
requestIdempotencyKey,
key,
result,
});

return result.val ? result.val : undefined;
});
}

async saveRequest(
type: TTypes,
requestIdempotencyKey: string,
value: RequestIdempotencyCacheEntry
) {
if (!this.#validateRequestId(requestIdempotencyKey)) {
this.logger.warn("RequestIdempotency: invalid requestIdempotencyKey", {
requestIdempotencyKey,
});
return undefined;
}

const key = `${type}:${requestIdempotencyKey}`;
const result = await this.cache.requests.set(key, value);

if (result.err) {
this.logger.error("RequestIdempotency: error saving request", {
key,
error: result.err,
});
} else {
this.logger.debug("RequestIdempotency: saved request", {
type,
requestIdempotencyKey,
key,
value,
});
}

return result;
}

// The requestIdempotencyKey should be a valid UUID
#validateRequestId(requestIdempotencyKey: string): boolean {
return isValidV4UUID(requestIdempotencyKey);
}
}

function isValidV4UUID(uuid: string): boolean {
return uuidValidate(uuid) && uuidVersion(uuid) === 4;
}
Loading