diff --git a/.changeset/gold-melons-fetch.md b/.changeset/gold-melons-fetch.md new file mode 100644 index 0000000000..34662546fa --- /dev/null +++ b/.changeset/gold-melons-fetch.md @@ -0,0 +1,34 @@ +--- +"@trigger.dev/sdk": patch +"@trigger.dev/core": patch +--- + +Add support for specifying machine preset at trigger time. Works with any trigger function: + +```ts +// Same as usual, will use the machine preset on childTask, defaults to "small-1x" +await childTask.trigger({ message: "Hello, world!" }); + +// This will override the task's machine preset and any defaults. Works with all trigger functions. +await childTask.trigger({ message: "Hello, world!" }, { machine: "small-2x" }); +await childTask.triggerAndWait({ message: "Hello, world!" }, { machine: "small-2x" }); + +await childTask.batchTrigger([ + { payload: { message: "Hello, world!" }, options: { machine: "micro" } }, + { payload: { message: "Hello, world!" }, options: { machine: "large-1x" } }, +]); +await childTask.batchTriggerAndWait([ + { payload: { message: "Hello, world!" }, options: { machine: "micro" } }, + { payload: { message: "Hello, world!" }, options: { machine: "large-1x" } }, +]); + +await tasks.trigger( + "child", + { message: "Hello, world!" }, + { machine: "small-2x" } +); +await tasks.batchTrigger("child", [ + { payload: { message: "Hello, world!" }, options: { machine: "micro" } }, + { payload: { message: "Hello, world!" }, options: { machine: "large-1x" } }, +]); +``` diff --git a/apps/webapp/app/v3/machinePresets.server.ts b/apps/webapp/app/v3/machinePresets.server.ts index 120a235c54..23b8bfa8fa 100644 --- a/apps/webapp/app/v3/machinePresets.server.ts +++ b/apps/webapp/app/v3/machinePresets.server.ts @@ -31,6 +31,16 @@ export function machinePresetFromName(name: MachinePresetName): MachinePreset { }; } +export function machinePresetFromRun(run: { machinePreset: string | null }): MachinePreset | null { + const presetName = MachinePresetName.safeParse(run.machinePreset).data; + + if (!presetName) { + return null; + } + + return machinePresetFromName(presetName); +} + // Finds the smallest machine preset name that satisfies the given CPU and memory requirements function derivePresetNameFromValues(cpu: number, memory: number): MachinePresetName { for (const [name, preset] of Object.entries(machines)) { diff --git a/apps/webapp/app/v3/marqs/sharedQueueConsumer.server.ts b/apps/webapp/app/v3/marqs/sharedQueueConsumer.server.ts index b8cd96a6e3..39037903cd 100644 --- a/apps/webapp/app/v3/marqs/sharedQueueConsumer.server.ts +++ b/apps/webapp/app/v3/marqs/sharedQueueConsumer.server.ts @@ -43,7 +43,7 @@ import { RestoreCheckpointService } from "../services/restoreCheckpoint.server"; import { SEMINTATTRS_FORCE_RECORDING, tracer } from "../tracer.server"; import { generateJWTTokenForEnvironment } from "~/services/apiAuth.server"; import { EnvironmentVariable } from "../environmentVariables/repository"; -import { machinePresetFromConfig } from "../machinePresets.server"; +import { machinePresetFromConfig, machinePresetFromRun } from "../machinePresets.server"; import { env } from "~/env.server"; import { FINAL_ATTEMPT_STATUSES, @@ -413,7 +413,9 @@ export class SharedQueueConsumer { cliVersion: deployment.worker.cliVersion, startedAt: existingTaskRun.startedAt ?? new Date(), baseCostInCents: env.CENTS_PER_RUN, - machinePreset: machinePresetFromConfig(backgroundTask.machineConfig ?? {}).name, + machinePreset: + existingTaskRun.machinePreset ?? + machinePresetFromConfig(backgroundTask.machineConfig ?? {}).name, maxDurationInSeconds: getMaxDuration( existingTaskRun.maxDurationInSeconds, backgroundTask.maxDurationInSeconds @@ -542,8 +544,9 @@ export class SharedQueueConsumer { // Retries for workers with disabled retry checkpoints will be handled just like normal attempts } else { - const machineConfig = lockedTaskRun.lockedBy?.machineConfig; - const machine = machinePresetFromConfig(machineConfig ?? {}); + const machine = + machinePresetFromRun(lockedTaskRun) ?? + machinePresetFromConfig(lockedTaskRun.lockedBy?.machineConfig ?? {}); await this._sender.send("BACKGROUND_WORKER_MESSAGE", { backgroundWorkerId: deployment.worker.friendlyId, @@ -1077,7 +1080,9 @@ class SharedQueueTasks { const { backgroundWorkerTask, taskRun, queue } = attempt; if (!machinePreset) { - machinePreset = machinePresetFromConfig(backgroundWorkerTask.machineConfig ?? {}); + machinePreset = + machinePresetFromRun(attempt.taskRun) ?? + machinePresetFromConfig(backgroundWorkerTask.machineConfig ?? {}); } const metadata = await parsePacket({ @@ -1294,9 +1299,13 @@ class SharedQueueTasks { }, }); } + const { backgroundWorkerTask, taskRun } = attempt; - const machinePreset = machinePresetFromConfig(backgroundWorkerTask.machineConfig ?? {}); + const machinePreset = + machinePresetFromRun(attempt.taskRun) ?? + machinePresetFromConfig(backgroundWorkerTask.machineConfig ?? {}); + const execution = await this._executionFromAttempt(attempt, machinePreset); const variables = await this.#buildEnvironmentVariables( attempt.runtimeEnvironment, @@ -1432,6 +1441,7 @@ class SharedQueueTasks { machineConfig: true, }, }, + machinePreset: true, }, }); @@ -1451,7 +1461,8 @@ class SharedQueueTasks { attemptCount, }); - const machinePreset = machinePresetFromConfig(run.lockedBy?.machineConfig ?? {}); + const machinePreset = + machinePresetFromRun(run) ?? machinePresetFromConfig(run.lockedBy?.machineConfig ?? {}); const variables = await this.#buildEnvironmentVariables(environment, run.id, machinePreset); diff --git a/apps/webapp/app/v3/services/createTaskRunAttempt.server.ts b/apps/webapp/app/v3/services/createTaskRunAttempt.server.ts index 5e7afc8e37..9bcc2b4f01 100644 --- a/apps/webapp/app/v3/services/createTaskRunAttempt.server.ts +++ b/apps/webapp/app/v3/services/createTaskRunAttempt.server.ts @@ -6,7 +6,7 @@ import { AuthenticatedEnvironment } from "~/services/apiAuth.server"; import { logger } from "~/services/logger.server"; import { reportInvocationUsage } from "~/services/platform.v3.server"; import { generateFriendlyId } from "../friendlyIdentifiers"; -import { machinePresetFromConfig } from "../machinePresets.server"; +import { machinePresetFromConfig, machinePresetFromRun } from "../machinePresets.server"; import { BaseService, ServiceValidationError } from "./baseService.server"; import { CrashTaskRunService } from "./crashTaskRun.server"; import { ExpireEnqueuedRunService } from "./expireEnqueuedRun.server"; @@ -173,7 +173,9 @@ export class CreateTaskRunAttemptService extends BaseService { }); } - const machinePreset = machinePresetFromConfig(taskRun.lockedBy.machineConfig ?? {}); + const machinePreset = + machinePresetFromRun(taskRun) ?? + machinePresetFromConfig(taskRun.lockedBy.machineConfig ?? {}); const metadata = await parsePacket({ data: taskRun.metadata ?? undefined, diff --git a/apps/webapp/app/v3/services/restoreCheckpoint.server.ts b/apps/webapp/app/v3/services/restoreCheckpoint.server.ts index 9a9b82210a..a04f46e626 100644 --- a/apps/webapp/app/v3/services/restoreCheckpoint.server.ts +++ b/apps/webapp/app/v3/services/restoreCheckpoint.server.ts @@ -1,7 +1,7 @@ import { type Checkpoint } from "@trigger.dev/database"; import { logger } from "~/services/logger.server"; import { socketIo } from "../handleSocketIo.server"; -import { machinePresetFromConfig } from "../machinePresets.server"; +import { machinePresetFromConfig, machinePresetFromRun } from "../machinePresets.server"; import { BaseService } from "./baseService.server"; import { CreateCheckpointRestoreEventService } from "./createCheckpointRestoreEvent.server"; import { isRestorableAttemptStatus, isRestorableRunStatus } from "../taskStatus"; @@ -24,6 +24,7 @@ export class RestoreCheckpointService extends BaseService { run: { select: { status: true, + machinePreset: true, }, }, attempt: { @@ -69,8 +70,9 @@ export class RestoreCheckpointService extends BaseService { return; } - const { machineConfig } = checkpoint.attempt.backgroundWorkerTask; - const machine = machinePresetFromConfig(machineConfig ?? {}); + const machine = + machinePresetFromRun(checkpoint.run) ?? + machinePresetFromConfig(checkpoint.attempt.backgroundWorkerTask.machineConfig ?? {}); const restoreEvent = await this._prisma.checkpointRestoreEvent.findFirst({ where: { diff --git a/apps/webapp/app/v3/services/triggerTask.server.ts b/apps/webapp/app/v3/services/triggerTask.server.ts index b3eb7b55dc..c0d1cce189 100644 --- a/apps/webapp/app/v3/services/triggerTask.server.ts +++ b/apps/webapp/app/v3/services/triggerTask.server.ts @@ -415,6 +415,7 @@ export class TriggerTaskService extends BaseService { : undefined, runTags: bodyTags, oneTimeUseToken: options.oneTimeUseToken, + machinePreset: body.options?.machine, }, }); diff --git a/docs/v3-openapi.yaml b/docs/v3-openapi.yaml index 6c33b5bdda..293a54b577 100644 --- a/docs/v3-openapi.yaml +++ b/docs/v3-openapi.yaml @@ -1595,6 +1595,18 @@ components: We recommend prefixing tags with a namespace using an underscore or colon, like `user_1234567` or `org:9876543`. Stripe uses underscores. items: type: string + machine: + type: string + enum: + - micro + - small-1x + - small-2x + - medium-1x + - medium-2x + - large-1x + - large-2x + example: "small-2x" + description: The machine preset to use for this run. This will override the task's machine preset and any defaults. TTL: type: - string diff --git a/packages/core/src/v3/schemas/api.ts b/packages/core/src/v3/schemas/api.ts index f1c1fb15cd..2164ed7ef0 100644 --- a/packages/core/src/v3/schemas/api.ts +++ b/packages/core/src/v3/schemas/api.ts @@ -2,6 +2,7 @@ import { z } from "zod"; import { DeserializedJsonSchema } from "../../schemas/json.js"; import { FlushedRunMetadata, + MachinePresetName, RunMetadataChangeOperation, SerializedError, TaskRunError, @@ -91,6 +92,7 @@ export const TriggerTaskRequestBody = z.object({ metadata: z.any(), metadataType: z.string().optional(), maxDuration: z.number().optional(), + machine: MachinePresetName.optional(), }) .optional(), }); @@ -131,6 +133,7 @@ export const BatchTriggerTaskItem = z.object({ metadataType: z.string().optional(), maxDuration: z.number().optional(), parentAttempt: z.string().optional(), + machine: MachinePresetName.optional(), }) .optional(), }); diff --git a/packages/core/src/v3/types/tasks.ts b/packages/core/src/v3/types/tasks.ts index 8db01adc81..c8b82e60e3 100644 --- a/packages/core/src/v3/types/tasks.ts +++ b/packages/core/src/v3/types/tasks.ts @@ -1,11 +1,10 @@ -import type { Schema as AISchema } from "ai"; -import { z } from "zod"; import { SerializableJson } from "../../schemas/json.js"; import { TriggerApiRequestOptions } from "../apiClient/index.js"; import { RunTags } from "../schemas/api.js"; import { MachineCpu, MachineMemory, + MachinePresetName, RetryOptions, TaskMetadata, TaskRunContext, @@ -220,41 +219,36 @@ type CommonTaskOptions< }); * ``` */ - machine?: { - /** vCPUs. The default is 0.5. - * - * Possible values: - * - 0.25 - * - 0.5 - * - 1 - * - 2 - * - 4 - * @deprecated use preset instead - */ - cpu?: MachineCpu; - /** In GBs of RAM. The default is 1. - * - * Possible values: - * - 0.25 - * - 0.5 - * - 1 - * - 2 - * - 4 - * - 8 - * * @deprecated use preset instead - */ - memory?: MachineMemory; - - /** Preset to use for the machine. Defaults to small-1x */ - preset?: - | "micro" - | "small-1x" - | "small-2x" - | "medium-1x" - | "medium-2x" - | "large-1x" - | "large-2x"; - }; + machine?: + | { + /** vCPUs. The default is 0.5. + * + * Possible values: + * - 0.25 + * - 0.5 + * - 1 + * - 2 + * - 4 + * @deprecated use preset instead + */ + cpu?: MachineCpu; + /** In GBs of RAM. The default is 1. + * + * Possible values: + * - 0.25 + * - 0.5 + * - 1 + * - 2 + * - 4 + * - 8 + * * @deprecated use preset instead + */ + memory?: MachineMemory; + + /** Preset to use for the machine. Defaults to small-1x */ + preset?: MachinePresetName; + } + | MachinePresetName; /** * The maximum duration in compute-time seconds that a task run is allowed to run. If the task run exceeds this duration, it will be stopped. @@ -775,6 +769,11 @@ export type TriggerOptions = { * Minimum value is 5 seconds */ maxDuration?: number; + + /** + * The machine preset to use for this run. This will override the task's machine preset and any defaults. + */ + machine?: MachinePresetName; }; export type TriggerAndWaitOptions = Omit; diff --git a/packages/trigger-sdk/src/v3/shared.ts b/packages/trigger-sdk/src/v3/shared.ts index f7513f8c49..4c747f6240 100644 --- a/packages/trigger-sdk/src/v3/shared.ts +++ b/packages/trigger-sdk/src/v3/shared.ts @@ -63,7 +63,6 @@ import type { TaskOutputHandle, TaskPayload, TaskRunResult, - TaskRunResultFromTask, TaskSchema, TaskWithSchema, TaskWithSchemaOptions, @@ -75,6 +74,7 @@ import type { TriggerOptions, AnyTaskRunResult, BatchTriggerAndWaitOptions, + BatchTriggerTaskV2RequestBody, } from "@trigger.dev/core/v3"; export type { @@ -204,7 +204,7 @@ export function createTask< description: params.description, queue: params.queue, retry: params.retry ? { ...defaultRetryOptions, ...params.retry } : undefined, - machine: params.machine, + machine: typeof params.machine === "string" ? { preset: params.machine } : params.machine, maxDuration: params.maxDuration, fns: { run: params.run, @@ -350,7 +350,7 @@ export function createSchemaTask< description: params.description, queue: params.queue, retry: params.retry ? { ...defaultRetryOptions, ...params.retry } : undefined, - machine: params.machine, + machine: typeof params.machine === "string" ? { preset: params.machine } : params.machine, maxDuration: params.maxDuration, fns: { run: params.run, @@ -613,8 +613,9 @@ export async function batchTriggerById( parentAttempt: taskContext.ctx?.attempt.id, metadata: item.options?.metadata, maxDuration: item.options?.maxDuration, + machine: item.options?.machine, }, - }; + } satisfies BatchTriggerTaskV2RequestBody["items"][0]; }) ), }, @@ -786,8 +787,9 @@ export async function batchTriggerByIdAndWait( maxAttempts: item.options?.maxAttempts, metadata: item.options?.metadata, maxDuration: item.options?.maxDuration, + machine: item.options?.machine, }, - }; + } satisfies BatchTriggerTaskV2RequestBody["items"][0]; }) ), dependentAttempt: ctx.attempt.id, @@ -947,8 +949,9 @@ export async function batchTriggerTasks( parentAttempt: taskContext.ctx?.attempt.id, metadata: item.options?.metadata, maxDuration: item.options?.maxDuration, + machine: item.options?.machine, }, - }; + } satisfies BatchTriggerTaskV2RequestBody["items"][0]; }) ), }, @@ -1122,8 +1125,9 @@ export async function batchTriggerAndWaitTasks( parentAttempt: taskContext.ctx?.attempt.id, metadata: options?.metadata, maxDuration: options?.maxDuration, + machine: options?.machine, }, }, { @@ -1259,8 +1264,9 @@ async function batchTrigger_internal( parentAttempt: taskContext.ctx?.attempt.id, metadata: item.options?.metadata, maxDuration: item.options?.maxDuration, + machine: item.options?.machine, }, - }; + } satisfies BatchTriggerTaskV2RequestBody["items"][0]; }) ), }, @@ -1352,6 +1358,7 @@ async function triggerAndWait_internal