Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/proud-nails-grin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@trigger.dev/core": patch
---

Add optional billing info to DequeuedMessage for tiered scheduling
6 changes: 6 additions & 0 deletions apps/supervisor/src/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ const Env = z.object({
KUBERNETES_EPHEMERAL_STORAGE_SIZE_LIMIT: z.string().default("10Gi"),
KUBERNETES_EPHEMERAL_STORAGE_SIZE_REQUEST: z.string().default("2Gi"),

// Tier scheduling settings
ENABLE_TIER_SCHEDULING: BoolEnv.default(false),
TIER_LABEL_KEY: z.string().default("node.cluster.x-k8s.io/paid"),
TIER_LABEL_VALUE_FREE: z.string().default("false"),
TIER_LABEL_VALUE_PAID: z.string().default("true"),

// Metrics
METRICS_ENABLED: BoolEnv.default(true),
METRICS_COLLECT_DEFAULTS: BoolEnv.default(true),
Expand Down
1 change: 1 addition & 0 deletions apps/supervisor/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ class ManagedSupervisor {
nextAttemptNumber: message.run.attemptNumber,
snapshotId: message.snapshot.id,
snapshotFriendlyId: message.snapshot.friendlyId,
isPaidTier: message.billing?.currentPlan.isPaying ?? false,
});

// Disabled for now
Expand Down
39 changes: 38 additions & 1 deletion apps/supervisor/src/workloadManager/kubernetes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ type ResourceQuantities = {
[K in "cpu" | "memory" | "ephemeral-storage"]?: string;
};

interface TierConfig {
enabled: boolean;
labelKey: string;
freeValue: string;
paidValue: string;
}

export class KubernetesWorkloadManager implements WorkloadManager {
private readonly logger = new SimpleStructuredLogger("kubernetes-workload-provider");
private k8s: K8sApi;
private namespace = env.KUBERNETES_NAMESPACE;
private tierConfig: TierConfig;

constructor(private opts: WorkloadManagerOptions) {
this.k8s = createK8sApi();
this.tierConfig = this.tierSchedulingConfig;

if (opts.workloadApiDomain) {
this.logger.warn("[KubernetesWorkloadManager] ⚠️ Custom workload API domain", {
Expand All @@ -28,6 +37,34 @@ export class KubernetesWorkloadManager implements WorkloadManager {
}
}

private get tierSchedulingConfig(): TierConfig {
return {
enabled: env.ENABLE_TIER_SCHEDULING,
labelKey: env.TIER_LABEL_KEY,
freeValue: env.TIER_LABEL_VALUE_FREE,
paidValue: env.TIER_LABEL_VALUE_PAID,
};
}

private addTierScheduling(
podSpec: Omit<k8s.V1PodSpec, "containers">,
isPaidTier: boolean
): Omit<k8s.V1PodSpec, "containers"> {
if (!this.tierConfig.enabled) {
return podSpec;
}

const labelValue = isPaidTier ? this.tierConfig.paidValue : this.tierConfig.freeValue;

return {
...podSpec,
nodeSelector: {
...podSpec.nodeSelector,
[this.tierConfig.labelKey]: labelValue,
},
};
}

async create(opts: WorkloadManagerCreateOptions) {
this.logger.log("[KubernetesWorkloadManager] Creating container", { opts });

Expand All @@ -48,7 +85,7 @@ export class KubernetesWorkloadManager implements WorkloadManager {
},
},
spec: {
...this.#defaultPodSpec,
...this.addTierScheduling(this.#defaultPodSpec, opts.isPaidTier ?? false),
terminationGracePeriodSeconds: 60 * 60,
containers: [
{
Expand Down
2 changes: 2 additions & 0 deletions apps/supervisor/src/workloadManager/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ export interface WorkloadManagerCreateOptions {
runFriendlyId: string;
snapshotId: string;
snapshotFriendlyId: string;
// tier scheduling
isPaidTier?: boolean;
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import { redirectWithErrorMessage } from "~/models/message.server";
import { logger } from "~/services/logger.server";
import { setPlan } from "~/services/platform.v3.server";
import { requireUser } from "~/services/session.server";
import { engine } from "~/v3/runEngine.server";
import { cn } from "~/utils/cn";
import { sendToPlain } from "~/utils/plain.server";

Expand Down Expand Up @@ -152,7 +153,9 @@ export async function action({ request, params }: ActionFunctionArgs) {
}
}

return setPlan(organization, request, form.callerPath, payload);
return setPlan(organization, request, form.callerPath, payload, {
invalidateBillingCache: engine.invalidateBillingCache.bind(engine),
});
}

const pricingDefinitions = {
Expand Down
9 changes: 8 additions & 1 deletion apps/webapp/app/services/platform.v3.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ export async function setPlan(
organization: { id: string; slug: string },
request: Request,
callerPath: string,
plan: SetPlanBody
plan: SetPlanBody,
opts?: { invalidateBillingCache?: (orgId: string) => void }
) {
if (!client) {
throw redirectWithErrorMessage(callerPath, request, "Error setting plan");
Expand All @@ -308,6 +309,8 @@ export async function setPlan(
}
case "free_connected": {
if (result.accepted) {
// Invalidate billing cache since plan changed
opts?.invalidateBillingCache?.(organization.id);
return redirect(newProjectPath(organization, "You're on the Free plan."));
} else {
return redirectWithErrorMessage(
Expand All @@ -321,13 +324,17 @@ export async function setPlan(
return redirect(result.checkoutUrl);
}
case "updated_subscription": {
// Invalidate billing cache since subscription changed
opts?.invalidateBillingCache?.(organization.id);
return redirectWithSuccessMessage(
callerPath,
request,
"Subscription updated successfully."
);
}
case "canceled_subscription": {
// Invalidate billing cache since subscription was canceled
opts?.invalidateBillingCache?.(organization.id);
return redirectWithSuccessMessage(callerPath, request, "Subscription canceled.");
}
}
Expand Down
17 changes: 16 additions & 1 deletion apps/webapp/app/v3/runEngine.server.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { RunEngine } from "@internal/run-engine";
import { $replica, prisma } from "~/db.server";
import { env } from "~/env.server";
import { defaultMachine } from "~/services/platform.v3.server";
import { defaultMachine, getCurrentPlan } from "~/services/platform.v3.server";
import { singleton } from "~/utils/singleton";
import { allMachines } from "./machinePresets.server";
import { meter, tracer } from "./tracer.server";
Expand Down Expand Up @@ -105,6 +105,21 @@ function createRunEngine() {
SUSPENDED: env.RUN_ENGINE_TIMEOUT_SUSPENDED,
},
retryWarmStartThresholdMs: env.RUN_ENGINE_RETRY_WARM_START_THRESHOLD_MS,
billing: {
getCurrentPlan: async (orgId: string) => {
const plan = await getCurrentPlan(orgId);

if (!plan) {
return { isPaying: false };
}

if (!plan.v3Subscription) {
return { isPaying: false };
}

return { isPaying: plan.v3Subscription.isPaying };
},
},
});

return engine;
Expand Down
10 changes: 10 additions & 0 deletions internal-packages/run-engine/src/engine/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ export class RunEngine {
executionSnapshotSystem: this.executionSnapshotSystem,
runAttemptSystem: this.runAttemptSystem,
machines: this.options.machines,
billing: this.options.billing,
redisOptions: this.options.cache?.redis ?? this.options.runLock.redis,
});
}

Expand Down Expand Up @@ -1347,4 +1349,12 @@ export class RunEngine {
orgId: run.organizationId!,
}));
}

/**
* Invalidates the billing cache for an organization when their plan changes
* Runs in background and handles all errors internally
*/
invalidateBillingCache(orgId: string): void {
this.dequeueSystem.invalidateBillingCache(orgId);
}
}
89 changes: 89 additions & 0 deletions internal-packages/run-engine/src/engine/systems/dequeueSystem.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
import {
createCache,
DefaultStatefulContext,
MemoryStore,
Namespace,
RedisCacheStore,
UnkeyCache,
} from "@internal/cache";
import type { RedisOptions } from "@internal/redis";
import { startSpan } from "@internal/tracing";
import { assertExhaustive } from "@trigger.dev/core";
import { DequeuedMessage, RetryOptions } from "@trigger.dev/core/v3";
Expand All @@ -11,23 +20,66 @@ import { RunEngineOptions } from "../types.js";
import { ExecutionSnapshotSystem, getLatestExecutionSnapshot } from "./executionSnapshotSystem.js";
import { RunAttemptSystem } from "./runAttemptSystem.js";
import { SystemResources } from "./systems.js";
import { ServiceValidationError } from "../errors.js";

export type DequeueSystemOptions = {
resources: SystemResources;
machines: RunEngineOptions["machines"];
executionSnapshotSystem: ExecutionSnapshotSystem;
runAttemptSystem: RunAttemptSystem;
billing?: RunEngineOptions["billing"];
redisOptions: RedisOptions;
};

// Cache TTLs for billing information - shorter than other caches since billing can change
const BILLING_FRESH_TTL = 60000 * 5; // 5 minutes
const BILLING_STALE_TTL = 60000 * 10; // 10 minutes

export class DequeueSystem {
private readonly $: SystemResources;
private readonly executionSnapshotSystem: ExecutionSnapshotSystem;
private readonly runAttemptSystem: RunAttemptSystem;
private readonly billingCache: UnkeyCache<{
billing: { isPaying: boolean };
}>;

constructor(private readonly options: DequeueSystemOptions) {
this.$ = options.resources;
this.executionSnapshotSystem = options.executionSnapshotSystem;
this.runAttemptSystem = options.runAttemptSystem;

// Initialize billing cache
const ctx = new DefaultStatefulContext();
const memory = new MemoryStore({ persistentMap: new Map() });
const redisCacheStore = new RedisCacheStore({
name: "dequeue-system",
connection: {
...options.redisOptions,
keyPrefix: "engine:dequeue-system:cache:",
},
useModernCacheKeyBuilder: true,
});

this.billingCache = createCache({
billing: new Namespace<{ isPaying: boolean }>(ctx, {
stores: [memory, redisCacheStore],
fresh: BILLING_FRESH_TTL,
stale: BILLING_STALE_TTL,
}),
});
}

/**
* Invalidates the billing cache for an organization when their plan changes
* Runs in background and handles all errors internally
*/
invalidateBillingCache(orgId: string): void {
this.billingCache.billing.remove(orgId).catch((error) => {
this.$.logger.warn("Failed to invalidate billing cache", {
orgId,
error: error.message,
});
});
}

/**
Expand Down Expand Up @@ -380,6 +432,9 @@ export class DequeueSystem {
const currentAttemptNumber = lockedTaskRun.attemptNumber ?? 0;
const nextAttemptNumber = currentAttemptNumber + 1;

// Get billing information if available
const billing = await this.#getBillingInfo({ orgId, runId });

const newSnapshot = await this.executionSnapshotSystem.createExecutionSnapshot(
prisma,
{
Expand Down Expand Up @@ -448,6 +503,7 @@ export class DequeueSystem {
project: {
id: lockedTaskRun.projectId,
},
billing,
} satisfies DequeuedMessage;
}
);
Expand Down Expand Up @@ -612,4 +668,37 @@ export class DequeueSystem {
});
});
}

async #getBillingInfo({
orgId,
runId,
}: {
orgId: string;
runId: string;
}): Promise<{ currentPlan: { isPaying: boolean } }> {
if (!this.options.billing?.getCurrentPlan) {
return { currentPlan: { isPaying: false } };
}

const result = await this.billingCache.billing.swr(orgId, async () => {
// This is safe because options can't change at runtime
const planResult = await this.options.billing!.getCurrentPlan(orgId);

return { isPaying: planResult.isPaying };
});

if (result.err) {
throw result.err;
}

if (!result.val) {
throw new ServiceValidationError(
`Could not resolve billing information for organization ${orgId}`,
undefined,
{ orgId, runId }
);
}

return { currentPlan: { isPaying: result.val.isPaying } };
}
}
3 changes: 3 additions & 0 deletions internal-packages/run-engine/src/engine/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ export type RunEngineOptions = {
machines: Record<string, MachinePreset>;
baseCostInCents: number;
};
billing?: {
getCurrentPlan: (orgId: string) => Promise<{ isPaying: boolean }>;
};
queue: {
redis: RedisOptions;
shardCount?: number;
Expand Down
7 changes: 7 additions & 0 deletions packages/core/src/v3/schemas/runEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,12 @@ export const DequeuedMessage = z.object({
project: z.object({
id: z.string(),
}),
billing: z
.object({
currentPlan: z.object({
isPaying: z.boolean(),
}),
})
.optional(),
});
export type DequeuedMessage = z.infer<typeof DequeuedMessage>;
Loading