diff --git a/examples/cloudflare-workers-hono/src/index.ts b/examples/cloudflare-workers-hono/src/index.ts index a313a477a2..1860cb083f 100644 --- a/examples/cloudflare-workers-hono/src/index.ts +++ b/examples/cloudflare-workers-hono/src/index.ts @@ -1,21 +1,21 @@ -// import { type Client, createHandler } from "@rivetkit/cloudflare-workers"; -// import { Hono } from "hono"; -// import { registry } from "./registry"; -// -// // Setup router -// const app = new Hono<{ Bindings: { RIVET: Client } }>(); -// -// // Example HTTP endpoint -// app.post("/increment/:name", async (c) => { -// const client = c.env.RIVET; -// -// const name = c.req.param("name"); -// -// const counter = client.counter.getOrCreate(name); -// const newCount = await counter.increment(1); -// -// return c.text(`New Count: ${newCount}`); -// }); -// -// const { handler, ActorHandler } = createHandler(registry, { fetch: app.fetch }); -// export { handler as default, ActorHandler }; +import { type Client, createHandler } from "@rivetkit/cloudflare-workers"; +import { Hono } from "hono"; +import { registry } from "./registry"; + +// Setup router +const app = new Hono<{ Bindings: { RIVET: Client } }>(); + +// Example HTTP endpoint +app.post("/increment/:name", async (c) => { + const client = c.env.RIVET; + + const name = c.req.param("name"); + + const counter = client.counter.getOrCreate(name); + const newCount = await counter.increment(1); + + return c.text(`New Count: ${newCount}`); +}); + +const { handler, ActorHandler } = createHandler(registry, { fetch: app.fetch }); +export { handler as default, ActorHandler }; diff --git a/examples/cloudflare-workers-hono/wrangler.json b/examples/cloudflare-workers-hono/wrangler.json index 29b055cf3f..f5b84c4ef6 100644 --- a/examples/cloudflare-workers-hono/wrangler.json +++ b/examples/cloudflare-workers-hono/wrangler.json @@ -6,7 +6,7 @@ "migrations": [ { "tag": "v1", - "new_classes": ["ActorHandler"] + "new_sqlite_classes": ["ActorHandler"] } ], "durable_objects": { diff --git a/examples/cloudflare-workers-inline-client/src/index.ts b/examples/cloudflare-workers-inline-client/src/index.ts index 8f055c4366..0c7db7d9fc 100644 --- a/examples/cloudflare-workers-inline-client/src/index.ts +++ b/examples/cloudflare-workers-inline-client/src/index.ts @@ -1,43 +1,43 @@ -// import { createInlineClient } from "@rivetkit/cloudflare-workers"; -// import { registry } from "./registry"; -// -// const { -// client, -// fetch: rivetFetch, -// ActorHandler, -// } = createInlineClient(registry); -// -// // IMPORTANT: Your Durable Object must be exported here -// export { ActorHandler }; -// -// export default { -// fetch: async (request, env, ctx) => { -// const url = new URL(request.url); -// -// // Custom request handler -// if ( -// request.method === "POST" && -// url.pathname.startsWith("/increment/") -// ) { -// const name = url.pathname.slice("/increment/".length); -// -// const counter = client.counter.getOrCreate(name); -// const newCount = await counter.increment(1); -// -// return new Response(`New Count: ${newCount}`, { -// headers: { "Content-Type": "text/plain" }, -// }); -// } -// -// // Optional: If you want to access Rivet Actors publicly, mount the path -// if (url.pathname.startsWith("/rivet")) { -// const strippedPath = url.pathname.substring("/rivet".length); -// url.pathname = strippedPath; -// console.log("URL", url.toString()); -// const modifiedRequest = new Request(url.toString(), request); -// return rivetFetch(modifiedRequest, env, ctx); -// } -// -// return new Response("Not Found", { status: 404 }); -// }, -// } satisfies ExportedHandler; +import { createInlineClient } from "@rivetkit/cloudflare-workers"; +import { registry } from "./registry"; + +const { + client, + fetch: rivetFetch, + ActorHandler, +} = createInlineClient(registry); + +// IMPORTANT: Your Durable Object must be exported here +export { ActorHandler }; + +export default { + fetch: async (request, env, ctx) => { + const url = new URL(request.url); + + // Custom request handler + if ( + request.method === "POST" && + url.pathname.startsWith("/increment/") + ) { + const name = url.pathname.slice("/increment/".length); + + const counter = client.counter.getOrCreate(name); + const newCount = await counter.increment(1); + + return new Response(`New Count: ${newCount}`, { + headers: { "Content-Type": "text/plain" }, + }); + } + + // Optional: If you want to access Rivet Actors publicly, mount the path + if (url.pathname.startsWith("/rivet")) { + const strippedPath = url.pathname.substring("/rivet".length); + url.pathname = strippedPath; + console.log("URL", url.toString()); + const modifiedRequest = new Request(url.toString(), request); + return rivetFetch(modifiedRequest, env, ctx); + } + + return new Response("Not Found", { status: 404 }); + }, +} satisfies ExportedHandler; diff --git a/examples/cloudflare-workers/scripts/client.ts b/examples/cloudflare-workers/scripts/client.ts index 5f9990e20b..fe64e00077 100644 --- a/examples/cloudflare-workers/scripts/client.ts +++ b/examples/cloudflare-workers/scripts/client.ts @@ -11,7 +11,7 @@ async function main() { try { // Create counter instance - const counter = client.counter.getOrCreate("demo"); + const counter = client.counter.getOrCreate("demo").connect(); // Increment counter console.log("Incrementing counter 'demo'..."); diff --git a/examples/cloudflare-workers/src/index.ts b/examples/cloudflare-workers/src/index.ts index 672e271c82..48c0cc626d 100644 --- a/examples/cloudflare-workers/src/index.ts +++ b/examples/cloudflare-workers/src/index.ts @@ -1,5 +1,5 @@ -// import { createHandler } from "@rivetkit/cloudflare-workers"; -// import { registry } from "./registry"; -// -// const { handler, ActorHandler } = createHandler(registry); -// export { handler as default, ActorHandler }; +import { createHandler } from "@rivetkit/cloudflare-workers"; +import { registry } from "./registry"; + +const { handler, ActorHandler } = createHandler(registry); +export { handler as default, ActorHandler }; diff --git a/examples/next-js/package.json b/examples/next-js/package.json index 241afe77e0..f62dbd6066 100644 --- a/examples/next-js/package.json +++ b/examples/next-js/package.json @@ -15,6 +15,8 @@ "react-dom": "19.1.0", "next": "16.1.1", "@rivetkit/next-js": "*", + "@hono/node-server": "1.14.2", + "@hono/node-ws": "1.3.0", "rivetkit": "*" }, "devDependencies": { diff --git a/examples/next-js/src/app/api/rivet/[...all]/route.ts b/examples/next-js/src/app/api/rivet/[...all]/route.ts index 82d553a30d..a9ebd3e0a7 100644 --- a/examples/next-js/src/app/api/rivet/[...all]/route.ts +++ b/examples/next-js/src/app/api/rivet/[...all]/route.ts @@ -1,7 +1,6 @@ -// import { toNextHandler } from "@rivetkit/next-js"; -// import { registry } from "@/rivet/registry"; -// -// export const maxDuration = 300; -// -// export const { GET, POST, PUT, PATCH, HEAD, OPTIONS } = toNextHandler(registry); -export const GET = () => "foo"; +import { toNextHandler } from "@rivetkit/next-js"; +import { registry } from "@/rivet/registry"; + +export const maxDuration = 300; + +export const { GET, POST, PUT, PATCH, HEAD, OPTIONS } = toNextHandler(registry); diff --git a/examples/next-js/src/components/Counter.tsx b/examples/next-js/src/components/Counter.tsx index 2ed06b99d6..01032d50cf 100644 --- a/examples/next-js/src/components/Counter.tsx +++ b/examples/next-js/src/components/Counter.tsx @@ -13,21 +13,20 @@ export const { useActor } = createRivetKit({ export function Counter() { const [counterId, setCounterId] = useState("default"); const [count, setCount] = useState(0); - const [isConnected, setIsConnected] = useState(false); const counter = useActor({ name: "counter", key: [counterId], }); + // Use connStatus from the hook instead of tracking connection state manually + const isConnected = counter.connStatus === "connected"; + useEffect(() => { - if (counter.connection) { - setIsConnected(true); + if (counter.connection && isConnected) { counter.connection.getCount().then(setCount); - } else { - setIsConnected(false); } - }, [counter.connection]); + }, [counter.connection, isConnected]); counter.useEvent("newCount", (newCount: number) => { setCount(newCount); diff --git a/examples/next-js/tsconfig.json b/examples/next-js/tsconfig.json index 7df89e76da..b575f7dac7 100644 --- a/examples/next-js/tsconfig.json +++ b/examples/next-js/tsconfig.json @@ -1,27 +1,41 @@ { - "compilerOptions": { - "target": "ES2017", - "lib": ["dom", "dom.iterable", "esnext"], - "allowJs": true, - "skipLibCheck": true, - "strict": true, - "noEmit": true, - "esModuleInterop": true, - "module": "esnext", - "moduleResolution": "bundler", - "resolveJsonModule": true, - "isolatedModules": true, - "jsx": "preserve", - "incremental": true, - "plugins": [ - { - "name": "next" - } - ], - "paths": { - "@/*": ["./src/*"] - } - }, - "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], - "exclude": ["node_modules"] + "compilerOptions": { + "target": "ES2017", + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], + "allowJs": true, + "skipLibCheck": true, + "strict": true, + "noEmit": true, + "esModuleInterop": true, + "module": "esnext", + "moduleResolution": "bundler", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "react-jsx", + "incremental": true, + "plugins": [ + { + "name": "next" + } + ], + "paths": { + "@/*": [ + "./src/*" + ] + } + }, + "include": [ + "next-env.d.ts", + "**/*.ts", + "**/*.tsx", + ".next/types/**/*.ts", + ".next/dev/types/**/*.ts" + ], + "exclude": [ + "node_modules" + ] } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index cb90bb72e6..bfb3b7790f 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1250,6 +1250,12 @@ importers: examples/next-js: dependencies: + '@hono/node-server': + specifier: 1.14.2 + version: 1.14.2(hono@4.11.3) + '@hono/node-ws': + specifier: 1.3.0 + version: 1.3.0(@hono/node-server@1.14.2(hono@4.11.3))(hono@4.11.3) '@rivetkit/next-js': specifier: workspace:* version: link:../../rivetkit-typescript/packages/next-js @@ -2215,7 +2221,7 @@ importers: version: 3.13.12(react-dom@19.1.1(react@19.1.1))(react@19.1.1) '@uiw/codemirror-extensions-basic-setup': specifier: ^4.25.1 - version: 4.25.1(@codemirror/autocomplete@6.19.0)(@codemirror/commands@6.9.0)(@codemirror/language@6.11.3)(@codemirror/lint@6.9.0)(@codemirror/search@6.5.11)(@codemirror/state@6.5.2)(@codemirror/view@6.38.2) + version: 4.25.1(@codemirror/autocomplete@6.19.0)(@codemirror/commands@6.8.1)(@codemirror/language@6.11.3)(@codemirror/lint@6.9.0)(@codemirror/search@6.5.11)(@codemirror/state@6.5.2)(@codemirror/view@6.38.2) '@uiw/codemirror-theme-github': specifier: ^4.25.1 version: 4.25.1(@codemirror/language@6.11.3)(@codemirror/state@6.5.2)(@codemirror/view@6.38.2) @@ -5312,6 +5318,12 @@ packages: resolution: {integrity: sha512-hLpID6NCs8+stbz935UyvyGOXY44oLBSOy7ZEpwXxj977A/0U41iihDQllDoCJrxtbe06DnDgwPOn6/xnRJ71w==} deprecated: Starting with v0.73.0, this package is bundled directly inside @hey-api/openapi-ts. + '@hono/node-server@1.14.2': + resolution: {integrity: sha512-GHjpOeHYbr9d1vkID2sNUYkl5IxumyhDrUJB7wBp7jvqYwPFt+oNKsAPBRcdSbV7kIrXhouLE199ks1QcK4r7A==} + engines: {node: '>=18.14.1'} + peerDependencies: + hono: ^4 + '@hono/node-server@1.19.1': resolution: {integrity: sha512-h44e5s+ByUriaRIbeS/C74O8v90m0A95luyYQGMF7KEn96KkYMXO7bZAwombzTpjQTU4e0TkU8U1WBIXlwuwtA==} engines: {node: '>=18.14.1'} @@ -17660,6 +17672,10 @@ snapshots: '@hey-api/client-fetch@0.5.7': {} + '@hono/node-server@1.14.2(hono@4.11.3)': + dependencies: + hono: 4.11.3 + '@hono/node-server@1.19.1(hono@4.9.8)': dependencies: hono: 4.9.8 @@ -17690,11 +17706,20 @@ snapshots: - bufferutil - utf-8-validate + '@hono/node-ws@1.3.0(@hono/node-server@1.14.2(hono@4.11.3))(hono@4.11.3)': + dependencies: + '@hono/node-server': 1.14.2(hono@4.11.3) + hono: 4.11.3 + ws: 8.19.0 + transitivePeerDependencies: + - bufferutil + - utf-8-validate + '@hono/node-ws@1.3.0(@hono/node-server@1.19.1(hono@4.9.8))(hono@4.9.8)': dependencies: '@hono/node-server': 1.19.1(hono@4.9.8) hono: 4.9.8 - ws: 8.18.3 + ws: 8.19.0 transitivePeerDependencies: - bufferutil - utf-8-validate @@ -17703,7 +17728,7 @@ snapshots: dependencies: '@hono/node-server': 1.19.7(hono@4.11.3) hono: 4.11.3 - ws: 8.18.3 + ws: 8.19.0 transitivePeerDependencies: - bufferutil - utf-8-validate @@ -17712,7 +17737,7 @@ snapshots: dependencies: '@hono/node-server': 1.19.7(hono@4.9.8) hono: 4.9.8 - ws: 8.18.3 + ws: 8.19.0 transitivePeerDependencies: - bufferutil - utf-8-validate @@ -20991,16 +21016,6 @@ snapshots: '@codemirror/state': 6.5.2 '@codemirror/view': 6.38.2 - '@uiw/codemirror-extensions-basic-setup@4.25.1(@codemirror/autocomplete@6.19.0)(@codemirror/commands@6.9.0)(@codemirror/language@6.11.3)(@codemirror/lint@6.9.0)(@codemirror/search@6.5.11)(@codemirror/state@6.5.2)(@codemirror/view@6.38.2)': - dependencies: - '@codemirror/autocomplete': 6.19.0 - '@codemirror/commands': 6.9.0 - '@codemirror/language': 6.11.3 - '@codemirror/lint': 6.9.0 - '@codemirror/search': 6.5.11 - '@codemirror/state': 6.5.2 - '@codemirror/view': 6.38.2 - '@uiw/codemirror-theme-github@4.25.1(@codemirror/language@6.11.3)(@codemirror/state@6.5.2)(@codemirror/view@6.38.2)': dependencies: '@uiw/codemirror-themes': 4.25.1(@codemirror/language@6.11.3)(@codemirror/state@6.5.2)(@codemirror/view@6.38.2) diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts index 627cae6fba..65307b2703 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts @@ -1,340 +1,334 @@ -// import invariant from "invariant"; -// import type { -// ActorKey, -// ActorRouter, -// AnyActorInstance as CoreAnyActorInstance, -// RegistryConfig, -// RunConfig, -// } from "rivetkit"; -// import { lookupInRegistry } from "rivetkit"; -// import type { Client } from "rivetkit/client"; -// import type { -// ActorDriver, -// AnyActorInstance, -// ManagerDriver, -// } from "rivetkit/driver-helpers"; -// import { promiseWithResolvers } from "rivetkit/utils"; -// import { parseActorId } from "./actor-id"; -// import { kvDelete, kvGet, kvListPrefix, kvPut } from "./actor-kv"; -// import { GLOBAL_KV_KEYS } from "./global-kv"; -// import { getCloudflareAmbientEnv } from "./handler"; -// -// interface DurableObjectGlobalState { -// ctx: DurableObjectState; -// env: unknown; -// } -// -// /** -// * Cloudflare DO can have multiple DO running within the same global scope. -// * -// * This allows for storing the actor context globally and looking it up by ID in `CloudflareActorsActorDriver`. -// */ -// export class CloudflareDurableObjectGlobalState { -// // Map of actor ID -> DO state -// #dos: Map = new Map(); -// -// // WeakMap of DO state -> ActorGlobalState for proper GC -// #actors: WeakMap = new WeakMap(); -// -// getDOState(doId: string): DurableObjectGlobalState { -// const state = this.#dos.get(doId); -// invariant( -// state !== undefined, -// "durable object state not in global state", -// ); -// return state; -// } -// -// setDOState(doId: string, state: DurableObjectGlobalState) { -// this.#dos.set(doId, state); -// } -// -// getActorState(ctx: DurableObjectState): ActorGlobalState | undefined { -// return this.#actors.get(ctx); -// } -// -// setActorState(ctx: DurableObjectState, actorState: ActorGlobalState): void { -// this.#actors.set(ctx, actorState); -// } -// } -// -// export interface DriverContext { -// state: DurableObjectState; -// } -// -// interface InitializedData { -// name: string; -// key: ActorKey; -// generation: number; -// } -// -// interface LoadedActor { -// actorRouter: ActorRouter; -// actorDriver: ActorDriver; -// generation: number; -// } -// -// // Actor global state to track running instances -// export class ActorGlobalState { -// // Initialization state -// initialized?: InitializedData; -// -// // Loaded actor state -// actor?: LoadedActor; -// actorInstance?: AnyActorInstance; -// actorPromise?: ReturnType>; -// -// /** -// * Indicates if `startDestroy` has been called. -// * -// * This is stored in memory instead of SQLite since the destroy may be cancelled. -// * -// * See the corresponding `destroyed` property in SQLite metadata. -// */ -// destroying: boolean = false; -// -// reset() { -// this.initialized = undefined; -// this.actor = undefined; -// this.actorInstance = undefined; -// this.actorPromise = undefined; -// this.destroying = false; -// } -// } -// -// export class CloudflareActorsActorDriver implements ActorDriver { -// #registryConfig: RegistryConfig; -// #runConfig: RunConfig; -// #managerDriver: ManagerDriver; -// #inlineClient: Client; -// #globalState: CloudflareDurableObjectGlobalState; -// -// constructor( -// registryConfig: RegistryConfig, -// runConfig: RunConfig, -// managerDriver: ManagerDriver, -// inlineClient: Client, -// globalState: CloudflareDurableObjectGlobalState, -// ) { -// this.#registryConfig = registryConfig; -// this.#runConfig = runConfig; -// this.#managerDriver = managerDriver; -// this.#inlineClient = inlineClient; -// this.#globalState = globalState; -// } -// -// #getDOCtx(actorId: string) { -// // Parse actor ID to get DO ID -// const [doId] = parseActorId(actorId); -// return this.#globalState.getDOState(doId).ctx; -// } -// -// async loadActor(actorId: string): Promise { -// // Parse actor ID to get DO ID and generation -// const [doId, expectedGeneration] = parseActorId(actorId); -// -// // Get the DO state -// const doState = this.#globalState.getDOState(doId); -// -// // Check if actor is already loaded -// let actorState = this.#globalState.getActorState(doState.ctx); -// if (actorState?.actorInstance) { -// // Actor is already loaded, return it -// return actorState.actorInstance; -// } -// -// // Create new actor state if it doesn't exist -// if (!actorState) { -// actorState = new ActorGlobalState(); -// actorState.actorPromise = promiseWithResolvers(); -// this.#globalState.setActorState(doState.ctx, actorState); -// } else if (actorState.actorPromise) { -// // Another request is already loading this actor, wait for it -// await actorState.actorPromise.promise; -// if (!actorState.actorInstance) { -// throw new Error( -// `Actor ${actorId} failed to load in concurrent request`, -// ); -// } -// return actorState.actorInstance; -// } -// -// // Load actor metadata -// const sql = doState.ctx.storage.sql; -// const cursor = sql.exec( -// "SELECT name, key, destroyed, generation FROM _rivetkit_metadata LIMIT 1", -// ); -// const result = cursor.raw().next(); -// -// if (result.done || !result.value) { -// throw new Error( -// `Actor ${actorId} is not initialized - missing metadata`, -// ); -// } -// -// const name = result.value[0] as string; -// const key = JSON.parse(result.value[1] as string) as string[]; -// const destroyed = result.value[2] as number; -// const generation = result.value[3] as number; -// -// // Check if actor is destroyed -// if (destroyed) { -// throw new Error(`Actor ${actorId} is destroyed`); -// } -// -// // Check if generation matches -// if (generation !== expectedGeneration) { -// throw new Error( -// `Actor ${actorId} generation mismatch - expected ${expectedGeneration}, got ${generation}`, -// ); -// } -// -// // Create actor instance -// const definition = lookupInRegistry(this.#registryConfig, name); -// actorState.actorInstance = definition.instantiate(); -// -// // Start actor -// await actorState.actorInstance.start( -// this, -// this.#inlineClient, -// actorId, -// name, -// key, -// "unknown", // TODO: Support regions in Cloudflare -// ); -// -// // Finish -// actorState.actorPromise?.resolve(); -// actorState.actorPromise = undefined; -// -// return actorState.actorInstance; -// } -// -// getContext(actorId: string): DriverContext { -// // Parse actor ID to get DO ID -// const [doId] = parseActorId(actorId); -// const state = this.#globalState.getDOState(doId); -// return { state: state.ctx }; -// } -// -// async setAlarm(actor: AnyActorInstance, timestamp: number): Promise { -// await this.#getDOCtx(actor.id).storage.setAlarm(timestamp); -// } -// -// async getDatabase(actorId: string): Promise { -// return this.#getDOCtx(actorId).storage.sql; -// } -// -// // Batch KV operations -// async kvBatchPut( -// actorId: string, -// entries: [Uint8Array, Uint8Array][], -// ): Promise { -// const sql = this.#getDOCtx(actorId).storage.sql; -// -// for (const [key, value] of entries) { -// kvPut(sql, key, value); -// } -// } -// -// async kvBatchGet( -// actorId: string, -// keys: Uint8Array[], -// ): Promise<(Uint8Array | null)[]> { -// const sql = this.#getDOCtx(actorId).storage.sql; -// -// const results: (Uint8Array | null)[] = []; -// for (const key of keys) { -// results.push(kvGet(sql, key)); -// } -// -// return results; -// } -// -// async kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise { -// const sql = this.#getDOCtx(actorId).storage.sql; -// -// for (const key of keys) { -// kvDelete(sql, key); -// } -// } -// -// async kvListPrefix( -// actorId: string, -// prefix: Uint8Array, -// ): Promise<[Uint8Array, Uint8Array][]> { -// const sql = this.#getDOCtx(actorId).storage.sql; -// -// return kvListPrefix(sql, prefix); -// } -// -// startDestroy(actorId: string): void { -// // Parse actor ID to get DO ID and generation -// const [doId, generation] = parseActorId(actorId); -// -// // Get the DO state -// const doState = this.#globalState.getDOState(doId); -// const actorState = this.#globalState.getActorState(doState.ctx); -// -// // Actor not loaded, nothing to destroy -// if (!actorState?.actorInstance) { -// return; -// } -// -// // Check if already destroying -// if (actorState.destroying) { -// return; -// } -// actorState.destroying = true; -// -// // Spawn onStop in background -// this.#callOnStopAsync(actorId, doId, actorState.actorInstance); -// } -// -// async #callOnStopAsync( -// actorId: string, -// doId: string, -// actor: CoreAnyActorInstance, -// ) { -// // Stop -// await actor.onStop("destroy"); -// -// // Remove state -// const doState = this.#globalState.getDOState(doId); -// const sql = doState.ctx.storage.sql; -// sql.exec("UPDATE _rivetkit_metadata SET destroyed = 1 WHERE 1=1"); -// sql.exec("DELETE FROM _rivetkit_kv_storage"); -// -// // Clear any scheduled alarms -// await doState.ctx.storage.deleteAlarm(); -// -// // Delete from ACTOR_KV in the background - use full actorId including generation -// const env = getCloudflareAmbientEnv(); -// doState.ctx.waitUntil( -// env.ACTOR_KV.delete(GLOBAL_KV_KEYS.actorMetadata(actorId)), -// ); -// -// // Reset global state using the DO context -// const actorHandle = this.#globalState.getActorState(doState.ctx); -// actorHandle?.reset(); -// } -// } -// -// export function createCloudflareActorsActorDriverBuilder( -// globalState: CloudflareDurableObjectGlobalState, -// ) { -// return ( -// config: RegistryConfig, -// runConfig: RunConfig, -// managerDriver: ManagerDriver, -// inlineClient: Client, -// ) => { -// return new CloudflareActorsActorDriver( -// config, -// runConfig, -// managerDriver, -// inlineClient, -// globalState, -// ); -// }; -// } +import invariant from "invariant"; +import type { + ActorKey, + ActorRouter, + AnyActorInstance as CoreAnyActorInstance, + RegistryConfig, +} from "rivetkit"; +import { lookupInRegistry } from "rivetkit"; +import type { Client } from "rivetkit/client"; +import type { + ActorDriver, + AnyActorInstance, + ManagerDriver, +} from "rivetkit/driver-helpers"; +import { promiseWithResolvers } from "rivetkit/utils"; +import { kvDelete, kvGet, kvListPrefix, kvPut } from "./actor-kv"; +import { GLOBAL_KV_KEYS } from "./global-kv"; +import { getCloudflareAmbientEnv } from "./handler"; +import { parseActorId } from "./actor-id"; + +interface DurableObjectGlobalState { + ctx: DurableObjectState; + env: unknown; +} + +/** + * Cloudflare DO can have multiple DO running within the same global scope. + * + * This allows for storing the actor context globally and looking it up by ID in `CloudflareActorsActorDriver`. + */ +export class CloudflareDurableObjectGlobalState { + // Map of actor ID -> DO state + #dos: Map = new Map(); + + // WeakMap of DO state -> ActorGlobalState for proper GC + #actors: WeakMap = new WeakMap(); + + getDOState(doId: string): DurableObjectGlobalState { + const state = this.#dos.get(doId); + invariant( + state !== undefined, + "durable object state not in global state", + ); + return state; + } + + setDOState(doId: string, state: DurableObjectGlobalState) { + this.#dos.set(doId, state); + } + + getActorState(ctx: DurableObjectState): ActorGlobalState | undefined { + return this.#actors.get(ctx); + } + + setActorState(ctx: DurableObjectState, actorState: ActorGlobalState): void { + this.#actors.set(ctx, actorState); + } +} + +export interface DriverContext { + state: DurableObjectState; +} + +interface InitializedData { + name: string; + key: ActorKey; + generation: number; +} + +interface LoadedActor { + actorRouter: ActorRouter; + actorDriver: ActorDriver; + generation: number; +} + +// Actor global state to track running instances +export class ActorGlobalState { + // Initialization state + initialized?: InitializedData; + + // Loaded actor state + actor?: LoadedActor; + actorInstance?: AnyActorInstance; + actorPromise?: ReturnType>; + + /** + * Indicates if `startDestroy` has been called. + * + * This is stored in memory instead of SQLite since the destroy may be cancelled. + * + * See the corresponding `destroyed` property in SQLite metadata. + */ + destroying: boolean = false; + + reset() { + this.initialized = undefined; + this.actor = undefined; + this.actorInstance = undefined; + this.actorPromise = undefined; + this.destroying = false; + } +} + +export class CloudflareActorsActorDriver implements ActorDriver { + #registryConfig: RegistryConfig; + #managerDriver: ManagerDriver; + #inlineClient: Client; + #globalState: CloudflareDurableObjectGlobalState; + + constructor( + registryConfig: RegistryConfig, + managerDriver: ManagerDriver, + inlineClient: Client, + globalState: CloudflareDurableObjectGlobalState, + ) { + this.#registryConfig = registryConfig; + this.#managerDriver = managerDriver; + this.#inlineClient = inlineClient; + this.#globalState = globalState; + } + + #getDOCtx(actorId: string) { + // Parse actor ID to get DO ID + const [doId] = parseActorId(actorId); + return this.#globalState.getDOState(doId).ctx; + } + + async loadActor(actorId: string): Promise { + // Parse actor ID to get DO ID and generation + const [doId, expectedGeneration] = parseActorId(actorId); + + // Get the DO state + const doState = this.#globalState.getDOState(doId); + + // Check if actor is already loaded + let actorState = this.#globalState.getActorState(doState.ctx); + if (actorState?.actorInstance) { + // Actor is already loaded, return it + return actorState.actorInstance; + } + + // Create new actor state if it doesn't exist + if (!actorState) { + actorState = new ActorGlobalState(); + actorState.actorPromise = promiseWithResolvers(); + this.#globalState.setActorState(doState.ctx, actorState); + } else if (actorState.actorPromise) { + // Another request is already loading this actor, wait for it + await actorState.actorPromise.promise; + if (!actorState.actorInstance) { + throw new Error( + `Actor ${actorId} failed to load in concurrent request`, + ); + } + return actorState.actorInstance; + } + + // Load actor metadata + const sql = doState.ctx.storage.sql; + const cursor = sql.exec( + "SELECT name, key, destroyed, generation FROM _rivetkit_metadata LIMIT 1", + ); + const result = cursor.raw().next(); + + if (result.done || !result.value) { + throw new Error( + `Actor ${actorId} is not initialized - missing metadata`, + ); + } + + const name = result.value[0] as string; + const key = JSON.parse(result.value[1] as string) as string[]; + const destroyed = result.value[2] as number; + const generation = result.value[3] as number; + + // Check if actor is destroyed + if (destroyed) { + throw new Error(`Actor ${actorId} is destroyed`); + } + + // Check if generation matches + if (generation !== expectedGeneration) { + throw new Error( + `Actor ${actorId} generation mismatch - expected ${expectedGeneration}, got ${generation}`, + ); + } + + // Create actor instance + const definition = lookupInRegistry(this.#registryConfig, name); + actorState.actorInstance = definition.instantiate(); + + // Start actor + await actorState.actorInstance.start( + this, + this.#inlineClient, + actorId, + name, + key, + "unknown", // TODO: Support regions in Cloudflare + ); + + // Finish + actorState.actorPromise?.resolve(); + actorState.actorPromise = undefined; + + return actorState.actorInstance; + } + + getContext(actorId: string): DriverContext { + // Parse actor ID to get DO ID + const [doId] = parseActorId(actorId); + const state = this.#globalState.getDOState(doId); + return { state: state.ctx }; + } + + async setAlarm(actor: AnyActorInstance, timestamp: number): Promise { + await this.#getDOCtx(actor.id).storage.setAlarm(timestamp); + } + + async getDatabase(actorId: string): Promise { + return this.#getDOCtx(actorId).storage.sql; + } + + // Batch KV operations + async kvBatchPut( + actorId: string, + entries: [Uint8Array, Uint8Array][], + ): Promise { + const sql = this.#getDOCtx(actorId).storage.sql; + + for (const [key, value] of entries) { + kvPut(sql, key, value); + } + } + + async kvBatchGet( + actorId: string, + keys: Uint8Array[], + ): Promise<(Uint8Array | null)[]> { + const sql = this.#getDOCtx(actorId).storage.sql; + + const results: (Uint8Array | null)[] = []; + for (const key of keys) { + results.push(kvGet(sql, key)); + } + + return results; + } + + async kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise { + const sql = this.#getDOCtx(actorId).storage.sql; + + for (const key of keys) { + kvDelete(sql, key); + } + } + + async kvListPrefix( + actorId: string, + prefix: Uint8Array, + ): Promise<[Uint8Array, Uint8Array][]> { + const sql = this.#getDOCtx(actorId).storage.sql; + + return kvListPrefix(sql, prefix); + } + + startDestroy(actorId: string): void { + // Parse actor ID to get DO ID and generation + const [doId, generation] = parseActorId(actorId); + + // Get the DO state + const doState = this.#globalState.getDOState(doId); + const actorState = this.#globalState.getActorState(doState.ctx); + + // Actor not loaded, nothing to destroy + if (!actorState?.actorInstance) { + return; + } + + // Check if already destroying + if (actorState.destroying) { + return; + } + actorState.destroying = true; + + // Spawn onStop in background + this.#callOnStopAsync(actorId, doId, actorState.actorInstance); + } + + async #callOnStopAsync( + actorId: string, + doId: string, + actor: CoreAnyActorInstance, + ) { + // Stop + await actor.onStop("destroy"); + + // Remove state + const doState = this.#globalState.getDOState(doId); + const sql = doState.ctx.storage.sql; + sql.exec("UPDATE _rivetkit_metadata SET destroyed = 1 WHERE 1=1"); + sql.exec("DELETE FROM _rivetkit_kv_storage"); + + // Clear any scheduled alarms + await doState.ctx.storage.deleteAlarm(); + + // Delete from ACTOR_KV in the background - use full actorId including generation + const env = getCloudflareAmbientEnv(); + doState.ctx.waitUntil( + env.ACTOR_KV.delete(GLOBAL_KV_KEYS.actorMetadata(actorId)), + ); + + // Reset global state using the DO context + const actorHandle = this.#globalState.getActorState(doState.ctx); + actorHandle?.reset(); + } +} + +export function createCloudflareActorsActorDriverBuilder( + globalState: CloudflareDurableObjectGlobalState, +) { + return ( + config: RegistryConfig, + managerDriver: ManagerDriver, + inlineClient: Client, + ) => { + return new CloudflareActorsActorDriver( + config, + managerDriver, + inlineClient, + globalState, + ); + }; +} diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts index fc84e11003..2d97b44c88 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts @@ -1,438 +1,437 @@ -// import { DurableObject, env } from "cloudflare:workers"; -// import type { ExecutionContext } from "hono"; -// import invariant from "invariant"; -// import type { ActorKey, ActorRouter, Registry, RunConfig } from "rivetkit"; -// import { createActorRouter, createClientWithDriver } from "rivetkit"; -// import type { ActorDriver, ManagerDriver } from "rivetkit/driver-helpers"; -// import { getInitialActorKvState } from "rivetkit/driver-helpers"; -// import { stringifyError } from "rivetkit/utils"; -// import { -// ActorGlobalState, -// CloudflareDurableObjectGlobalState, -// createCloudflareActorsActorDriverBuilder, -// } from "./actor-driver"; -// import { buildActorId, parseActorId } from "./actor-id"; -// import { kvPut } from "./actor-kv"; -// import { GLOBAL_KV_KEYS } from "./global-kv"; -// import type { Bindings } from "./handler"; -// import { getCloudflareAmbientEnv } from "./handler"; -// import { logger } from "./log"; -// -// export interface ActorHandlerInterface extends DurableObject { -// create(req: ActorInitRequest): Promise; -// getMetadata(): Promise< -// | { -// actorId: string; -// name: string; -// key: ActorKey; -// destroying: boolean; -// } -// | undefined -// >; -// } -// -// export interface ActorInitRequest { -// name: string; -// key: ActorKey; -// input?: unknown; -// allowExisting: boolean; -// } -// export type ActorInitResponse = -// | { success: { actorId: string; created: boolean } } -// | { error: { actorAlreadyExists: true } }; -// -// export type DurableObjectConstructor = new ( -// ...args: ConstructorParameters> -// ) => DurableObject; -// -// export function createActorDurableObject( -// registry: Registry, -// rootRunConfig: RunConfig, -// ): DurableObjectConstructor { -// const globalState = new CloudflareDurableObjectGlobalState(); -// -// // Configure to use the runner role instead of server role -// const runConfig = Object.assign({}, rootRunConfig, { role: "runner" }); -// -// /** -// * Startup steps: -// * 1. If not already created call `initialize`, otherwise check KV to ensure it's initialized -// * 2. Load actor -// * 3. Start service requests -// */ -// return class ActorHandler -// extends DurableObject -// implements ActorHandlerInterface -// { -// /** -// * This holds a strong reference to ActorGlobalState. -// * CloudflareDurableObjectGlobalState holds a weak reference so we can -// * access it elsewhere. -// **/ -// #state: ActorGlobalState; -// -// constructor( -// ...args: ConstructorParameters> -// ) { -// super(...args); -// -// // Initialize SQL table for key-value storage -// // -// // We do this instead of using the native KV storage so we can store blob keys. The native CF KV API only supports string keys. -// this.ctx.storage.sql.exec(` -// CREATE TABLE IF NOT EXISTS _rivetkit_kv_storage( -// key BLOB PRIMARY KEY, -// value BLOB -// ); -// `); -// -// // Initialize SQL table for actor metadata -// // -// // id always equals 1 in order to ensure that there's always exactly 1 row in this table -// this.ctx.storage.sql.exec(` -// CREATE TABLE IF NOT EXISTS _rivetkit_metadata( -// id INTEGER PRIMARY KEY CHECK (id = 1), -// name TEXT NOT NULL, -// key TEXT NOT NULL, -// destroyed INTEGER DEFAULT 0, -// generation INTEGER DEFAULT 0 -// ); -// `); -// -// // Get or create the actor state from the global WeakMap -// const state = globalState.getActorState(this.ctx); -// if (state) { -// this.#state = state; -// } else { -// this.#state = new ActorGlobalState(); -// globalState.setActorState(this.ctx, this.#state); -// } -// } -// -// async #loadActor() { -// invariant(this.#state, "State should be initialized"); -// -// // Check if initialized -// if (!this.#state.initialized) { -// // Query SQL for initialization data -// const cursor = this.ctx.storage.sql.exec( -// "SELECT name, key, destroyed, generation FROM _rivetkit_metadata WHERE id = 1", -// ); -// const result = cursor.raw().next(); -// -// if (!result.done && result.value) { -// const name = result.value[0] as string; -// const key = JSON.parse( -// result.value[1] as string, -// ) as ActorKey; -// const destroyed = result.value[2] as number; -// const generation = result.value[3] as number; -// -// // Only initialize if not destroyed -// if (!destroyed) { -// logger().debug({ -// msg: "already initialized", -// name, -// key, -// generation, -// }); -// -// this.#state.initialized = { name, key, generation }; -// } else { -// logger().debug("actor is destroyed, cannot load"); -// throw new Error("Actor is destroyed"); -// } -// } else { -// logger().debug("not initialized"); -// throw new Error("Actor is not initialized"); -// } -// } -// -// // Check if already loaded -// if (this.#state.actor) { -// // Assert that the cached actor has the correct generation -// // This will catch any cases where #state.actor has a stale generation -// invariant( -// !this.#state.initialized || -// this.#state.actor.generation === -// this.#state.initialized.generation, -// `Stale actor cached: actor generation ${this.#state.actor.generation} != initialized generation ${this.#state.initialized?.generation}. This should not happen.`, -// ); -// return this.#state.actor; -// } -// -// if (!this.#state.initialized) throw new Error("Not initialized"); -// -// // Register DO with global state first -// // HACK: This leaks the DO context, but DO does not provide a native way -// // of knowing when the DO shuts down. We're making a broad assumption -// // that DO will boot a new isolate frequenlty enough that this is not an issue. -// const actorId = this.ctx.id.toString(); -// globalState.setDOState(actorId, { ctx: this.ctx, env: env }); -// -// // Configure actor driver -// invariant(runConfig.driver, "runConfig.driver"); -// runConfig.driver.actor = -// createCloudflareActorsActorDriverBuilder(globalState); -// -// // Create manager driver (we need this for the actor router) -// const managerDriver = runConfig.driver.manager( -// registry.config, -// runConfig, -// ); -// -// // Create inline client -// const inlineClient = createClientWithDriver( -// managerDriver, -// runConfig, -// ); -// -// // Create actor driver -// const actorDriver = runConfig.driver.actor( -// registry.config, -// runConfig, -// managerDriver, -// inlineClient, -// ); -// -// // Create actor router -// const actorRouter = createActorRouter( -// runConfig, -// actorDriver, -// false, -// ); -// -// // Save actor with generation -// this.#state.actor = { -// actorRouter, -// actorDriver, -// generation: this.#state.initialized.generation, -// }; -// -// // Build actor ID with generation for loading -// const actorIdWithGen = buildActorId( -// actorId, -// this.#state.initialized.generation, -// ); -// -// // Initialize the actor instance with proper metadata -// // This ensures the actor driver knows about this actor -// await actorDriver.loadActor(actorIdWithGen); -// -// return this.#state.actor; -// } -// -// /** RPC called to get actor metadata without creating it */ -// async getMetadata(): Promise< -// | { -// actorId: string; -// name: string; -// key: ActorKey; -// destroying: boolean; -// } -// | undefined -// > { -// // Query the metadata -// const cursor = this.ctx.storage.sql.exec( -// "SELECT name, key, destroyed, generation FROM _rivetkit_metadata WHERE id = 1", -// ); -// const result = cursor.raw().next(); -// -// if (!result.done && result.value) { -// const name = result.value[0] as string; -// const key = JSON.parse(result.value[1] as string) as ActorKey; -// const destroyed = result.value[2] as number; -// const generation = result.value[3] as number; -// -// // Check if destroyed -// if (destroyed) { -// logger().debug({ -// msg: "getMetadata: actor is destroyed", -// name, -// key, -// generation, -// }); -// return undefined; -// } -// -// // Build actor ID with generation -// const doId = this.ctx.id.toString(); -// const actorId = buildActorId(doId, generation); -// const destroying = -// globalState.getActorState(this.ctx)?.destroying ?? false; -// -// logger().debug({ -// msg: "getMetadata: found actor metadata", -// actorId, -// name, -// key, -// generation, -// destroying, -// }); -// -// return { actorId, name, key, destroying }; -// } -// -// logger().debug({ -// msg: "getMetadata: no metadata found", -// }); -// return undefined; -// } -// -// /** RPC called by the manager to create a DO. Can optionally allow existing actors. */ -// async create(req: ActorInitRequest): Promise { -// // Check if actor exists -// const checkCursor = this.ctx.storage.sql.exec( -// "SELECT destroyed, generation FROM _rivetkit_metadata WHERE id = 1", -// ); -// const checkResult = checkCursor.raw().next(); -// -// let created = false; -// let generation = 0; -// -// if (!checkResult.done && checkResult.value) { -// const destroyed = checkResult.value[0] as number; -// generation = checkResult.value[1] as number; -// -// if (!destroyed) { -// // Actor exists and is not destroyed -// if (!req.allowExisting) { -// // Fail if not allowing existing actors -// logger().debug({ -// msg: "create failed: actor already exists", -// name: req.name, -// key: req.key, -// generation, -// }); -// return { error: { actorAlreadyExists: true } }; -// } -// -// // Return existing actor -// logger().debug({ -// msg: "actor already exists", -// key: req.key, -// generation, -// }); -// const doId = this.ctx.id.toString(); -// const actorId = buildActorId(doId, generation); -// return { success: { actorId, created: false } }; -// } -// -// // Actor exists but is destroyed - resurrect with incremented generation -// generation = generation + 1; -// created = true; -// -// // Clear stale actor from previous generation -// // This is necessary because the DO instance may still be in memory -// // with the old #state.actor field from before the destroy -// if (this.#state) { -// this.#state.actor = undefined; -// } -// -// logger().debug({ -// msg: "resurrecting destroyed actor", -// key: req.key, -// oldGeneration: generation - 1, -// newGeneration: generation, -// }); -// } else { -// // No actor exists - will create with generation 0 -// generation = 0; -// created = true; -// logger().debug({ -// msg: "creating new actor", -// key: req.key, -// generation, -// }); -// } -// -// // Perform upsert - either inserts new or updates destroyed actor -// this.ctx.storage.sql.exec( -// `INSERT INTO _rivetkit_metadata (id, name, key, destroyed, generation) -// VALUES (1, ?, ?, 0, ?) -// ON CONFLICT(id) DO UPDATE SET -// name = excluded.name, -// key = excluded.key, -// destroyed = 0, -// generation = excluded.generation`, -// req.name, -// JSON.stringify(req.key), -// generation, -// ); -// -// this.#state.initialized = { -// name: req.name, -// key: req.key, -// generation, -// }; -// -// // Build actor ID with generation -// const doId = this.ctx.id.toString(); -// const actorId = buildActorId(doId, generation); -// -// // Initialize storage and update KV when created or resurrected -// if (created) { -// // Initialize persist data in KV storage -// initializeActorKvStorage(this.ctx.storage.sql, req.input); -// -// // Update metadata in the background -// const env = getCloudflareAmbientEnv(); -// const actorData = { name: req.name, key: req.key, generation }; -// this.ctx.waitUntil( -// env.ACTOR_KV.put( -// GLOBAL_KV_KEYS.actorMetadata(actorId), -// JSON.stringify(actorData), -// ), -// ); -// } -// -// // Preemptively load actor so the lifecycle hooks are called -// await this.#loadActor(); -// -// logger().debug({ -// msg: created -// ? "actor created/resurrected" -// : "returning existing actor", -// actorId, -// created, -// generation, -// }); -// -// return { success: { actorId, created } }; -// } -// -// async fetch(request: Request): Promise { -// const { actorRouter, generation } = await this.#loadActor(); -// -// // Build actor ID with generation -// const doId = this.ctx.id.toString(); -// const actorId = buildActorId(doId, generation); -// -// return await actorRouter.fetch(request, { -// actorId, -// }); -// } -// -// async alarm(): Promise { -// const { actorDriver, generation } = await this.#loadActor(); -// -// // Build actor ID with generation -// const doId = this.ctx.id.toString(); -// const actorId = buildActorId(doId, generation); -// -// // Load the actor instance and trigger alarm -// const actor = await actorDriver.loadActor(actorId); -// await actor.onAlarm(); -// } -// }; -// } -// -// function initializeActorKvStorage( -// sql: SqlStorage, -// input: unknown | undefined, -// ): void { -// const initialKvState = getInitialActorKvState(input); -// for (const [key, value] of initialKvState) { -// kvPut(sql, key, value); -// } -// } +import { DurableObject, env } from "cloudflare:workers"; +import type { ExecutionContext } from "hono"; +import invariant from "invariant"; +import type { ActorKey, ActorRouter, Registry, RegistryConfig } from "rivetkit"; +import { createActorRouter, createClientWithDriver } from "rivetkit"; +import type { ActorDriver, ManagerDriver } from "rivetkit/driver-helpers"; +import { getInitialActorKvState } from "rivetkit/driver-helpers"; +import type { GetUpgradeWebSocket } from "rivetkit/utils"; +import { stringifyError } from "rivetkit/utils"; +import { + ActorGlobalState, + CloudflareDurableObjectGlobalState, + createCloudflareActorsActorDriverBuilder, +} from "./actor-driver"; +import { buildActorId, parseActorId } from "./actor-id"; +import { kvGet, kvPut } from "./actor-kv"; +import { GLOBAL_KV_KEYS } from "./global-kv"; +import type { Bindings } from "./handler"; +import { getCloudflareAmbientEnv } from "./handler"; +import { logger } from "./log"; +import { CloudflareActorsManagerDriver } from "./manager-driver"; + +export interface ActorHandlerInterface extends DurableObject { + create(req: ActorInitRequest): Promise; + getMetadata(): Promise< + | { + actorId: string; + name: string; + key: ActorKey; + destroying: boolean; + } + | undefined + >; + managerKvGet(key: Uint8Array): Promise; +} + +export interface ActorInitRequest { + name: string; + key: ActorKey; + input?: unknown; + allowExisting: boolean; +} +export type ActorInitResponse = + | { success: { actorId: string; created: boolean } } + | { error: { actorAlreadyExists: true } }; + +export type DurableObjectConstructor = new ( + ...args: ConstructorParameters> +) => DurableObject; + +export function createActorDurableObject( + registry: Registry, + getUpgradeWebSocket: GetUpgradeWebSocket, +): DurableObjectConstructor { + const globalState = new CloudflareDurableObjectGlobalState(); + const parsedConfig = registry.parseConfig(); + + /** + * Startup steps: + * 1. If not already created call `initialize`, otherwise check KV to ensure it's initialized + * 2. Load actor + * 3. Start service requests + */ + return class ActorHandler + extends DurableObject + implements ActorHandlerInterface + { + /** + * This holds a strong reference to ActorGlobalState. + * CloudflareDurableObjectGlobalState holds a weak reference so we can + * access it elsewhere. + **/ + #state: ActorGlobalState; + + constructor( + ...args: ConstructorParameters> + ) { + super(...args); + + // Initialize SQL table for key-value storage + // + // We do this instead of using the native KV storage so we can store blob keys. The native CF KV API only supports string keys. + this.ctx.storage.sql.exec(` + CREATE TABLE IF NOT EXISTS _rivetkit_kv_storage( + key BLOB PRIMARY KEY, + value BLOB + ); + `); + + // Initialize SQL table for actor metadata + // + // id always equals 1 in order to ensure that there's always exactly 1 row in this table + this.ctx.storage.sql.exec(` + CREATE TABLE IF NOT EXISTS _rivetkit_metadata( + id INTEGER PRIMARY KEY CHECK (id = 1), + name TEXT NOT NULL, + key TEXT NOT NULL, + destroyed INTEGER DEFAULT 0, + generation INTEGER DEFAULT 0 + ); + `); + + // Get or create the actor state from the global WeakMap + const state = globalState.getActorState(this.ctx); + if (state) { + this.#state = state; + } else { + this.#state = new ActorGlobalState(); + globalState.setActorState(this.ctx, this.#state); + } + } + + async #loadActor() { + invariant(this.#state, "State should be initialized"); + + // Check if initialized + if (!this.#state.initialized) { + // Query SQL for initialization data + const cursor = this.ctx.storage.sql.exec( + "SELECT name, key, destroyed, generation FROM _rivetkit_metadata WHERE id = 1", + ); + const result = cursor.raw().next(); + + if (!result.done && result.value) { + const name = result.value[0] as string; + const key = JSON.parse( + result.value[1] as string, + ) as ActorKey; + const destroyed = result.value[2] as number; + const generation = result.value[3] as number; + + // Only initialize if not destroyed + if (!destroyed) { + logger().debug({ + msg: "already initialized", + name, + key, + generation, + }); + + this.#state.initialized = { name, key, generation }; + } else { + logger().debug("actor is destroyed, cannot load"); + throw new Error("Actor is destroyed"); + } + } else { + logger().debug("not initialized"); + throw new Error("Actor is not initialized"); + } + } + + // Check if already loaded + if (this.#state.actor) { + // Assert that the cached actor has the correct generation + // This will catch any cases where #state.actor has a stale generation + invariant( + !this.#state.initialized || + this.#state.actor.generation === + this.#state.initialized.generation, + `Stale actor cached: actor generation ${this.#state.actor.generation} != initialized generation ${this.#state.initialized?.generation}. This should not happen.`, + ); + return this.#state.actor; + } + + if (!this.#state.initialized) throw new Error("Not initialized"); + + // Register DO with global state first + // HACK: This leaks the DO context, but DO does not provide a native way + // of knowing when the DO shuts down. We're making a broad assumption + // that DO will boot a new isolate frequenlty enough that this is not an issue. + const actorId = this.ctx.id.toString(); + globalState.setDOState(actorId, { ctx: this.ctx, env: env }); + + // Create manager driver + const managerDriver = new CloudflareActorsManagerDriver(); + + // Create inline client + const inlineClient = createClientWithDriver(managerDriver); + + // Create actor driver builder + const actorDriverBuilder = + createCloudflareActorsActorDriverBuilder(globalState); + + // Create actor driver + const actorDriver = actorDriverBuilder( + parsedConfig, + managerDriver, + inlineClient, + ); + + // Create actor router + const actorRouter = createActorRouter( + parsedConfig, + actorDriver, + getUpgradeWebSocket, + registry.config.test?.enabled ?? false, + ); + + // Save actor with generation + this.#state.actor = { + actorRouter, + actorDriver, + generation: this.#state.initialized.generation, + }; + + // Build actor ID with generation for loading + const actorIdWithGen = buildActorId( + actorId, + this.#state.initialized.generation, + ); + + // Initialize the actor instance with proper metadata + // This ensures the actor driver knows about this actor + await actorDriver.loadActor(actorIdWithGen); + + return this.#state.actor; + } + + /** RPC called to get actor metadata without creating it */ + async getMetadata(): Promise< + | { + actorId: string; + name: string; + key: ActorKey; + destroying: boolean; + } + | undefined + > { + // Query the metadata + const cursor = this.ctx.storage.sql.exec( + "SELECT name, key, destroyed, generation FROM _rivetkit_metadata WHERE id = 1", + ); + const result = cursor.raw().next(); + + if (!result.done && result.value) { + const name = result.value[0] as string; + const key = JSON.parse(result.value[1] as string) as ActorKey; + const destroyed = result.value[2] as number; + const generation = result.value[3] as number; + + // Check if destroyed + if (destroyed) { + logger().debug({ + msg: "getMetadata: actor is destroyed", + name, + key, + generation, + }); + return undefined; + } + + // Build actor ID with generation + const doId = this.ctx.id.toString(); + const actorId = buildActorId(doId, generation); + const destroying = + globalState.getActorState(this.ctx)?.destroying ?? false; + + logger().debug({ + msg: "getMetadata: found actor metadata", + actorId, + name, + key, + generation, + destroying, + }); + + return { actorId, name, key, destroying }; + } + + logger().debug({ + msg: "getMetadata: no metadata found", + }); + return undefined; + } + + /** RPC called by ManagerDriver.kvGet to read from KV. */ + async managerKvGet(key: Uint8Array): Promise { + return kvGet(this.ctx.storage.sql, key); + } + + /** RPC called by the manager to create a DO. Can optionally allow existing actors. */ + async create(req: ActorInitRequest): Promise { + // Check if actor exists + const checkCursor = this.ctx.storage.sql.exec( + "SELECT destroyed, generation FROM _rivetkit_metadata WHERE id = 1", + ); + const checkResult = checkCursor.raw().next(); + + let created = false; + let generation = 0; + + if (!checkResult.done && checkResult.value) { + const destroyed = checkResult.value[0] as number; + generation = checkResult.value[1] as number; + + if (!destroyed) { + // Actor exists and is not destroyed + if (!req.allowExisting) { + // Fail if not allowing existing actors + logger().debug({ + msg: "create failed: actor already exists", + name: req.name, + key: req.key, + generation, + }); + return { error: { actorAlreadyExists: true } }; + } + + // Return existing actor + logger().debug({ + msg: "actor already exists", + key: req.key, + generation, + }); + const doId = this.ctx.id.toString(); + const actorId = buildActorId(doId, generation); + return { success: { actorId, created: false } }; + } + + // Actor exists but is destroyed - resurrect with incremented generation + generation = generation + 1; + created = true; + + // Clear stale actor from previous generation + // This is necessary because the DO instance may still be in memory + // with the old #state.actor field from before the destroy + if (this.#state) { + this.#state.actor = undefined; + } + + logger().debug({ + msg: "resurrecting destroyed actor", + key: req.key, + oldGeneration: generation - 1, + newGeneration: generation, + }); + } else { + // No actor exists - will create with generation 0 + generation = 0; + created = true; + logger().debug({ + msg: "creating new actor", + key: req.key, + generation, + }); + } + + // Perform upsert - either inserts new or updates destroyed actor + this.ctx.storage.sql.exec( + `INSERT INTO _rivetkit_metadata (id, name, key, destroyed, generation) + VALUES (1, ?, ?, 0, ?) + ON CONFLICT(id) DO UPDATE SET + name = excluded.name, + key = excluded.key, + destroyed = 0, + generation = excluded.generation`, + req.name, + JSON.stringify(req.key), + generation, + ); + + this.#state.initialized = { + name: req.name, + key: req.key, + generation, + }; + + // Build actor ID with generation + const doId = this.ctx.id.toString(); + const actorId = buildActorId(doId, generation); + + // Initialize storage and update KV when created or resurrected + if (created) { + // Initialize persist data in KV storage + initializeActorKvStorage(this.ctx.storage.sql, req.input); + + // Update metadata in the background + const env = getCloudflareAmbientEnv(); + const actorData = { name: req.name, key: req.key, generation }; + this.ctx.waitUntil( + env.ACTOR_KV.put( + GLOBAL_KV_KEYS.actorMetadata(actorId), + JSON.stringify(actorData), + ), + ); + } + + // Preemptively load actor so the lifecycle hooks are called + await this.#loadActor(); + + logger().debug({ + msg: created + ? "actor created/resurrected" + : "returning existing actor", + actorId, + created, + generation, + }); + + return { success: { actorId, created } }; + } + + async fetch(request: Request): Promise { + const { actorRouter, generation } = await this.#loadActor(); + + // Build actor ID with generation + const doId = this.ctx.id.toString(); + const actorId = buildActorId(doId, generation); + + return await actorRouter.fetch(request, { + actorId, + }); + } + + async alarm(): Promise { + const { actorDriver, generation } = await this.#loadActor(); + + // Build actor ID with generation + const doId = this.ctx.id.toString(); + const actorId = buildActorId(doId, generation); + + // Load the actor instance and trigger alarm + const actor = await actorDriver.loadActor(actorId); + await actor.onAlarm(); + } + }; +} + +function initializeActorKvStorage( + sql: SqlStorage, + input: unknown | undefined, +): void { + const initialKvState = getInitialActorKvState(input); + for (const [key, value] of initialKvState) { + kvPut(sql, key, value); + } +} diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-id.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-id.ts index 92b3cd5841..a4e850f48f 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-id.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-id.ts @@ -1,38 +1,38 @@ -// /** -// * Actor ID utilities for managing actor IDs with generation tracking. -// * -// * Actor IDs are formatted as: `{doId}:{generation}` -// * This allows tracking actor resurrection and preventing stale references. -// */ -// -// /** -// * Build an actor ID from a Durable Object ID and generation number. -// * @param doId The Durable Object ID -// * @param generation The generation number (increments on resurrection) -// * @returns The formatted actor ID -// */ -// export function buildActorId(doId: string, generation: number): string { -// return `${doId}:${generation}`; -// } -// -// /** -// * Parse an actor ID into its components. -// * @param actorId The actor ID to parse -// * @returns A tuple of [doId, generation] -// * @throws Error if the actor ID format is invalid -// */ -// export function parseActorId(actorId: string): [string, number] { -// const parts = actorId.split(":"); -// if (parts.length !== 2) { -// throw new Error(`Invalid actor ID format: ${actorId}`); -// } -// -// const [doId, generationStr] = parts; -// const generation = parseInt(generationStr, 10); -// -// if (Number.isNaN(generation)) { -// throw new Error(`Invalid generation number in actor ID: ${actorId}`); -// } -// -// return [doId, generation]; -// } +/** + * Actor ID utilities for managing actor IDs with generation tracking. + * + * Actor IDs are formatted as: `{doId}:{generation}` + * This allows tracking actor resurrection and preventing stale references. + */ + +/** + * Build an actor ID from a Durable Object ID and generation number. + * @param doId The Durable Object ID + * @param generation The generation number (increments on resurrection) + * @returns The formatted actor ID + */ +export function buildActorId(doId: string, generation: number): string { + return `${doId}:${generation}`; +} + +/** + * Parse an actor ID into its components. + * @param actorId The actor ID to parse + * @returns A tuple of [doId, generation] + * @throws Error if the actor ID format is invalid + */ +export function parseActorId(actorId: string): [string, number] { + const parts = actorId.split(":"); + if (parts.length !== 2) { + throw new Error(`Invalid actor ID format: ${actorId}`); + } + + const [doId, generationStr] = parts; + const generation = parseInt(generationStr, 10); + + if (Number.isNaN(generation)) { + throw new Error(`Invalid generation number in actor ID: ${actorId}`); + } + + return [doId, generation]; +} diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-kv.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-kv.ts index 730f3bc5ac..2891d823c5 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-kv.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-kv.ts @@ -1,71 +1,71 @@ -// export function kvGet(sql: SqlStorage, key: Uint8Array): Uint8Array | null { -// const cursor = sql.exec( -// "SELECT value FROM _rivetkit_kv_storage WHERE key = ?", -// key, -// ); -// const result = cursor.raw().next(); -// -// if (!result.done && result.value) { -// return toUint8Array(result.value[0]); -// } -// return null; -// } -// -// export function kvPut( -// sql: SqlStorage, -// key: Uint8Array, -// value: Uint8Array, -// ): void { -// sql.exec( -// "INSERT OR REPLACE INTO _rivetkit_kv_storage (key, value) VALUES (?, ?)", -// key, -// value, -// ); -// } -// -// export function kvDelete(sql: SqlStorage, key: Uint8Array): void { -// sql.exec("DELETE FROM _rivetkit_kv_storage WHERE key = ?", key); -// } -// -// export function kvListPrefix( -// sql: SqlStorage, -// prefix: Uint8Array, -// ): [Uint8Array, Uint8Array][] { -// const cursor = sql.exec("SELECT key, value FROM _rivetkit_kv_storage"); -// const entries: [Uint8Array, Uint8Array][] = []; -// -// for (const row of cursor.raw()) { -// const key = toUint8Array(row[0]); -// const value = toUint8Array(row[1]); -// -// // Check if key starts with prefix -// if (hasPrefix(key, prefix)) { -// entries.push([key, value]); -// } -// } -// -// return entries; -// } -// -// // Helper function to convert SqlStorageValue to Uint8Array -// function toUint8Array( -// value: string | number | ArrayBuffer | Uint8Array | null, -// ): Uint8Array { -// if (value instanceof Uint8Array) { -// return value; -// } -// if (value instanceof ArrayBuffer) { -// return new Uint8Array(value); -// } -// throw new Error( -// `Unexpected SQL value type: ${typeof value} (${value?.constructor?.name})`, -// ); -// } -// -// function hasPrefix(arr: Uint8Array, prefix: Uint8Array): boolean { -// if (prefix.length > arr.length) return false; -// for (let i = 0; i < prefix.length; i++) { -// if (arr[i] !== prefix[i]) return false; -// } -// return true; -// } +export function kvGet(sql: SqlStorage, key: Uint8Array): Uint8Array | null { + const cursor = sql.exec( + "SELECT value FROM _rivetkit_kv_storage WHERE key = ?", + key, + ); + const result = cursor.raw().next(); + + if (!result.done && result.value) { + return toUint8Array(result.value[0]); + } + return null; +} + +export function kvPut( + sql: SqlStorage, + key: Uint8Array, + value: Uint8Array, +): void { + sql.exec( + "INSERT OR REPLACE INTO _rivetkit_kv_storage (key, value) VALUES (?, ?)", + key, + value, + ); +} + +export function kvDelete(sql: SqlStorage, key: Uint8Array): void { + sql.exec("DELETE FROM _rivetkit_kv_storage WHERE key = ?", key); +} + +export function kvListPrefix( + sql: SqlStorage, + prefix: Uint8Array, +): [Uint8Array, Uint8Array][] { + const cursor = sql.exec("SELECT key, value FROM _rivetkit_kv_storage"); + const entries: [Uint8Array, Uint8Array][] = []; + + for (const row of cursor.raw()) { + const key = toUint8Array(row[0]); + const value = toUint8Array(row[1]); + + // Check if key starts with prefix + if (hasPrefix(key, prefix)) { + entries.push([key, value]); + } + } + + return entries; +} + +// Helper function to convert SqlStorageValue to Uint8Array +function toUint8Array( + value: string | number | ArrayBuffer | Uint8Array | null, +): Uint8Array { + if (value instanceof Uint8Array) { + return value; + } + if (value instanceof ArrayBuffer) { + return new Uint8Array(value); + } + throw new Error( + `Unexpected SQL value type: ${typeof value} (${value?.constructor?.name})`, + ); +} + +function hasPrefix(arr: Uint8Array, prefix: Uint8Array): boolean { + if (prefix.length > arr.length) return false; + for (let i = 0; i < prefix.length; i++) { + if (arr[i] !== prefix[i]) return false; + } + return true; +} diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/config.ts b/rivetkit-typescript/packages/cloudflare-workers/src/config.ts index d4a9f27693..8d502a6cdb 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/config.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/config.ts @@ -1,21 +1,24 @@ -// import type { Client } from "rivetkit"; -// import { RunConfigSchema } from "rivetkit/driver-helpers"; -// import { z } from "zod"; -// -// const ConfigSchemaBase = RunConfigSchema.removeDefault() -// .omit({ driver: true, getUpgradeWebSocket: true }) -// .extend({ -// /** Path that the Rivet manager API will be mounted. */ -// managerPath: z.string().optional().default("/rivet"), -// -// fetch: z -// .custom< -// ExportedHandlerFetchHandler<{ RIVET: Client }, unknown> -// >() -// .optional(), -// }); -// export const ConfigSchema = ConfigSchemaBase.default(() => -// ConfigSchemaBase.parse({}), -// ); -// export type InputConfig = z.input; -// export type Config = z.infer; +import type { Client } from "rivetkit"; +import { z } from "zod"; + +const ConfigSchemaBase = z.object({ + /** Path that the Rivet manager API will be mounted. */ + managerPath: z.string().optional().default("/api/rivet"), + + /** Runner key for authentication. */ + runnerKey: z.string().optional(), + + /** Disable the welcome message. */ + noWelcome: z.boolean().optional().default(false), + + fetch: z + .custom< + ExportedHandlerFetchHandler<{ RIVET: Client }, unknown> + >() + .optional(), +}); +export const ConfigSchema = ConfigSchemaBase.default(() => + ConfigSchemaBase.parse({}), +); +export type InputConfig = z.input; +export type Config = z.infer; diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/global-kv.ts b/rivetkit-typescript/packages/cloudflare-workers/src/global-kv.ts index d54b27e2e7..e9c205a2d4 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/global-kv.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/global-kv.ts @@ -1,6 +1,6 @@ -// /** KV keys for using Workers KV to store actor metadata globally. */ -// export const GLOBAL_KV_KEYS = { -// actorMetadata: (actorId: string): string => { -// return `actor:${actorId}:metadata`; -// }, -// }; +/** KV keys for using Workers KV to store actor metadata globally. */ +export const GLOBAL_KV_KEYS = { + actorMetadata: (actorId: string): string => { + return `actor:${actorId}:metadata`; + }, +}; diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/handler.ts b/rivetkit-typescript/packages/cloudflare-workers/src/handler.ts index ba140470dd..b26f6c5a32 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/handler.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/handler.ts @@ -1,132 +1,145 @@ -// import { env } from "cloudflare:workers"; -// import type { Client, Registry, RunConfig } from "rivetkit"; -// import { -// type ActorHandlerInterface, -// createActorDurableObject, -// type DurableObjectConstructor, -// } from "./actor-handler-do"; -// import { type Config, ConfigSchema, type InputConfig } from "./config"; -// import { CloudflareActorsManagerDriver } from "./manager-driver"; -// import { upgradeWebSocket } from "./websocket"; -// -// /** Cloudflare Workers env */ -// export interface Bindings { -// ACTOR_KV: KVNamespace; -// ACTOR_DO: DurableObjectNamespace; -// } -// -// /** -// * Stores the env for the current request. Required since some contexts like the inline client driver does not have access to the Hono context. -// * -// * Use getCloudflareAmbientEnv unless using CF_AMBIENT_ENV.run. -// */ -// export function getCloudflareAmbientEnv(): Bindings { -// return env as unknown as Bindings; -// } -// -// export interface InlineOutput> { -// /** Client to communicate with the actors. */ -// client: Client; -// -// /** Fetch handler to manually route requests to the Rivet manager API. */ -// fetch: (request: Request, ...args: any) => Response | Promise; -// -// config: Config; -// -// ActorHandler: DurableObjectConstructor; -// } -// -// export interface HandlerOutput { -// handler: ExportedHandler; -// ActorHandler: DurableObjectConstructor; -// } -// -// /** -// * Creates an inline client for accessing Rivet Actors privately without a public manager API. -// * -// * If you want to expose a public manager API, either: -// * -// * - Use `createHandler` to expose the Rivet API on `/rivet` -// * - Forward Rivet API requests to `InlineOutput::fetch` -// */ -// export function createInlineClient>( -// registry: R, -// inputConfig?: InputConfig, -// ): InlineOutput { -// // HACK: Cloudflare does not support using `crypto.randomUUID()` before start, so we pass a default value -// // -// // Runner key is not used on Cloudflare -// inputConfig = { ...inputConfig, runnerKey: "" }; -// -// // Parse config -// const config = ConfigSchema.parse(inputConfig); -// -// // Create config -// const runConfig = { -// ...config, -// noWelcome: true, -// driver: { -// name: "cloudflare-workers", -// manager: () => new CloudflareActorsManagerDriver(), -// // HACK: We can't build the actor driver until we're inside the Durable Object -// actor: undefined as any, -// }, -// getUpgradeWebSocket: () => upgradeWebSocket, -// } satisfies RunConfig; -// -// // Create Durable Object -// const ActorHandler = createActorDurableObject(registry, runConfig); -// -// // Create server -// const { client, fetch } = registry.start(runConfig); -// -// return { client, fetch, config, ActorHandler }; -// } -// -// /** -// * Creates a handler to be exported from a Cloudflare Worker. -// * -// * This will automatically expose the Rivet manager API on `/rivet`. -// * -// * This includes a `fetch` handler and `ActorHandler` Durable Object. -// */ -// export function createHandler>( -// registry: R, -// inputConfig?: InputConfig, -// ): HandlerOutput { -// const { client, fetch, config, ActorHandler } = createInlineClient( -// registry, -// inputConfig, -// ); -// -// // Create Cloudflare handler -// const handler = { -// fetch: async (request, cfEnv, ctx) => { -// const url = new URL(request.url); -// -// // Inject Rivet env -// const env = Object.assign({ RIVET: client }, cfEnv); -// -// // Mount Rivet manager API -// if (url.pathname.startsWith(config.managerPath)) { -// const strippedPath = url.pathname.substring( -// config.managerPath.length, -// ); -// url.pathname = strippedPath; -// const modifiedRequest = new Request(url.toString(), request); -// return fetch(modifiedRequest, env, ctx); -// } -// -// if (config.fetch) { -// return config.fetch(request, env, ctx); -// } else { -// return new Response( -// "This is a RivetKit server.\n\nLearn more at https://rivetkit.org\n", -// { status: 200 }, -// ); -// } -// }, -// } satisfies ExportedHandler; -// -// return { handler, ActorHandler }; -// } +import { env } from "cloudflare:workers"; +import type { Client, Registry } from "rivetkit"; +import { createClientWithDriver } from "rivetkit"; +import { buildManagerRouter } from "rivetkit/driver-helpers"; +import { + type ActorHandlerInterface, + createActorDurableObject, + type DurableObjectConstructor, +} from "./actor-handler-do"; +import { type Config, ConfigSchema, type InputConfig } from "./config"; +import { CloudflareActorsManagerDriver } from "./manager-driver"; +import { upgradeWebSocket } from "./websocket"; + +/** Cloudflare Workers env */ +export interface Bindings { + ACTOR_KV: KVNamespace; + ACTOR_DO: DurableObjectNamespace; +} + +/** + * Stores the env for the current request. Required since some contexts like the inline client driver does not have access to the Hono context. + * + * Use getCloudflareAmbientEnv unless using CF_AMBIENT_ENV.run. + */ +export function getCloudflareAmbientEnv(): Bindings { + return env as unknown as Bindings; +} + +export interface InlineOutput> { + /** Client to communicate with the actors. */ + client: Client; + + /** Fetch handler to manually route requests to the Rivet manager API. */ + fetch: (request: Request, ...args: any) => Response | Promise; + + config: Config; + + ActorHandler: DurableObjectConstructor; +} + +export interface HandlerOutput { + handler: ExportedHandler; + ActorHandler: DurableObjectConstructor; +} + +/** + * Creates an inline client for accessing Rivet Actors privately without a public manager API. + * + * If you want to expose a public manager API, either: + * + * - Use `createHandler` to expose the Rivet API on `/api/rivet` + * - Forward Rivet API requests to `InlineOutput::fetch` + */ +export function createInlineClient>( + registry: R, + inputConfig?: InputConfig, +): InlineOutput { + // HACK: Cloudflare does not support using `crypto.randomUUID()` before start, so we pass a default value + // + // Runner key is not used on Cloudflare + inputConfig = { ...inputConfig, runnerKey: "" }; + + // Parse config + const config = ConfigSchema.parse(inputConfig); + + // Create Durable Object + const ActorHandler = createActorDurableObject( + registry, + () => upgradeWebSocket, + ); + + // Configure registry for cloudflare-workers + registry.config.noWelcome = true; + // Disable inspector since it's not supported on Cloudflare Workers + registry.config.inspector = { + enabled: false, + token: () => "", + }; + // Set manager base path to "/" since the cloudflare handler strips the /api/rivet prefix + registry.config.managerBasePath = "/"; + const parsedConfig = registry.parseConfig(); + + // Create manager driver + const managerDriver = new CloudflareActorsManagerDriver(); + + // Build the manager router (has actor management endpoints like /actors) + const { router } = buildManagerRouter( + parsedConfig, + managerDriver, + () => upgradeWebSocket, + ); + + // Create client using the manager driver + const client = createClientWithDriver(managerDriver); + + return { client, fetch: router.fetch.bind(router), config, ActorHandler }; +} + +/** + * Creates a handler to be exported from a Cloudflare Worker. + * + * This will automatically expose the Rivet manager API on `/api/rivet`. + * + * This includes a `fetch` handler and `ActorHandler` Durable Object. + */ +export function createHandler>( + registry: R, + inputConfig?: InputConfig, +): HandlerOutput { + const { client, fetch, config, ActorHandler } = createInlineClient( + registry, + inputConfig, + ); + + // Create Cloudflare handler + const handler = { + fetch: async (request, cfEnv, ctx) => { + const url = new URL(request.url); + + // Inject Rivet env + const env = Object.assign({ RIVET: client }, cfEnv); + + // Mount Rivet manager API + if (url.pathname.startsWith(config.managerPath)) { + const strippedPath = url.pathname.substring( + config.managerPath.length, + ); + url.pathname = strippedPath; + const modifiedRequest = new Request(url.toString(), request); + return fetch(modifiedRequest, env, ctx); + } + + if (config.fetch) { + return config.fetch(request, env, ctx); + } else { + return new Response( + "This is a RivetKit server.\n\nLearn more at https://rivet.dev\n", + { status: 200 }, + ); + } + }, + } satisfies ExportedHandler; + + return { handler, ActorHandler }; +} diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/manager-driver.ts b/rivetkit-typescript/packages/cloudflare-workers/src/manager-driver.ts index 3c51d9a48e..d229c9d3a9 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/manager-driver.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/manager-driver.ts @@ -1,423 +1,440 @@ -// import type { Context as HonoContext } from "hono"; -// import type { Encoding, UniversalWebSocket } from "rivetkit"; -// import { -// type ActorOutput, -// type CreateInput, -// type GetForIdInput, -// type GetOrCreateWithKeyInput, -// type GetWithKeyInput, -// generateRandomString, -// type ListActorsInput, -// type ManagerDisplayInformation, -// type ManagerDriver, -// WS_PROTOCOL_ACTOR, -// WS_PROTOCOL_CONN_PARAMS, -// WS_PROTOCOL_ENCODING, -// WS_PROTOCOL_STANDARD, -// WS_PROTOCOL_TARGET, -// } from "rivetkit/driver-helpers"; -// import { -// ActorDuplicateKey, -// ActorNotFound, -// InternalError, -// } from "rivetkit/errors"; -// import { assertUnreachable } from "rivetkit/utils"; -// import { parseActorId } from "./actor-id"; -// import { getCloudflareAmbientEnv } from "./handler"; -// import { logger } from "./log"; -// import type { Bindings } from "./mod"; -// import { serializeNameAndKey } from "./util"; -// -// const STANDARD_WEBSOCKET_HEADERS = [ -// "connection", -// "upgrade", -// "sec-websocket-key", -// "sec-websocket-version", -// "sec-websocket-protocol", -// "sec-websocket-extensions", -// ]; -// -// export class CloudflareActorsManagerDriver implements ManagerDriver { -// async sendRequest( -// actorId: string, -// actorRequest: Request, -// ): Promise { -// const env = getCloudflareAmbientEnv(); -// -// // Parse actor ID to get DO ID -// const [doId] = parseActorId(actorId); -// -// logger().debug({ -// msg: "sending request to durable object", -// actorId, -// doId, -// method: actorRequest.method, -// url: actorRequest.url, -// }); -// -// const id = env.ACTOR_DO.idFromString(doId); -// const stub = env.ACTOR_DO.get(id); -// -// return await stub.fetch(actorRequest); -// } -// -// async openWebSocket( -// path: string, -// actorId: string, -// encoding: Encoding, -// params: unknown, -// ): Promise { -// const env = getCloudflareAmbientEnv(); -// -// // Parse actor ID to get DO ID -// const [doId] = parseActorId(actorId); -// -// logger().debug({ -// msg: "opening websocket to durable object", -// actorId, -// doId, -// path, -// }); -// -// // Make a fetch request to the Durable Object with WebSocket upgrade -// const id = env.ACTOR_DO.idFromString(doId); -// const stub = env.ACTOR_DO.get(id); -// -// const protocols: string[] = []; -// protocols.push(WS_PROTOCOL_STANDARD); -// protocols.push(`${WS_PROTOCOL_TARGET}actor`); -// protocols.push(`${WS_PROTOCOL_ACTOR}${encodeURIComponent(actorId)}`); -// protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); -// if (params) { -// protocols.push( -// `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, -// ); -// } -// -// const headers: Record = { -// Upgrade: "websocket", -// Connection: "Upgrade", -// "sec-websocket-protocol": protocols.join(", "), -// }; -// -// // Use the path parameter to determine the URL -// const normalizedPath = path.startsWith("/") ? path : `/${path}`; -// const url = `http://actor${normalizedPath}`; -// -// logger().debug({ msg: "rewriting websocket url", from: path, to: url }); -// -// const response = await stub.fetch(url, { -// headers, -// }); -// const webSocket = response.webSocket; -// -// if (!webSocket) { -// throw new InternalError( -// `missing websocket connection in response from DO\n\nStatus: ${response.status}\nResponse: ${await response.text()}`, -// ); -// } -// -// logger().debug({ -// msg: "durable object websocket connection open", -// actorId, -// }); -// -// webSocket.accept(); -// -// // TODO: Is this still needed? -// // HACK: Cloudflare does not call onopen automatically, so we need -// // to call this on the next tick -// setTimeout(() => { -// const event = new Event("open"); -// (webSocket as any).onopen?.(event); -// (webSocket as any).dispatchEvent(event); -// }, 0); -// -// return webSocket as unknown as UniversalWebSocket; -// } -// -// async proxyRequest( -// c: HonoContext<{ Bindings: Bindings }>, -// actorRequest: Request, -// actorId: string, -// ): Promise { -// // Parse actor ID to get DO ID -// const [doId] = parseActorId(actorId); -// -// logger().debug({ -// msg: "forwarding request to durable object", -// actorId, -// doId, -// method: actorRequest.method, -// url: actorRequest.url, -// }); -// -// const id = c.env.ACTOR_DO.idFromString(doId); -// const stub = c.env.ACTOR_DO.get(id); -// -// return await stub.fetch(actorRequest); -// } -// -// async proxyWebSocket( -// c: HonoContext<{ Bindings: Bindings }>, -// path: string, -// actorId: string, -// encoding: Encoding, -// params: unknown, -// ): Promise { -// logger().debug({ -// msg: "forwarding websocket to durable object", -// actorId, -// path, -// }); -// -// // Validate upgrade -// const upgradeHeader = c.req.header("Upgrade"); -// if (!upgradeHeader || upgradeHeader !== "websocket") { -// return new Response("Expected Upgrade: websocket", { -// status: 426, -// }); -// } -// -// const newUrl = new URL(`http://actor${path}`); -// const actorRequest = new Request(newUrl, c.req.raw); -// -// logger().debug({ -// msg: "rewriting websocket url", -// from: c.req.url, -// to: actorRequest.url, -// }); -// -// // Always build fresh request to prevent forwarding unwanted headers -// // HACK: Since we can't build a new request, we need to remove -// // non-standard headers manually -// const headerKeys: string[] = []; -// actorRequest.headers.forEach((v, k) => { -// headerKeys.push(k); -// }); -// for (const k of headerKeys) { -// if (!STANDARD_WEBSOCKET_HEADERS.includes(k)) { -// actorRequest.headers.delete(k); -// } -// } -// -// // Build protocols for WebSocket connection -// const protocols: string[] = []; -// protocols.push(WS_PROTOCOL_STANDARD); -// protocols.push(`${WS_PROTOCOL_TARGET}actor`); -// protocols.push(`${WS_PROTOCOL_ACTOR}${encodeURIComponent(actorId)}`); -// protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); -// if (params) { -// protocols.push( -// `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, -// ); -// } -// actorRequest.headers.set( -// "sec-websocket-protocol", -// protocols.join(", "), -// ); -// -// // Parse actor ID to get DO ID -// const [doId] = parseActorId(actorId); -// const id = c.env.ACTOR_DO.idFromString(doId); -// const stub = c.env.ACTOR_DO.get(id); -// -// return await stub.fetch(actorRequest); -// } -// -// async getForId({ -// c, -// actorId, -// }: GetForIdInput<{ Bindings: Bindings }>): Promise< -// ActorOutput | undefined -// > { -// const env = getCloudflareAmbientEnv(); -// -// // Parse actor ID to get DO ID and expected generation -// const [doId, expectedGeneration] = parseActorId(actorId); -// -// // Get the Durable Object stub -// const id = env.ACTOR_DO.idFromString(doId); -// const stub = env.ACTOR_DO.get(id); -// -// // Call the DO's getMetadata method -// const result = await stub.getMetadata(); -// -// if (!result) { -// logger().debug({ -// msg: "getForId: actor not found", -// actorId, -// }); -// return undefined; -// } -// -// // Check if the actor IDs match in order to check if the generation matches -// if (result.actorId !== actorId) { -// logger().debug({ -// msg: "getForId: generation mismatch", -// requestedActorId: actorId, -// actualActorId: result.actorId, -// }); -// return undefined; -// } -// -// if (result.destroying) { -// throw new ActorNotFound(actorId); -// } -// -// return { -// actorId: result.actorId, -// name: result.name, -// key: result.key, -// }; -// } -// -// async getWithKey({ -// c, -// name, -// key, -// }: GetWithKeyInput<{ Bindings: Bindings }>): Promise< -// ActorOutput | undefined -// > { -// const env = getCloudflareAmbientEnv(); -// -// logger().debug({ msg: "getWithKey: searching for actor", name, key }); -// -// // Generate deterministic ID from the name and key -// const nameKeyString = serializeNameAndKey(name, key); -// const doId = env.ACTOR_DO.idFromName(nameKeyString).toString(); -// -// // Try to get the Durable Object to see if it exists -// const id = env.ACTOR_DO.idFromString(doId); -// const stub = env.ACTOR_DO.get(id); -// -// // Check if actor exists without creating it -// const result = await stub.getMetadata(); -// -// if (result) { -// logger().debug({ -// msg: "getWithKey: found actor with matching name and key", -// actorId: result.actorId, -// name: result.name, -// key: result.key, -// }); -// return { -// actorId: result.actorId, -// name: result.name, -// key: result.key, -// }; -// } else { -// logger().debug({ -// msg: "getWithKey: no actor found with matching name and key", -// name, -// key, -// doId, -// }); -// return undefined; -// } -// } -// -// async getOrCreateWithKey({ -// c, -// name, -// key, -// input, -// }: GetOrCreateWithKeyInput<{ Bindings: Bindings }>): Promise { -// const env = getCloudflareAmbientEnv(); -// -// // Create a deterministic ID from the actor name and key -// // This ensures that actors with the same name and key will have the same ID -// const nameKeyString = serializeNameAndKey(name, key); -// const doId = env.ACTOR_DO.idFromName(nameKeyString); -// -// // Get or create actor using the Durable Object's method -// const actor = env.ACTOR_DO.get(doId); -// const result = await actor.create({ -// name, -// key, -// input, -// allowExisting: true, -// }); -// if ("success" in result) { -// const { actorId, created } = result.success; -// logger().debug({ -// msg: "getOrCreateWithKey result", -// actorId, -// name, -// key, -// created, -// }); -// -// return { -// actorId, -// name, -// key, -// }; -// } else if ("error" in result) { -// throw new Error(`Error: ${JSON.stringify(result.error)}`); -// } else { -// assertUnreachable(result); -// } -// } -// -// async createActor({ -// c, -// name, -// key, -// input, -// }: CreateInput<{ Bindings: Bindings }>): Promise { -// const env = getCloudflareAmbientEnv(); -// -// // Create a deterministic ID from the actor name and key -// // This ensures that actors with the same name and key will have the same ID -// const nameKeyString = serializeNameAndKey(name, key); -// const doId = env.ACTOR_DO.idFromName(nameKeyString); -// -// // Create actor - this will fail if it already exists -// const actor = env.ACTOR_DO.get(doId); -// const result = await actor.create({ -// name, -// key, -// input, -// allowExisting: false, -// }); -// -// if ("success" in result) { -// const { actorId } = result.success; -// return { -// actorId, -// name, -// key, -// }; -// } else if ("error" in result) { -// if (result.error.actorAlreadyExists) { -// throw new ActorDuplicateKey(name, key); -// } -// -// throw new InternalError( -// `Unknown error creating actor: ${JSON.stringify(result.error)}`, -// ); -// } else { -// assertUnreachable(result); -// } -// } -// -// async listActors({ c, name }: ListActorsInput): Promise { -// logger().warn({ -// msg: "listActors not fully implemented for Cloudflare Workers", -// name, -// }); -// return []; -// } -// -// displayInformation(): ManagerDisplayInformation { -// return { -// name: "Cloudflare Workers", -// properties: {}, -// }; -// } -// -// getOrCreateInspectorAccessToken() { -// return generateRandomString(); -// } -// } +import type { Hono, Context as HonoContext } from "hono"; +import type { Encoding, RegistryConfig, UniversalWebSocket } from "rivetkit"; +import { + type ActorOutput, + type CreateInput, + type GetForIdInput, + type GetOrCreateWithKeyInput, + type GetWithKeyInput, + type ListActorsInput, + type ManagerDisplayInformation, + type ManagerDriver, + WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, + WS_PROTOCOL_STANDARD, + WS_PROTOCOL_TARGET, +} from "rivetkit/driver-helpers"; +import { + ActorDuplicateKey, + ActorNotFound, + InternalError, +} from "rivetkit/errors"; +import { assertUnreachable } from "rivetkit/utils"; +import { parseActorId } from "./actor-id"; +import { getCloudflareAmbientEnv } from "./handler"; +import { logger } from "./log"; +import type { Bindings } from "./mod"; +import { serializeNameAndKey } from "./util"; + +const STANDARD_WEBSOCKET_HEADERS = [ + "connection", + "upgrade", + "sec-websocket-key", + "sec-websocket-version", + "sec-websocket-protocol", + "sec-websocket-extensions", +]; + +export class CloudflareActorsManagerDriver implements ManagerDriver { + async sendRequest( + actorId: string, + actorRequest: Request, + ): Promise { + const env = getCloudflareAmbientEnv(); + + // Parse actor ID to get DO ID + const [doId] = parseActorId(actorId); + + logger().debug({ + msg: "sending request to durable object", + actorId, + doId, + method: actorRequest.method, + url: actorRequest.url, + }); + + const id = env.ACTOR_DO.idFromString(doId); + const stub = env.ACTOR_DO.get(id); + + return await stub.fetch(actorRequest); + } + + async openWebSocket( + path: string, + actorId: string, + encoding: Encoding, + params: unknown, + ): Promise { + const env = getCloudflareAmbientEnv(); + + // Parse actor ID to get DO ID + const [doId] = parseActorId(actorId); + + logger().debug({ + msg: "opening websocket to durable object", + actorId, + doId, + path, + }); + + // Make a fetch request to the Durable Object with WebSocket upgrade + const id = env.ACTOR_DO.idFromString(doId); + const stub = env.ACTOR_DO.get(id); + + const protocols: string[] = []; + protocols.push(WS_PROTOCOL_STANDARD); + protocols.push(`${WS_PROTOCOL_TARGET}actor`); + protocols.push(`${WS_PROTOCOL_ACTOR}${encodeURIComponent(actorId)}`); + protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); + if (params) { + protocols.push( + `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, + ); + } + + const headers: Record = { + Upgrade: "websocket", + Connection: "Upgrade", + "sec-websocket-protocol": protocols.join(", "), + }; + + // Use the path parameter to determine the URL + const normalizedPath = path.startsWith("/") ? path : `/${path}`; + const url = `http://actor${normalizedPath}`; + + logger().debug({ msg: "rewriting websocket url", from: path, to: url }); + + const response = await stub.fetch(url, { + headers, + }); + const webSocket = response.webSocket; + + if (!webSocket) { + throw new InternalError( + `missing websocket connection in response from DO\n\nStatus: ${response.status}\nResponse: ${await response.text()}`, + ); + } + + logger().debug({ + msg: "durable object websocket connection open", + actorId, + }); + + webSocket.accept(); + + // TODO: Is this still needed? + // HACK: Cloudflare does not call onopen automatically, so we need + // to call this on the next tick + setTimeout(() => { + const event = new Event("open"); + (webSocket as any).onopen?.(event); + (webSocket as any).dispatchEvent(event); + }, 0); + + return webSocket as unknown as UniversalWebSocket; + } + + async proxyRequest( + c: HonoContext<{ Bindings: Bindings }>, + actorRequest: Request, + actorId: string, + ): Promise { + const env = getCloudflareAmbientEnv(); + + // Parse actor ID to get DO ID + const [doId] = parseActorId(actorId); + + logger().debug({ + msg: "forwarding request to durable object", + actorId, + doId, + method: actorRequest.method, + url: actorRequest.url, + }); + + const id = env.ACTOR_DO.idFromString(doId); + const stub = env.ACTOR_DO.get(id); + + return await stub.fetch(actorRequest); + } + + async proxyWebSocket( + c: HonoContext<{ Bindings: Bindings }>, + path: string, + actorId: string, + encoding: Encoding, + params: unknown, + ): Promise { + logger().debug({ + msg: "forwarding websocket to durable object", + actorId, + path, + }); + + // Validate upgrade + const upgradeHeader = c.req.header("Upgrade"); + if (!upgradeHeader || upgradeHeader !== "websocket") { + return new Response("Expected Upgrade: websocket", { + status: 426, + }); + } + + const newUrl = new URL(`http://actor${path}`); + const actorRequest = new Request(newUrl, c.req.raw); + + logger().debug({ + msg: "rewriting websocket url", + from: c.req.url, + to: actorRequest.url, + }); + + // Always build fresh request to prevent forwarding unwanted headers + // HACK: Since we can't build a new request, we need to remove + // non-standard headers manually + const headerKeys: string[] = []; + actorRequest.headers.forEach((v, k) => { + headerKeys.push(k); + }); + for (const k of headerKeys) { + if (!STANDARD_WEBSOCKET_HEADERS.includes(k)) { + actorRequest.headers.delete(k); + } + } + + // Build protocols for WebSocket connection + const protocols: string[] = []; + protocols.push(WS_PROTOCOL_STANDARD); + protocols.push(`${WS_PROTOCOL_TARGET}actor`); + protocols.push(`${WS_PROTOCOL_ACTOR}${encodeURIComponent(actorId)}`); + protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); + if (params) { + protocols.push( + `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, + ); + } + actorRequest.headers.set( + "sec-websocket-protocol", + protocols.join(", "), + ); + + // Parse actor ID to get DO ID + const env = getCloudflareAmbientEnv(); + const [doId] = parseActorId(actorId); + const id = env.ACTOR_DO.idFromString(doId); + const stub = env.ACTOR_DO.get(id); + + return await stub.fetch(actorRequest); + } + + async getForId({ + c, + name, + actorId, + }: GetForIdInput<{ Bindings: Bindings }>): Promise< + ActorOutput | undefined + > { + const env = getCloudflareAmbientEnv(); + + // Parse actor ID to get DO ID and expected generation + const [doId, expectedGeneration] = parseActorId(actorId); + + // Get the Durable Object stub + const id = env.ACTOR_DO.idFromString(doId); + const stub = env.ACTOR_DO.get(id); + + // Call the DO's getMetadata method + const result = await stub.getMetadata(); + + if (!result) { + logger().debug({ + msg: "getForId: actor not found", + actorId, + }); + return undefined; + } + + // Check if the actor IDs match in order to check if the generation matches + if (result.actorId !== actorId) { + logger().debug({ + msg: "getForId: generation mismatch", + requestedActorId: actorId, + actualActorId: result.actorId, + }); + return undefined; + } + + if (result.destroying) { + throw new ActorNotFound(actorId); + } + + return { + actorId: result.actorId, + name: result.name, + key: result.key, + }; + } + + async getWithKey({ + c, + name, + key, + }: GetWithKeyInput<{ Bindings: Bindings }>): Promise< + ActorOutput | undefined + > { + const env = getCloudflareAmbientEnv(); + + logger().debug({ msg: "getWithKey: searching for actor", name, key }); + + // Generate deterministic ID from the name and key + const nameKeyString = serializeNameAndKey(name, key); + const doId = env.ACTOR_DO.idFromName(nameKeyString).toString(); + + // Try to get the Durable Object to see if it exists + const id = env.ACTOR_DO.idFromString(doId); + const stub = env.ACTOR_DO.get(id); + + // Check if actor exists without creating it + const result = await stub.getMetadata(); + + if (result) { + logger().debug({ + msg: "getWithKey: found actor with matching name and key", + actorId: result.actorId, + name: result.name, + key: result.key, + }); + return { + actorId: result.actorId, + name: result.name, + key: result.key, + }; + } else { + logger().debug({ + msg: "getWithKey: no actor found with matching name and key", + name, + key, + doId, + }); + return undefined; + } + } + + async getOrCreateWithKey({ + c, + name, + key, + input, + }: GetOrCreateWithKeyInput<{ Bindings: Bindings }>): Promise { + const env = getCloudflareAmbientEnv(); + + // Create a deterministic ID from the actor name and key + // This ensures that actors with the same name and key will have the same ID + const nameKeyString = serializeNameAndKey(name, key); + const doId = env.ACTOR_DO.idFromName(nameKeyString); + + // Get or create actor using the Durable Object's method + const actor = env.ACTOR_DO.get(doId); + const result = await actor.create({ + name, + key, + input, + allowExisting: true, + }); + if ("success" in result) { + const { actorId, created } = result.success; + logger().debug({ + msg: "getOrCreateWithKey result", + actorId, + name, + key, + created, + }); + + return { + actorId, + name, + key, + }; + } else if ("error" in result) { + throw new Error(`Error: ${JSON.stringify(result.error)}`); + } else { + assertUnreachable(result); + } + } + + async createActor({ + c, + name, + key, + input, + }: CreateInput<{ Bindings: Bindings }>): Promise { + const env = getCloudflareAmbientEnv(); + + // Create a deterministic ID from the actor name and key + // This ensures that actors with the same name and key will have the same ID + const nameKeyString = serializeNameAndKey(name, key); + const doId = env.ACTOR_DO.idFromName(nameKeyString); + + // Create actor - this will fail if it already exists + const actor = env.ACTOR_DO.get(doId); + const result = await actor.create({ + name, + key, + input, + allowExisting: false, + }); + + if ("success" in result) { + const { actorId } = result.success; + return { + actorId, + name, + key, + }; + } else if ("error" in result) { + if (result.error.actorAlreadyExists) { + throw new ActorDuplicateKey(name, key); + } + + throw new InternalError( + `Unknown error creating actor: ${JSON.stringify(result.error)}`, + ); + } else { + assertUnreachable(result); + } + } + + async listActors({ c, name }: ListActorsInput): Promise { + logger().warn({ + msg: "listActors not fully implemented for Cloudflare Workers", + name, + }); + return []; + } + + displayInformation(): ManagerDisplayInformation { + return { + properties: { + Driver: "Cloudflare Workers", + }, + }; + } + + setGetUpgradeWebSocket(): void { + // No-op for Cloudflare Workers - WebSocket upgrades are handled by the DO + } + + async kvGet(actorId: string, key: Uint8Array): Promise { + const env = getCloudflareAmbientEnv(); + + // Parse actor ID to get DO ID + const [doId] = parseActorId(actorId); + + const id = env.ACTOR_DO.idFromString(doId); + const stub = env.ACTOR_DO.get(id); + + const value = await stub.managerKvGet(key); + return value !== null ? new TextDecoder().decode(value) : null; + } +} diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/mod.ts b/rivetkit-typescript/packages/cloudflare-workers/src/mod.ts index 082b9bd078..40129e55fe 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/mod.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/mod.ts @@ -1,11 +1,11 @@ -// export type { Client } from "rivetkit"; -// export type { DriverContext } from "./actor-driver"; -// export { createActorDurableObject } from "./actor-handler-do"; -// export type { InputConfig as Config } from "./config"; -// export { -// type Bindings, -// createHandler, -// createInlineClient, -// HandlerOutput, -// InlineOutput, -// } from "./handler"; +export type { Client } from "rivetkit"; +export type { DriverContext } from "./actor-driver"; +export { createActorDurableObject } from "./actor-handler-do"; +export type { InputConfig as Config } from "./config"; +export { + type Bindings, + createHandler, + createInlineClient, + HandlerOutput, + InlineOutput, +} from "./handler"; diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/util.ts b/rivetkit-typescript/packages/cloudflare-workers/src/util.ts index 602bf72ef2..27d7236d79 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/util.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/util.ts @@ -1,104 +1,104 @@ -// // Constants for key handling -// export const EMPTY_KEY = "(none)"; -// export const KEY_SEPARATOR = ","; -// -// /** -// * Serializes an array of key strings into a single string for use with idFromName -// * -// * @param name The actor name -// * @param key Array of key strings to serialize -// * @returns A single string containing the serialized name and key -// */ -// export function serializeNameAndKey(name: string, key: string[]): string { -// // Escape colons in the name -// const escapedName = name.replace(/:/g, "\\:"); -// -// // For empty keys, just return the name and a marker -// if (key.length === 0) { -// return `${escapedName}:${EMPTY_KEY}`; -// } -// -// // Serialize the key array -// const serializedKey = serializeKey(key); -// -// // Combine name and serialized key -// return `${escapedName}:${serializedKey}`; -// } -// -// /** -// * Serializes an array of key strings into a single string -// * -// * @param key Array of key strings to serialize -// * @returns A single string containing the serialized key -// */ -// export function serializeKey(key: string[]): string { -// // Use a special marker for empty key arrays -// if (key.length === 0) { -// return EMPTY_KEY; -// } -// -// // Escape each key part to handle the separator and the empty key marker -// const escapedParts = key.map((part) => { -// // First check if it matches our empty key marker -// if (part === EMPTY_KEY) { -// return `\\${EMPTY_KEY}`; -// } -// -// // Escape backslashes first, then commas -// let escaped = part.replace(/\\/g, "\\\\"); -// escaped = escaped.replace(/,/g, "\\,"); -// return escaped; -// }); -// -// return escapedParts.join(KEY_SEPARATOR); -// } -// -// /** -// * Deserializes a key string back into an array of key strings -// * -// * @param keyString The serialized key string -// * @returns Array of key strings -// */ -// export function deserializeKey(keyString: string): string[] { -// // Handle empty values -// if (!keyString) { -// return []; -// } -// -// // Check for special empty key marker -// if (keyString === EMPTY_KEY) { -// return []; -// } -// -// // Split by unescaped commas and unescape the escaped characters -// const parts: string[] = []; -// let currentPart = ""; -// let escaping = false; -// -// for (let i = 0; i < keyString.length; i++) { -// const char = keyString[i]; -// -// if (escaping) { -// // This is an escaped character, add it directly -// currentPart += char; -// escaping = false; -// } else if (char === "\\") { -// // Start of an escape sequence -// escaping = true; -// } else if (char === KEY_SEPARATOR) { -// // This is a separator -// parts.push(currentPart); -// currentPart = ""; -// } else { -// // Regular character -// currentPart += char; -// } -// } -// -// // Add the last part if it exists -// if (currentPart || parts.length > 0) { -// parts.push(currentPart); -// } -// -// return parts; -// } +// Constants for key handling +export const EMPTY_KEY = "(none)"; +export const KEY_SEPARATOR = ","; + +/** + * Serializes an array of key strings into a single string for use with idFromName + * + * @param name The actor name + * @param key Array of key strings to serialize + * @returns A single string containing the serialized name and key + */ +export function serializeNameAndKey(name: string, key: string[]): string { + // Escape colons in the name + const escapedName = name.replace(/:/g, "\\:"); + + // For empty keys, just return the name and a marker + if (key.length === 0) { + return `${escapedName}:${EMPTY_KEY}`; + } + + // Serialize the key array + const serializedKey = serializeKey(key); + + // Combine name and serialized key + return `${escapedName}:${serializedKey}`; +} + +/** + * Serializes an array of key strings into a single string + * + * @param key Array of key strings to serialize + * @returns A single string containing the serialized key + */ +export function serializeKey(key: string[]): string { + // Use a special marker for empty key arrays + if (key.length === 0) { + return EMPTY_KEY; + } + + // Escape each key part to handle the separator and the empty key marker + const escapedParts = key.map((part) => { + // First check if it matches our empty key marker + if (part === EMPTY_KEY) { + return `\\${EMPTY_KEY}`; + } + + // Escape backslashes first, then commas + let escaped = part.replace(/\\/g, "\\\\"); + escaped = escaped.replace(/,/g, "\\,"); + return escaped; + }); + + return escapedParts.join(KEY_SEPARATOR); +} + +/** + * Deserializes a key string back into an array of key strings + * + * @param keyString The serialized key string + * @returns Array of key strings + */ +export function deserializeKey(keyString: string): string[] { + // Handle empty values + if (!keyString) { + return []; + } + + // Check for special empty key marker + if (keyString === EMPTY_KEY) { + return []; + } + + // Split by unescaped commas and unescape the escaped characters + const parts: string[] = []; + let currentPart = ""; + let escaping = false; + + for (let i = 0; i < keyString.length; i++) { + const char = keyString[i]; + + if (escaping) { + // This is an escaped character, add it directly + currentPart += char; + escaping = false; + } else if (char === "\\") { + // Start of an escape sequence + escaping = true; + } else if (char === KEY_SEPARATOR) { + // This is a separator + parts.push(currentPart); + currentPart = ""; + } else { + // Regular character + currentPart += char; + } + } + + // Add the last part if it exists + if (currentPart || parts.length > 0) { + parts.push(currentPart); + } + + return parts; +} diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/websocket.ts b/rivetkit-typescript/packages/cloudflare-workers/src/websocket.ts index 29bd3a77b2..39bbd68e0a 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/websocket.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/websocket.ts @@ -1,81 +1,81 @@ -// // Modified from https://github.com/honojs/hono/blob/40ea0eee58e39b31053a0246c595434f1094ad31/src/adapter/cloudflare-workers/websocket.ts#L17 -// // -// // This version calls the open event by default +// Modified from https://github.com/honojs/hono/blob/40ea0eee58e39b31053a0246c595434f1094ad31/src/adapter/cloudflare-workers/websocket.ts#L17 // -// import type { UpgradeWebSocket, WSEvents, WSReadyState } from "hono/ws"; -// import { defineWebSocketHelper, WSContext } from "hono/ws"; -// import { WS_PROTOCOL_STANDARD } from "rivetkit/driver-helpers"; -// -// // Based on https://github.com/honojs/hono/issues/1153#issuecomment-1767321332 -// export const upgradeWebSocket: UpgradeWebSocket< -// WebSocket, -// // eslint-disable-next-line @typescript-eslint/no-explicit-any -// any, -// WSEvents -// > = defineWebSocketHelper(async (c, events) => { -// const upgradeHeader = c.req.header("Upgrade"); -// if (upgradeHeader !== "websocket") { -// return; -// } -// -// const webSocketPair = new WebSocketPair(); -// const client: WebSocket = webSocketPair[0]; -// const server: WebSocket = webSocketPair[1]; -// -// const wsContext = new WSContext({ -// close: (code, reason) => server.close(code, reason), -// get protocol() { -// return server.protocol; -// }, -// raw: server, -// get readyState() { -// return server.readyState as WSReadyState; -// }, -// url: server.url ? new URL(server.url) : null, -// send: (source) => server.send(source), -// }); -// -// if (events.onClose) { -// server.addEventListener("close", (evt: CloseEvent) => -// events.onClose?.(evt, wsContext), -// ); -// } -// if (events.onMessage) { -// server.addEventListener("message", (evt: MessageEvent) => -// events.onMessage?.(evt, wsContext), -// ); -// } -// if (events.onError) { -// server.addEventListener("error", (evt: Event) => -// events.onError?.(evt, wsContext), -// ); -// } -// -// server.accept?.(); -// -// // note: cloudflare actors doesn't support 'open' event, so we call it immediately with a fake event -// // -// // we have to do this after `server.accept() is called` -// events.onOpen?.(new Event("open"), wsContext); -// -// // Build response headers -// const headers: Record = {}; -// -// // Set Sec-WebSocket-Protocol if does not exist -// const protocols = c.req.header("Sec-WebSocket-Protocol"); -// if ( -// typeof protocols === "string" && -// protocols -// .split(",") -// .map((x) => x.trim()) -// .includes(WS_PROTOCOL_STANDARD) -// ) { -// headers["Sec-WebSocket-Protocol"] = WS_PROTOCOL_STANDARD; -// } -// -// return new Response(null, { -// status: 101, -// headers, -// webSocket: client, -// }); -// }); +// This version calls the open event by default + +import type { UpgradeWebSocket, WSEvents, WSReadyState } from "hono/ws"; +import { defineWebSocketHelper, WSContext } from "hono/ws"; +import { WS_PROTOCOL_STANDARD } from "rivetkit/driver-helpers"; + +// Based on https://github.com/honojs/hono/issues/1153#issuecomment-1767321332 +export const upgradeWebSocket: UpgradeWebSocket< + WebSocket, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + any, + WSEvents +> = defineWebSocketHelper(async (c, events) => { + const upgradeHeader = c.req.header("Upgrade"); + if (upgradeHeader !== "websocket") { + return; + } + + const webSocketPair = new WebSocketPair(); + const client: WebSocket = webSocketPair[0]; + const server: WebSocket = webSocketPair[1]; + + const wsContext = new WSContext({ + close: (code, reason) => server.close(code, reason), + get protocol() { + return server.protocol; + }, + raw: server, + get readyState() { + return server.readyState as WSReadyState; + }, + url: server.url ? new URL(server.url) : null, + send: (source) => server.send(source), + }); + + if (events.onClose) { + server.addEventListener("close", (evt: CloseEvent) => + events.onClose?.(evt, wsContext), + ); + } + if (events.onMessage) { + server.addEventListener("message", (evt: MessageEvent) => + events.onMessage?.(evt, wsContext), + ); + } + if (events.onError) { + server.addEventListener("error", (evt: Event) => + events.onError?.(evt, wsContext), + ); + } + + server.accept?.(); + + // note: cloudflare actors doesn't support 'open' event, so we call it immediately with a fake event + // + // we have to do this after `server.accept() is called` + events.onOpen?.(new Event("open"), wsContext); + + // Build response headers + const headers: Record = {}; + + // Set Sec-WebSocket-Protocol if does not exist + const protocols = c.req.header("Sec-WebSocket-Protocol"); + if ( + typeof protocols === "string" && + protocols + .split(",") + .map((x) => x.trim()) + .includes(WS_PROTOCOL_STANDARD) + ) { + headers["Sec-WebSocket-Protocol"] = WS_PROTOCOL_STANDARD; + } + + return new Response(null, { + status: 101, + headers, + webSocket: client, + }); +}); diff --git a/rivetkit-typescript/packages/next-js/package.json b/rivetkit-typescript/packages/next-js/package.json index e45893a13e..b31ebce182 100644 --- a/rivetkit-typescript/packages/next-js/package.json +++ b/rivetkit-typescript/packages/next-js/package.json @@ -53,8 +53,8 @@ }, "dependencies": { "@rivetkit/react": "workspace:*", - "rivetkit": "workspace:^", - "hono": "^4.8.3" + "hono": "^4.8.3", + "rivetkit": "workspace:^" }, "peerDependencies": { "react": "^18 || ^19", diff --git a/rivetkit-typescript/packages/next-js/src/mod.ts b/rivetkit-typescript/packages/next-js/src/mod.ts index cb7c835f15..aed75ac4e7 100644 --- a/rivetkit-typescript/packages/next-js/src/mod.ts +++ b/rivetkit-typescript/packages/next-js/src/mod.ts @@ -1,232 +1,230 @@ -// import { existsSync, statSync } from "node:fs"; -// import { join } from "node:path"; -// import type { Registry, RunConfigInput } from "rivetkit"; -// import { stringifyError } from "rivetkit/utils"; -// import { logger } from "./log"; -// -// export const toNextHandler = ( -// registry: Registry, -// inputConfig: RunConfigInput = {}, -// ) => { -// // Don't run server locally since we're using the fetch handler directly -// inputConfig.disableDefaultServer = true; -// -// // Configure serverless -// inputConfig.runnerKind = "serverless"; -// -// if (process.env.NODE_ENV !== "production") { -// // Auto-configure serverless runner if not in prod -// logger().debug( -// "detected development environment, auto-starting engine and auto-configuring serverless", -// ); -// -// const publicUrl = -// process.env.NEXT_PUBLIC_SITE_URL ?? -// process.env.NEXT_PUBLIC_VERCEL_URL ?? -// `http://127.0.0.1:${process.env.PORT ?? 3000}`; -// -// inputConfig.runEngine = true; -// inputConfig.autoConfigureServerless = { -// url: `${publicUrl}/api/rivet`, -// minRunners: 0, -// maxRunners: 100_000, -// requestLifespan: 300, -// slotsPerRunner: 1, -// metadata: { provider: "next-js" }, -// }; -// } else { -// logger().debug( -// "detected production environment, will not auto-start engine and auto-configure serverless", -// ); -// } -// -// // Next logs this on every request -// inputConfig.noWelcome = true; -// -// const { fetch } = registry.start(inputConfig); -// -// // Function that Next will call when handling requests -// const fetchWrapper = async ( -// request: Request, -// { params }: { params: Promise<{ all: string[] }> }, -// ): Promise => { -// const { all } = await params; -// -// const newUrl = new URL(request.url); -// newUrl.pathname = all.join("/"); -// -// if (process.env.NODE_ENV !== "development") { -// // Handle request -// const newReq = new Request(newUrl, request); -// return await fetch(newReq); -// } else { -// // Special request handling for file watching -// return await handleRequestWithFileWatcher(request, newUrl, fetch); -// } -// }; -// -// return { -// GET: fetchWrapper, -// POST: fetchWrapper, -// PUT: fetchWrapper, -// PATCH: fetchWrapper, -// HEAD: fetchWrapper, -// OPTIONS: fetchWrapper, -// }; -// }; -// -// /** -// * Special request handler that will watch the source file to terminate this -// * request once complete. -// * -// * See docs on watchRouteFile for more information. -// */ -// async function handleRequestWithFileWatcher( -// request: Request, -// newUrl: URL, -// fetch: (request: Request, ...args: any) => Response | Promise, -// ): Promise { -// // Create a new abort controller that we can abort, since the signal on -// // the request we cannot control -// const mergedController = new AbortController(); -// const abortMerged = () => mergedController.abort(); -// request.signal?.addEventListener("abort", abortMerged); -// -// // Watch for file changes in dev -// // -// // We spawn one watcher per-request since there is not a clean way of -// // cleaning up global watchers when hot reloading in Next -// const watchIntervalId = watchRouteFile(mergedController); -// -// // Clear interval if request is aborted -// request.signal.addEventListener("abort", () => { -// logger().debug("clearing file watcher interval: request aborted"); -// clearInterval(watchIntervalId); -// }); -// -// // Replace URL and abort signal -// const newReq = new Request(newUrl, { -// // Copy old request properties -// method: request.method, -// headers: request.headers, -// body: request.body, -// credentials: request.credentials, -// cache: request.cache, -// redirect: request.redirect, -// referrer: request.referrer, -// integrity: request.integrity, -// // Override with new signal -// signal: mergedController.signal, -// // Required for streaming body -// duplex: "half", -// } as RequestInit); -// -// // Handle request -// const response = await fetch(newReq); -// -// // HACK: Next.js does not provide a way to detect when a request -// // finishes, so we need to tap the response stream -// // -// // We can't just wait for `await fetch` to finish since SSE streams run -// // for longer -// if (response.body) { -// const wrappedStream = waitForStreamFinish(response.body, () => { -// logger().debug("clearing file watcher interval: stream finished"); -// clearInterval(watchIntervalId); -// }); -// return new Response(wrappedStream, { -// status: response.status, -// statusText: response.statusText, -// headers: response.headers, -// }); -// } else { -// // No response body, clear interval immediately -// logger().debug("clearing file watcher interval: no response body"); -// clearInterval(watchIntervalId); -// return response; -// } -// } -// -// /** -// * HACK: Watch for file changes on this route in order to shut down the runner. -// * We do this because Next.js does not terminate long-running requests on file -// * change, so we need to manually shut down the runner in order to trigger a -// * new `/start` request with the new code. -// * -// * We don't use file watchers since those are frequently buggy x-platform and -// * subject to misconfigured inotify limits. -// */ -// function watchRouteFile(abortController: AbortController): NodeJS.Timeout { -// logger().debug("starting file watcher"); -// -// const routePath = join( -// process.cwd(), -// ".next/server/app/api/rivet/[...all]/route.js", -// ); -// -// let lastMtime: number | null = null; -// const checkFile = () => { -// logger().debug({ msg: "checking for file changes", routePath }); -// try { -// if (!existsSync(routePath)) { -// return; -// } -// -// const stats = statSync(routePath); -// const mtime = stats.mtimeMs; -// -// if (lastMtime !== null && mtime !== lastMtime) { -// logger().info({ msg: "route file changed", routePath }); -// abortController.abort(); -// } -// -// lastMtime = mtime; -// } catch (err) { -// logger().info({ -// msg: "failed to check for route file change", -// err: stringifyError(err), -// }); -// } -// }; -// -// checkFile(); -// -// return setInterval(checkFile, 1000); -// } -// -// /** -// * Waits for a stream to finish and calls onFinish on complete. -// * -// * Used for cancelling the file watcher. -// */ -// function waitForStreamFinish( -// body: ReadableStream, -// onFinish: () => void, -// ): ReadableStream { -// const reader = body.getReader(); -// return new ReadableStream({ -// async start(controller) { -// try { -// while (true) { -// const { done, value } = await reader.read(); -// if (done) { -// logger().debug("stream completed"); -// onFinish(); -// controller.close(); -// break; -// } -// controller.enqueue(value); -// } -// } catch (err) { -// logger().debug("stream errored"); -// onFinish(); -// controller.error(err); -// } -// }, -// cancel() { -// logger().debug("stream cancelled"); -// onFinish(); -// reader.cancel(); -// }, -// }); -// } +import { existsSync, statSync } from "node:fs"; +import { join } from "node:path"; +import type { Registry } from "rivetkit"; +import { stringifyError } from "rivetkit/utils"; +import { logger } from "./log"; + +export const toNextHandler = (registry: Registry) => { + // Don't run server locally since we're using the fetch handler directly + registry.config.serveManager = false; + + // Set basePath to "/" since Next.js route strips the /api/rivet prefix + registry.config.serverless = { ...registry.config.serverless, basePath: "/" }; + + if (process.env.NODE_ENV !== "production") { + // Auto-configure serverless runner if not in prod + logger().debug( + "detected development environment, auto-starting engine and auto-configuring serverless", + ); + + const publicUrl = + process.env.NEXT_PUBLIC_SITE_URL ?? + process.env.NEXT_PUBLIC_VERCEL_URL ?? + `http://127.0.0.1:${process.env.PORT ?? 3000}`; + + // Set these on the registry's config directly since the legacy inputConfig + // isn't used by the serverless router + registry.config.serverless.spawnEngine = true; + registry.config.serverless.configureRunnerPool = { + url: `${publicUrl}/api/rivet`, + minRunners: 0, + maxRunners: 100_000, + requestLifespan: 300, + slotsPerRunner: 1, + metadata: { provider: "next-js" }, + }; + } else { + logger().debug( + "detected production environment, will not auto-start engine and auto-configure serverless", + ); + } + + // Next logs this on every request + registry.config.noWelcome = true; + + // Function that Next will call when handling requests + const fetchWrapper = async ( + request: Request, + { params }: { params: Promise<{ all: string[] }> }, + ): Promise => { + const { all } = await params; + + const newUrl = new URL(request.url); + newUrl.pathname = `/${all.join("/")}`; + + // if (process.env.NODE_ENV === "development") { + if (false) { + // Special request handling for file watching + return await handleRequestWithFileWatcher(request, newUrl, fetch); + } else { + // Handle request + const newReq = new Request(newUrl, request); + return await registry.handler(newReq); + } + }; + + return { + GET: fetchWrapper, + POST: fetchWrapper, + PUT: fetchWrapper, + PATCH: fetchWrapper, + HEAD: fetchWrapper, + OPTIONS: fetchWrapper, + }; +}; + +/** + * Special request handler that will watch the source file to terminate this + * request once complete. + * + * See docs on watchRouteFile for more information. + */ +async function handleRequestWithFileWatcher( + request: Request, + newUrl: URL, + fetch: (request: Request, ...args: any) => Response | Promise, +): Promise { + // Create a new abort controller that we can abort, since the signal on + // the request we cannot control + const mergedController = new AbortController(); + const abortMerged = () => mergedController.abort(); + request.signal?.addEventListener("abort", abortMerged); + + // Watch for file changes in dev + // + // We spawn one watcher per-request since there is not a clean way of + // cleaning up global watchers when hot reloading in Next + const watchIntervalId = watchRouteFile(mergedController); + + // Clear interval if request is aborted + request.signal.addEventListener("abort", () => { + logger().debug("clearing file watcher interval: request aborted"); + clearInterval(watchIntervalId); + }); + + // Replace URL and abort signal + const newReq = new Request(newUrl, { + // Copy old request properties + method: request.method, + headers: request.headers, + body: request.body, + credentials: request.credentials, + cache: request.cache, + redirect: request.redirect, + referrer: request.referrer, + integrity: request.integrity, + // Override with new signal + signal: mergedController.signal, + // Required for streaming body + duplex: "half", + } as RequestInit); + + // Handle request + const response = await fetch(newReq); + + // HACK: Next.js does not provide a way to detect when a request + // finishes, so we need to tap the response stream + // + // We can't just wait for `await fetch` to finish since SSE streams run + // for longer + if (response.body) { + const wrappedStream = waitForStreamFinish(response.body, () => { + logger().debug("clearing file watcher interval: stream finished"); + clearInterval(watchIntervalId); + }); + return new Response(wrappedStream, { + status: response.status, + statusText: response.statusText, + headers: response.headers, + }); + } else { + // No response body, clear interval immediately + logger().debug("clearing file watcher interval: no response body"); + clearInterval(watchIntervalId); + return response; + } +} + +/** + * HACK: Watch for file changes on this route in order to shut down the runner. + * We do this because Next.js does not terminate long-running requests on file + * change, so we need to manually shut down the runner in order to trigger a + * new `/start` request with the new code. + * + * We don't use file watchers since those are frequently buggy x-platform and + * subject to misconfigured inotify limits. + */ +function watchRouteFile(abortController: AbortController): NodeJS.Timeout { + logger().debug("starting file watcher"); + + const routePath = join( + process.cwd(), + ".next/server/app/api/rivet/[...all]/route.js", + ); + + let lastMtime: number | null = null; + const checkFile = () => { + logger().debug({ msg: "checking for file changes", routePath }); + try { + if (!existsSync(routePath)) { + return; + } + + const stats = statSync(routePath); + const mtime = stats.mtimeMs; + + if (lastMtime !== null && mtime !== lastMtime) { + logger().info({ msg: "route file changed", routePath }); + abortController.abort(); + } + + lastMtime = mtime; + } catch (err) { + logger().info({ + msg: "failed to check for route file change", + err: stringifyError(err), + }); + } + }; + + checkFile(); + + return setInterval(checkFile, 1000); +} + +/** + * Waits for a stream to finish and calls onFinish on complete. + * + * Used for cancelling the file watcher. + */ +function waitForStreamFinish( + body: ReadableStream, + onFinish: () => void, +): ReadableStream { + const reader = body.getReader(); + return new ReadableStream({ + async start(controller) { + try { + while (true) { + const { done, value } = await reader.read(); + if (done) { + logger().debug("stream completed"); + onFinish(); + controller.close(); + break; + } + controller.enqueue(value); + } + } catch (err) { + logger().debug("stream errored"); + onFinish(); + controller.error(err); + } + }, + cancel() { + logger().debug("stream cancelled"); + onFinish(); + reader.cancel(); + }, + }); +} diff --git a/rivetkit-typescript/packages/rivetkit/runtime/index.ts b/rivetkit-typescript/packages/rivetkit/runtime/index.ts new file mode 100644 index 0000000000..20665020b6 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/runtime/index.ts @@ -0,0 +1,229 @@ +import invariant from "invariant"; +import { createClientWithDriver } from "@/client/client"; +import { configureBaseLogger, configureDefaultLogger } from "@/common/log"; +import { chooseDefaultDriver } from "@/drivers/default"; +import { ENGINE_PORT, ensureEngineProcess } from "@/engine-process/mod"; +import { getInspectorUrl } from "@/inspector/utils"; +import { buildManagerRouter } from "@/manager/router"; +import { configureServerlessRunner } from "@/serverless/configure"; +import type { GetUpgradeWebSocket } from "@/utils"; +import pkg from "../package.json" with { type: "json" }; +import { + type DriverConfig, + type RegistryActors, + type RegistryConfig, +} from "@/registry/config"; +import { logger } from "../src/registry/log"; +import { crossPlatformServe, findFreePort } from "@/registry/serve"; +import { ManagerDriver } from "@/manager/driver"; +import { buildServerlessRouter } from "@/serverless/router"; +import { Registry } from "@/registry"; + +/** + * Defines what type of server is being started. Used internally for + * Registry.#start + **/ +export type StartKind = "serverless" | "runner"; + +export class Runtime { + #registry: Registry; + managerPort?: number; + #config: RegistryConfig; + #driver: DriverConfig; + #kind: StartKind; + #managerDriver: ManagerDriver; + + get config() { + return this.#config; + } + + get driver() { + return this.#driver; + } + + get managerDriver() { + return this.#managerDriver; + } + + #serverlessRouter?: ReturnType["router"]; + + constructor(registry: Registry, kind: StartKind) { + this.#registry = registry; + this.#kind = kind; + + const config = this.#registry.parseConfig(); + this.#config = config; + + // Promise for any async operations we need to wait to complete + const readyPromises: Promise[] = []; + + // Configure logger + if (config.logging?.baseLogger) { + // Use provided base logger + configureBaseLogger(config.logging.baseLogger); + } else { + // Configure default logger with log level from config getPinoLevel + // will handle env variable priority + configureDefaultLogger(config.logging?.level); + } + + // Handle spawnEngine before choosing driver + // Start engine + invariant( + !( + kind === "serverless" && + config.serverless.spawnEngine && + config.serveManager + ), + "cannot specify spawnEngine and serveManager together", + ); + + if (kind === "serverless" && config.serverless.spawnEngine) { + this.managerPort = ENGINE_PORT; + + logger().debug({ + msg: "run engine requested", + version: config.serverless.engineVersion, + }); + + // Start the engine + const engineProcessPromise = ensureEngineProcess({ + version: config.serverless.engineVersion, + }); + + // Chain ready promise + readyPromises.push(engineProcessPromise); + } + + // Choose the driver based on configuration + const driver = chooseDefaultDriver(config); + + // Create manager driver (always needed for actor driver + inline client) + const managerDriver = driver.manager(config); + + // Start manager + if (config.serveManager) { + // Configure getUpgradeWebSocket lazily so we can assign it in crossPlatformServe + let upgradeWebSocket: any; + const getUpgradeWebSocket: GetUpgradeWebSocket = () => + upgradeWebSocket; + managerDriver.setGetUpgradeWebSocket(getUpgradeWebSocket); + + // Build router + const { router: managerRouter } = buildManagerRouter( + config, + managerDriver, + getUpgradeWebSocket, + ); + + // Serve manager + const serverPromise = (async () => { + const managerPort = await findFreePort(config.managerPort); + this.managerPort = managerPort; + + const out = await crossPlatformServe( + config, + managerPort, + managerRouter, + ); + upgradeWebSocket = out.upgradeWebSocket; + })(); + readyPromises.push(serverPromise); + } + + // Build serverless router + if (kind === "serverless") { + this.#serverlessRouter = buildServerlessRouter( + driver, + config, + ).router; + } + + this.#driver = driver; + this.#managerDriver = managerDriver; + + // Log and print welcome after all ready promises complete + // biome-ignore lint/nursery/noFloatingPromises: bg promise + Promise.all(readyPromises).then(async () => this.#onAfterReady()); + } + + async #onAfterReady() { + const config = this.#config; + const kind = this.#kind; + const driver = this.#driver; + const managerDriver = this.#managerDriver; + + // Auto-start actor driver for drivers that require it. + // + // This is only enabled for runner config since serverless will + // auto-start the actor driver on `GET /start`. + if (kind === "runner" && config.runner && driver.autoStartActorDriver) { + logger().debug("starting actor driver"); + const inlineClient = + createClientWithDriver>(managerDriver); + driver.actor(config, managerDriver, inlineClient); + } + + // Log starting + const driverLog = managerDriver.extraStartupLog?.() ?? {}; + logger().info({ + msg: "rivetkit ready", + driver: driver.name, + definitions: Object.keys(config.use).length, + ...driverLog, + }); + invariant(this.managerPort, "managerPort should be set"); + const inspectorUrl = getInspectorUrl(config, this.managerPort); + if (inspectorUrl && config.inspector.enabled) { + logger().info({ + msg: "inspector ready", + url: inspectorUrl, + }); + } + + // Print welcome information + if (!config.noWelcome) { + console.log(); + console.log(` RivetKit ${pkg.version} (${driver.displayName})`); + // Only show endpoint if manager is running or engine is spawned + const shouldShowEndpoint = + config.serveManager || + (kind === "serverless" && config.serverless.spawnEngine); + if ( + kind === "serverless" && + config.serverless.advertiseEndpoint && + shouldShowEndpoint + ) { + console.log( + ` - Endpoint: ${config.serverless.advertiseEndpoint}`, + ); + } + if (kind === "serverless" && config.serverless.spawnEngine) { + const padding = " ".repeat(Math.max(0, 13 - "Engine".length)); + console.log( + ` - Engine:${padding}v${config.serverless.engineVersion}`, + ); + } + const displayInfo = managerDriver.displayInformation(); + for (const [k, v] of Object.entries(displayInfo.properties)) { + const padding = " ".repeat(Math.max(0, 13 - k.length)); + console.log(` - ${k}:${padding}${v}`); + } + if (inspectorUrl && config.inspector.enabled) { + console.log(` - Inspector: ${inspectorUrl}`); + } + console.log(); + } + + // Configure serverless runner if enabled when actor driver is disabled + if (kind === "serverless" && config.serverless.configureRunnerPool) { + await configureServerlessRunner(config); + } + } + + public handler(request: Request): Response | Promise { + invariant(this.#kind === "serverless", "kind not serverless"); + invariant(this.#serverlessRouter, "missing serverless router"); + return this.#serverlessRouter.fetch(request); + } +} diff --git a/rivetkit-typescript/packages/rivetkit/scripts/dump-openapi.ts b/rivetkit-typescript/packages/rivetkit/scripts/dump-openapi.ts index 21a2e44b33..68fc4acb57 100644 --- a/rivetkit-typescript/packages/rivetkit/scripts/dump-openapi.ts +++ b/rivetkit-typescript/packages/rivetkit/scripts/dump-openapi.ts @@ -1,16 +1,12 @@ import * as fs from "node:fs/promises"; import { resolve } from "node:path"; import { z } from "zod"; -import { ClientConfigSchema } from "@/client/config"; import { createFileSystemOrMemoryDriver } from "@/drivers/file-system/mod"; import type { - ActorOutput, - ListActorsInput, ManagerDriver, } from "@/manager/driver"; import { buildManagerRouter } from "@/manager/router"; import { type RegistryConfig, RegistryConfigSchema } from "@/registry/config"; -import { LegacyRunnerConfigSchema } from "@/registry/config/legacy-runner"; import { VERSION } from "@/utils"; import { toJsonSchema } from "./schema-utils"; @@ -26,18 +22,19 @@ async function main() { // const registry = setup(registryConfig); const managerDriver: ManagerDriver = { - getForId: unimplemented, - getWithKey: unimplemented, - getOrCreateWithKey: unimplemented, - createActor: unimplemented, - listActors: unimplemented, - sendRequest: unimplemented, - openWebSocket: unimplemented, - proxyRequest: unimplemented, - proxyWebSocket: unimplemented, - displayInformation: unimplemented, - setGetUpgradeWebSocket: unimplemented, - }; + getForId: unimplemented, + getWithKey: unimplemented, + getOrCreateWithKey: unimplemented, + createActor: unimplemented, + listActors: unimplemented, + sendRequest: unimplemented, + openWebSocket: unimplemented, + proxyRequest: unimplemented, + proxyWebSocket: unimplemented, + displayInformation: unimplemented, + setGetUpgradeWebSocket: unimplemented, + kvGet: unimplemented, + }; // const client = createClientWithDriver( // managerDriver, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts b/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts index f53dafecd6..80e80bb477 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts @@ -3,6 +3,7 @@ import type { Actions, ActorConfig } from "./config"; import type { ActionContextOf, ActorContext } from "./contexts"; import type { AnyDatabaseProvider } from "./database"; import { ActorInstance } from "./instance/mod"; +import { DeepMutable } from "@/utils"; export type AnyActorDefinition = ActorDefinition< any, diff --git a/rivetkit-typescript/packages/rivetkit/src/client/config.ts b/rivetkit-typescript/packages/rivetkit/src/client/config.ts index 06ed71e502..0b67d561e0 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/config.ts @@ -1,19 +1,19 @@ import z from "zod"; import { EncodingSchema } from "@/actor/protocol/serde"; +import { type GetUpgradeWebSocket } from "@/utils"; +import { + getRivetEngine, + getRivetEndpoint, + getRivetToken, + getRivetNamespace, + getRivetRunner, +} from "@/utils/env-vars"; import type { RegistryConfig } from "@/registry/config"; -import type { GetUpgradeWebSocket } from "@/utils"; import { EndpointSchema, type ParsedEndpoint, zodCheckDuplicateCredentials, } from "@/utils/endpoint-parser"; -import { - getRivetEndpoint, - getRivetEngine, - getRivetNamespace, - getRivetRunner, - getRivetToken, -} from "@/utils/env-vars"; /** * Base client config schema without transforms so it can be merged in to other schemas. diff --git a/rivetkit-typescript/packages/rivetkit/src/client/utils.ts b/rivetkit-typescript/packages/rivetkit/src/client/utils.ts index 4615f3e3db..bc64cec4e5 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/utils.ts @@ -124,7 +124,7 @@ export async function sendHttpRequest< : {}), "User-Agent": httpUserAgent(), }, - body: bodyData as BodyInit | undefined, + body: bodyData, credentials: "include", signal: opts.signal, }), diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts index 4a0b911f3d..ec57d91f76 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts @@ -28,4 +28,5 @@ export type { ManagerDisplayInformation, ManagerDriver, } from "@/manager/driver"; +export { buildManagerRouter } from "@/manager/router"; export { getInitialActorKvState } from "./utils"; diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/mod.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/mod.ts index e0ea31584a..7f2f691b43 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/mod.ts @@ -1,8 +1,7 @@ import { serve as honoServe } from "@hono/node-server"; -import { createNodeWebSocket, type NodeWebSocket } from "@hono/node-ws"; +import { createNodeWebSocket } from "@hono/node-ws"; import invariant from "invariant"; import { describe } from "vitest"; -import { ClientConfigSchema } from "@/client/config"; import type { Encoding } from "@/client/mod"; import { buildManagerRouter } from "@/manager/router"; import { createClientWithDriver, type Registry } from "@/mod"; @@ -174,7 +173,7 @@ export async function createTestRuntime( // TODO: Find a cleaner way of flagging an registry as test mode (ideally not in the config itself) // Force enable test - registry.config.test.enabled = true; + registry.config.test = { ...registry.config.test, enabled: true }; registry.config.inspector = { enabled: true, token: () => "token", @@ -209,14 +208,15 @@ export async function createTestRuntime( let upgradeWebSocket: any; // Create router - const managerDriver = driver.manager?.(registry.config); + const parsedConfig = registry.parseConfig(); + const managerDriver = driver.manager?.(parsedConfig); invariant(managerDriver, "missing manager driver"); // const client = createClientWithDriver( // managerDriver, // ClientConfigSchema.parse({}), // ); const { router } = buildManagerRouter( - registry.config, + parsedConfig, managerDriver, () => upgradeWebSocket, ); @@ -226,6 +226,7 @@ export async function createTestRuntime( upgradeWebSocket = nodeWebSocket.upgradeWebSocket; managerDriver.setGetUpgradeWebSocket(() => upgradeWebSocket); + // TODO: I think this whole function is fucked, we should probably switch to calling registry.serve() directly // Start server const port = await getPort(); const server = honoServe({ diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts index 72c7bb2da8..14315b163e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts @@ -243,6 +243,9 @@ export function createTestInlineClientDriver( setGetUpgradeWebSocket: (getUpgradeWebSocketInner) => { getUpgradeWebSocket = getUpgradeWebSocketInner; }, + kvGet: (_actorId: string, _key: Uint8Array) => { + throw new Error("kvGet not impelmented on inline client driver"); + }, } satisfies ManagerDriver; } diff --git a/rivetkit-typescript/packages/rivetkit/src/engine-process/constants.ts b/rivetkit-typescript/packages/rivetkit/src/engine-process/constants.ts new file mode 100644 index 0000000000..ef5887f354 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/engine-process/constants.ts @@ -0,0 +1,2 @@ +export const ENGINE_PORT = 6420; +export const ENGINE_ENDPOINT = `http://localhost:${ENGINE_PORT}`; diff --git a/rivetkit-typescript/packages/rivetkit/src/engine-process/mod.ts b/rivetkit-typescript/packages/rivetkit/src/engine-process/mod.ts index 4b132e49f8..f991e0542f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/engine-process/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/engine-process/mod.ts @@ -12,9 +12,9 @@ import { importNodeDependencies, } from "@/utils/node"; import { logger } from "./log"; +import { ENGINE_ENDPOINT, ENGINE_PORT } from "./constants"; -export const ENGINE_PORT = 6420; -export const ENGINE_ENDPOINT = `http://localhost:${ENGINE_PORT}`; +export { ENGINE_ENDPOINT, ENGINE_PORT }; const ENGINE_BASE_URL = "https://releases.rivet.dev/rivet"; const ENGINE_BINARY_NAME = "rivet-engine"; @@ -42,14 +42,7 @@ export async function ensureEngineProcess( await ensureDirectoryExists(varDir); await ensureDirectoryExists(logsDir); - const executableName = - process.platform === "win32" - ? `${ENGINE_BINARY_NAME}-${options.version}.exe` - : `${ENGINE_BINARY_NAME}-${options.version}`; - const binaryPath = path.join(binDir, executableName); - await downloadEngineBinaryIfNeeded(binaryPath, options.version, varDir); - - // Check if the engine is already running on the port + // Check if the engine is already running on the port before downloading if (await isEngineRunning()) { try { const health = await waitForEngineHealth(); @@ -68,6 +61,13 @@ export async function ensureEngineProcess( ); } } + + const executableName = + process.platform === "win32" + ? `${ENGINE_BINARY_NAME}-${options.version}.exe` + : `${ENGINE_BINARY_NAME}-${options.version}`; + const binaryPath = path.join(binDir, executableName); + await downloadEngineBinaryIfNeeded(binaryPath, options.version, varDir); // Create log file streams with timestamp in the filename const timestamp = new Date() .toISOString() diff --git a/rivetkit-typescript/packages/rivetkit/src/inspector/utils.ts b/rivetkit-typescript/packages/rivetkit/src/inspector/utils.ts index 44584cb600..73eb54edce 100644 --- a/rivetkit-typescript/packages/rivetkit/src/inspector/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/inspector/utils.ts @@ -1,11 +1,65 @@ -import type { RegistryConfig } from "@/mod"; +import { createMiddleware } from "hono/factory"; +import { inspectorLogger } from "./log"; +import type { RegistryConfig } from "@/registry/config"; + +export function compareSecrets(providedSecret: string, validSecret: string) { + // Early length check to avoid unnecessary processing + if (providedSecret.length !== validSecret.length) { + return false; + } + + const encoder = new TextEncoder(); + + const a = encoder.encode(providedSecret); + const b = encoder.encode(validSecret); + + if (a.byteLength !== b.byteLength) { + return false; + } + + // TODO: + // // Perform timing-safe comparison + // if (!crypto.timingSafeEqual(a, b)) { + // return false; + // } + return true; +} + +export const secureInspector = (config: RegistryConfig) => + createMiddleware(async (c, next) => { + const userToken = c.req.header("Authorization")?.replace("Bearer ", ""); + if (!userToken) { + return c.text("Unauthorized", 401); + } + + const inspectorToken = config.inspector.token(); + if (!inspectorToken) { + return c.text("Unauthorized", 401); + } + + const isValid = compareSecrets(userToken, inspectorToken); + if (!isValid) { + return c.text("Unauthorized", 401); + } + await next(); + }); + +export function getInspectorUrl( + config: RegistryConfig, + managerPort: number, +): string | undefined { + if (!config.inspector.enabled) return undefined; -export function getInspectorUrl(runConfig: RegistryConfig | undefined) { const url = new URL("https://inspect.rivet.dev"); - const overrideDefaultEndpoint = runConfig?.inspector?.defaultEndpoint; - if (overrideDefaultEndpoint) { - url.searchParams.set("u", overrideDefaultEndpoint); + // Only override endpoint if using non-default port or custom endpoint is set + const endpoint = + config.inspector.defaultEndpoint ?? + (config.managerPort !== 6420 + ? `http://localhost:${managerPort}` + : undefined); + if (endpoint) { + url.searchParams.set("u", endpoint); } return url.href; diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/gateway.ts b/rivetkit-typescript/packages/rivetkit/src/manager/gateway.ts index 9598202312..f20fb4d36e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/gateway.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/gateway.ts @@ -122,6 +122,7 @@ export async function actorGateway( c: HonoContext, next: Next, ) { + // Skip test routes - let them be handled by their specific handlers if (c.req.path.startsWith("/.test/")) { return next(); diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/router.ts b/rivetkit-typescript/packages/rivetkit/src/manager/router.ts index 6f2bd6e9ca..8965268398 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/router.ts @@ -71,7 +71,7 @@ export function buildManagerRouter( // GET / router.get("/", (c) => { return c.text( - "This is a RivetKit server.\n\nLearn more at https://rivetkit.org", + "This is a RivetKit server.\n\nLearn more at https://rivet.dev", ); }); diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts b/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts index abb9e737d7..fcd1f2ca04 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts @@ -3,6 +3,7 @@ import { z } from "zod"; import type { ActorDefinition, AnyActorDefinition } from "@/actor/definition"; import { resolveEndpoint } from "@/client/config"; import { type Logger, LogLevelSchema } from "@/common/log"; +import { ENGINE_ENDPOINT } from "@/engine-process/constants"; import { InspectorConfigSchema } from "@/inspector/config"; import { EndpointSchema, @@ -12,6 +13,7 @@ import { getRivetNamespace, getRivetToken, isDev } from "@/utils/env-vars"; import { type DriverConfig, DriverConfigSchema } from "./driver"; import { RunnerConfigSchema } from "./runner"; import { ServerlessConfigSchema } from "./serverless"; +import { DeepReadonly } from "@/utils"; export { DriverConfigSchema, type DriverConfig }; @@ -155,23 +157,26 @@ export const RegistryConfigSchema = z }); } - // clientEndpoint required in production without endpoint + // advertiseEndpoint required in production without endpoint if ( !isDevEnv && !resolvedEndpoint && - !config.serverless.clientEndpoint + !config.serverless.advertiseEndpoint ) { ctx.addIssue({ code: "custom", message: - "clientEndpoint is required in production mode without endpoint", - path: ["clientEndpoint"], + "advertiseEndpoint is required in production mode without endpoint", + path: ["advertiseEndpoint"], }); } } // Flatten the endpoint and apply defaults for namespace/token - const endpoint = resolvedEndpoint?.endpoint; + // If spawnEngine is enabled, set endpoint to the engine endpoint + const endpoint = config.serverless?.spawnEngine + ? ENGINE_ENDPOINT + : resolvedEndpoint?.endpoint; const namespace = resolvedEndpoint?.namespace ?? config.namespace ?? @@ -185,32 +190,31 @@ export const RegistryConfigSchema = z // - If dev mode without endpoint: start manager server // - If prod mode without endpoint: do not start manager server let serveManager: boolean; - let clientEndpoint: string; + let advertiseEndpoint: string | undefined; if (endpoint) { // Remote endpoint provided: // - Do not start manager server // - Redirect clients to remote endpoint serveManager = config.serveManager ?? false; - clientEndpoint = config.serverless.clientEndpoint ?? endpoint; + advertiseEndpoint = + config.serverless.advertiseEndpoint ?? endpoint; } else if (isDevEnv) { // Development mode, no endpoint: // - Start manager server // - Redirect clients to local server serveManager = config.serveManager ?? true; - clientEndpoint = - config.serverless.clientEndpoint ?? - `http://localhost:${config.managerPort}`; + advertiseEndpoint = config.serverless.advertiseEndpoint; } else { // Production mode, no endpoint: // - Do not start manager server // - Use file system driver serveManager = config.serveManager ?? false; invariant( - config.serverless.clientEndpoint, - "clientEndpoint is required in production mode without endpoint", + config.serverless.advertiseEndpoint, + "advertiseEndpoint is required in production mode without endpoint", ); - clientEndpoint = config.serverless.clientEndpoint; + advertiseEndpoint = config.serverless.advertiseEndpoint; } // If endpoint is set or spawning engine, we'll use engine driver - disable manager inspector @@ -224,15 +228,16 @@ export const RegistryConfigSchema = z return { ...config, - serverless: { - ...config.serverless, - clientEndpoint, - }, endpoint, namespace, token, serveManager, + advertiseEndpoint, inspector, + serverless: { + ...config.serverless, + advertiseEndpoint, + }, }; }); diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/config/serverless.ts b/rivetkit-typescript/packages/rivetkit/src/registry/config/serverless.ts index 3747263c3e..57b5845026 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/config/serverless.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/config/serverless.ts @@ -2,6 +2,20 @@ import { z } from "zod"; import { VERSION } from "@/utils"; import { getRivetRunEngineVersion, getRivetRunEngine } from "@/utils/env-vars"; +export const ConfigureRunnerPoolSchema = z + .object({ + name: z.string().optional(), + url: z.string(), + headers: z.record(z.string(), z.string()).optional(), + maxRunners: z.number().optional(), + minRunners: z.number().optional(), + requestLifespan: z.number().optional(), + runnersMargin: z.number().optional(), + slotsPerRunner: z.number().optional(), + metadata: z.record(z.string(), z.unknown()).optional(), + }) + .optional(); + export const ServerlessConfigSchema = z.object({ // MARK: Run Engine /** @@ -25,19 +39,7 @@ export const ServerlessConfigSchema = z.object({ * Can only be used when runnerKind is "serverless". * If true, uses default configuration. Can also provide custom configuration. */ - configureRunnerPool: z - .object({ - name: z.string().optional(), - url: z.string(), - headers: z.record(z.string(), z.string()).optional(), - maxRunners: z.number().optional(), - minRunners: z.number().optional(), - requestLifespan: z.number().optional(), - runnersMargin: z.number().optional(), - slotsPerRunner: z.number().optional(), - metadata: z.record(z.string(), z.unknown()).optional(), - }) - .optional(), + configureRunnerPool: ConfigureRunnerPoolSchema.optional(), // MARK: Routing // TODO: serverlessBasePath? better naming? @@ -53,7 +55,7 @@ export const ServerlessConfigSchema = z.object({ * * Auto-determined based on endpoint and NODE_ENV if not specified. */ - clientEndpoint: z.string().optional(), + advertiseEndpoint: z.string().optional(), }); export type ServerlessConfigInput = z.input; export type ServerlessConfig = z.infer; diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/index.ts b/rivetkit-typescript/packages/rivetkit/src/registry/index.ts index 4e7bdb1206..5a52f6a60f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/index.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/index.ts @@ -1,19 +1,11 @@ -import invariant from "invariant"; import type { Client } from "@/client/client"; -import { createClientWithDriver } from "@/client/client"; import { createClient } from "@/client/mod"; -import { configureBaseLogger, configureDefaultLogger } from "@/common/log"; import { chooseDefaultDriver } from "@/drivers/default"; import { ENGINE_ENDPOINT, ensureEngineProcess } from "@/engine-process/mod"; import { getInspectorUrl } from "@/inspector/utils"; import { buildManagerRouter } from "@/manager/router"; -import { configureServerlessRunner } from "@/serverless/configure"; -import { buildServerlessRouter } from "@/serverless/router"; -import type { GetUpgradeWebSocket } from "@/utils"; import { isDev } from "@/utils/env-vars"; -import pkg from "../../package.json" with { type: "json" }; import { - type DriverConfig, type RegistryActors, type RegistryConfig, type RegistryConfigInput, @@ -24,8 +16,7 @@ import { type LegacyRunnerConfigInput, LegacyRunnerConfigSchema, } from "./config/legacy-runner"; -import { logger } from "./log"; -import { crossPlatformServe, findFreePort } from "./serve"; +import { Runtime } from "../../runtime"; export type FetchHandler = ( request: Request, @@ -43,25 +34,23 @@ export interface LegacyStartServerOutput> { fetch: FetchHandler; } -/** - * Defines what type of server is being started. Used internally for - * Registry.#start - **/ -type StartKind = "serverless" | "runner"; - export class Registry { - #config: RegistryConfig; - - /** - * Cached serverless state. Subsequent calls to `handler()` will use this. - */ - #serverlessState: ServerlessHandler | null = null; + #config: RegistryConfigInput; - public get config(): RegistryConfig { + get config(): RegistryConfigInput { return this.#config; } - constructor(config: RegistryConfig) { + parseConfig(): RegistryConfig { + return RegistryConfigSchema.parse(this.#config); + } + + // HACK: We need to be able to call `registry.handler` cheaply without + // re-initializing the runtime every time. We lazily create the runtime and + // store it here for future calls to `registry.handler`. + #cachedServerlessRuntime?: Runtime; + + constructor(config: RegistryConfigInput) { this.#config = config; } @@ -76,8 +65,7 @@ export class Registry { * ``` */ public handler(request: Request): Response | Promise { - const { fetch } = this.#ensureServerlessInitialized(); - return fetch(request); + return this.#ensureServerlessInitialized().handler(request); } /** @@ -92,182 +80,19 @@ export class Registry { return { fetch: this.handler.bind(this) }; } - /** - * Starts an actor runner for standalone server deployments. - */ - public startRunner() { - this.#start("runner"); - } - /** Lazily initializes serverless state on first request, caches for subsequent calls. */ - #ensureServerlessInitialized(): { fetch: FetchHandler } { - if (!this.#serverlessState) { - const { driver } = this.#start("serverless"); - - const { router } = buildServerlessRouter(driver, this.#config); - this.#serverlessState = { fetch: router.fetch.bind(router) }; + #ensureServerlessInitialized(): Runtime { + if (!this.#cachedServerlessRuntime) { + this.#cachedServerlessRuntime = new Runtime(this, "serverless"); } - return this.#serverlessState; + return this.#cachedServerlessRuntime; } - #start(kind: StartKind): { driver: DriverConfig } { - const config = this.#config; - - // Promise for any async operations we need to wait to complete - const readyPromises: Promise[] = []; - - // Configure logger - if (config.logging?.baseLogger) { - // Use provided base logger - configureBaseLogger(config.logging.baseLogger); - } else { - // Configure default logger with log level from config getPinoLevel - // will handle env variable priority - configureDefaultLogger(config.logging?.level); - } - - // Handle spawnEngine before choosing driver - // Start engine - invariant( - !( - kind === "serverless" && - config.serverless.spawnEngine && - config.serveManager - ), - "cannot specify spawnEngine and serveManager together", - ); - - if (kind === "serverless" && config.serverless.spawnEngine) { - logger().debug({ - msg: "run engine requested", - version: config.serverless.engineVersion, - }); - - // Set config to point to the engine - invariant( - config.endpoint === undefined, - "cannot specify endpoint with spawnEngine", - ); - config.endpoint = ENGINE_ENDPOINT; - - // Start the engine - const engineProcessPromise = ensureEngineProcess({ - version: config.serverless.engineVersion, - }); - - // Chain ready promise - readyPromises.push(engineProcessPromise); - } - - // Choose the driver based on configuration (after endpoint may have been set by spawnEngine) - const driver = chooseDefaultDriver(config); - - // Create manager driver (always needed for actor driver + inline client) - const managerDriver = driver.manager(this.#config); - - if (config.serveManager) { - // Configure getUpgradeWebSocket lazily so we can assign it in crossPlatformServe - let upgradeWebSocket: any; - const getUpgradeWebSocket: GetUpgradeWebSocket = () => - upgradeWebSocket; - managerDriver.setGetUpgradeWebSocket(getUpgradeWebSocket); - - // Build router - const { router: managerRouter } = buildManagerRouter( - this.#config, - managerDriver, - getUpgradeWebSocket, - ); - - // Serve manager - const serverPromise = (async () => { - const managerPort = await findFreePort(config.managerPort); - config.managerPort = managerPort; - - const out = await crossPlatformServe(config, managerRouter); - upgradeWebSocket = out.upgradeWebSocket; - })(); - readyPromises.push(serverPromise); - } - - // Log and print welcome after all ready promises complete - // biome-ignore lint/nursery/noFloatingPromises: bg promise - Promise.all(readyPromises).then(async () => { - // Auto-start actor driver for drivers that require it. - // - // This is only enabled for runner config since serverless will - // auto-start the actor driver on `GET /start`. - if ( - kind === "runner" && - config.runner && - driver.autoStartActorDriver - ) { - logger().debug("starting actor driver"); - const inlineClient = - createClientWithDriver(managerDriver); - driver.actor(this.#config, managerDriver, inlineClient); - } - - // Log starting - const driverLog = managerDriver.extraStartupLog?.() ?? {}; - logger().info({ - msg: "rivetkit ready", - driver: driver.name, - definitions: Object.keys(this.#config.use).length, - ...driverLog, - }); - const inspectorUrl = getInspectorUrl(config); - if (inspectorUrl && config.inspector.enabled) { - logger().info({ - msg: "inspector ready", - url: inspectorUrl, - }); - } - - // Print welcome information - if (!config.noWelcome) { - console.log(); - console.log( - ` RivetKit ${pkg.version} (${driver.displayName})`, - ); - // Only show endpoint if manager is running or engine is spawned - if ( - config.serveManager || - (kind === "serverless" && config.serverless.spawnEngine) - ) { - console.log( - ` - Endpoint: ${config.endpoint ?? config.serverless.clientEndpoint}`, - ); - } - if (kind === "serverless" && config.serverless.spawnEngine) { - const padding = " ".repeat( - Math.max(0, 13 - "Engine".length), - ); - console.log( - ` - Engine:${padding}v${config.serverless.engineVersion}`, - ); - } - const displayInfo = managerDriver.displayInformation(); - for (const [k, v] of Object.entries(displayInfo.properties)) { - const padding = " ".repeat(Math.max(0, 13 - k.length)); - console.log(` - ${k}:${padding}${v}`); - } - if (inspectorUrl && config.inspector.enabled) { - console.log(` - Inspector: ${inspectorUrl}`); - } - console.log(); - } - - // Configure serverless runner if enabled when actor driver is disabled - if ( - kind === "serverless" && - config.serverless.configureRunnerPool - ) { - await configureServerlessRunner(config); - } - }); - - return { driver }; + /** + * Starts an actor runner for standalone server deployments. + */ + public startRunner() { + new Runtime(this, "runner"); } // MARK: Legacy @@ -332,7 +157,7 @@ export class Registry { ): LegacyStartServerOutput { // Start the runner // Note: Legacy config is ignored - all config should now be passed to setup() - this.startRunner(); + const runtime = new Runtime(this, "runner"); // Create client for the legacy return value const client = createClient({ @@ -342,15 +167,11 @@ export class Registry { headers: config.headers, }); - // For normal runner, we need to build a manager router to get the fetch handler - const driver = chooseDefaultDriver(this.#config); - const managerDriver = driver.manager(this.#config); - // Configure getUpgradeWebSocket as undefined for this legacy path // since it's only used when actually serving const { router } = buildManagerRouter( - this.#config, - managerDriver, + runtime.config, + runtime.managerDriver, undefined, // getUpgradeWebSocket ); @@ -364,8 +185,7 @@ export class Registry { export function setup( input: RegistryConfigInput, ): Registry { - const config = RegistryConfigSchema.parse(input); - return new Registry(config); + return new Registry(input); } export type { RegistryConfig, RegistryActors }; diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/serve.ts b/rivetkit-typescript/packages/rivetkit/src/registry/serve.ts index c2aece9a43..ea8f4d7285 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/serve.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/serve.ts @@ -32,6 +32,7 @@ export async function findFreePort( export async function crossPlatformServe( config: RegistryConfig, + managerPort: number, app: Hono, ): Promise<{ upgradeWebSocket: any }> { const runtime = detectRuntime(); @@ -39,18 +40,19 @@ export async function crossPlatformServe( switch (runtime) { case "deno": - return serveDeno(config, app); + return serveDeno(config, managerPort, app); case "bun": - return serveBun(config, app); + return serveBun(config, managerPort, app); case "node": - return serveNode(config, app); + return serveNode(config, managerPort, app); default: - return serveNode(config, app); + return serveNode(config, managerPort, app); } } async function serveNode( config: RegistryConfig, + managerPort: number, app: Hono, ): Promise<{ upgradeWebSocket: any }> { // Import @hono/node-server using string variable to prevent static analysis @@ -93,7 +95,7 @@ async function serveNode( }); // Start server - const port = config.managerPort; + const port = managerPort; const server = serve({ fetch: app.fetch, port }, () => logger().info({ msg: "server listening", port }), ); @@ -104,6 +106,7 @@ async function serveNode( async function serveDeno( config: RegistryConfig, + managerPort: number, app: Hono, ): Promise<{ upgradeWebSocket: any }> { // Import hono/deno using string variable to prevent static analysis @@ -134,6 +137,7 @@ async function serveDeno( async function serveBun( config: RegistryConfig, + managerPort: number, app: Hono, ): Promise<{ upgradeWebSocket: any }> { // Import hono/bun using string variable to prevent static analysis diff --git a/rivetkit-typescript/packages/rivetkit/src/serverless/router.ts b/rivetkit-typescript/packages/rivetkit/src/serverless/router.ts index 1c7f8c9e8b..217cf3bcbf 100644 --- a/rivetkit-typescript/packages/rivetkit/src/serverless/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/serverless/router.ts @@ -94,7 +94,7 @@ export function buildServerlessRouter( c, config, { serverless: {} }, - config.serverless.clientEndpoint, + config.serverless.advertiseEndpoint, ), ); }); diff --git a/rivetkit-typescript/packages/rivetkit/src/test/mod.ts b/rivetkit-typescript/packages/rivetkit/src/test/mod.ts index 5ddb524f7b..fc8451a87c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/test/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/test/mod.ts @@ -6,9 +6,9 @@ import { type TestContext, vi } from "vitest"; import { ClientConfigSchema } from "@/client/config"; import { type Client, createClient } from "@/client/mod"; import { createFileSystemOrMemoryDriver } from "@/drivers/file-system/mod"; -import { buildManagerRouter } from "@/manager/router"; import { createClientWithDriver, type Registry } from "@/mod"; import { RegistryConfig, RegistryConfigSchema } from "@/registry/config"; +import { buildManagerRouter } from "@/manager/router"; import { logger } from "./log"; export interface SetupTestResult> { @@ -21,7 +21,7 @@ export async function setupTest>( registry: A, ): Promise> { // Force enable test mode - registry.config.test.enabled = true; + registry.config.test = { ...registry.config.test, enabled: true }; // Create driver const driver = await createFileSystemOrMemoryDriver( @@ -39,14 +39,15 @@ export async function setupTest>( }; // Create router - const managerDriver = driver.manager?.(registry.config); + const parsedConfig = registry.parseConfig(); + const managerDriver = driver.manager?.(parsedConfig); invariant(managerDriver, "missing manager driver"); // const internalClient = createClientWithDriver( // managerDriver, // ClientConfigSchema.parse({}), // ); const { router } = buildManagerRouter( - registry.config, + parsedConfig, managerDriver, () => upgradeWebSocket!, ); @@ -55,6 +56,7 @@ export async function setupTest>( const nodeWebSocket = createNodeWebSocket({ app: router }); upgradeWebSocket = nodeWebSocket.upgradeWebSocket; + // TODO: I think this whole function is fucked, we should probably switch to calling registry.serve() directly // Start server const port = await getPort(); const server = honoServe({ diff --git a/rivetkit-typescript/packages/rivetkit/src/utils.ts b/rivetkit-typescript/packages/rivetkit/src/utils.ts index 3ecdfac380..78410528f0 100644 --- a/rivetkit-typescript/packages/rivetkit/src/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/utils.ts @@ -293,3 +293,11 @@ export function detectRuntime(): Runtime { } return "node"; } + +export type DeepReadonly = { + readonly [K in keyof T]: T[K] extends object ? DeepReadonly : T[K]; +}; + +export type DeepMutable = { + -readonly [K in keyof T]: T[K] extends object ? DeepMutable : T[K]; +}; diff --git a/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts b/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts index 1041857953..8e41522ac3 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts @@ -47,21 +47,27 @@ runDriverTests({ // Create driver config const driverConfig = createEngineDriver(); + // TODO: We should not have to do this, we should have access to the Runtime instead + const parsedConfig = registry.parseConfig(); + // Start the actor driver registry.config.driver = driverConfig; registry.config.endpoint = endpoint; registry.config.namespace = namespace; registry.config.token = token; - registry.config.runner.runnerName = runnerName; - const managerDriver = driverConfig.manager?.(registry.config); + registry.config.runner = { + ...registry.config.runner, + runnerName, + }; + const managerDriver = driverConfig.manager?.(parsedConfig); invariant(managerDriver, "missing manager driver"); const inlineClient = createClientWithDriver( managerDriver, - convertRegistryConfigToClientConfig(registry.config), + convertRegistryConfigToClientConfig(parsedConfig), ); const actorDriver = driverConfig.actor( - registry.config, + parsedConfig, managerDriver, inlineClient, ); diff --git a/rivetkit-typescript/packages/rivetkit/tsconfig.json b/rivetkit-typescript/packages/rivetkit/tsconfig.json index abc57f2f60..f1223270fc 100644 --- a/rivetkit-typescript/packages/rivetkit/tsconfig.json +++ b/rivetkit-typescript/packages/rivetkit/tsconfig.json @@ -16,5 +16,5 @@ "scripts/**/*", "fixtures/driver-test-suite/**/*", "dist/schemas/**/*" - ] +, "runtime/index.ts" ] } diff --git a/scripts/run/docker/engine-rocksdb.sh b/scripts/run/docker/engine-rocksdb.sh index c105ed78e4..3dc33785a4 100755 --- a/scripts/run/docker/engine-rocksdb.sh +++ b/scripts/run/docker/engine-rocksdb.sh @@ -20,4 +20,4 @@ RIVET__PEGBOARD__BASE_RETRY_TIMEOUT="100" \ RIVET__PEGBOARD__RESCHEDULE_BACKOFF_MAX_EXPONENT="1" \ RIVET__PEGBOARD__RUNNER_ELIGIBLE_THRESHOLD="5000" \ RIVET__PEGBOARD__RUNNER_LOST_THRESHOLD="7000" \ -cargo run --bin rivet-engine "${FILTERED_ARGS[@]}" -- start 2>&1 | tee -i /tmp/rivet-engine.log +cargo run --bin rivet-engine ${FILTERED_ARGS[@]+"${FILTERED_ARGS[@]}"} -- start 2>&1 | tee -i /tmp/rivet-engine.log diff --git a/website/src/content/docs/actors/quickstart/cloudflare-workers.mdx b/website/src/content/docs/actors/quickstart/cloudflare-workers.mdx index bbdb5c025d..a2ad829fa9 100644 --- a/website/src/content/docs/actors/quickstart/cloudflare-workers.mdx +++ b/website/src/content/docs/actors/quickstart/cloudflare-workers.mdx @@ -48,7 +48,7 @@ Choose your preferred web framework: import { createHandler } from "@rivetkit/cloudflare-workers"; import { registry } from "./registry"; -// The `/rivet` endpoint is automatically exposed here for external clients +// The `/api/rivet` endpoint is automatically exposed here for external clients const { handler, ActorHandler } = createHandler(registry); export { handler as default, ActorHandler }; ``` @@ -71,16 +71,16 @@ app.post("/increment/:name", async (c) => { return c.json({ count: newCount }); }); -// The `/rivet` endpoint is automatically exposed here for external clients +// The `/api/rivet` endpoint is automatically exposed here for external clients const { handler, ActorHandler } = createHandler(registry, { fetch: app.fetch }); export { handler as default, ActorHandler }; ``` -```ts {{"title":"No Router"}} +```ts {{"title":"Manual Routing"}} import { createHandler } from "@rivetkit/cloudflare-workers"; import { registry } from "./registry"; -// The `/rivet` endpoint is automatically mounted on this router for external clients +// The `/api//rivet` endpoint is automatically mounted on this router for external clients const { handler, ActorHandler } = createHandler(registry, { fetch: async (request, env, ctx) => { const url = new URL(request.url); @@ -133,9 +133,9 @@ export default { }); } - // Optional: Mount /rivet path to access actors from external clients - if (url.pathname.startsWith("/rivet")) { - const strippedPath = url.pathname.substring("/rivet".length); + // Optional: Mount /api/rivet path to access actors from external clients + if (url.pathname.startsWith("/api/rivet")) { + const strippedPath = url.pathname.substring("/api/rivet".length); url.pathname = strippedPath; const modifiedRequest = new Request(url.toString(), request); return rivetFetch(modifiedRequest, env, ctx); @@ -248,7 +248,7 @@ import { createClient } from "rivetkit/client"; import type { registry } from "./registry"; // Create typed client (use your deployed URL) -const client = createClient("https://your-app.workers.dev/rivet"); +const client = createClient("https://your-app.workers.dev/api/rivet"); // Use the counter actor directly const counter = client.counter.getOrCreate(["my-counter"]); @@ -282,7 +282,7 @@ import { createRivetKit } from "@rivetkit/react"; import { useState } from "react"; import type { registry } from "./registry"; -const { useActor } = createRivetKit("https://your-app.workers.dev/rivet"); +const { useActor } = createRivetKit("https://your-app.workers.dev/api/rivet"); function Counter() { const [count, setCount] = useState(0); @@ -320,7 +320,7 @@ use serde_json::json; #[tokio::main] async fn main() -> Result<(), Box> { let client = Client::new( - "https://your-app.workers.dev/rivet", + "https://your-app.workers.dev/api/rivet", TransportKind::Sse, EncodingKind::Json ); @@ -351,6 +351,6 @@ See the [Rust client documentation](/docs/clients/rust) for more information. - Cloudflare Workers mounts the Rivet endpoint on `/rivet` by default. + Cloudflare Workers mounts the Rivet endpoint on `/api/rivet` by default.