Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/proud-nails-grin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@trigger.dev/core": patch
---

Add optional billing info to DequeuedMessage for tiered scheduling
6 changes: 6 additions & 0 deletions apps/supervisor/src/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ const Env = z.object({
KUBERNETES_EPHEMERAL_STORAGE_SIZE_LIMIT: z.string().default("10Gi"),
KUBERNETES_EPHEMERAL_STORAGE_SIZE_REQUEST: z.string().default("2Gi"),

// Tier scheduling settings
ENABLE_TIER_SCHEDULING: BoolEnv.default(false),
TIER_LABEL_KEY: z.string().default("node.cluster.x-k8s.io/paid"),
TIER_LABEL_VALUE_FREE: z.string().default("false"),
TIER_LABEL_VALUE_PAID: z.string().default("true"),

// Metrics
METRICS_ENABLED: BoolEnv.default(true),
METRICS_COLLECT_DEFAULTS: BoolEnv.default(true),
Expand Down
1 change: 1 addition & 0 deletions apps/supervisor/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ class ManagedSupervisor {
nextAttemptNumber: message.run.attemptNumber,
snapshotId: message.snapshot.id,
snapshotFriendlyId: message.snapshot.friendlyId,
isPaidTier: message.billing?.currentPlan.isPaying ?? false,
});

// Disabled for now
Expand Down
39 changes: 38 additions & 1 deletion apps/supervisor/src/workloadManager/kubernetes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ type ResourceQuantities = {
[K in "cpu" | "memory" | "ephemeral-storage"]?: string;
};

interface TierConfig {
enabled: boolean;
labelKey: string;
freeValue: string;
paidValue: string;
}

export class KubernetesWorkloadManager implements WorkloadManager {
private readonly logger = new SimpleStructuredLogger("kubernetes-workload-provider");
private k8s: K8sApi;
private namespace = env.KUBERNETES_NAMESPACE;
private tierConfig: TierConfig;

constructor(private opts: WorkloadManagerOptions) {
this.k8s = createK8sApi();
this.tierConfig = this.tierSchedulingConfig;

if (opts.workloadApiDomain) {
this.logger.warn("[KubernetesWorkloadManager] ⚠️ Custom workload API domain", {
Expand All @@ -28,6 +37,34 @@ export class KubernetesWorkloadManager implements WorkloadManager {
}
}

private get tierSchedulingConfig(): TierConfig {
return {
enabled: env.ENABLE_TIER_SCHEDULING,
labelKey: env.TIER_LABEL_KEY,
freeValue: env.TIER_LABEL_VALUE_FREE,
paidValue: env.TIER_LABEL_VALUE_PAID,
};
}

private addTierScheduling(
podSpec: Omit<k8s.V1PodSpec, "containers">,
isPaidTier: boolean
): Omit<k8s.V1PodSpec, "containers"> {
if (!this.tierConfig.enabled) {
return podSpec;
}

const labelValue = isPaidTier ? this.tierConfig.paidValue : this.tierConfig.freeValue;

return {
...podSpec,
nodeSelector: {
...podSpec.nodeSelector,
[this.tierConfig.labelKey]: labelValue,
},
};
}

async create(opts: WorkloadManagerCreateOptions) {
this.logger.log("[KubernetesWorkloadManager] Creating container", { opts });

Expand All @@ -48,7 +85,7 @@ export class KubernetesWorkloadManager implements WorkloadManager {
},
},
spec: {
...this.#defaultPodSpec,
...this.addTierScheduling(this.#defaultPodSpec, opts.isPaidTier ?? false),
terminationGracePeriodSeconds: 60 * 60,
containers: [
{
Expand Down
2 changes: 2 additions & 0 deletions apps/supervisor/src/workloadManager/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ export interface WorkloadManagerCreateOptions {
runFriendlyId: string;
snapshotId: string;
snapshotFriendlyId: string;
// tier scheduling
isPaidTier?: boolean;
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import { redirectWithErrorMessage } from "~/models/message.server";
import { logger } from "~/services/logger.server";
import { setPlan } from "~/services/platform.v3.server";
import { requireUser } from "~/services/session.server";
import { engine } from "~/v3/runEngine.server";
import { cn } from "~/utils/cn";
import { sendToPlain } from "~/utils/plain.server";

Expand Down Expand Up @@ -152,7 +153,9 @@ export async function action({ request, params }: ActionFunctionArgs) {
}
}

return setPlan(organization, request, form.callerPath, payload);
return setPlan(organization, request, form.callerPath, payload, {
invalidateBillingCache: engine.invalidateBillingCache.bind(engine),
});
}

const pricingDefinitions = {
Expand Down
9 changes: 8 additions & 1 deletion apps/webapp/app/services/platform.v3.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ export async function setPlan(
organization: { id: string; slug: string },
request: Request,
callerPath: string,
plan: SetPlanBody
plan: SetPlanBody,
opts?: { invalidateBillingCache?: (orgId: string) => void }
) {
if (!client) {
throw redirectWithErrorMessage(callerPath, request, "Error setting plan");
Expand All @@ -308,6 +309,8 @@ export async function setPlan(
}
case "free_connected": {
if (result.accepted) {
// Invalidate billing cache since plan changed
opts?.invalidateBillingCache?.(organization.id);
return redirect(newProjectPath(organization, "You're on the Free plan."));
} else {
return redirectWithErrorMessage(
Expand All @@ -321,13 +324,17 @@ export async function setPlan(
return redirect(result.checkoutUrl);
}
case "updated_subscription": {
// Invalidate billing cache since subscription changed
opts?.invalidateBillingCache?.(organization.id);
return redirectWithSuccessMessage(
callerPath,
request,
"Subscription updated successfully."
);
}
case "canceled_subscription": {
// Invalidate billing cache since subscription was canceled
opts?.invalidateBillingCache?.(organization.id);
return redirectWithSuccessMessage(callerPath, request, "Subscription canceled.");
}
}
Expand Down
26 changes: 25 additions & 1 deletion apps/webapp/app/v3/runEngine.server.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { RunEngine } from "@internal/run-engine";
import { $replica, prisma } from "~/db.server";
import { env } from "~/env.server";
import { defaultMachine } from "~/services/platform.v3.server";
import { defaultMachine, getCurrentPlan } from "~/services/platform.v3.server";
import { singleton } from "~/utils/singleton";
import { allMachines } from "./machinePresets.server";
import { meter, tracer } from "./tracer.server";
Expand Down Expand Up @@ -105,6 +105,30 @@ function createRunEngine() {
SUSPENDED: env.RUN_ENGINE_TIMEOUT_SUSPENDED,
},
retryWarmStartThresholdMs: env.RUN_ENGINE_RETRY_WARM_START_THRESHOLD_MS,
billing: {
getCurrentPlan: async (orgId: string) => {
const plan = await getCurrentPlan(orgId);

if (!plan) {
return {
isPaying: false,
type: "free",
};
}

if (!plan.v3Subscription) {
return {
isPaying: false,
type: "free",
};
}

return {
isPaying: plan.v3Subscription.isPaying,
type: plan.v3Subscription.plan?.type ?? "free",
};
},
},
});

return engine;
Expand Down
2 changes: 2 additions & 0 deletions internal-packages/cache/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ export {
DefaultStatefulContext,
Namespace,
type Cache as UnkeyCache,
type CacheError,
} from "@unkey/cache";
export { type Result, Ok, Err } from "@unkey/error";
export { MemoryStore } from "@unkey/cache/stores";
export { RedisCacheStore } from "./stores/redis.js";
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "TaskRun" ADD COLUMN "planType" TEXT;
3 changes: 3 additions & 0 deletions internal-packages/database/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,9 @@ model TaskRun {
/// Run error
error Json?

/// Organization's billing plan type (cached for fallback when billing API fails)
planType String?

maxDurationInSeconds Int?

@@unique([oneTimeUseToken])
Expand Down
92 changes: 92 additions & 0 deletions internal-packages/run-engine/src/engine/billingCache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import {
createCache,
DefaultStatefulContext,
MemoryStore,
Namespace,
Ok,
RedisCacheStore,
type UnkeyCache,
type CacheError,
type Result,
} from "@internal/cache";
import type { RedisOptions } from "@internal/redis";
import type { Logger } from "@trigger.dev/core/logger";
import type { RunEngineOptions } from "./types.js";

// Cache TTLs for billing information - shorter than other caches since billing can change
const BILLING_FRESH_TTL = 60000 * 5; // 5 minutes
const BILLING_STALE_TTL = 60000 * 10; // 10 minutes

export type BillingPlan = {
isPaying: boolean;
type: "free" | "paid" | "enterprise";
};

export type BillingCacheOptions = {
billingOptions?: RunEngineOptions["billing"];
redisOptions: RedisOptions;
logger: Logger;
};

export class BillingCache {
private readonly cache: UnkeyCache<{
currentPlan: BillingPlan;
}>;
private readonly logger: Logger;
private readonly billingOptions?: RunEngineOptions["billing"];

constructor(options: BillingCacheOptions) {
this.logger = options.logger;
this.billingOptions = options.billingOptions;

// Initialize cache
const ctx = new DefaultStatefulContext();
const memory = new MemoryStore({ persistentMap: new Map() });
const redisCacheStore = new RedisCacheStore({
name: "billing-cache",
connection: {
...options.redisOptions,
keyPrefix: "engine:billing:cache:",
},
useModernCacheKeyBuilder: true,
});

this.cache = createCache({
currentPlan: new Namespace<BillingPlan>(ctx, {
stores: [memory, redisCacheStore],
fresh: BILLING_FRESH_TTL,
stale: BILLING_STALE_TTL,
}),
});
}

/**
* Gets the current billing plan for an organization
* Returns a Result that allows the caller to handle errors and missing values
*/
async getCurrentPlan(orgId: string): Promise<Result<BillingPlan | undefined, CacheError>> {
if (!this.billingOptions?.getCurrentPlan) {
// Return a successful result with default free plan
return Ok({ isPaying: false, type: "free" });
}

return await this.cache.currentPlan.swr(orgId, async () => {
// This is safe because options can't change at runtime
const planResult = await this.billingOptions!.getCurrentPlan(orgId);
return { isPaying: planResult.isPaying, type: planResult.type };
});
}

/**
* Invalidates the billing cache for an organization when their plan changes
* Runs in background and handles all errors internally
*/
invalidate(orgId: string): void {
this.cache.currentPlan.remove(orgId).catch((error) => {
this.logger.warn("Failed to invalidate billing cache", {
orgId,
error: error instanceof Error ? error.message : String(error),
});
});
}
}
Loading
Loading