diff --git a/apps/sim/app/_shell/providers/session-provider.tsx b/apps/sim/app/_shell/providers/session-provider.tsx index 70fe344bd0..29ab636e74 100644 --- a/apps/sim/app/_shell/providers/session-provider.tsx +++ b/apps/sim/app/_shell/providers/session-provider.tsx @@ -2,6 +2,7 @@ import type React from 'react' import { createContext, useCallback, useEffect, useMemo, useState } from 'react' +import { useQueryClient } from '@tanstack/react-query' import posthog from 'posthog-js' import { client } from '@/lib/auth/auth-client' @@ -35,12 +36,15 @@ export function SessionProvider({ children }: { children: React.ReactNode }) { const [data, setData] = useState(null) const [isPending, setIsPending] = useState(true) const [error, setError] = useState(null) + const queryClient = useQueryClient() - const loadSession = useCallback(async () => { + const loadSession = useCallback(async (bypassCache = false) => { try { setIsPending(true) setError(null) - const res = await client.getSession() + const res = bypassCache + ? await client.getSession({ query: { disableCookieCache: true } }) + : await client.getSession() setData(res?.data ?? null) } catch (e) { setError(e instanceof Error ? e : new Error('Failed to fetch session')) @@ -50,8 +54,25 @@ export function SessionProvider({ children }: { children: React.ReactNode }) { }, []) useEffect(() => { - loadSession() - }, [loadSession]) + // Check if user was redirected after plan upgrade + const params = new URLSearchParams(window.location.search) + const wasUpgraded = params.get('upgraded') === 'true' + + if (wasUpgraded) { + params.delete('upgraded') + const newUrl = params.toString() + ? `${window.location.pathname}?${params.toString()}` + : window.location.pathname + window.history.replaceState({}, '', newUrl) + } + + loadSession(wasUpgraded).then(() => { + if (wasUpgraded) { + queryClient.invalidateQueries({ queryKey: ['organizations'] }) + queryClient.invalidateQueries({ queryKey: ['subscription'] }) + } + }) + }, [loadSession, queryClient]) useEffect(() => { if (isPending || typeof posthog.identify !== 'function') { diff --git a/apps/sim/app/workspace/[workspaceId]/w/components/sidebar/components/settings-modal/components/subscription/subscription.tsx b/apps/sim/app/workspace/[workspaceId]/w/components/sidebar/components/settings-modal/components/subscription/subscription.tsx index 5eafe5b90c..c9c7995259 100644 --- a/apps/sim/app/workspace/[workspaceId]/w/components/sidebar/components/settings-modal/components/subscription/subscription.tsx +++ b/apps/sim/app/workspace/[workspaceId]/w/components/sidebar/components/settings-modal/components/subscription/subscription.tsx @@ -8,6 +8,7 @@ import { Skeleton } from '@/components/ui' import { useSession } from '@/lib/auth/auth-client' import { useSubscriptionUpgrade } from '@/lib/billing/client/upgrade' import { USAGE_THRESHOLDS } from '@/lib/billing/client/usage-visualization' +import { getEffectiveSeats } from '@/lib/billing/subscriptions/utils' import { cn } from '@/lib/core/utils/cn' import { getBaseUrl } from '@/lib/core/utils/urls' import { getUserRole } from '@/lib/workspaces/organization/utils' @@ -191,7 +192,13 @@ export function Subscription() { const [upgradeError, setUpgradeError] = useState<'pro' | 'team' | null>(null) const usageLimitRef = useRef(null) - const isLoading = isSubscriptionLoading || isUsageLimitLoading || isWorkspaceLoading + const isOrgPlan = + subscriptionData?.data?.plan === 'team' || subscriptionData?.data?.plan === 'enterprise' + const isLoading = + isSubscriptionLoading || + isUsageLimitLoading || + isWorkspaceLoading || + (isOrgPlan && isOrgBillingLoading) const subscription = { isFree: subscriptionData?.data?.plan === 'free' || !subscriptionData?.data?.plan, @@ -204,7 +211,7 @@ export function Subscription() { subscriptionData?.data?.status === 'active', plan: subscriptionData?.data?.plan || 'free', status: subscriptionData?.data?.status || 'inactive', - seats: organizationBillingData?.totalSeats ?? 0, + seats: getEffectiveSeats(subscriptionData?.data), } const usage = { @@ -445,16 +452,10 @@ export function Subscription() { ? `${subscription.seats} seats` : undefined } - current={ - subscription.isEnterprise || subscription.isTeam - ? (organizationBillingData?.totalCurrentUsage ?? usage.current) - : usage.current - } + current={usage.current} limit={ subscription.isEnterprise || subscription.isTeam - ? organizationBillingData?.totalUsageLimit || - organizationBillingData?.minimumBillingAmount || - usage.limit + ? organizationBillingData?.data?.totalUsageLimit : !subscription.isFree && (permissions.canEditUsageLimit || permissions.showTeamMemberView) ? usage.current // placeholder; rightContent will render UsageLimit @@ -468,19 +469,31 @@ export function Subscription() { { logger.info('Usage limit updated') }} diff --git a/apps/sim/app/workspace/[workspaceId]/w/components/sidebar/components/settings-modal/settings-modal.tsx b/apps/sim/app/workspace/[workspaceId]/w/components/sidebar/components/settings-modal/settings-modal.tsx index 63c6748519..811f60c811 100644 --- a/apps/sim/app/workspace/[workspaceId]/w/components/sidebar/components/settings-modal/settings-modal.tsx +++ b/apps/sim/app/workspace/[workspaceId]/w/components/sidebar/components/settings-modal/settings-modal.tsx @@ -174,6 +174,7 @@ export function SettingsModal({ open, onOpenChange }: SettingsModalProps) { const userEmail = session?.user?.email const userId = session?.user?.id + const userRole = getUserRole(activeOrganization, userEmail) const isOwner = userRole === 'owner' const isAdmin = userRole === 'admin' diff --git a/apps/sim/lib/auth/auth.ts b/apps/sim/lib/auth/auth.ts index 43e4c919ad..38d9e80754 100644 --- a/apps/sim/lib/auth/auth.ts +++ b/apps/sim/lib/auth/auth.ts @@ -2184,8 +2184,22 @@ export const auth = betterAuth({ status: subscription.status, }) - const resolvedSubscription = - await ensureOrganizationForTeamSubscription(subscription) + let resolvedSubscription = subscription + try { + resolvedSubscription = await ensureOrganizationForTeamSubscription(subscription) + } catch (orgError) { + logger.error( + '[onSubscriptionComplete] Failed to ensure organization for team subscription', + { + subscriptionId: subscription.id, + referenceId: subscription.referenceId, + plan: subscription.plan, + error: orgError instanceof Error ? orgError.message : String(orgError), + stack: orgError instanceof Error ? orgError.stack : undefined, + } + ) + throw orgError + } await handleSubscriptionCreated(resolvedSubscription) @@ -2206,8 +2220,22 @@ export const auth = betterAuth({ plan: subscription.plan, }) - const resolvedSubscription = - await ensureOrganizationForTeamSubscription(subscription) + let resolvedSubscription = subscription + try { + resolvedSubscription = await ensureOrganizationForTeamSubscription(subscription) + } catch (orgError) { + logger.error( + '[onSubscriptionUpdate] Failed to ensure organization for team subscription', + { + subscriptionId: subscription.id, + referenceId: subscription.referenceId, + plan: subscription.plan, + error: orgError instanceof Error ? orgError.message : String(orgError), + stack: orgError instanceof Error ? orgError.stack : undefined, + } + ) + throw orgError + } try { await syncSubscriptionUsageLimits(resolvedSubscription) diff --git a/apps/sim/lib/billing/client/upgrade.ts b/apps/sim/lib/billing/client/upgrade.ts index 953f585a94..acd7e651ce 100644 --- a/apps/sim/lib/billing/client/upgrade.ts +++ b/apps/sim/lib/billing/client/upgrade.ts @@ -81,12 +81,15 @@ export function useSubscriptionUpgrade() { } const currentUrl = `${window.location.origin}${window.location.pathname}` + const successUrlObj = new URL(window.location.href) + successUrlObj.searchParams.set('upgraded', 'true') + const successUrl = successUrlObj.toString() try { const upgradeParams = { plan: targetPlan, referenceId, - successUrl: currentUrl, + successUrl, cancelUrl: currentUrl, ...(targetPlan === 'team' && { seats: CONSTANTS.INITIAL_TEAM_SEATS }), } as const diff --git a/apps/sim/lib/billing/organization.ts b/apps/sim/lib/billing/organization.ts index 579dfbd886..eff6a03c0c 100644 --- a/apps/sim/lib/billing/organization.ts +++ b/apps/sim/lib/billing/organization.ts @@ -1,5 +1,11 @@ import { db } from '@sim/db' -import * as schema from '@sim/db/schema' +import { + member, + organization, + session, + subscription as subscriptionTable, + user, +} from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq } from 'drizzle-orm' import { getPlanPricing } from '@/lib/billing/core/billing' @@ -20,16 +26,16 @@ type SubscriptionData = { */ async function getUserOwnedOrganization(userId: string): Promise { const existingMemberships = await db - .select({ organizationId: schema.member.organizationId }) - .from(schema.member) - .where(and(eq(schema.member.userId, userId), eq(schema.member.role, 'owner'))) + .select({ organizationId: member.organizationId }) + .from(member) + .where(and(eq(member.userId, userId), eq(member.role, 'owner'))) .limit(1) if (existingMemberships.length > 0) { const [existingOrg] = await db - .select({ id: schema.organization.id }) - .from(schema.organization) - .where(eq(schema.organization.id, existingMemberships[0].organizationId)) + .select({ id: organization.id }) + .from(organization) + .where(eq(organization.id, existingMemberships[0].organizationId)) .limit(1) return existingOrg?.id || null @@ -40,6 +46,8 @@ async function getUserOwnedOrganization(userId: string): Promise /** * Create a new organization and add user as owner + * Uses transaction to ensure org + member are created atomically + * Also updates user's active sessions to set the new org as active */ async function createOrganizationWithOwner( userId: string, @@ -48,32 +56,40 @@ async function createOrganizationWithOwner( metadata: Record = {} ): Promise { const orgId = `org_${crypto.randomUUID()}` + let sessionsUpdated = 0 - const [newOrg] = await db - .insert(schema.organization) - .values({ + await db.transaction(async (tx) => { + await tx.insert(organization).values({ id: orgId, name: organizationName, slug: organizationSlug, metadata, }) - .returning({ id: schema.organization.id }) - - // Add user as owner/admin of the organization - await db.insert(schema.member).values({ - id: crypto.randomUUID(), - userId: userId, - organizationId: newOrg.id, - role: 'owner', + + await tx.insert(member).values({ + id: crypto.randomUUID(), + userId: userId, + organizationId: orgId, + role: 'owner', + }) + + const updatedSessions = await tx + .update(session) + .set({ activeOrganizationId: orgId }) + .where(eq(session.userId, userId)) + .returning({ id: session.id }) + + sessionsUpdated = updatedSessions.length }) logger.info('Created organization with owner', { userId, - organizationId: newOrg.id, + organizationId: orgId, organizationName, + sessionsUpdated, }) - return newOrg.id + return orgId } export async function createOrganizationForTeamPlan( @@ -132,12 +148,12 @@ export async function ensureOrganizationForTeamSubscription( const existingMembership = await db .select({ - id: schema.member.id, - organizationId: schema.member.organizationId, - role: schema.member.role, + id: member.id, + organizationId: member.organizationId, + role: member.role, }) - .from(schema.member) - .where(eq(schema.member.userId, userId)) + .from(member) + .where(eq(member.userId, userId)) .limit(1) if (existingMembership.length > 0) { @@ -148,10 +164,17 @@ export async function ensureOrganizationForTeamSubscription( organizationId: membership.organizationId, }) - await db - .update(schema.subscription) - .set({ referenceId: membership.organizationId }) - .where(eq(schema.subscription.id, subscription.id)) + await db.transaction(async (tx) => { + await tx + .update(subscriptionTable) + .set({ referenceId: membership.organizationId }) + .where(eq(subscriptionTable.id, subscription.id)) + + await tx + .update(session) + .set({ activeOrganizationId: membership.organizationId }) + .where(eq(session.userId, userId)) + }) return { ...subscription, referenceId: membership.organizationId } } @@ -165,9 +188,9 @@ export async function ensureOrganizationForTeamSubscription( } const [userData] = await db - .select({ name: schema.user.name, email: schema.user.email }) - .from(schema.user) - .where(eq(schema.user.id, userId)) + .select({ name: user.name, email: user.email }) + .from(user) + .where(eq(user.id, userId)) .limit(1) const orgId = await createOrganizationForTeamPlan( @@ -177,9 +200,9 @@ export async function ensureOrganizationForTeamSubscription( ) await db - .update(schema.subscription) + .update(subscriptionTable) .set({ referenceId: orgId }) - .where(eq(schema.subscription.id, subscription.id)) + .where(eq(subscriptionTable.id, subscription.id)) logger.info('Created organization and updated subscription referenceId', { subscriptionId: subscription.id, @@ -204,9 +227,9 @@ export async function syncSubscriptionUsageLimits(subscription: SubscriptionData // Check if this is a user or organization subscription const users = await db - .select({ id: schema.user.id }) - .from(schema.user) - .where(eq(schema.user.id, subscription.referenceId)) + .select({ id: user.id }) + .from(user) + .where(eq(user.id, subscription.referenceId)) .limit(1) if (users.length > 0) { @@ -230,9 +253,9 @@ export async function syncSubscriptionUsageLimits(subscription: SubscriptionData // Only set if not already set or if updating to a higher value based on seats const orgData = await db - .select({ orgUsageLimit: schema.organization.orgUsageLimit }) - .from(schema.organization) - .where(eq(schema.organization.id, organizationId)) + .select({ orgUsageLimit: organization.orgUsageLimit }) + .from(organization) + .where(eq(organization.id, organizationId)) .limit(1) const currentLimit = @@ -243,12 +266,12 @@ export async function syncSubscriptionUsageLimits(subscription: SubscriptionData // Update if no limit set, or if new seat-based minimum is higher if (currentLimit < orgLimit) { await db - .update(schema.organization) + .update(organization) .set({ orgUsageLimit: orgLimit.toFixed(2), updatedAt: new Date(), }) - .where(eq(schema.organization.id, organizationId)) + .where(eq(organization.id, organizationId)) logger.info('Set organization usage limit for team plan', { organizationId, @@ -262,17 +285,17 @@ export async function syncSubscriptionUsageLimits(subscription: SubscriptionData // Sync usage limits for all members const members = await db - .select({ userId: schema.member.userId }) - .from(schema.member) - .where(eq(schema.member.organizationId, organizationId)) + .select({ userId: member.userId }) + .from(member) + .where(eq(member.organizationId, organizationId)) if (members.length > 0) { - for (const member of members) { + for (const m of members) { try { - await syncUsageLimitsFromSubscription(member.userId) + await syncUsageLimitsFromSubscription(m.userId) } catch (memberError) { logger.error('Failed to sync usage limits for organization member', { - userId: member.userId, + userId: m.userId, organizationId, subscriptionId: subscription.id, error: memberError,