diff --git a/apps/api/src/comments/comment-mention-notifier.service.ts b/apps/api/src/comments/comment-mention-notifier.service.ts index 447d61b62..bfcdecbbb 100644 --- a/apps/api/src/comments/comment-mention-notifier.service.ts +++ b/apps/api/src/comments/comment-mention-notifier.service.ts @@ -30,6 +30,168 @@ function extractMentionedUserIds(content: string | null): string[] { } import { CommentEntityType } from '@db'; +function getAppBaseUrl(): string { + return ( + process.env.NEXT_PUBLIC_APP_URL ?? + process.env.BETTER_AUTH_URL ?? + 'https://app.trycomp.ai' + ); +} + +function getAllowedOrigins(): string[] { + const candidates = [ + process.env.NEXT_PUBLIC_APP_URL, + process.env.BETTER_AUTH_URL, + 'https://app.trycomp.ai', + ].filter(Boolean) as string[]; + + const origins = new Set(); + for (const candidate of candidates) { + try { + origins.add(new URL(candidate).origin); + } catch { + // ignore invalid env values + } + } + + return [...origins]; +} + +function tryNormalizeContextUrl(params: { + organizationId: string; + contextUrl?: string; +}): string | null { + const { organizationId, contextUrl } = params; + if (!contextUrl) return null; + + try { + const url = new URL(contextUrl); + const allowedOrigins = new Set(getAllowedOrigins()); + if (!allowedOrigins.has(url.origin)) return null; + + // Ensure the URL is for the same org so we don't accidentally deep-link elsewhere. + // Use startsWith to prevent path traversal attacks (e.g., /attacker_org/victim_org/) + if (!url.pathname.startsWith(`/${organizationId}/`)) return null; + + return url.toString(); + } catch { + return null; + } +} + +async function buildFallbackCommentContext(params: { + organizationId: string; + entityType: CommentEntityType; + entityId: string; +}): Promise<{ + entityName: string; + entityRoutePath: string; + commentUrl: string; +} | null> { + const { organizationId, entityType, entityId } = params; + const appUrl = getAppBaseUrl(); + + if (entityType === CommentEntityType.task) { + // CommentEntityType.task can be: + // - TaskItem id (preferred) + // - Task id (legacy) + // Use findFirst with organizationId to ensure entity belongs to correct org + const taskItem = await db.taskItem.findFirst({ + where: { id: entityId, organizationId }, + select: { title: true, entityType: true, entityId: true }, + }); + + if (taskItem) { + const parentRoutePath = taskItem.entityType === 'vendor' ? 'vendors' : 'risk'; + const url = new URL( + `${appUrl}/${organizationId}/${parentRoutePath}/${taskItem.entityId}`, + ); + url.searchParams.set('taskItemId', entityId); + url.hash = 'task-items'; + + return { + entityName: taskItem.title || 'Task', + entityRoutePath: parentRoutePath, + commentUrl: url.toString(), + }; + } + + const task = await db.task.findFirst({ + where: { id: entityId, organizationId }, + select: { title: true }, + }); + + if (!task) { + // Entity not found in this organization - do not send notification + return null; + } + + const url = new URL(`${appUrl}/${organizationId}/tasks/${entityId}`); + + return { + entityName: task.title || 'Task', + entityRoutePath: 'tasks', + commentUrl: url.toString(), + }; + } + + if (entityType === CommentEntityType.vendor) { + const vendor = await db.vendor.findFirst({ + where: { id: entityId, organizationId }, + select: { name: true }, + }); + + if (!vendor) { + return null; + } + + const url = new URL(`${appUrl}/${organizationId}/vendors/${entityId}`); + + return { + entityName: vendor.name || 'Vendor', + entityRoutePath: 'vendors', + commentUrl: url.toString(), + }; + } + + if (entityType === CommentEntityType.risk) { + const risk = await db.risk.findFirst({ + where: { id: entityId, organizationId }, + select: { title: true }, + }); + + if (!risk) { + return null; + } + + const url = new URL(`${appUrl}/${organizationId}/risk/${entityId}`); + + return { + entityName: risk.title || 'Risk', + entityRoutePath: 'risk', + commentUrl: url.toString(), + }; + } + + // CommentEntityType.policy + const policy = await db.policy.findFirst({ + where: { id: entityId, organizationId }, + select: { name: true }, + }); + + if (!policy) { + return null; + } + + const url = new URL(`${appUrl}/${organizationId}/policies/${entityId}`); + + return { + entityName: policy.name || 'Policy', + entityRoutePath: 'policies', + commentUrl: url.toString(), + }; +} + @Injectable() export class CommentMentionNotifierService { private readonly logger = new Logger(CommentMentionNotifierService.name); @@ -45,6 +207,7 @@ export class CommentMentionNotifierService { commentContent: string; entityType: CommentEntityType; entityId: string; + contextUrl?: string; mentionedUserIds: string[]; mentionedByUserId: string; }): Promise { @@ -54,6 +217,7 @@ export class CommentMentionNotifierService { commentContent, entityType, entityId, + contextUrl, mentionedUserIds, mentionedByUserId, } = params; @@ -62,14 +226,6 @@ export class CommentMentionNotifierService { return; } - // Only send notifications for task comments - if (entityType !== CommentEntityType.task) { - this.logger.log( - `Skipping comment mention notifications: only task comments are supported (entityType: ${entityType})`, - ); - return; - } - try { // Get the user who mentioned others const mentionedByUser = await db.user.findUnique({ @@ -90,31 +246,27 @@ export class CommentMentionNotifierService { }, }); - // Get entity name for context (only for task comments) - const taskItem = await db.taskItem.findUnique({ - where: { id: entityId }, - select: { title: true, entityType: true, entityId: true }, + const normalizedContextUrl = tryNormalizeContextUrl({ + organizationId, + contextUrl, }); - const entityName = taskItem?.title || 'Unknown Task'; - // For task comments, we need to get the parent entity route - let entityRoutePath = ''; - if (taskItem?.entityType === 'risk') { - entityRoutePath = 'risk'; - } else if (taskItem?.entityType === 'vendor') { - entityRoutePath = 'vendors'; + const fallback = await buildFallbackCommentContext({ + organizationId, + entityType, + entityId, + }); + + // If entity not found in this organization, skip notifications for security + if (!fallback) { + this.logger.warn( + `Skipping comment mention notifications: entity ${entityId} (${entityType}) not found in organization ${organizationId}`, + ); + return; } - // Build comment URL (only for task comments) - const appUrl = - process.env.NEXT_PUBLIC_APP_URL ?? - process.env.BETTER_AUTH_URL ?? - 'https://app.trycomp.ai'; - - // For task comments, link to the task item's parent entity - const parentRoutePath = taskItem?.entityType === 'vendor' ? 'vendors' : 'risk'; - const commentUrl = taskItem - ? `${appUrl}/${organizationId}/${parentRoutePath}/${taskItem.entityId}?taskItemId=${entityId}#task-items` - : ''; + const entityName = fallback.entityName; + const entityRoutePath = fallback.entityRoutePath; + const commentUrl = normalizedContextUrl ?? fallback.commentUrl; const mentionedByName = mentionedByUser.name || mentionedByUser.email || 'Someone'; diff --git a/apps/api/src/comments/comments.controller.ts b/apps/api/src/comments/comments.controller.ts index 15949d70f..c2d9f5937 100644 --- a/apps/api/src/comments/comments.controller.ts +++ b/apps/api/src/comments/comments.controller.ts @@ -163,6 +163,7 @@ export class CommentsController { commentId, userId, updateCommentDto.content, + updateCommentDto.contextUrl, ); } diff --git a/apps/api/src/comments/comments.service.ts b/apps/api/src/comments/comments.service.ts index 5cb7e578e..d78b66cd5 100644 --- a/apps/api/src/comments/comments.service.ts +++ b/apps/api/src/comments/comments.service.ts @@ -279,6 +279,7 @@ export class CommentsService { commentContent: createCommentDto.content, entityType: createCommentDto.entityType, entityId: createCommentDto.entityId, + contextUrl: createCommentDto.contextUrl, mentionedUserIds, mentionedByUserId: userId, }); @@ -315,6 +316,7 @@ export class CommentsService { commentId: string, userId: string, content: string, + contextUrl?: string, ): Promise { try { // Get comment and verify ownership/permissions @@ -378,6 +380,7 @@ export class CommentsService { commentContent: content, entityType: existingComment.entityType, entityId: existingComment.entityId, + contextUrl, mentionedUserIds: newlyMentionedUserIds, mentionedByUserId: userId, }); diff --git a/apps/api/src/comments/dto/create-comment.dto.ts b/apps/api/src/comments/dto/create-comment.dto.ts index 086f31b4d..f83266ae1 100644 --- a/apps/api/src/comments/dto/create-comment.dto.ts +++ b/apps/api/src/comments/dto/create-comment.dto.ts @@ -39,6 +39,19 @@ export class CreateCommentDto { @IsEnum(CommentEntityType) entityType: CommentEntityType; + @ApiProperty({ + description: + 'Optional URL of the page where the comment was created, used for deep-linking in notifications', + example: + 'https://app.trycomp.ai/org_abc123/vendors/vnd_abc123?taskItemId=tki_abc123#task-items', + required: false, + maxLength: 2048, + }) + @IsOptional() + @IsString() + @MaxLength(2048) + contextUrl?: string; + @ApiProperty({ description: 'Optional attachments to include with the comment', type: [UploadAttachmentDto], diff --git a/apps/api/src/comments/dto/update-comment.dto.ts b/apps/api/src/comments/dto/update-comment.dto.ts index 883438837..00b24d3c9 100644 --- a/apps/api/src/comments/dto/update-comment.dto.ts +++ b/apps/api/src/comments/dto/update-comment.dto.ts @@ -12,6 +12,19 @@ export class UpdateCommentDto { @MaxLength(2000) content: string; + @ApiProperty({ + description: + 'Optional URL of the page where the comment was updated, used for deep-linking in notifications', + example: + 'https://app.trycomp.ai/org_abc123/risk/rsk_abc123?taskItemId=tki_abc123#task-items', + required: false, + maxLength: 2048, + }) + @IsOptional() + @IsString() + @MaxLength(2048) + contextUrl?: string; + @ApiProperty({ description: 'User ID of the comment author (required for API key auth, ignored for JWT auth)', diff --git a/apps/api/src/vendors/vendors.service.ts b/apps/api/src/vendors/vendors.service.ts index e5237802f..10e277eb2 100644 --- a/apps/api/src/vendors/vendors.service.ts +++ b/apps/api/src/vendors/vendors.service.ts @@ -82,6 +82,20 @@ export class VendorsService { id, organizationId, }, + include: { + assignee: { + include: { + user: { + select: { + id: true, + name: true, + email: true, + image: true, + }, + }, + }, + }, + }, }); if (!vendor) { @@ -90,8 +104,51 @@ export class VendorsService { ); } + // Fetch risk assessment from GlobalVendors if vendor has a website + const domain = extractDomain(vendor.website); + let globalVendorData: { + website: string; + riskAssessmentData: Prisma.JsonValue; + riskAssessmentVersion: string | null; + riskAssessmentUpdatedAt: Date | null; + } | null = null; + + if (domain) { + const duplicates = await db.globalVendors.findMany({ + where: { + website: { + contains: domain, + }, + }, + select: { + website: true, + riskAssessmentData: true, + riskAssessmentVersion: true, + riskAssessmentUpdatedAt: true, + }, + orderBy: [ + { riskAssessmentUpdatedAt: 'desc' }, + { createdAt: 'desc' }, + ], + }); + + // Prefer record WITH risk assessment data (most recent) + globalVendorData = + duplicates.find((gv) => gv.riskAssessmentData !== null) ?? + duplicates[0] ?? + null; + } + + // Merge GlobalVendors risk assessment data into response + const vendorWithRiskAssessment = { + ...vendor, + riskAssessmentData: globalVendorData?.riskAssessmentData ?? null, + riskAssessmentVersion: globalVendorData?.riskAssessmentVersion ?? null, + riskAssessmentUpdatedAt: globalVendorData?.riskAssessmentUpdatedAt ?? null, + }; + this.logger.log(`Retrieved vendor: ${vendor.name} (${id})`); - return vendor; + return vendorWithRiskAssessment; } catch (error) { if (error instanceof NotFoundException) { throw error; diff --git a/apps/app/src/app/(app)/[orgId]/layout.tsx b/apps/app/src/app/(app)/[orgId]/layout.tsx index 99d1b69ea..d14f95951 100644 --- a/apps/app/src/app/(app)/[orgId]/layout.tsx +++ b/apps/app/src/app/(app)/[orgId]/layout.tsx @@ -41,9 +41,12 @@ export default async function Layout({ const isCollapsed = cookieStore.get('sidebar-collapsed')?.value === 'true'; let publicAccessToken = cookieStore.get('publicAccessToken')?.value || undefined; + // Get headers once to avoid multiple async calls + const requestHeaders = await headers(); + // Check if user has access to this organization const session = await auth.api.getSession({ - headers: await headers(), + headers: requestHeaders, }); if (!session) { @@ -74,6 +77,23 @@ export default async function Layout({ return redirect('/auth/unauthorized'); } + // Sync activeOrganizationId BEFORE any redirects that might use it + // This ensures session.activeOrganizationId is always correct for users with multiple orgs + const currentActiveOrgId = session.session.activeOrganizationId; + if (!currentActiveOrgId || currentActiveOrgId !== requestedOrgId) { + try { + await auth.api.setActiveOrganization({ + headers: requestHeaders, + body: { + organizationId: requestedOrgId, + }, + }); + } catch (error) { + console.error('[Layout] Failed to sync activeOrganizationId:', error); + // Continue anyway - the URL params are the source of truth for this request + } + } + const roles = parseRolesString(member.role); const hasAccess = roles.includes(Role.owner) || roles.includes(Role.admin) || roles.includes(Role.auditor); diff --git a/apps/app/src/app/(app)/[orgId]/policies/[policyId]/components/PolicyPage.tsx b/apps/app/src/app/(app)/[orgId]/policies/[policyId]/components/PolicyPage.tsx index ac3304831..f3bd72d21 100644 --- a/apps/app/src/app/(app)/[orgId]/policies/[policyId]/components/PolicyPage.tsx +++ b/apps/app/src/app/(app)/[orgId]/policies/[policyId]/components/PolicyPage.tsx @@ -13,6 +13,7 @@ export default function PolicyPage({ allControls, isPendingApproval, policyId, + organizationId, logs, }: { policy: (Policy & { approver: (Member & { user: User }) | null }) | null; @@ -21,6 +22,8 @@ export default function PolicyPage({ allControls: Control[]; isPendingApproval: boolean; policyId: string; + /** Organization ID - required for correct org context in comments */ + organizationId: string; logs: AuditLogWithRelations[]; }) { return ( @@ -42,7 +45,7 @@ export default function PolicyPage({ - + ); } diff --git a/apps/app/src/app/(app)/[orgId]/policies/[policyId]/page.tsx b/apps/app/src/app/(app)/[orgId]/policies/[policyId]/page.tsx index f1892e377..b3b412d65 100644 --- a/apps/app/src/app/(app)/[orgId]/policies/[policyId]/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/policies/[policyId]/page.tsx @@ -29,6 +29,7 @@ export default async function PolicyDetails({ toast.success('Regeneration triggered. This may take a moment.'), + onSuccess: () => { + toast.success('Regeneration triggered. This may take a moment.'); + // Trigger SWR revalidation for risk detail, list views, and comments + refreshRisk(); + globalMutate( + (key) => Array.isArray(key) && key[0] === 'risks', + undefined, + { revalidate: true }, + ); + // Invalidate comments cache for this risk + globalMutate( + (key) => typeof key === 'string' && key.includes(`/v1/comments`) && key.includes(riskId), + undefined, + { revalidate: true }, + ); + }, onError: () => toast.error('Failed to trigger mitigation regeneration'), }); diff --git a/apps/app/src/app/(app)/[orgId]/risk/[riskId]/components/RiskPageClient.tsx b/apps/app/src/app/(app)/[orgId]/risk/[riskId]/components/RiskPageClient.tsx new file mode 100644 index 000000000..8695ea269 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/risk/[riskId]/components/RiskPageClient.tsx @@ -0,0 +1,97 @@ +'use client'; + +import { Comments } from '@/components/comments/Comments'; +import { InherentRiskChart } from '@/components/risks/charts/InherentRiskChart'; +import { ResidualRiskChart } from '@/components/risks/charts/ResidualRiskChart'; +import { RiskOverview } from '@/components/risks/risk-overview'; +import { TaskItems } from '@/components/task-items/TaskItems'; +import { useRisk, type RiskResponse } from '@/hooks/use-risks'; +import { CommentEntityType } from '@db'; +import type { Member, Risk, User } from '@db'; +import { useMemo } from 'react'; + +type RiskWithAssignee = Risk & { + assignee: { user: User } | null; +}; + +/** + * Normalize API response to match Prisma types + * API returns dates as strings, Prisma returns Date objects + */ +function normalizeRisk(apiRisk: RiskResponse): RiskWithAssignee { + return { + ...apiRisk, + createdAt: new Date(apiRisk.createdAt), + updatedAt: new Date(apiRisk.updatedAt), + assignee: apiRisk.assignee + ? { + ...apiRisk.assignee, + user: apiRisk.assignee.user as User, + } + : null, + } as unknown as RiskWithAssignee; +} + +interface RiskPageClientProps { + riskId: string; + orgId: string; + initialRisk: RiskWithAssignee; + assignees: (Member & { user: User })[]; + isViewingTask: boolean; +} + +/** + * Client component for risk detail page content + * Uses SWR for real-time updates and caching + * + * Benefits: + * - Instant initial render (uses server-fetched data) + * - Real-time updates via polling (5s interval) + * - Mutations trigger automatic refresh via mutate() + */ +export function RiskPageClient({ + riskId, + orgId, + initialRisk, + assignees, + isViewingTask, +}: RiskPageClientProps) { + // Use SWR for real-time updates with polling + const { risk: swrRisk, isLoading } = useRisk(riskId, { + organizationId: orgId, + }); + + // Normalize and memoize the risk data + // Use SWR data when available, fall back to initial data + const risk = useMemo(() => { + if (swrRisk) { + return normalizeRisk(swrRisk); + } + return initialRisk; + }, [swrRisk, initialRisk]); + + return ( +
+ {!isViewingTask && ( + <> + +
+ + +
+ + )} + + {!isViewingTask && ( + + )} +
+ ); +} + +/** + * Export the risk mutate function for use by mutation components + * Call this after updating risk data to trigger SWR revalidation + */ +export { useRisk } from '@/hooks/use-risks'; + diff --git a/apps/app/src/app/(app)/[orgId]/risk/[riskId]/page.tsx b/apps/app/src/app/(app)/[orgId]/risk/[riskId]/page.tsx index 02dfc6aca..08cdc0059 100644 --- a/apps/app/src/app/(app)/[orgId]/risk/[riskId]/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/risk/[riskId]/page.tsx @@ -1,16 +1,12 @@ import PageWithBreadcrumb from '@/components/pages/PageWithBreadcrumb'; -import { InherentRiskChart } from '@/components/risks/charts/InherentRiskChart'; -import { ResidualRiskChart } from '@/components/risks/charts/ResidualRiskChart'; -import { RiskOverview } from '@/components/risks/risk-overview'; import { auth } from '@/utils/auth'; -import { CommentEntityType, db } from '@db'; +import { db } from '@db'; import type { Metadata } from 'next'; import { headers } from 'next/headers'; import { redirect } from 'next/navigation'; import { cache } from 'react'; -import { Comments } from '../../../../../components/comments/Comments'; -import { TaskItems } from '../../../../../components/task-items/TaskItems'; import { RiskActions } from './components/RiskActions'; +import { RiskPageClient } from './components/RiskPageClient'; interface PageProps { searchParams: Promise<{ @@ -24,16 +20,23 @@ interface PageProps { params: Promise<{ riskId: string; orgId: string }>; } +/** + * Risk detail page - server component + * Fetches initial data server-side for fast first render + * Passes data to RiskPageClient which uses SWR for real-time updates + */ export default async function RiskPage({ searchParams, params }: PageProps) { const { riskId, orgId } = await params; const { taskItemId } = await searchParams; const risk = await getRisk(riskId); const assignees = await getAssignees(); + if (!risk) { redirect('/'); } const shortTaskId = (id: string) => id.slice(-6).toUpperCase(); + const isViewingTask = Boolean(taskItemId); return ( } + headerRight={} > -
- {!taskItemId && ( - <> - -
- - -
- - )} - - {!taskItemId && ( - - )} -
+
); } diff --git a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/SingleTask.tsx b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/SingleTask.tsx index 5a696d5a9..67f2a669b 100644 --- a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/SingleTask.tsx +++ b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/SingleTask.tsx @@ -215,6 +215,7 @@ export function SingleTask({ diff --git a/apps/app/src/app/(app)/[orgId]/vendors/(overview)/components/VendorDeleteCell.tsx b/apps/app/src/app/(app)/[orgId]/vendors/(overview)/components/VendorDeleteCell.tsx index d40c9f6d6..ee7c18f70 100644 --- a/apps/app/src/app/(app)/[orgId]/vendors/(overview)/components/VendorDeleteCell.tsx +++ b/apps/app/src/app/(app)/[orgId]/vendors/(overview)/components/VendorDeleteCell.tsx @@ -12,6 +12,7 @@ import { Button } from '@comp/ui/button'; import { Trash2 } from 'lucide-react'; import * as React from 'react'; import { toast } from 'sonner'; +import { useSWRConfig } from 'swr'; import { deleteVendor } from '../actions/deleteVendor'; import type { GetVendorsResult } from '../data/queries'; @@ -22,6 +23,7 @@ interface VendorDeleteCellProps { } export const VendorDeleteCell: React.FC = ({ vendor }) => { + const { mutate } = useSWRConfig(); const [isRemoveAlertOpen, setIsRemoveAlertOpen] = React.useState(false); const [isDeleting, setIsDeleting] = React.useState(false); @@ -34,6 +36,12 @@ export const VendorDeleteCell: React.FC = ({ vendor }) => if (response?.data?.success) { toast.success(`Vendor "${vendor.name}" has been deleted.`); setIsRemoveAlertOpen(false); + // Invalidate all vendors SWR caches (any key starting with 'vendors') + mutate( + (key) => Array.isArray(key) && key[0] === 'vendors', + undefined, + { revalidate: true }, + ); } else { toast.error(String(response?.data?.error) || 'Failed to delete vendor.'); } diff --git a/apps/app/src/app/(app)/[orgId]/vendors/(overview)/components/VendorsTable.tsx b/apps/app/src/app/(app)/[orgId]/vendors/(overview)/components/VendorsTable.tsx index 9c107e1b9..89127ebe3 100644 --- a/apps/app/src/app/(app)/[orgId]/vendors/(overview)/components/VendorsTable.tsx +++ b/apps/app/src/app/(app)/[orgId]/vendors/(overview)/components/VendorsTable.tsx @@ -109,10 +109,11 @@ export function VendorsTable({ return await callGetVendorsAction({ orgId, searchParams: currentSearchParams }); }, [orgId, currentSearchParams]); - // Use SWR to fetch vendors with polling when onboarding is active + // Use SWR to fetch vendors with polling for real-time updates + // Poll faster during onboarding, slower otherwise const { data: vendorsData } = useSWR(swrKey, fetcher, { fallbackData: { data: initialVendors, pageCount: initialPageCount }, - refreshInterval: isActive ? 1000 : 0, // Poll every 1 second when onboarding is active + refreshInterval: isActive ? 1000 : 5000, // 1s during onboarding, 5s otherwise revalidateOnFocus: false, revalidateOnReconnect: true, keepPreviousData: true, diff --git a/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/components/VendorActions.tsx b/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/components/VendorActions.tsx index c8c5f6097..fd4aab0fb 100644 --- a/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/components/VendorActions.tsx +++ b/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/components/VendorActions.tsx @@ -1,6 +1,7 @@ 'use client'; import { regenerateVendorMitigationAction } from '@/app/(app)/[orgId]/vendors/[vendorId]/actions/regenerate-vendor-mitigation'; +import { useVendor } from '@/hooks/use-vendors'; import { Button } from '@comp/ui/button'; import { Dialog, @@ -21,12 +22,34 @@ import { useAction } from 'next-safe-action/hooks'; import { useQueryState } from 'nuqs'; import { useState } from 'react'; import { toast } from 'sonner'; +import { useSWRConfig } from 'swr'; -export function VendorActions({ vendorId }: { vendorId: string }) { +export function VendorActions({ vendorId, orgId }: { vendorId: string; orgId: string }) { + const { mutate: globalMutate } = useSWRConfig(); const [_, setOpen] = useQueryState('vendor-overview-sheet'); const [isConfirmOpen, setIsConfirmOpen] = useState(false); + + // Get SWR mutate function to refresh vendor data after mutations + // Pass orgId to ensure same cache key as VendorPageClient + const { mutate: refreshVendor } = useVendor(vendorId, { organizationId: orgId }); + const regenerate = useAction(regenerateVendorMitigationAction, { - onSuccess: () => toast.success('Regeneration triggered. This may take a moment.'), + onSuccess: () => { + toast.success('Regeneration triggered. This may take a moment.'); + // Trigger SWR revalidation for vendor detail, list views, and comments + refreshVendor(); + globalMutate( + (key) => Array.isArray(key) && key[0] === 'vendors', + undefined, + { revalidate: true }, + ); + // Invalidate comments cache for this vendor + globalMutate( + (key) => typeof key === 'string' && key.includes(`/v1/comments`) && key.includes(vendorId), + undefined, + { revalidate: true }, + ); + }, onError: () => toast.error('Failed to trigger mitigation regeneration'), }); diff --git a/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/components/VendorPageClient.tsx b/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/components/VendorPageClient.tsx new file mode 100644 index 000000000..60ba8c80a --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/components/VendorPageClient.tsx @@ -0,0 +1,115 @@ +'use client'; + +import { Comments } from '@/components/comments/Comments'; +import { TaskItems } from '@/components/task-items/TaskItems'; +import { useVendor, type VendorResponse } from '@/hooks/use-vendors'; +import { CommentEntityType } from '@db'; +import type { Member, User, Vendor } from '@db'; +import type { Prisma } from '@prisma/client'; +import { useMemo } from 'react'; +import { SecondaryFields } from './secondary-fields/secondary-fields'; +import { VendorHeader } from './VendorHeader'; +import { VendorInherentRiskChart } from './VendorInherentRiskChart'; +import { VendorResidualRiskChart } from './VendorResidualRiskChart'; +import { VendorTabs } from './VendorTabs'; + +// Vendor with risk assessment data merged from GlobalVendors +type VendorWithRiskAssessment = Vendor & { + assignee: { user: User | null } | null; + riskAssessmentData?: Prisma.InputJsonValue | null; + riskAssessmentVersion?: string | null; + riskAssessmentUpdatedAt?: Date | null; +}; + +/** + * Normalize API response to match Prisma types + * API returns dates as strings, Prisma returns Date objects + */ +function normalizeVendor(apiVendor: VendorResponse): VendorWithRiskAssessment { + return { + ...apiVendor, + createdAt: new Date(apiVendor.createdAt), + updatedAt: new Date(apiVendor.updatedAt), + riskAssessmentUpdatedAt: apiVendor.riskAssessmentUpdatedAt + ? new Date(apiVendor.riskAssessmentUpdatedAt) + : null, + assignee: apiVendor.assignee + ? { + ...apiVendor.assignee, + user: apiVendor.assignee.user as User | null, + } + : null, + } as unknown as VendorWithRiskAssessment; +} + +interface VendorPageClientProps { + vendorId: string; + orgId: string; + initialVendor: VendorWithRiskAssessment; + assignees: (Member & { user: User })[]; + isViewingTask: boolean; +} + +/** + * Client component for vendor detail page content + * Uses SWR for real-time updates and caching + * + * Benefits: + * - Instant initial render (uses server-fetched data) + * - Real-time updates via polling (5s interval) + * - Mutations trigger automatic refresh via mutate() + */ +export function VendorPageClient({ + vendorId, + orgId, + initialVendor, + assignees, + isViewingTask, +}: VendorPageClientProps) { + // Use SWR for real-time updates with polling + const { vendor: swrVendor, isLoading } = useVendor(vendorId, { + organizationId: orgId, + }); + + // Normalize and memoize the vendor data + // Use SWR data when available, fall back to initial data + const vendor = useMemo(() => { + if (swrVendor) { + return normalizeVendor(swrVendor); + } + return initialVendor; + }, [swrVendor, initialVendor]); + + return ( + <> + {!isViewingTask && } + {!isViewingTask && } +
+ {!isViewingTask && ( + <> + +
+ + +
+ + )} + + {!isViewingTask && ( + + )} +
+ + ); +} + +/** + * Export the vendor mutate function for use by mutation components + * Call this after updating vendor data to trigger SWR revalidation + */ +export { useVendor } from '@/hooks/use-vendors'; + diff --git a/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/page.tsx b/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/page.tsx index cee6954b4..857ccfceb 100644 --- a/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/page.tsx @@ -1,21 +1,13 @@ -'use server'; - import PageWithBreadcrumb from '@/components/pages/PageWithBreadcrumb'; import { auth } from '@/utils/auth'; import { extractDomain } from '@/utils/normalize-website'; -import { CommentEntityType, db } from '@db'; +import { db } from '@db'; import type { Metadata } from 'next'; import { headers } from 'next/headers'; import { redirect } from 'next/navigation'; import { cache } from 'react'; -import { Comments } from '../../../../../components/comments/Comments'; -import { TaskItems } from '../../../../../components/task-items/TaskItems'; import { VendorActions } from './components/VendorActions'; -import { VendorInherentRiskChart } from './components/VendorInherentRiskChart'; -import { VendorResidualRiskChart } from './components/VendorResidualRiskChart'; -import { VendorTabs } from './components/VendorTabs'; -import { VendorHeader } from './components/VendorHeader'; -import { SecondaryFields } from './components/secondary-fields/secondary-fields'; +import { VendorPageClient } from './components/VendorPageClient'; interface PageProps { params: Promise<{ vendorId: string; locale: string; orgId: string }>; @@ -24,17 +16,22 @@ interface PageProps { }>; } +/** + * Vendor detail page - server component + * Fetches initial data server-side for fast first render + * Passes data to VendorPageClient which uses SWR for real-time updates + */ export default async function VendorPage({ params, searchParams }: PageProps) { const { vendorId, orgId } = await params; const { taskItemId } = (await searchParams) ?? {}; // Fetch data in parallel for faster loading - const [vendor, assignees] = await Promise.all([ + const [vendorData, assignees] = await Promise.all([ getVendor({ vendorId, organizationId: orgId }), getAssignees(orgId), ]); - if (!vendor || !vendor.vendor) { + if (!vendorData || !vendorData.vendor) { redirect('/'); } @@ -46,38 +43,21 @@ export default async function VendorPage({ params, searchParams }: PageProps) { breadcrumbs={[ { label: 'Vendors', href: `/${orgId}/vendors` }, { - label: vendor.vendor?.name ?? '', + label: vendorData.vendor?.name ?? '', // Make vendor name clickable when viewing a task to navigate back to vendor overview href: isViewingTask ? `/${orgId}/vendors/${vendorId}` : undefined, current: !isViewingTask, }, ]} - headerRight={} + headerRight={} > - {!isViewingTask && } - {!isViewingTask && } -
- {!isViewingTask && ( - <> - -
- - -
- - )} - - {!isViewingTask && ( - - )} -
+ ); } diff --git a/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/review/page.tsx b/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/review/page.tsx index a9a2cf698..1f368bb96 100644 --- a/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/review/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/vendors/[vendorId]/review/page.tsx @@ -49,7 +49,7 @@ export default async function ReviewPage({ params, searchParams }: ReviewPagePro current: !isViewingTask, }, ]} - headerRight={} + headerRight={} > {!isViewingTask && } {!isViewingTask && } diff --git a/apps/app/src/app/(app)/[orgId]/vendors/components/create-vendor-form.tsx b/apps/app/src/app/(app)/[orgId]/vendors/components/create-vendor-form.tsx index d90caa5c5..fe4f64e5e 100644 --- a/apps/app/src/app/(app)/[orgId]/vendors/components/create-vendor-form.tsx +++ b/apps/app/src/app/(app)/[orgId]/vendors/components/create-vendor-form.tsx @@ -18,6 +18,7 @@ import { useQueryState } from 'nuqs'; import { useState } from 'react'; import { useForm } from 'react-hook-form'; import { toast } from 'sonner'; +import { useSWRConfig } from 'swr'; import { z } from 'zod'; import { createVendorAction } from '../actions/create-vendor-action'; import { searchGlobalVendorsAction } from '../actions/search-global-vendors-action'; @@ -42,6 +43,7 @@ export function CreateVendorForm({ assignees: (Member & { user: User })[]; organizationId: string; }) { + const { mutate } = useSWRConfig(); const [_, setCreateVendorSheet] = useQueryState('createVendorSheet'); const [searchQuery, setSearchQuery] = useState(''); @@ -53,6 +55,12 @@ export function CreateVendorForm({ onSuccess: async () => { toast.success('Vendor created successfully'); setCreateVendorSheet(null); + // Invalidate all vendors SWR caches (any key starting with 'vendors') + mutate( + (key) => Array.isArray(key) && key[0] === 'vendors', + undefined, + { revalidate: true }, + ); }, onError: () => { toast.error('Failed to create vendor'); diff --git a/apps/app/src/app/(app)/onboarding/[orgId]/page.tsx b/apps/app/src/app/(app)/onboarding/[orgId]/page.tsx index effe7ca65..933844d08 100644 --- a/apps/app/src/app/(app)/onboarding/[orgId]/page.tsx +++ b/apps/app/src/app/(app)/onboarding/[orgId]/page.tsx @@ -11,16 +11,19 @@ interface OnboardingPageProps { export default async function OnboardingPage({ params }: OnboardingPageProps) { const { orgId } = await params; + // Get headers once to avoid multiple async calls + const requestHeaders = await headers(); + // Get current user const session = await auth.api.getSession({ - headers: await headers(), + headers: requestHeaders, }); if (!session?.user?.id) { redirect('/auth'); } - // Get organization with subscription info + // Verify membership BEFORE syncing activeOrganizationId const organization = await db.organization.findFirst({ where: { id: orgId, @@ -45,6 +48,21 @@ export default async function OnboardingPage({ params }: OnboardingPageProps) { notFound(); } + // Sync activeOrganizationId only after membership is verified + const currentActiveOrgId = session.session.activeOrganizationId; + if (!currentActiveOrgId || currentActiveOrgId !== orgId) { + try { + await auth.api.setActiveOrganization({ + headers: requestHeaders, + body: { + organizationId: orgId, + }, + }); + } catch (error) { + console.error('[OnboardingPage] Failed to sync activeOrganizationId:', error); + } + } + // Check if already completed onboarding if (organization.onboardingCompleted) { redirect(`/${orgId}/`); diff --git a/apps/app/src/app/(app)/upgrade/[orgId]/page.tsx b/apps/app/src/app/(app)/upgrade/[orgId]/page.tsx index 2acc446ae..1fc7911b2 100644 --- a/apps/app/src/app/(app)/upgrade/[orgId]/page.tsx +++ b/apps/app/src/app/(app)/upgrade/[orgId]/page.tsx @@ -15,16 +15,19 @@ interface PageProps { export default async function UpgradePage({ params }: PageProps) { const { orgId } = await params; + // Get headers once to avoid multiple async calls + const requestHeaders = await headers(); + // Check auth const authSession = await auth.api.getSession({ - headers: await headers(), + headers: requestHeaders, }); if (!authSession?.user?.id) { redirect('/sign-in'); } - // Verify user has access to this org + // Verify user has access to this org BEFORE syncing activeOrganizationId const member = await db.member.findFirst({ where: { organizationId: orgId, @@ -40,6 +43,21 @@ export default async function UpgradePage({ params }: PageProps) { redirect('/'); } + // Sync activeOrganizationId only after membership is verified + const currentActiveOrgId = authSession.session.activeOrganizationId; + if (!currentActiveOrgId || currentActiveOrgId !== orgId) { + try { + await auth.api.setActiveOrganization({ + headers: requestHeaders, + body: { + organizationId: orgId, + }, + }); + } catch (error) { + console.error('[UpgradePage] Failed to sync activeOrganizationId:', error); + } + } + let hasAccess = member.organization.hasAccess; // Auto-approve based on user's email domain diff --git a/apps/app/src/app/api/cloud-tests/findings/route.ts b/apps/app/src/app/api/cloud-tests/findings/route.ts index b0f3ebe2b..d790583f3 100644 --- a/apps/app/src/app/api/cloud-tests/findings/route.ts +++ b/apps/app/src/app/api/cloud-tests/findings/route.ts @@ -41,22 +41,39 @@ export async function GET(request: NextRequest) { const newConnections = await db.integrationConnection.findMany({ where: { organizationId: orgId, + status: 'active', provider: { slug: { in: CLOUD_PROVIDER_SLUGS, }, }, }, - select: { - id: true, - provider: { - select: { - slug: true, - }, + include: { + provider: true, + }, + }); + + // ==================================================================== + // Fetch from OLD integration table (Integration) - for backward compat + // ==================================================================== + const legacyIntegrations = await db.integration.findMany({ + where: { + organizationId: orgId, + integrationId: { + in: CLOUD_PROVIDER_SLUGS, }, }, }); + // Filter out legacy integrations that have been migrated to new platform + const newConnectionSlugs = new Set(newConnections.map((c) => c.provider.slug)); + const activeLegacyIntegrations = legacyIntegrations.filter( + (i) => !newConnectionSlugs.has(i.integrationId), + ); + + // ==================================================================== + // Fetch findings from NEW platform (IntegrationCheckResult) + // ==================================================================== const newConnectionIds = newConnections.map((c) => c.id); const connectionToSlug = Object.fromEntries(newConnections.map((c) => [c.id, c.provider.slug])); @@ -77,13 +94,12 @@ export async function GET(request: NextRequest) { const latestRunIds = latestRuns.map((r) => r.id); const checkRunMap = Object.fromEntries(latestRuns.map((cr) => [cr.id, cr])); - // Fetch only failed results from the latest runs (findings only, no passing results) + // Fetch results only from the latest runs (both passed and failed) const newResults = latestRunIds.length > 0 ? await db.integrationCheckResult.findMany({ where: { checkRunId: { in: latestRunIds }, - passed: false, }, select: { id: true, @@ -93,6 +109,7 @@ export async function GET(request: NextRequest) { severity: true, collectedAt: true, checkRunId: true, + passed: true, }, orderBy: { collectedAt: 'desc', @@ -107,33 +124,26 @@ export async function GET(request: NextRequest) { title: result.title, description: result.description, remediation: result.remediation, - status: 'failed', + status: result.passed ? 'passed' : 'failed', severity: result.severity, completedAt: result.collectedAt, integration: { - integrationId: checkRun - ? connectionToSlug[checkRun.connectionId] || 'unknown' - : 'unknown', + integrationId: checkRun ? connectionToSlug[checkRun.connectionId] || 'unknown' : 'unknown', }, }; }); // ==================================================================== - // Fetch from OLD integration platform + // Fetch findings from OLD platform (IntegrationResult) // ==================================================================== - // Filter out cloud providers that have migrated to new platform - const newConnectionSlugs = new Set(newConnections.map((c) => c.provider.slug)); - const legacySlugs = CLOUD_PROVIDER_SLUGS.filter((s) => !newConnectionSlugs.has(s)); + const legacyIntegrationIds = activeLegacyIntegrations.map((i) => i.id); const legacyResults = - legacySlugs.length > 0 + legacyIntegrationIds.length > 0 ? await db.integrationResult.findMany({ where: { - organizationId: orgId, - integration: { - integrationId: { - in: legacySlugs, - }, + integrationId: { + in: legacyIntegrationIds, }, }, select: { @@ -171,7 +181,7 @@ export async function GET(request: NextRequest) { })); // ==================================================================== - // Merge and sort by date + // Merge all findings and sort by date // ==================================================================== const findings = [...newFindings, ...legacyFindings].sort((a, b) => { const dateA = a.completedAt ? new Date(a.completedAt).getTime() : 0; diff --git a/apps/app/src/components/comments/CommentForm.tsx b/apps/app/src/components/comments/CommentForm.tsx index f03c59410..67b55ad47 100644 --- a/apps/app/src/components/comments/CommentForm.tsx +++ b/apps/app/src/components/comments/CommentForm.tsx @@ -19,14 +19,16 @@ import type { JSONContent } from '@tiptap/react'; import { CommentRichTextField } from './CommentRichTextField'; import { useOrganizationMembers } from '@/hooks/use-organization-members'; import { useMemo } from 'react'; -import { usePathname } from 'next/navigation'; +import { useParams, usePathname } from 'next/navigation'; interface CommentFormProps { entityId: string; entityType: CommentEntityType; + /** Optional org override; otherwise uses `orgId` from URL params */ + organizationId?: string; } -export function CommentForm({ entityId, entityType }: CommentFormProps) { +export function CommentForm({ entityId, entityType, organizationId }: CommentFormProps) { const [newComment, setNewComment] = useState(null); const [pendingFiles, setPendingFiles] = useState([]); const [isSubmitting, setIsSubmitting] = useState(false); @@ -35,21 +37,39 @@ export function CommentForm({ entityId, entityType }: CommentFormProps) { const [filesToAdd, setFilesToAdd] = useState([]); const [isSelectingMention, setIsSelectingMention] = useState(false); const pathname = usePathname(); + const params = useParams(); + const orgIdFromParams = + typeof params?.orgId === 'string' + ? params.orgId + : Array.isArray(params?.orgId) + ? params.orgId[0] + : undefined; + const resolvedOrgId = organizationId ?? orgIdFromParams; // Use SWR hooks for generic comments - const { mutate: refreshComments } = useComments(entityId, entityType); + // Pass organizationId explicitly to ensure correct org context + const { mutate: refreshComments } = useComments(entityId, entityType, { + organizationId: resolvedOrgId, + enabled: Boolean(resolvedOrgId), + }); const { createCommentWithFiles } = useCommentWithAttachments(); const { members } = useOrganizationMembers(); - // Convert members to MentionUser format + // Convert members to MentionUser format - only show admin/owner users const mentionMembers = useMemo(() => { if (!members) return []; - return members.map((member) => ({ - id: member.user.id, - name: member.user.name || member.user.email || 'Unknown', - email: member.user.email || '', - image: member.user.image, - })); + return members + .filter((member) => { + if (!member.role) return false; + const roles = member.role.split(',').map((r) => r.trim().toLowerCase()); + return roles.includes('owner') || roles.includes('admin'); + }) + .map((member) => ({ + id: member.user.id, + name: member.user.name || member.user.email || 'Unknown', + email: member.user.email || '', + image: member.user.image, + })); }, [members]); const triggerFileInput = () => { diff --git a/apps/app/src/components/comments/CommentItem.tsx b/apps/app/src/components/comments/CommentItem.tsx index b535dde83..1689fff49 100644 --- a/apps/app/src/components/comments/CommentItem.tsx +++ b/apps/app/src/components/comments/CommentItem.tsx @@ -99,15 +99,21 @@ export function CommentItem({ comment, refreshComments }: CommentItemProps) { const { get: apiGet } = useApi(); const { members } = useOrganizationMembers(); - // Convert members to MentionUser format + // Convert members to MentionUser format - only show admin/owner users const mentionMembers = useMemo(() => { if (!members) return []; - return members.map((member) => ({ - id: member.user.id, - name: member.user.name || member.user.email || 'Unknown', - email: member.user.email || '', - image: member.user.image, - })); + return members + .filter((member) => { + if (!member.role) return false; + const roles = member.role.split(',').map((r) => r.trim().toLowerCase()); + return roles.includes('owner') || roles.includes('admin'); + }) + .map((member) => ({ + id: member.user.id, + name: member.user.name || member.user.email || 'Unknown', + email: member.user.email || '', + image: member.user.image, + })); }, [members]); // Parse comment content to JSONContent diff --git a/apps/app/src/components/comments/Comments.tsx b/apps/app/src/components/comments/Comments.tsx index ba49f19e4..900ec4ea4 100644 --- a/apps/app/src/components/comments/Comments.tsx +++ b/apps/app/src/components/comments/Comments.tsx @@ -3,6 +3,7 @@ import { useComments } from '@/hooks/use-comments-api'; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@comp/ui/card'; import { CommentEntityType } from '@db'; +import { useParams } from 'next/navigation'; import { CommentForm } from './CommentForm'; import { CommentList } from './CommentList'; @@ -29,6 +30,11 @@ export type CommentWithAuthor = { interface CommentsProps { entityId: string; entityType: CommentEntityType; + /** + * Optional organization ID override. + * Best practice: omit this and let the component use `orgId` from URL params. + */ + organizationId?: string; /** Optional custom title for the comments section */ title?: string; /** Optional custom description */ @@ -57,17 +63,30 @@ interface CommentsProps { export const Comments = ({ entityId, entityType, + organizationId, title = 'Comments', description, variant = 'card', }: CommentsProps) => { + const params = useParams(); + const orgIdFromParams = + typeof params?.orgId === 'string' + ? params.orgId + : Array.isArray(params?.orgId) + ? params.orgId[0] + : undefined; + const resolvedOrgId = organizationId ?? orgIdFromParams; + // Use SWR hooks for real-time comment fetching const { data: commentsData, error: commentsError, isLoading: commentsLoading, mutate: refreshComments, - } = useComments(entityId, entityType); + } = useComments(entityId, entityType, { + organizationId: resolvedOrgId, + enabled: Boolean(resolvedOrgId), + }); // Extract comments from SWR response const comments = commentsData?.data || []; @@ -77,7 +96,7 @@ export const Comments = ({ const content = (
- + {commentsLoading && (
diff --git a/apps/app/src/components/forms/risks/create-risk-form.tsx b/apps/app/src/components/forms/risks/create-risk-form.tsx index 5e3e98f3a..1c6c024ca 100644 --- a/apps/app/src/components/forms/risks/create-risk-form.tsx +++ b/apps/app/src/components/forms/risks/create-risk-form.tsx @@ -17,31 +17,11 @@ import { useAction } from 'next-safe-action/hooks'; import { useQueryState } from 'nuqs'; import { useForm } from 'react-hook-form'; import { toast } from 'sonner'; +import { useSWRConfig } from 'swr'; import type { z } from 'zod'; export function CreateRisk({ assignees }: { assignees: (Member & { user: User })[] }) { - // Get the same query parameters as the table - const [search] = useQueryState('search'); - const [page] = useQueryState('page', { - defaultValue: 1, - parse: Number.parseInt, - }); - const [pageSize] = useQueryState('pageSize', { - defaultValue: 10, - parse: Number, - }); - const [status] = useQueryState('status', { - defaultValue: null, - parse: (value) => value as RiskStatus | null, - }); - const [department] = useQueryState('department', { - defaultValue: null, - parse: (value) => value as Departments | null, - }); - const [assigneeId] = useQueryState('assigneeId', { - defaultValue: null, - parse: (value) => value, - }); + const { mutate } = useSWRConfig(); const [_, setCreateRiskSheet] = useQueryState('create-risk-sheet'); @@ -49,6 +29,12 @@ export function CreateRisk({ assignees }: { assignees: (Member & { user: User }) onSuccess: async () => { toast.success('Risk created successfully'); setCreateRiskSheet(null); + // Invalidate all risks SWR caches (any key starting with 'risks') + mutate( + (key) => Array.isArray(key) && key[0] === 'risks', + undefined, + { revalidate: true }, + ); }, onError: () => { toast.error('Failed to create risk'); diff --git a/apps/app/src/components/task-items/TaskItemEditableDescription.tsx b/apps/app/src/components/task-items/TaskItemEditableDescription.tsx index 96d982aad..04ae6aa7b 100644 --- a/apps/app/src/components/task-items/TaskItemEditableDescription.tsx +++ b/apps/app/src/components/task-items/TaskItemEditableDescription.tsx @@ -73,14 +73,21 @@ export function TaskItemEditableDescription({ const isSelectingFileRef = useRef(false); const { members } = useOrganizationMembers(); + // Only show admin/owner users in mention suggestions const mentionMembers: MentionUser[] = useMemo(() => { if (!members) return []; - return members.map((member) => ({ - id: member.user.id, - name: member.user.name || member.user.email || 'Unknown', - email: member.user.email || '', - image: member.user.image, - })); + return members + .filter((member) => { + if (!member.role) return false; + const roles = member.role.split(',').map((r) => r.trim().toLowerCase()); + return roles.includes('owner') || roles.includes('admin'); + }) + .map((member) => ({ + id: member.user.id, + name: member.user.name || member.user.email || 'Unknown', + email: member.user.email || '', + image: member.user.image, + })); }, [members]); const descriptionInputRef = useRef(null); diff --git a/apps/app/src/components/task-items/TaskSmartForm.tsx b/apps/app/src/components/task-items/TaskSmartForm.tsx index a0b5990a1..b1637f34a 100644 --- a/apps/app/src/components/task-items/TaskSmartForm.tsx +++ b/apps/app/src/components/task-items/TaskSmartForm.tsx @@ -116,15 +116,21 @@ export function TaskSmartForm({ return filterMembersByOwnerOrAdmin({ members }); }, [members]); - // All members for mentions (not filtered) + // Only show admin/owner users in mention suggestions const mentionMembers = useMemo(() => { if (!members) return []; - return members.map((member) => ({ - id: member.user.id, - name: member.user.name || member.user.email || 'Unknown', - email: member.user.email || '', - image: member.user.image, - })); + return members + .filter((member) => { + if (!member.role) return false; + const roles = member.role.split(',').map((r) => r.trim().toLowerCase()); + return roles.includes('owner') || roles.includes('admin'); + }) + .map((member) => ({ + id: member.user.id, + name: member.user.name || member.user.email || 'Unknown', + email: member.user.email || '', + image: member.user.image, + })); }, [members]); const handleFileUpload = useCallback( diff --git a/apps/app/src/hooks/use-comments-api.ts b/apps/app/src/hooks/use-comments-api.ts index dbab2e0dd..18b9b79f1 100644 --- a/apps/app/src/hooks/use-comments-api.ts +++ b/apps/app/src/hooks/use-comments-api.ts @@ -36,6 +36,7 @@ interface CreateCommentData { content: string; entityId: string; entityType: CommentEntityType; + contextUrl?: string; attachments?: Array<{ fileName: string; fileType: string; @@ -45,20 +46,37 @@ interface CreateCommentData { interface UpdateCommentData { content: string; + contextUrl?: string; +} + +// Default polling interval for real-time updates (5 seconds) +const DEFAULT_COMMENTS_POLLING_INTERVAL = 5000; + +export interface UseCommentsOptions extends UseApiSWROptions { + /** Organization ID - MUST be passed to ensure correct org context */ + organizationId?: string; } /** * Generic hook to fetch comments for any entity using SWR + * Includes polling for real-time updates (e.g., when trigger.dev tasks create comments) + * + * IMPORTANT: Always pass organizationId from URL params to ensure correct org context + * when user navigates to a different org's page while active org is different. */ export function useComments( entityId: string | null, entityType: CommentEntityType | null, - options: UseApiSWROptions = {}, + options: UseCommentsOptions = {}, ) { const endpoint = entityId && entityType ? `/v1/comments?entityId=${entityId}&entityType=${entityType}` : null; - return useApiSWR(endpoint, options); + return useApiSWR(endpoint, { + ...options, + // Enable polling for real-time updates (when trigger.dev tasks create comments) + refreshInterval: options.refreshInterval ?? DEFAULT_COMMENTS_POLLING_INTERVAL, + }); } /** @@ -69,7 +87,11 @@ export function useCommentActions() { const createComment = useCallback( async (data: CreateCommentData) => { - const response = await api.post('/v1/comments', data); + const response = await api.post('/v1/comments', { + ...data, + contextUrl: + data.contextUrl ?? (typeof window !== 'undefined' ? window.location.href : undefined), + }); if (response.error) { throw new Error(response.error); } @@ -80,7 +102,11 @@ export function useCommentActions() { const updateComment = useCallback( async (commentId: string, data: UpdateCommentData) => { - const response = await api.put(`/v1/comments/${commentId}`, data); + const response = await api.put(`/v1/comments/${commentId}`, { + ...data, + contextUrl: + data.contextUrl ?? (typeof window !== 'undefined' ? window.location.href : undefined), + }); if (response.error) { throw new Error(response.error); } @@ -108,10 +134,18 @@ export function useCommentActions() { }; } +export interface UseCommentWithAttachmentsOptions { + /** Organization ID - for consistency with other hooks */ + organizationId?: string; +} + /** * Utility hook that combines file handling with comment creation */ -export function useCommentWithAttachments() { +export function useCommentWithAttachments(_options: UseCommentWithAttachmentsOptions = {}) { + // Note: useCommentActions uses useApi which gets orgId from URL params + // The options.organizationId is accepted for API consistency but not currently used + // since useApi already handles org context from URL const { createComment } = useCommentActions(); const createCommentWithFiles = useCallback( @@ -144,6 +178,7 @@ export function useCommentWithAttachments() { content, entityId, entityType, + contextUrl: typeof window !== 'undefined' ? window.location.href : undefined, attachments, }); }, @@ -303,7 +338,7 @@ export function useOptimisticComments(entityId: string, entityType: CommentEntit /** * Convenience hook for task comments */ -export function useTaskComments(taskId: string | null, options: UseApiSWROptions = {}) { +export function useTaskComments(taskId: string | null, options: UseCommentsOptions = {}) { return useComments(taskId, 'task', options); } @@ -312,7 +347,7 @@ export function useTaskComments(taskId: string | null, options: UseApiSWROptions */ export function usePolicyComments( policyId: string | null, - options: UseApiSWROptions = {}, + options: UseCommentsOptions = {}, ) { return useComments(policyId, 'policy', options); } @@ -322,7 +357,7 @@ export function usePolicyComments( */ export function useVendorComments( vendorId: string | null, - options: UseApiSWROptions = {}, + options: UseCommentsOptions = {}, ) { return useComments(vendorId, 'vendor', options); } @@ -330,7 +365,7 @@ export function useVendorComments( /** * Convenience hook for risk comments */ -export function useRiskComments(riskId: string | null, options: UseApiSWROptions = {}) { +export function useRiskComments(riskId: string | null, options: UseCommentsOptions = {}) { return useComments(riskId, 'risk', options); } diff --git a/apps/app/src/hooks/use-risks.ts b/apps/app/src/hooks/use-risks.ts new file mode 100644 index 000000000..0163dfbed --- /dev/null +++ b/apps/app/src/hooks/use-risks.ts @@ -0,0 +1,269 @@ +'use client'; + +import { useApi } from '@/hooks/use-api'; +import { useApiSWR, UseApiSWROptions } from '@/hooks/use-api-swr'; +import { ApiResponse } from '@/lib/api-client'; +import { useCallback } from 'react'; +import type { + RiskCategory, + Departments, + RiskStatus, + Likelihood, + Impact, + RiskTreatmentType, +} from '@db'; + +// Default polling interval for real-time updates (5 seconds) +const DEFAULT_POLLING_INTERVAL = 5000; + +export interface RiskAssignee { + id: string; + user: { + id: string; + name: string | null; + email: string; + image: string | null; + }; +} + +export interface Risk { + id: string; + title: string; + description: string; + category: RiskCategory; + department: Departments | null; + status: RiskStatus; + likelihood: Likelihood; + impact: Impact; + residualLikelihood: Likelihood; + residualImpact: Impact; + treatmentStrategyDescription: string | null; + treatmentStrategy: RiskTreatmentType; + organizationId: string; + assigneeId: string | null; + assignee?: RiskAssignee | null; + createdAt: string; + updatedAt: string; +} + +export interface RisksResponse { + data: Risk[]; + count: number; +} + +/** + * Risk response from API - same as Risk for now + */ +export type RiskResponse = Risk; + +interface CreateRiskData { + title: string; + description?: string; + category?: RiskCategory; + department?: Departments; + status?: RiskStatus; + likelihood?: Likelihood; + impact?: Impact; + residualLikelihood?: Likelihood; + residualImpact?: Impact; + treatmentStrategy?: RiskTreatmentType; + treatmentStrategyDescription?: string; + assigneeId?: string; +} + +interface UpdateRiskData { + title?: string; + description?: string; + category?: RiskCategory; + department?: Departments | null; + status?: RiskStatus; + likelihood?: Likelihood; + impact?: Impact; + residualLikelihood?: Likelihood; + residualImpact?: Impact; + treatmentStrategy?: RiskTreatmentType; + treatmentStrategyDescription?: string | null; + assigneeId?: string | null; +} + +export interface UseRisksOptions extends UseApiSWROptions { + /** Initial data from server for hydration - avoids loading state on first render */ + initialData?: Risk[]; +} + +export interface UseRiskOptions extends UseApiSWROptions { + /** Initial data from server for hydration - avoids loading state on first render */ + initialData?: RiskResponse; +} + +/** + * Hook to fetch all risks for the current organization using SWR + * Provides automatic caching, revalidation, and real-time updates + * + * @example + * // With server-side initial data (recommended for pages) + * const { data, mutate } = useRisks({ initialData: serverRisks }); + * + * @example + * // Without initial data (shows loading state) + * const { data, isLoading, mutate } = useRisks(); + */ +export function useRisks(options: UseRisksOptions = {}) { + const { initialData, ...restOptions } = options; + + return useApiSWR('/v1/risks', { + ...restOptions, + // Refresh risks periodically for real-time updates + refreshInterval: restOptions.refreshInterval ?? 30000, + // Use initial data as fallback for instant render + ...(initialData && { + fallbackData: { + data: { data: initialData, count: initialData.length }, + status: 200, + } as ApiResponse, + }), + }); +} + +/** + * Hook to fetch a single risk by ID using SWR + * Provides real-time updates via polling + * + * @example + * // With server-side initial data (recommended for detail pages) + * const { data, mutate } = useRisk(riskId, { initialData: serverRisk }); + * + * @example + * // Without initial data (shows loading state) + * const { data, isLoading, mutate } = useRisk(riskId); + */ +export function useRisk( + riskId: string | null, + options: UseRiskOptions = {}, +) { + const { initialData, ...restOptions } = options; + + const swrResult = useApiSWR( + riskId ? `/v1/risks/${riskId}` : null, + { + ...restOptions, + // Enable polling for real-time updates (when trigger.dev tasks complete) + refreshInterval: restOptions.refreshInterval ?? DEFAULT_POLLING_INTERVAL, + // Continue polling even when window is not focused + refreshWhenHidden: false, + // Use initial data as fallback for instant render + ...(initialData && { + fallbackData: { + data: initialData, + status: 200, + } as ApiResponse, + }), + }, + ); + + // Extract risk data from response + const risk = swrResult.data?.data ?? null; + + return { + ...swrResult, + risk, + }; +} + +/** + * Hook for risk CRUD operations (mutations) + * Use alongside useRisks/useRisk and call mutate() after mutations + */ +export function useRiskActions() { + const api = useApi(); + + const createRisk = useCallback( + async (data: CreateRiskData) => { + const response = await api.post('/v1/risks', data); + if (response.error) { + throw new Error(response.error); + } + return response.data!; + }, + [api], + ); + + const updateRisk = useCallback( + async (riskId: string, data: UpdateRiskData) => { + const response = await api.patch(`/v1/risks/${riskId}`, data); + if (response.error) { + throw new Error(response.error); + } + return response.data!; + }, + [api], + ); + + const deleteRisk = useCallback( + async (riskId: string) => { + const response = await api.delete(`/v1/risks/${riskId}`); + if (response.error) { + throw new Error(response.error); + } + return { success: true, status: response.status }; + }, + [api], + ); + + return { + createRisk, + updateRisk, + deleteRisk, + }; +} + +/** + * Combined hook for risks with data fetching and mutations + * Provides a complete solution for risk management with optimistic updates + */ +export function useRisksWithMutations(options: UseApiSWROptions = {}) { + const { data, error, isLoading, mutate } = useRisks(options); + const { createRisk, updateRisk, deleteRisk } = useRiskActions(); + + const create = useCallback( + async (riskData: CreateRiskData) => { + const result = await createRisk(riskData); + // Revalidate the risks list after creation + await mutate(); + return result; + }, + [createRisk, mutate], + ); + + const update = useCallback( + async (riskId: string, riskData: UpdateRiskData) => { + const result = await updateRisk(riskId, riskData); + // Revalidate the risks list after update + await mutate(); + return result; + }, + [updateRisk, mutate], + ); + + const remove = useCallback( + async (riskId: string) => { + const result = await deleteRisk(riskId); + // Revalidate the risks list after deletion + await mutate(); + return result; + }, + [deleteRisk, mutate], + ); + + return { + risks: data?.data?.data ?? [], + count: data?.data?.count ?? 0, + isLoading, + error, + mutate, + createRisk: create, + updateRisk: update, + deleteRisk: remove, + }; +} + diff --git a/apps/app/src/hooks/use-task-items.ts b/apps/app/src/hooks/use-task-items.ts index d6abdcdfb..02c3b5645 100644 --- a/apps/app/src/hooks/use-task-items.ts +++ b/apps/app/src/hooks/use-task-items.ts @@ -6,6 +6,9 @@ import { useCallback } from 'react'; export type TaskItemEntityType = 'vendor' | 'risk'; +// Default polling interval for cross-user updates (5 seconds) +const DEFAULT_TASK_ITEMS_POLLING_INTERVAL = 5000; + export type TaskItemStatus = 'todo' | 'in_progress' | 'in_review' | 'done' | 'canceled'; export type TaskItemPriority = 'urgent' | 'high' | 'medium' | 'low'; @@ -211,6 +214,9 @@ export function useTaskItems( return useApiSWR(endpoint, { ...options, + // Cross-user updates: when another teammate edits tasks, this view should update without refresh + refreshInterval: options.refreshInterval ?? DEFAULT_TASK_ITEMS_POLLING_INTERVAL, + revalidateOnFocus: options.revalidateOnFocus ?? true, // Keep previous data visible while loading new page keepPreviousData: true, }); @@ -229,7 +235,11 @@ export function useTaskItemsStats( ? `/v1/task-management/stats?entityId=${entityId}&entityType=${entityType}` : null; - return useApiSWR(endpoint, options); + return useApiSWR(endpoint, { + ...options, + refreshInterval: options.refreshInterval ?? DEFAULT_TASK_ITEMS_POLLING_INTERVAL, + revalidateOnFocus: options.revalidateOnFocus ?? true, + }); } /** diff --git a/apps/app/src/hooks/use-vendors.ts b/apps/app/src/hooks/use-vendors.ts new file mode 100644 index 000000000..ef90609c7 --- /dev/null +++ b/apps/app/src/hooks/use-vendors.ts @@ -0,0 +1,264 @@ +'use client'; + +import { useApi } from '@/hooks/use-api'; +import { useApiSWR, UseApiSWROptions } from '@/hooks/use-api-swr'; +import { ApiResponse } from '@/lib/api-client'; +import { useCallback } from 'react'; +import type { + VendorCategory, + VendorStatus, + Likelihood, + Impact, +} from '@db'; +import type { JsonValue } from '@prisma/client/runtime/library'; + +// Default polling interval for real-time updates (5 seconds) +const DEFAULT_POLLING_INTERVAL = 5000; + +export interface VendorAssignee { + id: string; + user: { + id: string; + name: string | null; + email: string; + image: string | null; + }; +} + +export interface Vendor { + id: string; + name: string; + description: string; + category: VendorCategory; + status: VendorStatus; + inherentProbability: Likelihood; + inherentImpact: Impact; + residualProbability: Likelihood; + residualImpact: Impact; + website: string | null; + organizationId: string; + assigneeId: string | null; + assignee?: VendorAssignee | null; + createdAt: string; + updatedAt: string; +} + +export interface VendorsResponse { + data: Vendor[]; + count: number; +} + +/** + * Vendor response from API includes GlobalVendors risk assessment data + */ +export interface VendorResponse extends Vendor { + // GlobalVendors risk assessment data merged by API + riskAssessmentData?: JsonValue | null; + riskAssessmentVersion?: string | null; + riskAssessmentUpdatedAt?: string | null; +} + +interface CreateVendorData { + name: string; + description?: string; + category?: VendorCategory; + website?: string; + assigneeId?: string; +} + +interface UpdateVendorData { + name?: string; + description?: string; + category?: VendorCategory; + status?: VendorStatus; + website?: string; + assigneeId?: string | null; + inherentProbability?: Likelihood; + inherentImpact?: Impact; + residualProbability?: Likelihood; + residualImpact?: Impact; +} + +export interface UseVendorsOptions extends UseApiSWROptions { + /** Initial data from server for hydration - avoids loading state on first render */ + initialData?: Vendor[]; +} + +export interface UseVendorOptions extends UseApiSWROptions { + /** Initial data from server for hydration - avoids loading state on first render */ + initialData?: VendorResponse; +} + +/** + * Hook to fetch all vendors for the current organization using SWR + * Provides automatic caching, revalidation, and real-time updates + * + * @example + * // With server-side initial data (recommended for pages) + * const { vendors, mutate } = useVendors({ initialData: serverVendors }); + * + * @example + * // Without initial data (shows loading state) + * const { vendors, isLoading, mutate } = useVendors(); + */ +export function useVendors(options: UseVendorsOptions = {}) { + const { initialData, ...restOptions } = options; + + const swrResponse = useApiSWR('/v1/vendors', { + ...restOptions, + // Refresh vendors periodically for real-time updates + refreshInterval: restOptions.refreshInterval ?? 30000, + // Use initial data as fallback for instant render + ...(initialData && { + fallbackData: { + data: { data: initialData, count: initialData.length }, + status: 200, + } as ApiResponse, + }), + }); + + return swrResponse; +} + +/** + * Hook to fetch a single vendor by ID using SWR + * Provides real-time updates via polling + * + * @example + * // With server-side initial data (recommended for detail pages) + * const { data, mutate } = useVendor(vendorId, { initialData: serverVendor }); + * + * @example + * // Without initial data (shows loading state) + * const { data, isLoading, mutate } = useVendor(vendorId); + */ +export function useVendor( + vendorId: string | null, + options: UseVendorOptions = {}, +) { + const { initialData, ...restOptions } = options; + + const swrResult = useApiSWR( + vendorId ? `/v1/vendors/${vendorId}` : null, + { + ...restOptions, + // Enable polling for real-time updates (when trigger.dev tasks complete) + refreshInterval: restOptions.refreshInterval ?? DEFAULT_POLLING_INTERVAL, + // Continue polling even when window is not focused + refreshWhenHidden: false, + // Use initial data as fallback for instant render + ...(initialData && { + fallbackData: { + data: initialData, + status: 200, + } as ApiResponse, + }), + }, + ); + + // Extract vendor data from response + const vendor = swrResult.data?.data ?? null; + + return { + ...swrResult, + vendor, + }; +} + +/** + * Hook for vendor CRUD operations (mutations) + * Use alongside useVendors/useVendor and call mutate() after mutations + */ +export function useVendorActions() { + const api = useApi(); + + const createVendor = useCallback( + async (data: CreateVendorData) => { + const response = await api.post('/v1/vendors', data); + if (response.error) { + throw new Error(response.error); + } + return response.data!; + }, + [api], + ); + + const updateVendor = useCallback( + async (vendorId: string, data: UpdateVendorData) => { + const response = await api.patch(`/v1/vendors/${vendorId}`, data); + if (response.error) { + throw new Error(response.error); + } + return response.data!; + }, + [api], + ); + + const deleteVendor = useCallback( + async (vendorId: string) => { + const response = await api.delete(`/v1/vendors/${vendorId}`); + if (response.error) { + throw new Error(response.error); + } + return { success: true, status: response.status }; + }, + [api], + ); + + return { + createVendor, + updateVendor, + deleteVendor, + }; +} + +/** + * Combined hook for vendors with data fetching and mutations + * Provides a complete solution for vendor management with optimistic updates + */ +export function useVendorsWithMutations(options: UseApiSWROptions = {}) { + const { data, error, isLoading, mutate } = useVendors(options); + const { createVendor, updateVendor, deleteVendor } = useVendorActions(); + + const create = useCallback( + async (vendorData: CreateVendorData) => { + const result = await createVendor(vendorData); + // Revalidate the vendors list after creation + await mutate(); + return result; + }, + [createVendor, mutate], + ); + + const update = useCallback( + async (vendorId: string, vendorData: UpdateVendorData) => { + const result = await updateVendor(vendorId, vendorData); + // Revalidate the vendors list after update + await mutate(); + return result; + }, + [updateVendor, mutate], + ); + + const remove = useCallback( + async (vendorId: string) => { + const result = await deleteVendor(vendorId); + // Revalidate the vendors list after deletion + await mutate(); + return result; + }, + [deleteVendor, mutate], + ); + + return { + vendors: data?.data?.data ?? [], + count: data?.data?.count ?? 0, + isLoading, + error, + mutate, + createVendor: create, + updateVendor: update, + deleteVendor: remove, + }; +} + diff --git a/packages/docs/openapi.json b/packages/docs/openapi.json index 2540bc143..628b959cb 100644 --- a/packages/docs/openapi.json +++ b/packages/docs/openapi.json @@ -14872,6 +14872,12 @@ ], "example": "task" }, + "contextUrl": { + "type": "string", + "description": "Optional URL of the page where the comment was created, used for deep-linking in notifications", + "example": "https://app.trycomp.ai/org_abc123/vendors/vnd_abc123?taskItemId=tki_abc123#task-items", + "maxLength": 2048 + }, "attachments": { "description": "Optional attachments to include with the comment", "type": "array", @@ -14900,6 +14906,12 @@ "example": "This task needs to be completed by end of week (updated)", "maxLength": 2000 }, + "contextUrl": { + "type": "string", + "description": "Optional URL of the page where the comment was updated, used for deep-linking in notifications", + "example": "https://app.trycomp.ai/org_abc123/risk/rsk_abc123?taskItemId=tki_abc123#task-items", + "maxLength": 2048 + }, "userId": { "type": "string", "description": "User ID of the comment author (required for API key auth, ignored for JWT auth)", diff --git a/packages/ui/src/components/editor/extensions/mention.tsx b/packages/ui/src/components/editor/extensions/mention.tsx index 49b99f514..1537c0dd9 100644 --- a/packages/ui/src/components/editor/extensions/mention.tsx +++ b/packages/ui/src/components/editor/extensions/mention.tsx @@ -17,13 +17,27 @@ export interface MentionListProps { items: MentionUser[]; command: (item: MentionUser) => void; onSelect?: () => void; + // Callback to register the onKeyDown handler for parent access + onKeyDownRef?: React.MutableRefObject<((props: { event: KeyboardEvent }) => boolean) | null>; } -function MentionList({ items, command, onSelect }: MentionListProps) { +function MentionList({ items, command, onSelect, onKeyDownRef }: MentionListProps) { const safeItems = items || []; const [selectedIndex, setSelectedIndex] = useState(0); const itemRefs = useRef<(HTMLButtonElement | null)[]>([]); const containerRef = useRef(null); + + // Store current state in refs for the keydown handler + const selectedIndexRef = useRef(selectedIndex); + const safeItemsRef = useRef(safeItems); + + useEffect(() => { + selectedIndexRef.current = selectedIndex; + }, [selectedIndex]); + + useEffect(() => { + safeItemsRef.current = safeItems; + }, [safeItems]); useEffect(() => { setSelectedIndex(0); @@ -50,6 +64,8 @@ function MentionList({ items, command, onSelect }: MentionListProps) { }; const handleKeyDown = (event: React.KeyboardEvent) => { + if (safeItems.length === 0) return; + if (event.key === 'ArrowDown') { event.preventDefault(); setSelectedIndex((prev) => (prev + 1) % safeItems.length); @@ -64,30 +80,41 @@ function MentionList({ items, command, onSelect }: MentionListProps) { } }; - // Expose onKeyDown for ReactRenderer + // Register the onKeyDown handler for external access useEffect(() => { - if (containerRef.current) { - (containerRef.current as any).onKeyDown = (props: { event: KeyboardEvent }) => { + if (onKeyDownRef) { + onKeyDownRef.current = (props: { event: KeyboardEvent }) => { const { event } = props; + const currentItems = safeItemsRef.current; + const currentIndex = selectedIndexRef.current; + + if (currentItems.length === 0) return false; + if (event.key === 'ArrowDown') { event.preventDefault(); - setSelectedIndex((prev) => (prev + 1) % safeItems.length); + setSelectedIndex((prev) => (prev + 1) % currentItems.length); return true; } else if (event.key === 'ArrowUp') { event.preventDefault(); - setSelectedIndex((prev) => (prev - 1 + safeItems.length) % safeItems.length); + setSelectedIndex((prev) => (prev - 1 + currentItems.length) % currentItems.length); return true; } else if (event.key === 'Enter') { event.preventDefault(); - if (safeItems[selectedIndex]) { - handleSelect(safeItems[selectedIndex]); + if (currentItems[currentIndex]) { + handleSelect(currentItems[currentIndex]); } return true; } return false; }; } - }, [safeItems, selectedIndex]); + + return () => { + if (onKeyDownRef) { + onKeyDownRef.current = null; + } + }; + }, [onKeyDownRef, command, onSelect]); if (safeItems.length === 0) { return ( @@ -183,19 +210,22 @@ export function createMentionExtension({ suggestion }: CreateMentionExtensionOpt render: () => { let component: ReactRenderer; let popup: TippyInstance | null = null; + // Mutable ref to store the keydown handler from the component + const keyDownHandlerRef: { current: ((props: { event: KeyboardEvent }) => boolean) | null } = { current: null }; return { onStart: (props) => { // Ensure items is always an array const items = Array.isArray(props.items) ? props.items : []; - component = new ReactRenderer(MentionList as any, { + component = new ReactRenderer(MentionList, { props: { ...props, items, onSelect: suggestion.onSelect, + onKeyDownRef: keyDownHandlerRef, }, - editor: props.editor as any, + editor: props.editor, }); if (!props.clientRect) { @@ -228,9 +258,12 @@ export function createMentionExtension({ suggestion }: CreateMentionExtensionOpt // Ensure items is always an array const items = Array.isArray(props.items) ? props.items : []; + // Include onSelect and onKeyDownRef to preserve keyboard navigation component.updateProps({ ...props, items, + onSelect: suggestion.onSelect, + onKeyDownRef: keyDownHandlerRef, }); if (!props.clientRect) { @@ -248,15 +281,9 @@ export function createMentionExtension({ suggestion }: CreateMentionExtensionOpt return true; } - // Use the ref's onKeyDown method if available - const ref = component.ref as { - onKeyDown?: (p: { event: KeyboardEvent }) => boolean; - } | null; - if (ref?.onKeyDown) { - const result = ref.onKeyDown(props); - if (result) { - return true; - } + // Use the registered keydown handler from the component + if (keyDownHandlerRef.current) { + return keyDownHandlerRef.current(props); } return false;