diff --git a/.changeset/wise-mirrors-hug.md b/.changeset/wise-mirrors-hug.md new file mode 100644 index 0000000000..dc2cf2a7c1 --- /dev/null +++ b/.changeset/wise-mirrors-hug.md @@ -0,0 +1,5 @@ +--- +"@trigger.dev/core": patch +--- + +Add manual checkpoint schema diff --git a/apps/coordinator/Containerfile b/apps/coordinator/Containerfile index c42b663d2c..4e7b89e0af 100644 --- a/apps/coordinator/Containerfile +++ b/apps/coordinator/Containerfile @@ -13,7 +13,7 @@ RUN find . -name "node_modules" -type d -prune -exec rm -rf '{}' + FROM node-20 AS base RUN apt-get update \ - && apt-get install -y buildah ca-certificates dumb-init docker.io \ + && apt-get install -y buildah ca-certificates dumb-init docker.io busybox \ && rm -rf /var/lib/apt/lists/* COPY --chown=node:node .gitignore .gitignore diff --git a/apps/coordinator/src/checkpointer.ts b/apps/coordinator/src/checkpointer.ts index bf82a6702c..c52060e9c6 100644 --- a/apps/coordinator/src/checkpointer.ts +++ b/apps/coordinator/src/checkpointer.ts @@ -415,6 +415,14 @@ export class Checkpointer { const buildah = new Buildah({ id: `${runId}-${shortCode}`, abortSignal: controller.signal }); const crictl = new Crictl({ id: `${runId}-${shortCode}`, abortSignal: controller.signal }); + const removeCurrentAbortController = () => { + // Ensure only the current controller is removed + if (this.#abortControllers.get(runId) === controller) { + this.#abortControllers.delete(runId); + } + controller.signal.removeEventListener("abort", onAbort); + }; + const cleanup = async () => { const metadata = { runId, @@ -424,6 +432,7 @@ export class Checkpointer { if (this.#dockerMode) { this.#logger.debug("Skipping cleanup in docker mode", metadata); + removeCurrentAbortController(); return; } @@ -436,11 +445,7 @@ export class Checkpointer { this.#logger.error("Error during cleanup", { ...metadata, error }); } - // Ensure only the current controller is removed - if (this.#abortControllers.get(runId) === controller) { - this.#abortControllers.delete(runId); - } - controller.signal.removeEventListener("abort", onAbort); + removeCurrentAbortController(); }; try { diff --git a/apps/coordinator/src/index.ts b/apps/coordinator/src/index.ts index e2a7ba3259..b493f3913f 100644 --- a/apps/coordinator/src/index.ts +++ b/apps/coordinator/src/index.ts @@ -14,7 +14,7 @@ import { ZodSocketConnection } from "@trigger.dev/core/v3/zodSocket"; import { HttpReply, getTextBody } from "@trigger.dev/core/v3/apps"; import { ChaosMonkey } from "./chaosMonkey"; import { Checkpointer } from "./checkpointer"; -import { boolFromEnv, numFromEnv } from "./util"; +import { boolFromEnv, numFromEnv, safeJsonParse } from "./util"; import { collectDefaultMetrics, register, Gauge } from "prom-client"; import { SimpleStructuredLogger } from "@trigger.dev/core/v3/utils/structuredLogger"; @@ -42,6 +42,8 @@ class CheckpointCancelError extends Error {} class TaskCoordinator { #httpServer: ReturnType; + #internalHttpServer: ReturnType; + #checkpointer = new Checkpointer({ dockerMode: !process.env.KUBERNETES_PORT, forceSimulate: boolFromEnv("FORCE_CHECKPOINT_SIMULATION", false), @@ -79,6 +81,8 @@ class TaskCoordinator { private host = "0.0.0.0" ) { this.#httpServer = this.#createHttpServer(); + this.#internalHttpServer = this.#createInternalHttpServer(); + this.#checkpointer.init(); this.#platformSocket = this.#createPlatformSocket(); @@ -653,11 +657,11 @@ class TaskCoordinator { log.error("READY_FOR_LAZY_ATTEMPT error", { error }); - await crashRun({ - name: "ReadyForLazyAttemptError", - message: - error instanceof Error ? `Unexpected error: ${error.message}` : "Unexpected error", - }); + // await crashRun({ + // name: "ReadyForLazyAttemptError", + // message: + // error instanceof Error ? `Unexpected error: ${error.message}` : "Unexpected error", + // }); return; } @@ -1368,13 +1372,236 @@ class TaskCoordinator { case "/metrics": { return reply.text(await register.metrics(), 200, register.contentType); } + default: { + return reply.empty(404); + } + } + }); + + httpServer.on("clientError", (err, socket) => { + socket.end("HTTP/1.1 400 Bad Request\r\n\r\n"); + }); + + httpServer.on("listening", () => { + logger.log("server listening on port", { port: HTTP_SERVER_PORT }); + }); + + return httpServer; + } + + #createInternalHttpServer() { + const httpServer = createServer(async (req, res) => { + logger.log(`[${req.method}]`, { url: req.url }); + + const reply = new HttpReply(res); + + switch (req.url) { case "/whoami": { return reply.text(NODE_NAME); } - case "/checkpoint": { - const body = await getTextBody(req); - // await this.#checkpointer.checkpointAndPush(body); - return reply.text(`sent restore request: ${body}`); + case "/checkpoint/duration": { + try { + const body = await getTextBody(req); + const json = safeJsonParse(body); + + if (typeof json !== "object" || !json) { + return reply.text("Invalid body", 400); + } + + if (!("runId" in json) || typeof json.runId !== "string") { + return reply.text("Missing or invalid: runId", 400); + } + + if (!("now" in json) || typeof json.now !== "number") { + return reply.text("Missing or invalid: now", 400); + } + + if (!("ms" in json) || typeof json.ms !== "number") { + return reply.text("Missing or invalid: ms", 400); + } + + let keepRunAlive = false; + if ("keepRunAlive" in json && typeof json.keepRunAlive === "boolean") { + keepRunAlive = json.keepRunAlive; + } + + let async = false; + if ("async" in json && typeof json.async === "boolean") { + async = json.async; + } + + const { runId, now, ms } = json; + + if (!runId) { + return reply.text("Missing runId", 400); + } + + const runSocket = await this.#getRunSocket(runId); + if (!runSocket) { + return reply.text("Run socket not found", 404); + } + + const { data } = runSocket; + + console.log("Manual duration checkpoint", data); + + if (async) { + reply.text("Creating checkpoint in the background", 202); + } + + const checkpoint = await this.#checkpointer.checkpointAndPush({ + runId: data.runId, + projectRef: data.projectRef, + deploymentVersion: data.deploymentVersion, + attemptNumber: data.attemptNumber ? parseInt(data.attemptNumber) : undefined, + }); + + if (!checkpoint) { + return reply.text("Failed to checkpoint", 500); + } + + if (!data.attemptFriendlyId) { + return reply.text("Socket data missing attemptFriendlyId", 500); + } + + const ack = await this.#platformSocket?.sendWithAck("CHECKPOINT_CREATED", { + version: "v1", + runId, + attemptFriendlyId: data.attemptFriendlyId, + docker: checkpoint.docker, + location: checkpoint.location, + reason: { + type: "WAIT_FOR_DURATION", + ms, + now, + }, + }); + + if (ack?.keepRunAlive || keepRunAlive) { + return reply.json({ + message: `keeping run ${runId} alive after checkpoint`, + checkpoint, + requestJson: json, + platformAck: ack, + }); + } + + runSocket.emit("REQUEST_EXIT", { + version: "v1", + }); + + return reply.json({ + message: `checkpoint created for run ${runId}`, + checkpoint, + requestJson: json, + platformAck: ack, + }); + } catch (error) { + return reply.json({ + message: `error`, + error, + }); + } + } + case "/checkpoint/manual": { + try { + const body = await getTextBody(req); + const json = safeJsonParse(body); + + if (typeof json !== "object" || !json) { + return reply.text("Invalid body", 400); + } + + if (!("runId" in json) || typeof json.runId !== "string") { + return reply.text("Missing or invalid: runId", 400); + } + + let restoreAtUnixTimeMs: number | undefined; + if ("restoreAtUnixTimeMs" in json && typeof json.restoreAtUnixTimeMs === "number") { + restoreAtUnixTimeMs = json.restoreAtUnixTimeMs; + } + + let keepRunAlive = false; + if ("keepRunAlive" in json && typeof json.keepRunAlive === "boolean") { + keepRunAlive = json.keepRunAlive; + } + + let async = false; + if ("async" in json && typeof json.async === "boolean") { + async = json.async; + } + + const { runId } = json; + + if (!runId) { + return reply.text("Missing runId", 400); + } + + const runSocket = await this.#getRunSocket(runId); + if (!runSocket) { + return reply.text("Run socket not found", 404); + } + + const { data } = runSocket; + + console.log("Manual checkpoint", data); + + if (async) { + reply.text("Creating checkpoint in the background", 202); + } + + const checkpoint = await this.#checkpointer.checkpointAndPush({ + runId: data.runId, + projectRef: data.projectRef, + deploymentVersion: data.deploymentVersion, + attemptNumber: data.attemptNumber ? parseInt(data.attemptNumber) : undefined, + }); + + if (!checkpoint) { + return reply.text("Failed to checkpoint", 500); + } + + if (!data.attemptFriendlyId) { + return reply.text("Socket data missing attemptFriendlyId", 500); + } + + const ack = await this.#platformSocket?.sendWithAck("CHECKPOINT_CREATED", { + version: "v1", + runId, + attemptFriendlyId: data.attemptFriendlyId, + docker: checkpoint.docker, + location: checkpoint.location, + reason: { + type: "MANUAL", + restoreAtUnixTimeMs, + }, + }); + + if (ack?.keepRunAlive || keepRunAlive) { + return reply.json({ + message: `keeping run ${runId} alive after checkpoint`, + checkpoint, + requestJson: json, + platformAck: ack, + }); + } + + runSocket.emit("REQUEST_EXIT", { + version: "v1", + }); + + return reply.json({ + message: `checkpoint created for run ${runId}`, + checkpoint, + requestJson: json, + platformAck: ack, + }); + } catch (error) { + return reply.json({ + message: `error`, + error, + }); + } } default: { return reply.empty(404); @@ -1387,7 +1614,7 @@ class TaskCoordinator { }); httpServer.on("listening", () => { - logger.log("server listening on port", { port: HTTP_SERVER_PORT }); + logger.log("internal server listening on port", { port: HTTP_SERVER_PORT + 100 }); }); return httpServer; @@ -1395,6 +1622,7 @@ class TaskCoordinator { listen() { this.#httpServer.listen(this.port, this.host); + this.#internalHttpServer.listen(this.port + 100, "127.0.0.1"); } } diff --git a/apps/coordinator/src/util.ts b/apps/coordinator/src/util.ts index 74bb605ee2..18464f230b 100644 --- a/apps/coordinator/src/util.ts +++ b/apps/coordinator/src/util.ts @@ -17,3 +17,15 @@ export const numFromEnv = (env: string, defaultValue: number): number => { return parseInt(value, 10); }; + +export function safeJsonParse(json?: string): unknown { + if (!json) { + return; + } + + try { + return JSON.parse(json); + } catch (e) { + return null; + } +} diff --git a/apps/webapp/app/database-types.ts b/apps/webapp/app/database-types.ts index aa1ac1d3c9..6214843f64 100644 --- a/apps/webapp/app/database-types.ts +++ b/apps/webapp/app/database-types.ts @@ -69,3 +69,11 @@ export const RuntimeEnvironmentType = { DEVELOPMENT: "DEVELOPMENT", PREVIEW: "PREVIEW", } as const satisfies Record; + +export function isTaskRunAttemptStatus(value: string): value is keyof typeof TaskRunAttemptStatus { + return Object.values(TaskRunAttemptStatus).includes(value as keyof typeof TaskRunAttemptStatus); +} + +export function isTaskRunStatus(value: string): value is keyof typeof TaskRunStatus { + return Object.values(TaskRunStatus).includes(value as keyof typeof TaskRunStatus); +} diff --git a/apps/webapp/app/v3/services/createCheckpoint.server.ts b/apps/webapp/app/v3/services/createCheckpoint.server.ts index 52492a4fbb..bbd8618898 100644 --- a/apps/webapp/app/v3/services/createCheckpoint.server.ts +++ b/apps/webapp/app/v3/services/createCheckpoint.server.ts @@ -1,4 +1,4 @@ -import { CoordinatorToPlatformMessages } from "@trigger.dev/core/v3"; +import { CoordinatorToPlatformMessages, ManualCheckpointMetadata } from "@trigger.dev/core/v3"; import type { InferSocketMessageSchema } from "@trigger.dev/core/v3/zodSocket"; import type { Checkpoint, CheckpointRestoreEvent } from "@trigger.dev/database"; import { logger } from "~/services/logger.server"; @@ -101,6 +101,19 @@ export class CreateCheckpointService extends BaseService { // setTimeout(resolve, waitSeconds * 1000); // }); + let metadata: string; + + if (params.reason.type === "MANUAL") { + metadata = JSON.stringify({ + ...params.reason, + attemptId: attempt.id, + previousAttemptStatus: attempt.status, + previousRunStatus: attempt.taskRun.status, + } satisfies ManualCheckpointMetadata); + } else { + metadata = JSON.stringify(params.reason); + } + const checkpoint = await this._prisma.checkpoint.create({ data: { friendlyId: generateFriendlyId("checkpoint"), @@ -112,7 +125,7 @@ export class CreateCheckpointService extends BaseService { location: params.location, type: params.docker ? "DOCKER" : "KUBERNETES", reason: params.reason.type, - metadata: JSON.stringify(params.reason), + metadata, imageRef, }, }); @@ -138,7 +151,17 @@ export class CreateCheckpointService extends BaseService { let checkpointEvent: CheckpointRestoreEvent | undefined; switch (reason.type) { + case "MANUAL": case "WAIT_FOR_DURATION": { + let restoreAtUnixTimeMs: number; + + if (reason.type === "MANUAL") { + // Restore immediately if not specified, useful for live migration + restoreAtUnixTimeMs = reason.restoreAtUnixTimeMs ?? Date.now(); + } else { + restoreAtUnixTimeMs = reason.now + reason.ms; + } + checkpointEvent = await eventService.checkpoint({ checkpointId: checkpoint.id, }); @@ -151,7 +174,7 @@ export class CreateCheckpointService extends BaseService { resumableAttemptId: attempt.id, checkpointEventId: checkpointEvent.id, }, - reason.now + reason.ms + restoreAtUnixTimeMs ); return { diff --git a/apps/webapp/app/v3/services/createCheckpointRestoreEvent.server.ts b/apps/webapp/app/v3/services/createCheckpointRestoreEvent.server.ts index 1981e0f8e5..63a8b6bb9a 100644 --- a/apps/webapp/app/v3/services/createCheckpointRestoreEvent.server.ts +++ b/apps/webapp/app/v3/services/createCheckpointRestoreEvent.server.ts @@ -1,6 +1,13 @@ -import type { CheckpointRestoreEvent, CheckpointRestoreEventType } from "@trigger.dev/database"; +import type { + Checkpoint, + CheckpointRestoreEvent, + CheckpointRestoreEventType, +} from "@trigger.dev/database"; import { logger } from "~/services/logger.server"; import { BaseService } from "./baseService.server"; +import { ManualCheckpointMetadata } from "@trigger.dev/core/v3"; +import { isTaskRunAttemptStatus, isTaskRunStatus, TaskRunAttemptStatus } from "~/database-types"; +import { safeJsonParse } from "~/utils/json"; interface CheckpointRestoreEventCallParams { checkpointId: string; @@ -39,6 +46,13 @@ export class CreateCheckpointRestoreEventService extends BaseService { return; } + if (params.type === "RESTORE" && checkpoint.reason === "MANUAL") { + const manualRestoreSuccess = await this.#handleManualCheckpointRestore(checkpoint); + if (!manualRestoreSuccess) { + return; + } + } + logger.debug(`Creating checkpoint/restore event`, { params }); let taskRunDependencyId: string | undefined; @@ -99,4 +113,81 @@ export class CreateCheckpointRestoreEventService extends BaseService { return checkpointEvent; } + + async #handleManualCheckpointRestore(checkpoint: Checkpoint): Promise { + const json = checkpoint.metadata ? safeJsonParse(checkpoint.metadata) : undefined; + + // We need to restore the previous run and attempt status as saved in the metadata + const metadata = ManualCheckpointMetadata.safeParse(json); + + if (!metadata.success) { + logger.error("Invalid metadata", { metadata }); + return false; + } + + const { attemptId, previousAttemptStatus, previousRunStatus } = metadata.data; + + if (!isTaskRunAttemptStatus(previousAttemptStatus)) { + logger.error("Invalid previous attempt status", { previousAttemptStatus }); + return false; + } + + if (!isTaskRunStatus(previousRunStatus)) { + logger.error("Invalid previous run status", { previousRunStatus }); + return false; + } + + try { + const updatedAttempt = await this._prisma.taskRunAttempt.update({ + where: { + id: attemptId, + }, + data: { + status: previousAttemptStatus, + taskRun: { + update: { + data: { + status: previousRunStatus, + }, + }, + }, + }, + select: { + id: true, + status: true, + taskRun: { + select: { + id: true, + status: true, + }, + }, + }, + }); + + logger.debug("Set post resume statuses after manual checkpoint", { + run: { + id: updatedAttempt.taskRun.id, + status: updatedAttempt.taskRun.status, + }, + attempt: { + id: updatedAttempt.id, + status: updatedAttempt.status, + }, + }); + + return true; + } catch (error) { + logger.error("Failed to set post resume statuses", { + error: + error instanceof Error + ? { + name: error.name, + message: error.message, + stack: error.stack, + } + : error, + }); + return false; + } + } } diff --git a/packages/core/src/v3/schemas/messages.ts b/packages/core/src/v3/schemas/messages.ts index 6147b86c3b..29b1b1a825 100644 --- a/packages/core/src/v3/schemas/messages.ts +++ b/packages/core/src/v3/schemas/messages.ts @@ -479,6 +479,11 @@ export const CoordinatorToPlatformMessages = { type: z.literal("RETRYING_AFTER_FAILURE"), attemptNumber: z.number(), }), + z.object({ + type: z.literal("MANUAL"), + /** If unspecified it will be restored immediately, e.g. for live migration */ + restoreAtUnixTimeMs: z.number().optional(), + }), ]), }), callback: z.object({ diff --git a/packages/core/src/v3/schemas/schemas.ts b/packages/core/src/v3/schemas/schemas.ts index a72112783b..8ed9b7b0be 100644 --- a/packages/core/src/v3/schemas/schemas.ts +++ b/packages/core/src/v3/schemas/schemas.ts @@ -250,3 +250,12 @@ export const TaskRunExecutionLazyAttemptPayload = z.object({ }); export type TaskRunExecutionLazyAttemptPayload = z.infer; + +export const ManualCheckpointMetadata = z.object({ + /** NOT a friendly ID */ + attemptId: z.string(), + previousRunStatus: z.string(), + previousAttemptStatus: z.string(), +}); + +export type ManualCheckpointMetadata = z.infer;