diff --git a/server/ai/fetchModels.ts b/server/ai/fetchModels.ts index 8c643058a..3cb265e2c 100644 --- a/server/ai/fetchModels.ts +++ b/server/ai/fetchModels.ts @@ -7,45 +7,346 @@ import { modelDetailsMap } from "./mappers" const Logger = getLogger(Subsystem.AI) -// Helper function to parse cost value (handles both numbers and scientific notation strings) -function parseCostValue(value: any): number { +const CACHE_TTL_MS = 5 * 60 * 1000 +const DEFAULT_MAX_INPUT_TOKENS = 128_000 + +type NumericLike = number | string | null | undefined + +interface RawLiteLLMParams { + model?: string + input_cost_per_token?: NumericLike + output_cost_per_token?: NumericLike + custom_llm_provider?: string +} + +interface RawModelInfoDetails { + description?: string | null + deep_research?: boolean | null + id?: string + input_cost_per_token?: NumericLike + litellm_provider?: string | null + max_input_tokens?: number | null + max_output_tokens?: number | null + output_cost_per_token?: NumericLike + reasoning?: boolean | null + supports_function_calling?: boolean | null + supports_reasoning?: boolean | null + supports_web_search?: boolean | null + supports_vision?: boolean | null + websearch?: boolean | null +} + +export interface RawModelInfoRecord { + model_name: string + litellm_params?: RawLiteLLMParams | null + model_info?: RawModelInfoDetails | null +} + +export interface NormalizedModelMetadata { + actualName: string + customLLMProvider?: string + deepResearch: boolean + description: string + hasConflictingMaxInputTokens: boolean + hasConflictingMaxOutputTokens: boolean + inputCostPerToken?: number + litellmProvider?: string + maxInputTokens?: number + maxOutputTokens?: number + modelInfoId?: string + modelName: string + outputCostPerToken?: number + reasoning: boolean + sourceCount: number + supportsFunctionCalling: boolean + supportsVision: boolean + websearch: boolean +} + +interface ModelInfoCache { + rawData: RawModelInfoRecord[] + normalizedByModelName: Map + timestamp: number +} + +let modelInfoCache: ModelInfoCache | null = null + +const warnedMissingTokenLimitKeys = new Set() + +function parseCostValue(value: NumericLike): number { if (typeof value === "number") { return value } if (typeof value === "string") { - // Handle scientific notation strings like "6e-07" const parsed = parseFloat(value) - return isNaN(parsed) ? 0 : parsed + return Number.isNaN(parsed) ? 0 : parsed } return 0 } -// Cache for model info from API -interface ModelInfoCache { - data: any[] - timestamp: number +function parseOptionalNumber(value: unknown): number | undefined { + if (typeof value === "number" && Number.isFinite(value) && value > 0) { + return value + } + return undefined } -let modelInfoCache: ModelInfoCache | null = null -const CACHE_TTL_MS = 5 * 60 * 1000 // 5 minutes +function parseOptionalCost(value: NumericLike): number | undefined { + const parsed = parseCostValue(value) + return parsed > 0 ? parsed : undefined +} + +function hasConflictingValues(values: number[]): boolean { + return new Set(values).size > 1 +} + +function resolveMinValue(values: Array): { + value?: number + hasConflict: boolean +} { + const presentValues = values.filter((value): value is number => value !== undefined) + if (presentValues.length === 0) { + return { value: undefined, hasConflict: false } + } + + return { + value: Math.min(...presentValues), + hasConflict: hasConflictingValues(presentValues), + } +} + +function firstNonEmpty(values: Array): string | undefined { + return values.find((value) => typeof value === "string" && value.trim().length > 0)?.trim() +} + +function normalizeBoolean(value: unknown): boolean { + return value === true +} + +function normalizeModelInfoRecords( + records: RawModelInfoRecord[], +): Map { + const groupedRecords = new Map() + + for (const record of records) { + if (!record.model_name) continue + const existing = groupedRecords.get(record.model_name) ?? [] + existing.push(record) + groupedRecords.set(record.model_name, existing) + } + + const normalized = new Map() + + for (const [modelName, group] of groupedRecords.entries()) { + const actualName = + firstNonEmpty(group.map((record) => record.litellm_params?.model)) ?? modelName + const inputTokenResolution = resolveMinValue( + group.map((record) => parseOptionalNumber(record.model_info?.max_input_tokens)), + ) + const outputTokenResolution = resolveMinValue( + group.map((record) => parseOptionalNumber(record.model_info?.max_output_tokens)), + ) + const inputCost = resolveMinValue( + group.map((record) => + parseOptionalCost(record.model_info?.input_cost_per_token) ?? + parseOptionalCost(record.litellm_params?.input_cost_per_token), + ), + ).value + const outputCost = resolveMinValue( + group.map((record) => + parseOptionalCost(record.model_info?.output_cost_per_token) ?? + parseOptionalCost(record.litellm_params?.output_cost_per_token), + ), + ).value + + const metadata: NormalizedModelMetadata = { + actualName, + customLLMProvider: firstNonEmpty( + group.map((record) => record.litellm_params?.custom_llm_provider), + ), + deepResearch: group.some( + (record) => normalizeBoolean(record.model_info?.deep_research), + ), + description: firstNonEmpty( + group.map((record) => record.model_info?.description), + ) ?? "", + hasConflictingMaxInputTokens: inputTokenResolution.hasConflict, + hasConflictingMaxOutputTokens: outputTokenResolution.hasConflict, + inputCostPerToken: inputCost, + litellmProvider: firstNonEmpty( + group.map((record) => record.model_info?.litellm_provider), + ), + maxInputTokens: inputTokenResolution.value, + maxOutputTokens: outputTokenResolution.value, + modelInfoId: firstNonEmpty(group.map((record) => record.model_info?.id)), + modelName, + outputCostPerToken: outputCost, + reasoning: group.some( + (record) => + normalizeBoolean(record.model_info?.reasoning) || + normalizeBoolean(record.model_info?.supports_reasoning), + ), + sourceCount: group.length, + supportsFunctionCalling: group.some( + (record) => normalizeBoolean(record.model_info?.supports_function_calling), + ), + supportsVision: group.some( + (record) => normalizeBoolean(record.model_info?.supports_vision), + ), + websearch: group.some( + (record) => + normalizeBoolean(record.model_info?.websearch) || + normalizeBoolean(record.model_info?.supports_web_search), + ), + } + + normalized.set(modelName, metadata) + } + + return normalized +} + +function logNormalizationConflicts( + normalizedByModelName: Map, +) { + for (const metadata of normalizedByModelName.values()) { + if ( + !metadata.hasConflictingMaxInputTokens && + !metadata.hasConflictingMaxOutputTokens + ) { + continue + } + + Logger.warn( + { + modelName: metadata.modelName, + actualName: metadata.actualName, + maxInputTokens: metadata.maxInputTokens, + maxOutputTokens: metadata.maxOutputTokens, + sourceCount: metadata.sourceCount, + hasConflictingMaxInputTokens: metadata.hasConflictingMaxInputTokens, + hasConflictingMaxOutputTokens: metadata.hasConflictingMaxOutputTokens, + }, + "[ModelInfo] Resolved duplicate upstream model records conservatively.", + ) + } +} -// Shared function to fetch model info from API with caching -export async function fetchModelInfoFromAPI(forceRefresh = false): Promise { - // Return cached data if still valid +function updateModelInfoCache(records: RawModelInfoRecord[]) { + const normalizedByModelName = normalizeModelInfoRecords(records) + modelInfoCache = { + rawData: records, + normalizedByModelName, + timestamp: Date.now(), + } + + logNormalizationConflicts(normalizedByModelName) +} + +function getCachedNormalizedModelMetadata(): NormalizedModelMetadata[] { + return [...(modelInfoCache?.normalizedByModelName.values() ?? [])] +} + +function matchesModelIdentifier( + metadata: NormalizedModelMetadata, + modelId: string, +): boolean { + if (metadata.modelName === modelId) return true + if (metadata.actualName === modelId) return true + if (metadata.actualName.endsWith(`/${modelId}`)) return true + if (modelId.endsWith(`/${metadata.modelName}`)) return true + return false +} + +function resolveModelMetadata( + modelId?: Models | string | null, +): NormalizedModelMetadata | undefined { + if (!modelId) { + return undefined + } + + const normalizedMetadata = getCachedNormalizedModelMetadata() + const requestedModelId = String(modelId) + const configuredActualName = + MODEL_CONFIGURATIONS[requestedModelId as Models]?.actualName + + return normalizedMetadata.find((metadata) => { + if (matchesModelIdentifier(metadata, requestedModelId)) return true + if (configuredActualName && matchesModelIdentifier(metadata, configuredActualName)) { + return true + } + return false + }) +} + +function warnMissingTokenLimitOnce( + modelId: string, + tokenKind: "input" | "output" | "metadata", +) { + const key = `${modelId}:${tokenKind}` + if (warnedMissingTokenLimitKeys.has(key)) { + return + } + + warnedMissingTokenLimitKeys.add(key) + Logger.warn( + { modelId, tokenKind }, + "[ModelInfo] Missing upstream token metadata for model in use.", + ) +} + +function isLiteLLMAllowlistedModel(modelName: string): boolean { + if ( + modelName === Models.LiteLLM_Claude_Sonnet_4_6 && + config.allowSonnet46 + ) { + return true + } + if (modelName === Models.LiteLLM_Claude_Opus_4_6 && config.allowOpus46) { + return true + } + return false +} + +function shouldIncludeLiteLLMModelForListing( + metadata: NormalizedModelMetadata, +): boolean { + if (metadata.litellmProvider === "hosted_vllm") { + return true + } + return isLiteLLMAllowlistedModel(metadata.modelName) +} + +type AvailableModel = { + actualName: string + labelName: string + provider: string + reasoning: boolean + websearch: boolean + deepResearch: boolean + description: string +} + +export type ResolvedModelTokenLimits = { + maxInputTokens: number + maxOutputTokens?: number +} + +export async function fetchModelInfoFromAPI( + forceRefresh = false, +): Promise { if (!forceRefresh && modelInfoCache) { const age = Date.now() - modelInfoCache.timestamp if (age < CACHE_TTL_MS) { - return modelInfoCache.data + return modelInfoCache.rawData } } - // Use API key from config if (!config.LiteLLMApiKey) { Logger.warn("LiteLLM API key not configured, returning empty array") - return [] + return modelInfoCache?.rawData ?? [] } - // Set timeout of 5 seconds const controller = new AbortController() const timeoutId = setTimeout(() => controller.abort(), 5000) @@ -54,57 +355,66 @@ export async function fetchModelInfoFromAPI(forceRefresh = false): Promise { + await fetchModelInfoFromAPI(forceRefresh) + return getCachedNormalizedModelMetadata() +} + export const preloadModelInfoCache = async (): Promise => { if (config.LiteLLMApiKey && config.LiteLLMBaseUrl) { try { - await fetchModelInfoFromAPI(true) // Force refresh on startup + await fetchModelInfoFromAPI(true) Logger.info("Model info cache preloaded successfully") } catch (error) { Logger.warn("Failed to preload model info cache", { @@ -114,137 +424,90 @@ export const preloadModelInfoCache = async (): Promise => { } } -// Export function to get cost config for a specific model (uses cached data) +export const getModelTokenLimits = ( + modelId?: Models | string | null, +): ResolvedModelTokenLimits => { + if (!modelId) { + return { maxInputTokens: DEFAULT_MAX_INPUT_TOKENS } + } + + const metadata = resolveModelMetadata(modelId) + const requestedModelId = String(modelId) + + if (!metadata) { + warnMissingTokenLimitOnce(requestedModelId, "metadata") + return { maxInputTokens: DEFAULT_MAX_INPUT_TOKENS } + } + + if (metadata.maxInputTokens === undefined) { + warnMissingTokenLimitOnce(requestedModelId, "input") + } + + if (metadata.maxOutputTokens === undefined) { + warnMissingTokenLimitOnce(requestedModelId, "output") + } + + return { + maxInputTokens: metadata.maxInputTokens ?? DEFAULT_MAX_INPUT_TOKENS, + maxOutputTokens: metadata.maxOutputTokens, + } +} + +export const getEffectiveMaxOutputTokens = ( + modelId: Models | string | null | undefined, + requestedMaxTokens?: number, +): number | undefined => { + if (requestedMaxTokens === undefined) { + return undefined + } + + const { maxOutputTokens } = getModelTokenLimits(modelId) + return maxOutputTokens !== undefined + ? Math.min(requestedMaxTokens, maxOutputTokens) + : requestedMaxTokens +} + export const getCostConfigForModel = async ( modelId: string, -): Promise<{ pricePerThousandInputTokens: number; pricePerThousandOutputTokens: number }> => { - const data = await fetchModelInfoFromAPI() - - // Find the model in the API response - // Match by model_name (enum value like "glm-latest") or by the actual model name in litellm_params.model - // Also handle cases where modelId might be the full path like "hosted_vllm/zai-org/GLM-4.7-dev" - const modelInfo = data.find( - (m: any) => { - // Direct match by model_name (enum value) - if (m.model_name === modelId) return true - - // Match by litellm_params.model (full path) - if (m.litellm_params?.model === modelId) return true - - // Match if modelId is at the end of the full path - if (m.litellm_params?.model?.endsWith(`/${modelId}`)) return true - - // Match if modelId contains the model_name - if (m.litellm_params?.model?.includes(`/${modelId}`)) return true - - return false - }, - ) +): Promise<{ + pricePerThousandInputTokens: number + pricePerThousandOutputTokens: number +}> => { + await fetchModelInfoFromAPI() + const metadata = resolveModelMetadata(modelId) - if (modelInfo) { - // Try to get costs from model_info first (as numbers), then from litellm_params (as strings) - const inputCost = modelInfo.model_info?.input_cost_per_token ?? - modelInfo.litellm_params?.input_cost_per_token - const outputCost = modelInfo.model_info?.output_cost_per_token ?? - modelInfo.litellm_params?.output_cost_per_token - - if (inputCost !== undefined && inputCost !== null && - outputCost !== undefined && outputCost !== null) { - const parsedInputCost = parseCostValue(inputCost) - const parsedOutputCost = parseCostValue(outputCost) - - if (parsedInputCost > 0 || parsedOutputCost > 0) { - return { - pricePerThousandInputTokens: parsedInputCost * 1000, - pricePerThousandOutputTokens: parsedOutputCost * 1000, - } - } + if ( + metadata?.inputCostPerToken !== undefined && + metadata?.outputCostPerToken !== undefined && + (metadata.inputCostPerToken > 0 || metadata.outputCostPerToken > 0) + ) { + return { + pricePerThousandInputTokens: metadata.inputCostPerToken * 1000, + pricePerThousandOutputTokens: metadata.outputCostPerToken * 1000, } } - - // Fallback to default config from modelDetailsMap + return modelDetailsMap[modelId]?.cost?.onDemand ?? { pricePerThousandInputTokens: 0, pricePerThousandOutputTokens: 0, } } -export const fetchModelConfigs = async (): Promise> => { - const data = await fetchModelInfoFromAPI() - - const availableModels: Array<{ - actualName: string - labelName: string - provider: string - reasoning: boolean - websearch: boolean - deepResearch: boolean - description: string - }> = [] - - // Use Set to track seen model IDs to avoid duplicates - const seenModelIds = new Set() - - // Filter models with litellm_provider === "hosted_vllm" and return in expected format - for (const modelInfo of data) { - // Only process models with litellm_provider === "hosted_vllm" - // Check model_info.litellm_provider (from API response structure) - const modelId = modelInfo.model_name - const actualName = modelInfo.litellm_params?.model || modelId - if (modelInfo.model_info?.litellm_provider !== "hosted_vllm") { - const allowlist = { - [Models.LiteLLM_Claude_Sonnet_4_6]: { - enabled: config.allowSonnet46, - name: "Claude Sonnet 4.5", - }, - [Models.LiteLLM_Claude_Opus_4_6]: { - enabled: config.allowOpus46, - name: "Claude Opus 4.5", - }, - }; - - const modelAllowlistInfo = allowlist[modelId as keyof typeof allowlist]; - - if (modelAllowlistInfo) { - if (modelAllowlistInfo.enabled) { - Logger.info(`Allowing ${modelAllowlistInfo.name} model despite litellm_provider not being 'hosted_vllm'`); - } else { - continue; - } - } else { - continue; - } - } +export const fetchModelConfigs = async (): Promise => { + const metadata = await fetchNormalizedModelMetadata() + const availableModels: AvailableModel[] = [] - // Skip if we've already processed this model (deduplicate by model_name) - if (seenModelIds.has(modelId)) { + for (const modelMetadata of metadata) { + if (!shouldIncludeLiteLLMModelForListing(modelMetadata)) { continue } - seenModelIds.add(modelId) - // Find the corresponding enum key in Models - const modelEnumKey = Object.keys(Models).find( - (key) => Models[key as keyof typeof Models] === modelId, - ) as keyof typeof Models | undefined - - // Get the enum value from the key (MODEL_CONFIGURATIONS is indexed by enum values, not keys) - const modelEnumValue = modelEnumKey ? (Models[modelEnumKey] as Models) : undefined - - // Get model configuration from MODEL_CONFIGURATIONS if it exists - const modelConfig = modelEnumValue ? MODEL_CONFIGURATIONS[modelEnumValue] : null + const modelConfig = + MODEL_CONFIGURATIONS[modelMetadata.modelName as Models] ?? null if (modelConfig) { - // Use configuration from MODEL_CONFIGURATIONS availableModels.push({ - actualName: actualName, + actualName: modelMetadata.actualName, labelName: modelConfig.labelName, provider: "LiteLLM", reasoning: modelConfig.reasoning, @@ -252,219 +515,231 @@ export const fetchModelConfigs = async (): Promise +> => { + const metadata = await fetchNormalizedModelMetadata() + const models = metadata + .filter(shouldIncludeLiteLLMModelForListing) + .map((modelMetadata) => { + const modelConfig = + MODEL_CONFIGURATIONS[modelMetadata.modelName as Models] ?? null + const labelName = modelConfig?.labelName ?? modelMetadata.modelName + const description = modelConfig?.description ?? modelMetadata.description + const reasoning = modelConfig?.reasoning ?? modelMetadata.reasoning + const websearch = modelConfig?.websearch ?? modelMetadata.websearch + const deepResearch = + modelConfig?.deepResearch ?? modelMetadata.deepResearch + const modelType = modelMetadata.modelName.includes("gemini") + ? "gemini" + : modelMetadata.modelName.includes("claude") + ? "claude" + : "other" + + return { + enumValue: modelMetadata.modelName, + labelName, + actualName: modelMetadata.actualName, + description, + reasoning, + websearch, + deepResearch, + modelType, + } + }) + + models.sort((a, b) => { + const typeOrder: Record = { claude: 1, gemini: 2, other: 3 } + const orderA = typeOrder[a.modelType] ?? 99 + const orderB = typeOrder[b.modelType] ?? 99 + if (orderA !== orderB) { + return orderA - orderB + } + return a.labelName.localeCompare(b.labelName) + }) + + return models +} + export const getAvailableModels = async (providerConfig: { - AwsAccessKey?: string - AwsSecretKey?: string - OpenAIKey?: string - OllamaModel?: string - TogetherAIModel?: string - TogetherApiKey?: string - FireworksAIModel?: string - FireworksApiKey?: string - GeminiAIModel?: string - GeminiApiKey?: string - VertexAIModel?: string - VertexProjectId?: string - VertexRegion?: string - LiteLLMApiKey?: string - LiteLLMBaseUrl?: string + AwsAccessKey?: string + AwsSecretKey?: string + OpenAIKey?: string + OllamaModel?: string + TogetherAIModel?: string + TogetherApiKey?: string + FireworksAIModel?: string + FireworksApiKey?: string + GeminiAIModel?: string + GeminiApiKey?: string + VertexAIModel?: string + VertexProjectId?: string + VertexRegion?: string + LiteLLMApiKey?: string + LiteLLMBaseUrl?: string }) => { - const availableModels: Array<{ - actualName: string - labelName: string - provider: string - reasoning: boolean - websearch: boolean - deepResearch: boolean - description: string - }> = [] - - // Priority (LiteLLM > AWS > OpenAI > Ollama > Together > Fireworks > Gemini > Vertex) - // Using if-else logic to ensure only ONE provider is active at a time - if (providerConfig.LiteLLMApiKey && providerConfig.LiteLLMBaseUrl) { - // Fetch models from API (hosted_vllm only) - const fetchedModels = await fetchModelConfigs() - if (fetchedModels.length > 0) { - // Use models fetched from API - availableModels.push(...fetchedModels) - } else { - // Fallback to static MODEL_CONFIGURATIONS if API call fails (with same allowlist gating as API path) - const isSonnet46 = (modelId: Models) => modelId === Models.LiteLLM_Claude_Sonnet_4_6 - const isOpus46 = (modelId: Models) => modelId === Models.LiteLLM_Claude_Opus_4_6 - Object.entries(MODEL_CONFIGURATIONS) - .filter(([, model]) => model.provider === AIProviders.LiteLLM) - .filter(([modelId, model]) => { - const id = modelId as Models - if (isSonnet46(id)) return config.allowSonnet46 - if (isOpus46(id)) return config.allowOpus46 - return true - }) - .forEach(([modelId, model]) => { - const id = modelId as Models - if (isSonnet46(id) && config.allowSonnet46) { - Logger.info("Allowing Claude Sonnet 4.5 model despite litellm_provider not being 'hosted_vllm'") - } - if (isOpus46(id) && config.allowOpus46) { - Logger.info("Allowing Claude Opus 4.5 model despite litellm_provider not being 'hosted_vllm'") - } - availableModels.push({ - actualName: model.actualName ?? "", - labelName: model.labelName, - provider: "LiteLLM", - reasoning: model.reasoning, - websearch: model.websearch, - deepResearch: model.deepResearch, - description: model.description, - }) - }) - } - } else if (providerConfig.AwsAccessKey && providerConfig.AwsSecretKey) { - // Add only AWS Bedrock models - Object.values(MODEL_CONFIGURATIONS) - .filter((model) => model.provider === AIProviders.AwsBedrock) - .forEach((model) => { - availableModels.push({ - actualName: model.actualName ?? "", - labelName: model.labelName, - provider: "AWS Bedrock", - reasoning: model.reasoning, - websearch: model.websearch, - deepResearch: model.deepResearch, - description: model.description, - }) + const availableModels: AvailableModel[] = [] + + if (providerConfig.LiteLLMApiKey && providerConfig.LiteLLMBaseUrl) { + const fetchedModels = await fetchModelConfigs() + if (fetchedModels.length > 0) { + availableModels.push(...fetchedModels) + } else { + Object.entries(MODEL_CONFIGURATIONS) + .filter(([, model]) => model.provider === AIProviders.LiteLLM) + .filter(([modelId]) => { + const id = modelId as Models + if (id === Models.LiteLLM_Claude_Sonnet_4_6) { + return config.allowSonnet46 + } + if (id === Models.LiteLLM_Claude_Opus_4_6) { + return config.allowOpus46 + } + return true }) - } else if (providerConfig.OpenAIKey) { - // Add only OpenAI models - Object.values(MODEL_CONFIGURATIONS) - .filter((model) => model.provider === AIProviders.OpenAI) - .forEach((model) => { - availableModels.push({ + .forEach(([modelId, model]) => { + availableModels.push({ actualName: model.actualName ?? "", labelName: model.labelName, - provider: "OpenAI", + provider: "LiteLLM", reasoning: model.reasoning, websearch: model.websearch, deepResearch: model.deepResearch, description: model.description, - }) + }) }) - } else if (providerConfig.OllamaModel) { - // Add only Ollama model + } + } else if (providerConfig.AwsAccessKey && providerConfig.AwsSecretKey) { + Object.values(MODEL_CONFIGURATIONS) + .filter((model) => model.provider === AIProviders.AwsBedrock) + .forEach((model) => { availableModels.push({ - actualName: providerConfig.OllamaModel, - labelName: providerConfig.OllamaModel, - provider: "Ollama", - reasoning: false, - websearch: true, - deepResearch: false, - description: "", + actualName: model.actualName ?? "", + labelName: model.labelName, + provider: "AWS Bedrock", + reasoning: model.reasoning, + websearch: model.websearch, + deepResearch: model.deepResearch, + description: model.description, }) - } else if (providerConfig.TogetherAIModel && providerConfig.TogetherApiKey) { - // Add only Together AI model + }) + } else if (providerConfig.OpenAIKey) { + Object.values(MODEL_CONFIGURATIONS) + .filter((model) => model.provider === AIProviders.OpenAI) + .forEach((model) => { availableModels.push({ - actualName: providerConfig.TogetherAIModel, - labelName: providerConfig.TogetherAIModel, - provider: "Together AI", - reasoning: false, - websearch: true, - deepResearch: false, - description: "", + actualName: model.actualName ?? "", + labelName: model.labelName, + provider: "OpenAI", + reasoning: model.reasoning, + websearch: model.websearch, + deepResearch: model.deepResearch, + description: model.description, }) - } else if (providerConfig.FireworksAIModel && providerConfig.FireworksApiKey) { - // Add only Fireworks AI model + }) + } else if (providerConfig.OllamaModel) { + availableModels.push({ + actualName: providerConfig.OllamaModel, + labelName: providerConfig.OllamaModel, + provider: "Ollama", + reasoning: false, + websearch: true, + deepResearch: false, + description: "", + }) + } else if (providerConfig.TogetherAIModel && providerConfig.TogetherApiKey) { + availableModels.push({ + actualName: providerConfig.TogetherAIModel, + labelName: providerConfig.TogetherAIModel, + provider: "Together AI", + reasoning: false, + websearch: true, + deepResearch: false, + description: "", + }) + } else if (providerConfig.FireworksAIModel && providerConfig.FireworksApiKey) { + availableModels.push({ + actualName: providerConfig.FireworksAIModel, + labelName: providerConfig.FireworksAIModel, + provider: "Fireworks AI", + reasoning: false, + websearch: true, + deepResearch: false, + description: "", + }) + } else if (providerConfig.GeminiAIModel && providerConfig.GeminiApiKey) { + Object.values(MODEL_CONFIGURATIONS) + .filter((model) => model.provider === AIProviders.GoogleAI) + .forEach((model) => { availableModels.push({ - actualName: providerConfig.FireworksAIModel, - labelName: providerConfig.FireworksAIModel, - provider: "Fireworks AI", - reasoning: false, - websearch: true, - deepResearch: false, - description: "", + actualName: model.actualName ?? "", + labelName: model.labelName, + provider: "Google AI", + reasoning: model.reasoning, + websearch: model.websearch, + deepResearch: model.deepResearch, + description: model.description, }) - } else if (providerConfig.GeminiAIModel && providerConfig.GeminiApiKey) { - // Add all Google AI models - Object.values(MODEL_CONFIGURATIONS) - .filter((model) => model.provider === AIProviders.GoogleAI) - .forEach((model) => { - availableModels.push({ - actualName: model.actualName ?? "", - labelName: model.labelName, - provider: "Google AI", - reasoning: model.reasoning, - websearch: model.websearch, - deepResearch: model.deepResearch, - description: model.description, - }) - }) - } else if (providerConfig.VertexProjectId && providerConfig.VertexRegion) { - // Add all Vertex AI models - no longer dependent on VERTEX_AI_MODEL being set - Object.values(MODEL_CONFIGURATIONS) - .filter((model) => model.provider === AIProviders.VertexAI) - .forEach((model) => { - availableModels.push({ - actualName: model.actualName ?? "", - labelName: model.labelName, - provider: "Vertex AI", - reasoning: model.reasoning, - websearch: model.websearch, - deepResearch: model.deepResearch, - description: model.description, - }) + }) + } else if (providerConfig.VertexProjectId && providerConfig.VertexRegion) { + Object.values(MODEL_CONFIGURATIONS) + .filter((model) => model.provider === AIProviders.VertexAI) + .forEach((model) => { + availableModels.push({ + actualName: model.actualName ?? "", + labelName: model.labelName, + provider: "Vertex AI", + reasoning: model.reasoning, + websearch: model.websearch, + deepResearch: model.deepResearch, + description: model.description, }) - } - - return availableModels -} - -// Legacy function for backward compatibility (returns old format) -export const getAvailableModelsLegacy = async (providerConfig: { - AwsAccessKey?: string - AwsSecretKey?: string - OpenAIKey?: string - OllamaModel?: string - TogetherAIModel?: string - TogetherApiKey?: string - FireworksAIModel?: string - FireworksApiKey?: string - GeminiAIModel?: string - GeminiApiKey?: string - VertexAIModel?: string - VertexProjectId?: string - VertexRegion?: string - LiteLLMApiKey?: string - LiteLLMBaseUrl?: string -}) => { - const newModels = await getAvailableModels(providerConfig) - return newModels.map( - (model: { - actualName: string - labelName: string - provider: string - reasoning: boolean - websearch: boolean - deepResearch: boolean - }) => ({ - label: model.labelName, - provider: model.provider, - }), - ) -} \ No newline at end of file + }) + } + + return availableModels +} + +export const __modelInfoInternals = { + getCachedNormalizedModelMetadata, + normalizeModelInfoRecords, + resolveModelMetadata, + resetModelInfoCacheForTests: () => { + modelInfoCache = null + warnedMissingTokenLimitKeys.clear() + }, + setModelInfoCacheForTests: (records: RawModelInfoRecord[]) => { + warnedMissingTokenLimitKeys.clear() + updateModelInfoCache(records) + }, +} diff --git a/server/ai/provider/base.ts b/server/ai/provider/base.ts index 10368b2cf..14e1c70cf 100644 --- a/server/ai/provider/base.ts +++ b/server/ai/provider/base.ts @@ -1,6 +1,7 @@ import { type Message } from "@aws-sdk/client-bedrock-runtime" import type { ConverseResponse, LLMProvider, ModelParams } from "@/ai/types" import { AIProviders } from "@/ai/types" +import { getEffectiveMaxOutputTokens } from "@/ai/fetchModels" import { MODEL_CONFIGURATIONS } from "@/ai/modelConfig" import config from "@/config" import path from "path" @@ -19,10 +20,14 @@ abstract class Provider implements LLMProvider { getModelParams(params: ModelParams) { // Look up the actual model name from MODEL_CONFIGURATIONS // This resolves enum values like "vertex-claude-sonnet-4" to actual API model names like "claude-sonnet-4@20250514" - const modelConfig = MODEL_CONFIGURATIONS[params.modelId || defaultFastModel] - const actualModelId = modelConfig?.actualName || params.modelId || defaultFastModel + const resolvedModelId = params.modelId || defaultFastModel + const modelConfig = MODEL_CONFIGURATIONS[resolvedModelId] + const actualModelId = modelConfig?.actualName || resolvedModelId || defaultFastModel + const requestedMaxTokens = params.max_new_tokens || 1024 * 8 return { - maxTokens: params.max_new_tokens || 1024 * 8, + maxTokens: + getEffectiveMaxOutputTokens(resolvedModelId, requestedMaxTokens) ?? + requestedMaxTokens, topP: params.top_p || 0.9, temperature: params.temperature || 0.6, modelId: actualModelId || defaultFastModel, diff --git a/server/api/chat/final-answer-synthesis.ts b/server/api/chat/final-answer-synthesis.ts new file mode 100644 index 000000000..d186a9fbf --- /dev/null +++ b/server/api/chat/final-answer-synthesis.ts @@ -0,0 +1,1686 @@ +import { + getEffectiveMaxOutputTokens, + getModelTokenLimits, +} from "@/ai/fetchModels" +import { getProviderByModel, jsonParseLLMOutput } from "@/ai/provider" +import { Models, type ConverseResponse, type ModelParams } from "@/ai/types" +import config from "@/config" +import { getLogger, getLoggerWithChild } from "@/logger" +import { AgentReasoningStepType } from "@/shared/types" +import { Subsystem } from "@/types" +import { ConversationRole } from "@aws-sdk/client-bedrock-runtime" +import type { Message } from "@aws-sdk/client-bedrock-runtime" +import { z } from "zod" +import type { AgentRunContext, PlanState, SubTask } from "./agent-schemas" +import { + isMessageAgentStopError, + raceWithStop, + throwIfStopRequested, +} from "./agent-stop" +import { + buildAgentSystemPromptContextBlock, + formatFragmentWithMetadata, + formatFragmentsWithMetadata, +} from "./message-agents-metadata" +import type { FragmentImageReference, MinimalAgentFragment } from "./types" + +const { defaultBestModel, IMAGE_CONTEXT_CONFIG } = config + +const Logger = getLogger(Subsystem.Chat) +const loggerWithChild = getLoggerWithChild(Subsystem.Chat) + +const IMAGE_TOKEN_ESTIMATE = 1_844 +const FINAL_OUTPUT_HEADROOM_RATIO = 0.15 +const FALLBACK_OUTPUT_TOKENS = 1_500 +const PREVIEW_TEXT_LENGTH = 320 +const MAX_SECTION_COUNT = 5 +const MAPPER_CONCURRENCY = 4 +const SECTION_CONCURRENCY = 4 +const FINAL_SYNTHESIS_STREAM_CHUNK_SIZE = 200 + +export type FinalSynthesisExecutionResult = { + textLength: number + totalImagesAvailable: number + imagesProvided: number + estimatedCostUsd: number + mode: "single" | "sectional" +} + +type SynthesisModeSelection = { + mode: "single" | "sectional" + maxInputTokens: number + safeInputBudget: number + estimatedInputTokens: number +} + +type FinalSection = { + sectionId: number + title: string + objective: string +} + +type SectionAnswerResult = { + sectionId: number + title: string + body: string +} + +type FragmentPreviewRecord = { + fragmentIndex: number + docId: string + title?: string + app?: string + entity?: string + timestamp?: string + previewText: string +} + +type FragmentAssignmentBatch = { + fragmentIndex: number + fragment: MinimalAgentFragment +} + +type SelectedImagesResult = { + selected: string[] + total: number + dropped: string[] + userAttachmentCount: number +} + +type SectionMappingEnvelope = { + sections?: Record +} + +const SectionPlanSchema = z.object({ + sections: z + .array( + z.object({ + sectionId: z.number().int().positive().optional(), + title: z.string().trim().min(1), + objective: z.string().trim().min(1), + }), + ) + .min(1) + .max(MAX_SECTION_COUNT), +}) + +const SectionMappingSchema = z.object({ + sections: z + .record(z.string(), z.array(z.number().int().positive())) + .default({}), +}) + +function normalizeWhitespace(value: string): string { + return value.replace(/\s+/g, " ").trim() +} + +function truncateText(value: string, maxLength: number): string { + if (value.length <= maxLength) return value + return `${value.slice(0, Math.max(1, maxLength - 1))}…` +} + +export function estimateTextTokens(text: string): number { + return Math.ceil(text.length / 4) +} + +function estimateImageTokens(imageCount: number): number { + return imageCount * IMAGE_TOKEN_ESTIMATE +} + +function estimatePromptTokens( + systemPrompt: string, + userMessage: string, + imageCount = 0, +): number { + return ( + estimateTextTokens(systemPrompt) + + estimateTextTokens(userMessage) + + estimateImageTokens(imageCount) + ) +} + +function getErrorStringProperty( + error: unknown, + key: "code" | "name" | "message", +): string | undefined { + if (typeof error !== "object" || error === null || !(key in error)) { + return undefined + } + + const value = (error as Record)[key] + if (value === undefined || value === null) { + return undefined + } + + return String(value) +} + +function classifyPlannerFallbackError(error: unknown): { + errorCode?: string + errorName?: string + errorMessage: string + isContextLengthError: boolean + isTransportError: boolean +} { + const errorCode = getErrorStringProperty(error, "code") + const errorName = getErrorStringProperty(error, "name") + const errorMessage = + getErrorStringProperty(error, "message") ?? + (error instanceof Error ? error.message : String(error)) + const normalized = `${errorName ?? ""} ${errorCode ?? ""} ${errorMessage}`.toLowerCase() + + return { + errorCode, + errorName, + errorMessage, + isContextLengthError: + errorName === "ValidationException" || + normalized.includes("input is too long") || + normalized.includes("context length") || + normalized.includes("context window") || + normalized.includes("maximum context") || + normalized.includes("prompt is too long") || + normalized.includes("too many tokens"), + isTransportError: + errorName === "AbortError" || + normalized.includes("timeout") || + normalized.includes("timed out") || + normalized.includes("transport") || + normalized.includes("network") || + normalized.includes("connection reset") || + normalized.includes("econnreset") || + normalized.includes("etimedout") || + normalized.includes("econnaborted") || + normalized.includes("eai_again") || + normalized.includes("502") || + normalized.includes("503") || + normalized.includes("504"), + } +} + +function buildStoppedFinalSynthesisResult( + mode: FinalSynthesisExecutionResult["mode"], + imageSelection: SelectedImagesResult, + estimatedCostUsd: number, + textLength: number, + imagesProvided: number, +): FinalSynthesisExecutionResult { + return { + textLength, + totalImagesAvailable: imageSelection.total, + imagesProvided, + estimatedCostUsd, + mode, + } +} + +function logFinalSynthesisStop( + context: AgentRunContext, + phase: string, + details: Record = {}, +) { + Logger.info( + { + chatId: context.chat.externalId, + phase, + ...details, + }, + "[FinalAnswerSynthesis] Stop requested; ending synthesis early.", + ) +} + +async function cancelConverseIterator( + iterator: AsyncIterableIterator, +) { + const cancelableIterator = iterator as AsyncIterableIterator & { + cancel?: () => Promise | unknown + close?: () => Promise | unknown + return?: ( + value?: unknown, + ) => Promise> | IteratorResult + } + + try { + if (typeof cancelableIterator.return === "function") { + await cancelableIterator.return(undefined) + return + } + if (typeof cancelableIterator.cancel === "function") { + await cancelableIterator.cancel() + return + } + if (typeof cancelableIterator.close === "function") { + await cancelableIterator.close() + } + } catch (error) { + Logger.warn( + { + err: error instanceof Error ? error.message : String(error), + }, + "[FinalAnswerSynthesis] Failed to cancel provider stream iterator after stop request.", + ) + } +} + +async function streamFinalAnswerChunk( + context: AgentRunContext, + text: string, +) { + if (!text) return + + const streamAnswer = context.runtime?.streamAnswerText + if (!streamAnswer) { + throw new Error("Streaming channel unavailable. Cannot deliver final answer.") + } + + throwIfStopRequested(context.stopSignal) + await raceWithStop(streamAnswer(text), context.stopSignal) + context.finalSynthesis.streamedText += text +} + +function estimateSafeInputBudget( + modelId: Models, + maxOutputTokens?: number, +): { maxInputTokens: number; safeInputBudget: number } { + const { maxInputTokens } = getModelTokenLimits(modelId) + const effectiveMaxOutputTokens = + getEffectiveMaxOutputTokens( + modelId, + maxOutputTokens ?? FALLBACK_OUTPUT_TOKENS, + ) ?? + maxOutputTokens ?? + FALLBACK_OUTPUT_TOKENS + const reservedOutputTokens = Math.ceil( + effectiveMaxOutputTokens * (1 + FINAL_OUTPUT_HEADROOM_RATIO), + ) + const safeInputBudget = Math.max(1_024, maxInputTokens - reservedOutputTokens) + return { maxInputTokens, safeInputBudget } +} + +export function formatPlanForPrompt(plan: PlanState | null): string { + if (!plan) return "" + const lines = [`Goal: ${plan.goal}`] + plan.subTasks.forEach((task, idx) => { + const icon = + task.status === "completed" + ? "✓" + : task.status === "in_progress" + ? "→" + : task.status === "failed" + ? "✗" + : task.status === "blocked" + ? "!" + : "○" + const baseLine = `${idx + 1}. [${icon}] ${task.description}` + const detailParts: string[] = [] + if (task.result) detailParts.push(`Result: ${task.result}`) + if (task.toolsRequired?.length) { + detailParts.push(`Tools: ${task.toolsRequired.join(", ")}`) + } + lines.push( + detailParts.length > 0 ? `${baseLine}\n ${detailParts.join(" | ")}` : baseLine, + ) + }) + return lines.join("\n") +} + +export function formatClarificationsForPrompt( + clarifications: AgentRunContext["clarifications"], +): string { + if (!clarifications?.length) return "" + return clarifications + .map( + (clarification, idx) => + `${idx + 1}. Q: ${clarification.question}\n A: ${clarification.answer}`, + ) + .join("\n") +} + +export function buildSharedFinalAnswerContext( + context: AgentRunContext, +): string { + const agentSystemPromptBlock = buildAgentSystemPromptContextBlock( + context.dedicatedAgentSystemPrompt, + ) + const agentSystemPromptSection = agentSystemPromptBlock + ? `Agent System Prompt Context:\n${agentSystemPromptBlock}` + : "" + const planSection = formatPlanForPrompt(context.plan) + const clarificationSection = formatClarificationsForPrompt( + context.clarifications, + ) + const workspaceSection = context.userContext?.trim() + ? `Workspace Context:\n${context.userContext}` + : "" + + return [ + `User Question:\n${context.message.text}`, + agentSystemPromptSection, + planSection ? `Execution Plan Snapshot:\n${planSection}` : "", + clarificationSection + ? `Clarifications Resolved:\n${clarificationSection}` + : "", + workspaceSection, + ] + .filter(Boolean) + .join("\n\n") +} + +export function buildBaseFinalAnswerSystemPrompt( + mode: "final" | "section" = "final", +): string { + const mission = + mode === "final" + ? "- Deliver the user's final answer using the conversation, plan snapshot, clarifications, workspace context, context fragments, and supplied images; never plan or call tools." + : "- Deliver only the assigned answer section using the conversation, plan snapshot, clarifications, workspace context, mapped context fragments, and supplied images. Other sections are being generated in parallel and a final ordered answer will be assembled later; never attempt to write the full final answer." + const sectionRules = + mode === "section" + ? ` + +### Section Constraints +- Write only the requested section body for the assigned section. +- Do not add a global introduction, conclusion, or next-step sentence that assumes other sections are already visible. +- Do not repeat section headings for other sections. +- Treat the provided section list as context only; answer exclusively for the assigned section. +`.trim() + : "" + + return ` +### Mission +${mission} + +### Evidence Intake +- Prioritize the highest-signal fragments, but pull any supporting fragment that improves accuracy. +- Only draw on context that directly answers the user's question; ignore unrelated fragments even if they were retrieved earlier. +- Treat delegated-agent outputs as citeable fragments; reference them like any other context entry. +- Describe evidence gaps plainly before concluding; never guess. +- Extract actionable details from provided images and cite them via their fragment indices. +- Respect user-imposed constraints using fragment metadata (any metadata field). If compliant evidence is missing, state that clearly. +- If "This is the system prompt of agent:" is present, analyse for instructions relevant for answering and strictly bind by them . + +### Response Construction +- Lead with the conclusion, then stack proof underneath. +- Organize output into tight sections (e.g., **Summary**, **Proof**, **Next Steps** when relevant); omit empty sections. +- Never mention internal tooling, planning logs, or this synthesis process. + +### Constraint Handling +- When the user asks for an action the system cannot execute (e.g., sending an email), deliver the closest actionable substitute (draft, checklist, explicit next steps) inside the answer. +- Pair the substitute with a concise explanation of the limitation and the manual action the user must take. + +### File & Chunk Formatting (CRITICAL) +- Each file starts with a header line exactly like: + index {docId} {file context begins here...} +- \`docId\` is a unique identifier for that file (e.g., 0, 1, 2, etc.). +- Inside the file context, text is split into chunks. +- Each chunk might begin with a bracketed numeric index, e.g.: [0], [1], [2], etc. +- This is the chunk index within that file, if it exists. + +### Guidelines for Response +1. Data Interpretation: + - Use ONLY the provided files and their chunks as your knowledge base. + - Treat every file header \`index {docId} ...\` as the start of a new document. + - Treat every bracketed number like [0], [1], [2] as the authoritative chunk index within that document. + - If dates exist, interpret them relative to the user's timezone when paraphrasing. +2. Response Structure: + - Start with the most relevant facts from the chunks across files. + - Keep order chronological when it helps comprehension. + - Every factual statement MUST cite the exact chunk it came from using the format: + K[docId_chunkIndex] + where: + - \`docId\` is taken from the file header line ("index {docId} ..."). + - \`chunkIndex\` is the bracketed number prefixed on that chunk within the same file. + - Examples: + - Single citation: "X is true K[12_3]." + - Two citations in one sentence (from different files or chunks): "X K[12_3] and Y K[7_0]." + - Use at most 1-2 citations per sentence; NEVER add more than 2 for one sentence. +3. Citation Rules (DOCUMENT+CHUNK LEVEL ONLY): + - ALWAYS cite at the chunk level with the K[docId_chunkIndex] format. + - Every chunk level citation must start with the K prefix eg. K[12_3] K[7_0] correct, but K[12_3] [7_0] is incorrect. + - Place the citation immediately after the relevant claim. + - Do NOT group indices inside one set of brackets (WRONG: "K[12_3,7_1]"). + - If a sentence draws on two distinct chunks (possibly from different files), include two separate citations inline, e.g., "... K[12_3] ... K[7_1]". + - Only cite information that appears verbatim or is directly inferable from the cited chunk. + - If you cannot ground a claim to a specific chunk, do not make the claim. +4. Quality Assurance: + - Cross-check across multiple chunks/files when available and briefly note inconsistencies if they exist. + - Keep tone professional and concise. + - Acknowledge gaps if the provided chunks don't contain enough detail. + +### Tone & Delivery +- Answer with confident, declarative, verb-first sentences that use concrete nouns. +- Highlight key deliverables using **bold** labels or short lists; keep wording razor-concise. +- Ask one targeted follow-up question only if missing info blocks action. + +### Tool Spotlighting +- Reference critical tool outputs explicitly, e.g., "**Slack Search:** Ops escalated the RCA at 09:42 [2]." +- Explain why each highlighted tool mattered so reviewers see coverage breadth. +- When multiple tools contribute, show the sequence, e.g., "**Vespa Search:** context -> **Sheet Lookup:** metrics." +${sectionRules ? `\n\n${sectionRules}` : ""} + +### Finish +- Close with a single sentence confirming completion or the next action you recommend. +`.trim() +} + +export function buildFinalSynthesisPayload( + context: AgentRunContext, + fragmentsLimit = Math.max(12, context.allFragments.length || 1), +): { systemPrompt: string; userMessage: string } { + const sharedContext = buildSharedFinalAnswerContext(context) + const formattedFragments = formatFragmentsWithMetadata( + context.allFragments, + fragmentsLimit, + ) + const fragmentsSection = formattedFragments + ? `Context Fragments:\n${formattedFragments}` + : "" + + return { + systemPrompt: buildBaseFinalAnswerSystemPrompt("final"), + userMessage: [sharedContext, fragmentsSection].filter(Boolean).join("\n\n"), + } +} + +function formatSectionPlanOverview(sections: FinalSection[]): string { + return sections + .map( + (section) => + `${section.sectionId}. ${section.title}\n Objective: ${section.objective}`, + ) + .join("\n") +} + +function buildPlannerSystemPrompt(): string { + return ` +You are planning a final answer for a large evidence set. + +Return JSON only in the shape: +{ + "sections": [ + { "sectionId": 1, "title": "string", "objective": "string" } + ] +} + +Rules: +- Produce between 2 and ${MAX_SECTION_COUNT} sections when possible. +- Keep sections ordered exactly as the final answer should appear. +- Use concise, user-facing titles. +- Objectives should state what each section must accomplish. +- Do not include a catch-all section unless needed. +- Do not mention internal processing, tools, or token limits. +`.trim() +} + +function buildMapperSystemPrompt(): string { + return ` +You are mapping evidence fragments to pre-planned answer sections. + +Return JSON only in the shape: +{ + "sections": { + "1": [3, 7], + "2": [1] + } +} + +Rules: +- Keys are section ids. +- Values are fragment indexes from the provided batch. +- A fragment may belong to multiple sections when directly relevant. +- Omit fragments that do not help any section. +- Omit section ids with no fragments from this batch. +- Never invent fragment indexes. +`.trim() +} + +function formatSectionFragments( + entries: FragmentAssignmentBatch[], +): string { + return entries + .map((entry) => formatFragmentWithMetadata(entry.fragment, entry.fragmentIndex - 1)) + .join("\n\n") +} + +function findTimestamp(fragment: MinimalAgentFragment): string | undefined { + const source = fragment.source ?? {} + return ( + source.closedAt || + source.resolvedAt || + source.createdAt || + undefined + ) +} + +function buildFragmentPreviewRecord( + fragment: MinimalAgentFragment, + fragmentIndex: number, +): FragmentPreviewRecord { + const source = fragment.source ?? {} + const previewText = truncateText( + normalizeWhitespace(fragment.content ?? ""), + PREVIEW_TEXT_LENGTH, + ) + + return { + fragmentIndex, + docId: source.docId || fragment.id, + title: source.title || source.page_title || undefined, + app: source.app ? String(source.app) : undefined, + entity: source.entity ? String(source.entity) : undefined, + timestamp: findTimestamp(fragment), + previewText, + } +} + +function formatPreviewRecord(preview: FragmentPreviewRecord): string { + const meta = [ + `fragmentIndex: ${preview.fragmentIndex}`, + `docId: ${preview.docId}`, + preview.title ? `title: ${preview.title}` : "", + preview.app ? `app: ${preview.app}` : "", + preview.entity ? `entity: ${preview.entity}` : "", + preview.timestamp ? `timestamp: ${preview.timestamp}` : "", + ] + .filter(Boolean) + .join(" | ") + + return `${meta}\npreviewText: ${preview.previewText}` +} + +function buildPreviewOmissionSummary(previews: FragmentPreviewRecord[]): string { + if (previews.length === 0) return "" + const counts = new Map() + for (const preview of previews) { + const key = preview.app || "unknown" + counts.set(key, (counts.get(key) ?? 0) + 1) + } + const summary = Array.from(counts.entries()) + .sort((a, b) => b[1] - a[1]) + .map(([key, count]) => `${key}: ${count}`) + .join(", ") + return `Additional fragment previews omitted due to budget: ${previews.length} (${summary}).` +} + +function buildPreviewTextWithinBudget( + previews: FragmentPreviewRecord[], + budgetTokens: number, +): { includedText: string; omittedSummary: string } { + if (previews.length === 0) { + return { includedText: "None.", omittedSummary: "" } + } + + const included: string[] = [] + let usedTokens = 0 + let cutoff = previews.length + + for (let index = 0; index < previews.length; index++) { + const previewText = formatPreviewRecord(previews[index]) + const previewTokens = estimateTextTokens(`${previewText}\n\n`) + if (included.length > 0 && usedTokens + previewTokens > budgetTokens) { + cutoff = index + break + } + if (included.length === 0 && previewTokens > budgetTokens) { + included.push(previewText) + cutoff = index + 1 + usedTokens += previewTokens + break + } + included.push(previewText) + usedTokens += previewTokens + } + + return { + includedText: included.join("\n\n"), + omittedSummary: buildPreviewOmissionSummary(previews.slice(cutoff)), + } +} + +function buildFragmentBatchesWithinBudget( + entries: FragmentAssignmentBatch[], + baseTokens: number, + budgetTokens: number, +): FragmentAssignmentBatch[][] { + if (entries.length === 0) return [] + + const batches: FragmentAssignmentBatch[][] = [] + let currentBatch: FragmentAssignmentBatch[] = [] + let currentTokens = baseTokens + + for (const entry of entries) { + const itemTokens = estimateTextTokens( + `${formatFragmentWithMetadata(entry.fragment, entry.fragmentIndex - 1)}\n\n`, + ) + const wouldOverflow = + currentBatch.length > 0 && currentTokens + itemTokens > budgetTokens + + if (wouldOverflow) { + batches.push(currentBatch) + currentBatch = [] + currentTokens = baseTokens + } + + currentBatch.push(entry) + currentTokens += itemTokens + } + + if (currentBatch.length > 0) { + batches.push(currentBatch) + } + + return batches +} + +function normalizeSectionPlan(data: z.infer): FinalSection[] { + return data.sections.slice(0, MAX_SECTION_COUNT).map((section, index) => ({ + sectionId: index + 1, + title: normalizeWhitespace(section.title), + objective: normalizeWhitespace(section.objective), + })) +} + +function normalizeSectionAssignments( + raw: SectionMappingEnvelope, + sections: FinalSection[], +): Map> { + const validSectionIds = new Set(sections.map((section) => section.sectionId)) + const merged = new Map>() + + for (const [key, indexes] of Object.entries(raw.sections ?? {})) { + const sectionId = Number(key) + if (!validSectionIds.has(sectionId)) continue + const target = merged.get(sectionId) ?? new Set() + for (const index of indexes) { + if (Number.isInteger(index) && index > 0) { + target.add(index) + } + } + if (target.size > 0) { + merged.set(sectionId, target) + } + } + + return merged +} + +function mergeSectionAssignments( + assignments: Array>>, + sections: FinalSection[], +): Map { + const merged = new Map>() + for (const section of sections) { + merged.set(section.sectionId, new Set()) + } + for (const batchAssignment of assignments) { + for (const [sectionId, indexes] of batchAssignment.entries()) { + const target = merged.get(sectionId) ?? new Set() + for (const index of indexes) { + target.add(index) + } + merged.set(sectionId, target) + } + } + + return new Map( + Array.from(merged.entries()).map(([sectionId, indexes]) => [ + sectionId, + Array.from(indexes).sort((a, b) => a - b), + ]), + ) +} + +async function runWithConcurrency( + items: T[], + limit: number, + worker: (item: T, index: number) => Promise, +): Promise { + if (items.length === 0) return [] + const results = new Array(items.length) + let cursor = 0 + + const runWorker = async () => { + while (true) { + const current = cursor + cursor += 1 + if (current >= items.length) { + return + } + results[current] = await worker(items[current], current) + } + } + + const concurrency = Math.max(1, Math.min(limit, items.length)) + await Promise.all(Array.from({ length: concurrency }, () => runWorker())) + return results +} + +function buildDefaultSectionPlan(): FinalSection[] { + return [ + { + sectionId: 1, + title: "Answer", + objective: "Provide the best complete answer using the mapped evidence.", + }, + ] +} + +function createSelectedImagesResult( + images: FragmentImageReference[], + turnCount: number, +): SelectedImagesResult { + const total = images.length + if (!IMAGE_CONTEXT_CONFIG.enabled || total === 0) { + return { selected: [], total, dropped: [], userAttachmentCount: 0 } + } + + const attachments = images.filter((img) => img.isUserAttachment) + const nonAttachments = images + .filter((img) => !img.isUserAttachment) + .sort((a, b) => { + const ageA = turnCount - a.addedAtTurn + const ageB = turnCount - b.addedAtTurn + return ageA - ageB + }) + + const prioritized = [...attachments, ...nonAttachments] + const uniqueNames: string[] = [] + const seen = new Set() + for (const image of prioritized) { + if (seen.has(image.fileName)) continue + seen.add(image.fileName) + uniqueNames.push(image.fileName) + } + + let selected = uniqueNames + let dropped: string[] = [] + + if ( + IMAGE_CONTEXT_CONFIG.maxImagesPerCall > 0 && + uniqueNames.length > IMAGE_CONTEXT_CONFIG.maxImagesPerCall + ) { + selected = uniqueNames.slice(0, IMAGE_CONTEXT_CONFIG.maxImagesPerCall) + dropped = uniqueNames.slice(IMAGE_CONTEXT_CONFIG.maxImagesPerCall) + } + + return { + selected, + total, + dropped, + userAttachmentCount: attachments.length, + } +} + +function selectImagesForFinalSynthesis( + context: AgentRunContext, +): SelectedImagesResult { + return createSelectedImagesResult(context.allImages, context.turnCount) +} + +function selectImagesForFragmentIds( + context: AgentRunContext, + fragmentIds: Set, +): SelectedImagesResult { + const images = context.allImages.filter((image) => + fragmentIds.has(image.sourceFragmentId), + ) + return createSelectedImagesResult(images, context.turnCount) +} + +function selectMappedEntriesWithinBudget( + entries: FragmentAssignmentBatch[], + baseTokens: number, + budgetTokens: number, +): { + selected: FragmentAssignmentBatch[] + trimmedCount: number + skippedForBudgetCount: number +} { + if (entries.length === 0) { + return { selected: [], trimmedCount: 0, skippedForBudgetCount: 0 } + } + + const selected: FragmentAssignmentBatch[] = [] + let usedTokens = baseTokens + let skippedForBudgetCount = 0 + + for (const entry of entries) { + const entryTokens = estimateTextTokens( + `${formatFragmentWithMetadata(entry.fragment, entry.fragmentIndex - 1)}\n\n`, + ) + if (usedTokens + entryTokens > budgetTokens) { + skippedForBudgetCount += 1 + continue + } + selected.push(entry) + usedTokens += entryTokens + } + + return { + selected, + trimmedCount: Math.max(entries.length - selected.length, 0), + skippedForBudgetCount, + } +} + +function buildSectionAnswerPayload( + context: AgentRunContext, + sections: FinalSection[], + section: FinalSection, + entries: FragmentAssignmentBatch[], + imageFileNames: string[], +): { systemPrompt: string; userMessage: string; imageFileNames: string[] } { + const sharedContext = buildSharedFinalAnswerContext(context) + const sectionOverview = formatSectionPlanOverview(sections) + const fragmentsText = formatSectionFragments(entries) + const fragmentsSection = fragmentsText + ? `Context Fragments For This Section:\n${fragmentsText}` + : "Context Fragments For This Section:\nNone." + + const userMessage = [ + sharedContext, + `All Planned Sections (generated in parallel; a final ordered answer will be assembled later):\n${sectionOverview}`, + `Assigned Section:\n${section.sectionId}. ${section.title}\nObjective: ${section.objective}`, + [ + "Section Instructions:", + "- Write only this section.", + "- Do not write the full final answer.", + "- Other sections are being generated in parallel.", + "- A final ordered answer will be assembled afterwards.", + "- Avoid intro or outro language that assumes the whole answer is already visible.", + "- Use the provided global fragment indexes exactly as shown for citations.", + ].join("\n"), + fragmentsSection, + ] + .filter(Boolean) + .join("\n\n") + + return { + systemPrompt: buildBaseFinalAnswerSystemPrompt("section"), + userMessage, + imageFileNames, + } +} + +function decideSynthesisMode( + context: AgentRunContext, + modelId: Models, + imageSelection: SelectedImagesResult, +): SynthesisModeSelection { + const payload = buildFinalSynthesisPayload( + context, + Math.max(12, context.allFragments.length || 1), + ) + const { maxInputTokens, safeInputBudget } = estimateSafeInputBudget( + modelId, + context.maxOutputTokens, + ) + const estimatedInputTokens = estimatePromptTokens( + payload.systemPrompt, + payload.userMessage, + imageSelection.selected.length, + ) + + return { + mode: estimatedInputTokens > safeInputBudget ? "sectional" : "single", + maxInputTokens, + safeInputBudget, + estimatedInputTokens, + } +} + +async function planSections( + context: AgentRunContext, + providerModelId: Models, + safeInputBudget: number, +): Promise<{ sections: FinalSection[]; estimatedCostUsd: number }> { + throwIfStopRequested(context.stopSignal) + + const previews = context.allFragments.map((fragment, index) => + buildFragmentPreviewRecord(fragment, index + 1), + ) + const sharedContext = buildSharedFinalAnswerContext(context) + const plannerUserIntro = [ + sharedContext, + "Create the final answer section plan using the fragment previews below.", + `Return at most ${MAX_SECTION_COUNT} sections.`, + "Fragment Previews:", + ].join("\n\n") + const baseTokens = + estimatePromptTokens( + buildPlannerSystemPrompt(), + plannerUserIntro, + 0, + ) + 256 + if (baseTokens >= safeInputBudget) { + Logger.warn( + { + baseTokens, + safeInputBudget, + chatId: context.chat.externalId, + }, + "[FinalAnswerSynthesis] Planner base prompt exceeds safe budget; falling back to default section plan.", + ) + return { + sections: buildDefaultSectionPlan(), + estimatedCostUsd: 0, + } + } + + const previewBudget = safeInputBudget - baseTokens + const { includedText, omittedSummary } = buildPreviewTextWithinBudget( + previews, + previewBudget, + ) + const userMessage = [plannerUserIntro, includedText, omittedSummary] + .filter(Boolean) + .join("\n\n") + + let response: ConverseResponse + throwIfStopRequested(context.stopSignal) + try { + response = await raceWithStop( + getProviderByModel(providerModelId).converse( + [ + { + role: ConversationRole.USER, + content: [{ text: userMessage }], + }, + ], + { + modelId: providerModelId, + json: true, + stream: false, + temperature: 0, + max_new_tokens: 800, + systemPrompt: buildPlannerSystemPrompt(), + }, + ), + context.stopSignal, + ) + } catch (error) { + if (isMessageAgentStopError(error)) { + throw error + } + + const { + errorCode, + errorMessage, + errorName, + isContextLengthError, + isTransportError, + } = classifyPlannerFallbackError(error) + Logger.warn( + { + baseTokens, + previewBudget, + errorCode, + errorMessage, + errorName, + isContextLengthError, + isTransportError, + chatId: context.chat.externalId, + }, + "[FinalAnswerSynthesis] Planner request failed; falling back to default section plan.", + ) + return { + sections: buildDefaultSectionPlan(), + estimatedCostUsd: 0, + } + } + + throwIfStopRequested(context.stopSignal) + const parsed = SectionPlanSchema.safeParse( + jsonParseLLMOutput(response.text ?? ""), + ) + + if (!parsed.success) { + Logger.warn( + { + issues: parsed.error.issues, + response: response.text, + chatId: context.chat.externalId, + }, + "[FinalAnswerSynthesis] Invalid planner output; falling back to default section plan.", + ) + return { + sections: buildDefaultSectionPlan(), + estimatedCostUsd: response.cost ?? 0, + } + } + + return { + sections: normalizeSectionPlan(parsed.data), + estimatedCostUsd: response.cost ?? 0, + } +} + +async function mapFragmentsToSections( + context: AgentRunContext, + sections: FinalSection[], + providerModelId: Models, + safeInputBudget: number, +): Promise<{ assignments: Map; estimatedCostUsd: number }> { + throwIfStopRequested(context.stopSignal) + + const sharedContext = buildSharedFinalAnswerContext(context) + const sectionOverview = formatSectionPlanOverview(sections) + const batchIntro = [ + sharedContext, + `Sections:\n${sectionOverview}`, + "Map the following fragments to the relevant section ids.", + "Fragments:", + ].join("\n\n") + const baseTokens = + estimatePromptTokens(buildMapperSystemPrompt(), batchIntro, 0) + 256 + const entries = context.allFragments.map((fragment, index) => ({ + fragmentIndex: index + 1, + fragment, + })) + const batches = buildFragmentBatchesWithinBudget( + entries, + baseTokens, + safeInputBudget, + ) + + if (batches.length === 0) { + return { assignments: new Map(), estimatedCostUsd: 0 } + } + + const batchResults = await raceWithStop( + runWithConcurrency( + batches, + MAPPER_CONCURRENCY, + async (batch) => { + throwIfStopRequested(context.stopSignal) + + const fragmentsText = formatSectionFragments(batch) + const userMessage = [batchIntro, fragmentsText].join("\n\n") + try { + const response = await raceWithStop( + getProviderByModel(providerModelId).converse( + [ + { + role: ConversationRole.USER, + content: [{ text: userMessage }], + }, + ], + { + modelId: providerModelId, + json: true, + stream: false, + temperature: 0, + max_new_tokens: 1_000, + systemPrompt: buildMapperSystemPrompt(), + }, + ), + context.stopSignal, + ) + + throwIfStopRequested(context.stopSignal) + const parsed = SectionMappingSchema.safeParse( + jsonParseLLMOutput(response.text ?? ""), + ) + if (!parsed.success) { + Logger.warn( + { + issues: parsed.error.issues, + response: response.text, + chatId: context.chat.externalId, + }, + "[FinalAnswerSynthesis] Invalid mapper output for batch; skipping batch.", + ) + return { + assignments: new Map>(), + estimatedCostUsd: response.cost ?? 0, + } + } + + return { + assignments: normalizeSectionAssignments(parsed.data, sections), + estimatedCostUsd: response.cost ?? 0, + } + } catch (error) { + if (isMessageAgentStopError(error)) { + throw error + } + + Logger.warn( + { + err: error instanceof Error ? error.message : String(error), + chatId: context.chat.externalId, + }, + "[FinalAnswerSynthesis] Mapper batch failed; skipping batch.", + ) + return { + assignments: new Map>(), + estimatedCostUsd: 0, + } + } + }, + ), + context.stopSignal, + ) + + return { + assignments: mergeSectionAssignments( + batchResults.map((result) => result.assignments), + sections, + ), + estimatedCostUsd: batchResults.reduce( + (sum, result) => sum + result.estimatedCostUsd, + 0, + ), + } +} + +function buildDefaultAssignments( + context: AgentRunContext, + sections: FinalSection[], +): Map { + const indexes = context.allFragments.map((_, index) => index + 1) + return new Map( + sections.map((section) => [section.sectionId, [...indexes]]), + ) +} + +async function synthesizeSingleAnswer( + context: AgentRunContext, + modelId: Models, + imageSelection: SelectedImagesResult, +): Promise { + if (!context.runtime?.streamAnswerText) { + throw new Error("Streaming channel unavailable. Cannot deliver final answer.") + } + if (context.stopSignal?.aborted) { + logFinalSynthesisStop(context, "single:before-stream-start") + return buildStoppedFinalSynthesisResult( + "single", + imageSelection, + 0, + context.finalSynthesis.streamedText.length, + 0, + ) + } + + const { systemPrompt, userMessage } = buildFinalSynthesisPayload(context) + const provider = getProviderByModel(modelId) + let streamedCharacters = 0 + let estimatedCostUsd = 0 + + const iterator = provider.converseStream( + [ + { + role: ConversationRole.USER, + content: [ + { + text: `${userMessage}\n\nSynthesize the final answer using the evidence above.`, + }, + ], + }, + ], + { + modelId, + systemPrompt, + stream: true, + temperature: 0.2, + max_new_tokens: context.maxOutputTokens ?? FALLBACK_OUTPUT_TOKENS, + imageFileNames: imageSelection.selected, + }, + ) + + let stopRequested = false + try { + throwIfStopRequested(context.stopSignal) + for await (const chunk of iterator) { + throwIfStopRequested(context.stopSignal) + + if (chunk.text) { + await streamFinalAnswerChunk(context, chunk.text) + streamedCharacters += chunk.text.length + throwIfStopRequested(context.stopSignal) + } + + const chunkCost = chunk.metadata?.cost + if (typeof chunkCost === "number" && !Number.isNaN(chunkCost)) { + estimatedCostUsd += chunkCost + } + } + throwIfStopRequested(context.stopSignal) + } catch (error) { + if (!isMessageAgentStopError(error)) { + throw error + } + + stopRequested = true + logFinalSynthesisStop(context, "single:streaming", { + estimatedCostUsd, + streamedCharacters, + }) + return buildStoppedFinalSynthesisResult( + "single", + imageSelection, + estimatedCostUsd, + streamedCharacters, + imageSelection.selected.length, + ) + } finally { + if (stopRequested || context.stopSignal?.aborted) { + await cancelConverseIterator(iterator) + } + } + + return { + textLength: streamedCharacters, + totalImagesAvailable: imageSelection.total, + imagesProvided: imageSelection.selected.length, + estimatedCostUsd, + mode: "single", + } +} + +async function synthesizeSection( + context: AgentRunContext, + sections: FinalSection[], + section: FinalSection, + mappedIndexes: number[], + providerModelId: Models, + safeInputBudget: number, +): Promise<{ + result: SectionAnswerResult | null + estimatedCostUsd: number + imageFileNames: string[] +}> { + throwIfStopRequested(context.stopSignal) + + const orderedEntries = mappedIndexes + .map((index) => ({ + fragmentIndex: index, + fragment: context.allFragments[index - 1], + })) + .filter((entry) => !!entry.fragment) as FragmentAssignmentBatch[] + + if (orderedEntries.length === 0) { + return { result: null, estimatedCostUsd: 0, imageFileNames: [] } + } + + let emptyPayload = buildSectionAnswerPayload( + context, + sections, + section, + [], + [], + ) + let baseTokens = + estimatePromptTokens(emptyPayload.systemPrompt, emptyPayload.userMessage, 0) + + 128 + let { selected, trimmedCount, skippedForBudgetCount } = + selectMappedEntriesWithinBudget( + orderedEntries, + baseTokens, + safeInputBudget, + ) + let fragmentIds = new Set(selected.map((entry) => entry.fragment.id)) + let imageSelection = selectImagesForFragmentIds(context, fragmentIds) + + emptyPayload = buildSectionAnswerPayload( + context, + sections, + section, + [], + imageSelection.selected, + ) + const imageAwareBaseTokens = + estimatePromptTokens( + emptyPayload.systemPrompt, + emptyPayload.userMessage, + imageSelection.selected.length, + ) + 128 + + if (imageAwareBaseTokens > baseTokens) { + baseTokens = imageAwareBaseTokens + const imageAwareSelection = selectMappedEntriesWithinBudget( + orderedEntries, + baseTokens, + safeInputBudget, + ) + selected = imageAwareSelection.selected + trimmedCount = imageAwareSelection.trimmedCount + skippedForBudgetCount = imageAwareSelection.skippedForBudgetCount + fragmentIds = new Set(selected.map((entry) => entry.fragment.id)) + imageSelection = selectImagesForFragmentIds(context, fragmentIds) + } + + if (trimmedCount > 0) { + loggerWithChild({ email: context.user.email }).info( + { + chatId: context.chat.externalId, + sectionId: section.sectionId, + orderedEntryCount: orderedEntries.length, + selectedCount: selected.length, + trimmedCount, + skippedForBudgetCount, + imageCount: imageSelection.selected.length, + }, + "[FinalAnswerSynthesis] Trimmed mapped fragments to fit section input budget.", + ) + } + + if (selected.length === 0) { + loggerWithChild({ email: context.user.email }).info( + { + chatId: context.chat.externalId, + sectionId: section.sectionId, + orderedEntryCount: orderedEntries.length, + skippedForBudgetCount, + }, + "[FinalAnswerSynthesis] No mapped fragments fit within the section input budget; omitting section.", + ) + return { result: null, estimatedCostUsd: 0, imageFileNames: [] } + } + + const payload = buildSectionAnswerPayload( + context, + sections, + section, + selected, + imageSelection.selected, + ) + const sectionMaxTokens = Math.min( + context.maxOutputTokens ?? FALLBACK_OUTPUT_TOKENS, + Math.max( + 250, + Math.ceil( + (context.maxOutputTokens ?? FALLBACK_OUTPUT_TOKENS) / + Math.max(sections.length, 1), + ), + ), + ) + + throwIfStopRequested(context.stopSignal) + try { + const response = await raceWithStop( + getProviderByModel(providerModelId).converse( + [ + { + role: ConversationRole.USER, + content: [{ text: payload.userMessage }], + }, + ], + { + modelId: providerModelId, + stream: false, + temperature: 0.2, + max_new_tokens: sectionMaxTokens, + systemPrompt: payload.systemPrompt, + imageFileNames: payload.imageFileNames, + }, + ), + context.stopSignal, + ) + + throwIfStopRequested(context.stopSignal) + const body = response.text?.trim() ?? "" + return { + result: body + ? { + sectionId: section.sectionId, + title: section.title, + body, + } + : null, + estimatedCostUsd: response.cost ?? 0, + imageFileNames: payload.imageFileNames, + } + } catch (error) { + if (isMessageAgentStopError(error)) { + throw error + } + + Logger.warn( + { + err: error instanceof Error ? error.message : String(error), + chatId: context.chat.externalId, + sectionId: section.sectionId, + }, + "[FinalAnswerSynthesis] Section synthesis failed; omitting section.", + ) + return { + result: null, + estimatedCostUsd: 0, + imageFileNames: payload.imageFileNames, + } + } +} + +function assembleSectionAnswers(results: SectionAnswerResult[]): string { + return results + .sort((a, b) => a.sectionId - b.sectionId) + .map((result) => `**${result.title}**\n${result.body}`) + .join("\n\n") + .trim() +} + +async function synthesizeSectionalAnswer( + context: AgentRunContext, + modelId: Models, + imageSelection: SelectedImagesResult, + safeInputBudget: number, +): Promise { + let estimatedCostUsd = 0 + const uniqueImagesProvided = new Set() + + try { + throwIfStopRequested(context.stopSignal) + + await raceWithStop( + context.runtime?.emitReasoning?.({ + text: `Final synthesis exceeded the model input budget. Switching to sectional synthesis across ${context.allFragments.length} fragments.`, + step: { type: AgentReasoningStepType.LogMessage }, + }) ?? Promise.resolve(), + context.stopSignal, + ) + + throwIfStopRequested(context.stopSignal) + const planned = await raceWithStop( + planSections(context, modelId, safeInputBudget), + context.stopSignal, + ) + estimatedCostUsd += planned.estimatedCostUsd + let sections = planned.sections + + throwIfStopRequested(context.stopSignal) + const mapped = await raceWithStop( + mapFragmentsToSections(context, sections, modelId, safeInputBudget), + context.stopSignal, + ) + estimatedCostUsd += mapped.estimatedCostUsd + + let assignments = mapped.assignments + const hasAssignments = Array.from(assignments.values()).some( + (indexes) => indexes.length > 0, + ) + + if (!hasAssignments) { + sections = buildDefaultSectionPlan() + assignments = buildDefaultAssignments(context, sections) + } + + throwIfStopRequested(context.stopSignal) + const sectionResults = await raceWithStop( + runWithConcurrency( + sections, + SECTION_CONCURRENCY, + async (section) => + synthesizeSection( + context, + sections, + section, + assignments.get(section.sectionId) ?? [], + modelId, + safeInputBudget, + ), + ), + context.stopSignal, + ) + + estimatedCostUsd += sectionResults.reduce( + (sum, result) => sum + result.estimatedCostUsd, + 0, + ) + + for (const result of sectionResults) { + for (const imageName of result.imageFileNames) { + uniqueImagesProvided.add(imageName) + } + } + + let assembledText = assembleSectionAnswers( + sectionResults + .map((result) => result.result) + .filter((result): result is SectionAnswerResult => !!result), + ) + + if (!assembledText) { + throwIfStopRequested(context.stopSignal) + const fallbackSections = buildDefaultSectionPlan() + const fallbackAssignments = buildDefaultAssignments( + context, + fallbackSections, + ) + const fallback = await raceWithStop( + synthesizeSection( + context, + fallbackSections, + fallbackSections[0], + fallbackAssignments.get(1) ?? [], + modelId, + safeInputBudget, + ), + context.stopSignal, + ) + estimatedCostUsd += fallback.estimatedCostUsd + assembledText = assembleSectionAnswers( + fallback.result ? [fallback.result] : [], + ) + if (fallback.result) { + sectionResults.push(fallback) + } + for (const imageName of fallback.imageFileNames) { + uniqueImagesProvided.add(imageName) + } + } + + throwIfStopRequested(context.stopSignal) + if (!assembledText) { + throw new Error("Sectional final synthesis produced no answer text.") + } + + context.finalSynthesis.streamedText = "" + let streamedCharacters = 0 + for ( + let offset = 0; + offset < assembledText.length; + offset += FINAL_SYNTHESIS_STREAM_CHUNK_SIZE + ) { + throwIfStopRequested(context.stopSignal) + const chunk = assembledText.slice( + offset, + offset + FINAL_SYNTHESIS_STREAM_CHUNK_SIZE, + ) + await streamFinalAnswerChunk(context, chunk) + streamedCharacters += chunk.length + } + + return { + textLength: streamedCharacters, + totalImagesAvailable: imageSelection.total, + imagesProvided: uniqueImagesProvided.size, + estimatedCostUsd, + mode: "sectional", + } + } catch (error) { + if (!isMessageAgentStopError(error)) { + throw error + } + + logFinalSynthesisStop(context, "sectional", { + estimatedCostUsd, + streamedCharacters: context.finalSynthesis.streamedText.length, + imagesProvided: uniqueImagesProvided.size, + }) + return buildStoppedFinalSynthesisResult( + "sectional", + imageSelection, + estimatedCostUsd, + context.finalSynthesis.streamedText.length, + uniqueImagesProvided.size, + ) + } +} + +export async function executeFinalSynthesis( + context: AgentRunContext, +): Promise { + const modelId = + (context.modelId as Models) || + (defaultBestModel as Models) || + Models.Gpt_4o + const imageSelection = selectImagesForFinalSynthesis(context) + const modeSelection = decideSynthesisMode(context, modelId, imageSelection) + + loggerWithChild({ email: context.user.email }).debug( + { + chatId: context.chat.externalId, + mode: modeSelection.mode, + maxInputTokens: modeSelection.maxInputTokens, + safeInputBudget: modeSelection.safeInputBudget, + estimatedInputTokens: modeSelection.estimatedInputTokens, + fragmentsCount: context.allFragments.length, + selectedImages: imageSelection.selected, + droppedImages: imageSelection.dropped, + userAttachmentCount: imageSelection.userAttachmentCount, + }, + "[FinalAnswerSynthesis] Selected final synthesis mode.", + ) + + if (imageSelection.dropped.length > 0) { + loggerWithChild({ email: context.user.email }).info( + { + chatId: context.chat.externalId, + droppedCount: imageSelection.dropped.length, + limit: IMAGE_CONTEXT_CONFIG.maxImagesPerCall, + totalImages: imageSelection.total, + }, + "[FinalAnswerSynthesis] Image limit enforced for single-shot selection.", + ) + } + + if (context.stopSignal?.aborted) { + logFinalSynthesisStop(context, "execute:before-mode-branch") + return buildStoppedFinalSynthesisResult( + modeSelection.mode, + imageSelection, + 0, + context.finalSynthesis.streamedText.length, + 0, + ) + } + + return modeSelection.mode === "single" + ? synthesizeSingleAnswer(context, modelId, imageSelection) + : synthesizeSectionalAnswer( + context, + modelId, + imageSelection, + modeSelection.safeInputBudget, + ) +} + +export const __finalAnswerSynthesisInternals = { + buildFragmentPreviewRecord, + buildSectionAnswerPayload, + decideSynthesisMode, + selectImagesForFragmentIds, + selectMappedEntriesWithinBudget, + synthesizeSection, +} diff --git a/server/api/chat/message-agents.ts b/server/api/chat/message-agents.ts index 8ddde8c57..ee96fa85b 100644 --- a/server/api/chat/message-agents.ts +++ b/server/api/chat/message-agents.ts @@ -138,6 +138,11 @@ import { parseAgentAppIntegrations } from "./tools/utils" import { buildAgentPromptAddendum } from "./agentPromptCreation" import { getConnectorById } from "@/db/connector" import { getToolsByConnectorId } from "@/db/tool" +import { + executeFinalSynthesis, + formatClarificationsForPrompt, + formatPlanForPrompt, +} from "./final-answer-synthesis" import { buildMCPJAFTools, type FinalToolsList, @@ -159,17 +164,16 @@ import { getUserPersonalizationByEmail } from "@/db/personalization" import { getChunkCountPerDoc } from "./chunk-selection" import { getPrecomputedDbContextIfNeeded } from "@/lib/databaseContext" import { - buildAgentSystemPromptContextBlock, enforceMetadataConstraintsOnSelection, extractMetadataConstraintsFromUserMessage, formatFragmentWithMetadata, - formatFragmentsWithMetadata, rankFragmentsByMetadataConstraints, sanitizeAgentSystemPromptSnapshot, withAgentSystemPromptMessage, } from "./message-agents-metadata" export { __messageAgentsMetadataInternals } from "./message-agents-metadata" +export { buildFinalSynthesisPayload } from "./final-answer-synthesis" const { defaultBestModel, @@ -722,33 +726,6 @@ function getMetadataValue( return undefined } -function formatPlanForPrompt(plan: PlanState | null): string { - if (!plan) return "" - const lines = [`Goal: ${plan.goal}`] - plan.subTasks.forEach((task, idx) => { - const icon = - task.status === "completed" - ? "✓" - : task.status === "in_progress" - ? "→" - : task.status === "failed" - ? "✗" - : task.status === "blocked" - ? "!" - : "○" - const baseLine = `${idx + 1}. [${icon}] ${task.description}` - const detailParts: string[] = [] - if (task.result) detailParts.push(`Result: ${task.result}`) - if (task.toolsRequired?.length) { - detailParts.push(`Tools: ${task.toolsRequired.join(", ")}`) - } - lines.push( - detailParts.length > 0 ? `${baseLine}\n ${detailParts.join(" | ")}` : baseLine - ) - }) - return lines.join("\n") -} - function selectActiveSubTaskId(plan: PlanState | null): string | null { if (!plan || !Array.isArray(plan.subTasks) || plan.subTasks.length === 0) { return null @@ -860,180 +837,6 @@ function advancePlanAfterTool( } } -function formatClarificationsForPrompt( - clarifications: AgentRunContext["clarifications"] -): string { - if (!clarifications?.length) return "" - const formatted = clarifications - .map( - (clarification, idx) => - `${idx + 1}. Q: ${clarification.question}\n A: ${clarification.answer}` - ) - .join("\n") - return formatted -} - -export function buildFinalSynthesisPayload( - context: AgentRunContext, - fragmentsLimit = Math.max(12, context.allFragments.length || 1) -): { systemPrompt: string; userMessage: string } { - const fragments = context.allFragments - const agentSystemPromptBlock = buildAgentSystemPromptContextBlock( - context.dedicatedAgentSystemPrompt - ) - const agentSystemPromptSection = agentSystemPromptBlock - ? `Agent System Prompt Context:\n${agentSystemPromptBlock}` - : "" - const formattedFragments = formatFragmentsWithMetadata(fragments, fragmentsLimit) - const fragmentsSection = formattedFragments - ? `Context Fragments:\n${formattedFragments}` - : "" - const planSection = formatPlanForPrompt(context.plan) - const clarificationSection = formatClarificationsForPrompt(context.clarifications) - const workspaceSection = context.userContext?.trim() - ? `Workspace Context:\n${context.userContext}` - : "" - - const parts = [ - `User Question:\n${context.message.text}`, - agentSystemPromptSection, - planSection ? `Execution Plan Snapshot:\n${planSection}` : "", - clarificationSection ? `Clarifications Resolved:\n${clarificationSection}` : "", - workspaceSection, - fragmentsSection, - ].filter(Boolean) - - const userMessage = parts.join("\n\n") - - const systemPrompt = ` -### Mission -- Deliver the user's final answer using the conversation, plan snapshot, clarifications, workspace context, context fragments, and supplied images; never plan or call tools. - -### Evidence Intake -- Prioritize the highest-signal fragments, but pull any supporting fragment that improves accuracy. -- Only draw on context that directly answers the user's question; ignore unrelated fragments even if they were retrieved earlier. -- Treat delegated-agent outputs as citeable fragments; reference them like any other context entry. -- Describe evidence gaps plainly before concluding; never guess. -- Extract actionable details from provided images and cite them via their fragment indices. -- Respect user-imposed constraints using fragment metadata (any metadata field). If compliant evidence is missing, state that clearly. -- If "This is the system prompt of agent:" is present, analyse for instructions relevant for answering and strictly bind by them . - -### Response Construction -- Lead with the conclusion, then stack proof underneath. -- Organize output into tight sections (e.g., **Summary**, **Proof**, **Next Steps** when relevant); omit empty sections. -- Never mention internal tooling, planning logs, or this synthesis process. - -### Constraint Handling -- When the user asks for an action the system cannot execute (e.g., sending an email), deliver the closest actionable substitute (draft, checklist, explicit next steps) inside the answer. -- Pair the substitute with a concise explanation of the limitation and the manual action the user must take. - -### File & Chunk Formatting (CRITICAL) -- Each file starts with a header line exactly like: - index {docId} {file context begins here...} -- \`docId\` is a unique identifier for that file (e.g., 0, 1, 2, etc.). -- Inside the file context, text is split into chunks. -- Each chunk might begin with a bracketed numeric index, e.g.: [0], [1], [2], etc. -- This is the chunk index within that file, if it exists. - -### Guidelines for Response -1. Data Interpretation: - - Use ONLY the provided files and their chunks as your knowledge base. - - Treat every file header \`index {docId} ...\` as the start of a new document. - - Treat every bracketed number like [0], [1], [2] as the authoritative chunk index within that document. - - If dates exist, interpret them relative to the user's timezone when paraphrasing. -2. Response Structure: - - Start with the most relevant facts from the chunks across files. - - Keep order chronological when it helps comprehension. - - Every factual statement MUST cite the exact chunk it came from using the format: - K[docId_chunkIndex] - where: - - \`docId\` is taken from the file header line ("index {docId} ..."). - - \`chunkIndex\` is the bracketed number prefixed on that chunk within the same file. - - Examples: - - Single citation: "X is true K[12_3]." - - Two citations in one sentence (from different files or chunks): "X K[12_3] and Y K[7_0]." - - Use at most 1-2 citations per sentence; NEVER add more than 2 for one sentence. -3. Citation Rules (DOCUMENT+CHUNK LEVEL ONLY): - - ALWAYS cite at the chunk level with the K[docId_chunkIndex] format. - - Every chunk level citation must start with the K prefix eg. K[12_3] K[7_0] correct, but K[12_3] [7_0] is incorrect. - - Place the citation immediately after the relevant claim. - - Do NOT group indices inside one set of brackets (WRONG: "K[12_3,7_1]"). - - If a sentence draws on two distinct chunks (possibly from different files), include two separate citations inline, e.g., "... K[12_3] ... K[7_1]". - - Only cite information that appears verbatim or is directly inferable from the cited chunk. - - If you cannot ground a claim to a specific chunk, do not make the claim. -4. Quality Assurance: - - Cross-check across multiple chunks/files when available and briefly note inconsistencies if they exist. - - Keep tone professional and concise. - - Acknowledge gaps if the provided chunks don't contain enough detail. - -### Tone & Delivery -- Answer with confident, declarative, verb-first sentences that use concrete nouns. -- Highlight key deliverables using **bold** labels or short lists; keep wording razor-concise. -- Ask one targeted follow-up question only if missing info blocks action. - -### Tool Spotlighting -- Reference critical tool outputs explicitly, e.g., "**Slack Search:** Ops escalated the RCA at 09:42 [2]." -- Explain why each highlighted tool mattered so reviewers see coverage breadth. -- When multiple tools contribute, show the sequence, e.g., "**Vespa Search:** context -> **Sheet Lookup:** metrics." - -### Finish -- Close with a single sentence confirming completion or the next action you recommend. -`.trim() - - return { systemPrompt, userMessage } -} - -function selectImagesForFinalSynthesis( - context: AgentRunContext -): { - selected: string[] - total: number - dropped: string[] - userAttachmentCount: number -} { - const images = context.allImages - const total = images.length - if (!IMAGE_CONTEXT_CONFIG.enabled || total === 0) { - return { selected: [], total, dropped: [], userAttachmentCount: 0 } - } - - const attachments = images.filter((img) => img.isUserAttachment) - const nonAttachments = images - .filter((img) => !img.isUserAttachment) - .sort((a, b) => { - const ageA = context.turnCount - a.addedAtTurn - const ageB = context.turnCount - b.addedAtTurn - return ageA - ageB - }) - - const prioritized = [...attachments, ...nonAttachments] - const uniqueNames: string[] = [] - const seen = new Set() - for (const image of prioritized) { - if (seen.has(image.fileName)) continue - seen.add(image.fileName) - uniqueNames.push(image.fileName) - } - - let selected = uniqueNames - let dropped: string[] = [] - - if ( - IMAGE_CONTEXT_CONFIG.maxImagesPerCall > 0 && - uniqueNames.length > IMAGE_CONTEXT_CONFIG.maxImagesPerCall - ) { - selected = uniqueNames.slice(0, IMAGE_CONTEXT_CONFIG.maxImagesPerCall) - dropped = uniqueNames.slice(IMAGE_CONTEXT_CONFIG.maxImagesPerCall) - } - - return { - selected, - total, - dropped, - userAttachmentCount: attachments.length, - } -} - function buildAttachmentToolMessage( fragments: MinimalAgentFragment[], summary: string @@ -3308,37 +3111,14 @@ function createFinalSynthesisTool(): Tool { ) } - const streamAnswer = mutableContext.runtime?.streamAnswerText - if (!streamAnswer) { + if (!mutableContext.runtime?.streamAnswerText) { return ToolResponse.error( "EXECUTION_FAILED", "Streaming channel unavailable. Cannot deliver final answer." ) } - const { selected, total, dropped, userAttachmentCount } = - selectImagesForFinalSynthesis(context) - loggerWithChild({ email: context.user.email }).debug( - { - chatId: context.chat.externalId, - selectedImages: selected, - totalImages: total, - droppedImages: dropped, - userAttachmentCount, - }, - "[MessageAgents][FinalSynthesis] Image payload" - ) - - const { systemPrompt, userMessage } = buildFinalSynthesisPayload(context) const fragmentsCount = context.allFragments.length - loggerWithChild({ email: context.user.email }).debug( - { - chatId: context.chat.externalId, - finalSynthesisSystemPrompt: systemPrompt, - finalSynthesisUserMessage: userMessage, - }, - "[MessageAgents][FinalSynthesis] Full context payload" - ) mutableContext.finalSynthesis.requested = true mutableContext.finalSynthesis.suppressAssistantStreaming = true @@ -3346,93 +3126,20 @@ function createFinalSynthesisTool(): Tool { mutableContext.finalSynthesis.streamedText = "" await mutableContext.runtime?.emitReasoning?.({ - text: `Initiating final synthesis with ${fragmentsCount} context fragments and ${selected.length}/${total} images (${userAttachmentCount} user attachments).`, + text: `Initiating final synthesis with ${fragmentsCount} context fragments.`, step: { type: AgentReasoningStepType.LogMessage }, }) - const logger = loggerWithChild({ email: context.user.email }) - if (dropped.length > 0) { - logger.info( - { - droppedCount: dropped.length, - limit: IMAGE_CONTEXT_CONFIG.maxImagesPerCall, - totalImages: total, - }, - "Final synthesis image limit enforced; dropped oldest references." - ) - } - - const modelId = - (context.modelId as Models) || (defaultBestModel as Models) || Models.Gpt_4o - const modelParams: ModelParams = { - modelId, - systemPrompt, - stream: true, - temperature: 0.2, - max_new_tokens: context.maxOutputTokens ?? 1500, - imageFileNames: selected, - } - - const finalUserPrompt = `${userMessage}\n\nSynthesize the final answer using the evidence above.` - const messages: Message[] = [ - { - role: ConversationRole.USER, - content: [{ text: finalUserPrompt }], - }, - ] - Logger.debug( - { - email: context.user.email, - chatId: context.chat.externalId, - fragmentsCount: context.allFragments.length, - planPresent: !!context.plan, - clarificationsCount: context.clarifications.length, - toolOutputsThisTurn: context.currentTurnArtifacts.toolOutputs.length, - imageNames: selected, - }, - "[MessageAgents][FinalSynthesis] Context summary for synthesis call" - ) - - Logger.debug({ - email: context.user.email, - chatId: context.chat.externalId, - modelId, - systemPrompt, - messagesCount: messages.length, - imagesProvided: selected.length, - }, "[MessageAgents][FinalSynthesis] LLM call parameters") - - const provider = getProviderByModel(modelId) - let streamedCharacters = 0 - let estimatedCostUsd = 0 - try { - const iterator = provider.converseStream(messages, modelParams) - for await (const chunk of iterator) { - if (chunk.text) { - streamedCharacters += chunk.text.length - context.finalSynthesis.streamedText += chunk.text - await streamAnswer(chunk.text) - } - const chunkCost = chunk.metadata?.cost - if (typeof chunkCost === "number" && !Number.isNaN(chunkCost)) { - estimatedCostUsd += chunkCost - } - } - + const synthesisResult = await executeFinalSynthesis(mutableContext) + throwIfStopRequested(mutableContext.stopSignal) context.finalSynthesis.completed = true - loggerWithChild({ email: context.user.email }).debug( - { - chatId: context.chat.externalId, - streamedCharacters, - estimatedCostUsd, - imagesProvided: selected, - }, - "[MessageAgents][FinalSynthesis] LLM call completed" - ) await context.runtime?.emitReasoning?.({ - text: "Final synthesis completed and streamed to the user.", + text: + synthesisResult.mode === "sectional" + ? "Sectional final synthesis completed and delivered to the user." + : "Final synthesis completed and streamed to the user.", step: { type: AgentReasoningStepType.LogMessage }, }) @@ -3441,20 +3148,27 @@ function createFinalSynthesisTool(): Tool { result: "Final answer streamed to user.", streamed: true, metadata: { - textLength: streamedCharacters, - totalImagesAvailable: total, - imagesProvided: selected.length, + textLength: synthesisResult.textLength, + totalImagesAvailable: synthesisResult.totalImagesAvailable, + imagesProvided: synthesisResult.imagesProvided, }, }, { - estimatedCostUsd, + estimatedCostUsd: synthesisResult.estimatedCostUsd, } ) } catch (error) { + if (isMessageAgentStopError(error)) { + context.finalSynthesis.suppressAssistantStreaming = false + context.finalSynthesis.requested = false + context.finalSynthesis.completed = false + throw error + } + context.finalSynthesis.suppressAssistantStreaming = false context.finalSynthesis.requested = false context.finalSynthesis.completed = false - logger.error( + loggerWithChild({ email: context.user.email }).error( { err: error instanceof Error ? error.message : String(error) }, "Final synthesis tool failed." ) diff --git a/server/api/workflow.ts b/server/api/workflow.ts index 7ba3d2c00..0e0414002 100644 --- a/server/api/workflow.ts +++ b/server/api/workflow.ts @@ -5631,7 +5631,7 @@ export const GetVertexAIModelEnumsApi = async (c: Context) => { // Get model enum names for workflow tools export const GetModelEnumsApi = async (c: Context) => { try { - const { fetchModelInfoFromAPI } = await import("@/ai/fetchModels") + const { getLiteLLMWorkflowModels } = await import("@/ai/fetchModels") const { MODEL_CONFIGURATIONS } = await import("@/ai/modelConfig") const { AIProviders } = await import("@/ai/types") @@ -5647,80 +5647,7 @@ export const GetModelEnumsApi = async (c: Context) => { // For LiteLLM provider, fetch models from API if (activeProvider === AIProviders.LiteLLM) { - const apiModels = await fetchModelInfoFromAPI() - - // Filter models with litellm_provider === "hosted_vllm" - const hostedVllmModels = apiModels.filter( - (m: any) => m.model_info?.litellm_provider === "hosted_vllm" - ) - - // Use Set to track seen model IDs to avoid duplicates - const seenModelIds = new Set() - const modelEnums: Array<{ - enumValue: string - labelName: string - actualName: string - description: string - reasoning: boolean - websearch: boolean - deepResearch: boolean - modelType: string - }> = [] - - for (const modelInfo of hostedVllmModels) { - const modelId = modelInfo.model_name - - // Skip duplicates - if (seenModelIds.has(modelId)) { - continue - } - seenModelIds.add(modelId) - - const actualName = modelInfo.litellm_params?.model || modelId - - // Check if model exists in MODEL_CONFIGURATIONS for additional metadata - const modelConfig = MODEL_CONFIGURATIONS[modelId as keyof typeof MODEL_CONFIGURATIONS] - - if (modelConfig) { - // Use configuration from MODEL_CONFIGURATIONS - modelEnums.push({ - enumValue: modelId, - labelName: modelConfig.labelName, - actualName: actualName, - description: modelConfig.description, - reasoning: modelConfig.reasoning, - websearch: modelConfig.websearch, - deepResearch: modelConfig.deepResearch, - modelType: modelId.includes('gemini') ? 'gemini' : - modelId.includes('claude') ? 'claude' : 'other', - }) - } else { - // For models not in MODEL_CONFIGURATIONS, use API data with defaults - modelEnums.push({ - enumValue: modelId, - labelName: modelId, // Use model_name as label - actualName: actualName, - description: modelInfo.model_info?.description || "", - reasoning: modelInfo.model_info?.reasoning ?? false, - websearch: modelInfo.model_info?.websearch ?? false, - deepResearch: modelInfo.model_info?.deep_research ?? false, - modelType: modelId.includes('gemini') ? 'gemini' : - modelId.includes('claude') ? 'claude' : 'other', - }) - } - } - - // Sort by model type and then by label name - modelEnums.sort((a, b) => { - const typeOrder: Record = { claude: 1, gemini: 2, other: 3 }; - const orderA = typeOrder[a.modelType] ?? 99; - const orderB = typeOrder[b.modelType] ?? 99; - - if (orderA !== orderB) { - return orderA - orderB; - } - return a.labelName.localeCompare(b.labelName); - }) + const modelEnums = await getLiteLLMWorkflowModels() console.log(`Model enums for LiteLLM provider from API:`, modelEnums.length, "models") diff --git a/server/logger/index.ts b/server/logger/index.ts index 119615143..7550e95e2 100644 --- a/server/logger/index.ts +++ b/server/logger/index.ts @@ -1,4 +1,4 @@ -import { levels, pino, type Logger } from "pino" +import { pino, type Logger } from "pino" import { Subsystem, type loggerChildSchema } from "@/types" import type { MiddlewareHandler, Context, Next } from "hono" import { getPath } from "hono/utils/url" diff --git a/server/shared/types.ts b/server/shared/types.ts index 325e87dc1..c992d9c34 100644 --- a/server/shared/types.ts +++ b/server/shared/types.ts @@ -827,6 +827,8 @@ export interface ModelConfiguration { websearch: boolean deepResearch: boolean description: string + maxInputTokens?: number + maxOutputTokens?: number } export const getDocumentSchema = z.object({ docId: z.string().min(1), diff --git a/server/tests/fetchModels.test.ts b/server/tests/fetchModels.test.ts new file mode 100644 index 000000000..2d5405def --- /dev/null +++ b/server/tests/fetchModels.test.ts @@ -0,0 +1,128 @@ +import { afterEach, beforeEach, describe, expect, test } from "bun:test" +import { + __modelInfoInternals, + getEffectiveMaxOutputTokens, + getModelTokenLimits, +} from "@/ai/fetchModels" +import { Models } from "@/ai/types" + +describe("fetchModels normalization", () => { + beforeEach(() => { + __modelInfoInternals.resetModelInfoCacheForTests() + }) + + afterEach(() => { + __modelInfoInternals.resetModelInfoCacheForTests() + }) + + test("normalizes duplicate model records conservatively", () => { + const normalized = __modelInfoInternals.normalizeModelInfoRecords([ + { + model_name: "kimi-latest", + litellm_params: { model: "hosted_vllm/kimi-k2-5-dev" }, + model_info: { + max_input_tokens: 262_000, + max_output_tokens: 16_000, + }, + }, + { + model_name: "kimi-latest", + litellm_params: { model: "hosted_vllm/kimi-k2-5-dev" }, + model_info: { + max_input_tokens: 200_000, + max_output_tokens: 8_000, + }, + }, + ]) + + const metadata = normalized.get("kimi-latest") + expect(metadata?.maxInputTokens).toBe(200_000) + expect(metadata?.maxOutputTokens).toBe(8_000) + expect(metadata?.hasConflictingMaxInputTokens).toBe(true) + expect(metadata?.hasConflictingMaxOutputTokens).toBe(true) + }) + + test("ignores null token limits when another duplicate has values", () => { + const normalized = __modelInfoInternals.normalizeModelInfoRecords([ + { + model_name: "glm-latest", + litellm_params: { model: "openai/zai-org/GLM-5-Dev" }, + model_info: { + max_input_tokens: null, + max_output_tokens: null, + }, + }, + { + model_name: "glm-latest", + litellm_params: { model: "openai/zai-org/GLM-5-Dev" }, + model_info: { + max_input_tokens: 128_000, + max_output_tokens: 12_000, + }, + }, + ]) + + const metadata = normalized.get("glm-latest") + expect(metadata?.maxInputTokens).toBe(128_000) + expect(metadata?.maxOutputTokens).toBe(12_000) + expect(metadata?.hasConflictingMaxInputTokens).toBe(false) + expect(metadata?.hasConflictingMaxOutputTokens).toBe(false) + }) + + test("resolves model token limits by enum id and actual provider model name", () => { + __modelInfoInternals.setModelInfoCacheForTests([ + { + model_name: Models.GLM_LATEST, + litellm_params: { model: "openai/zai-org/GLM-5-Dev" }, + model_info: { + max_input_tokens: 128_000, + max_output_tokens: 32_000, + }, + }, + ]) + + expect(getModelTokenLimits(Models.GLM_LATEST)).toEqual({ + maxInputTokens: 128_000, + maxOutputTokens: 32_000, + }) + expect(getModelTokenLimits("openai/zai-org/GLM-5-Dev")).toEqual({ + maxInputTokens: 128_000, + maxOutputTokens: 32_000, + }) + }) + + test("falls back to generic input default and preserves requested output when upstream output max is absent", () => { + __modelInfoInternals.setModelInfoCacheForTests([ + { + model_name: Models.KIMI_LATEST, + litellm_params: { model: "hosted_vllm/kimi-k2-5-dev" }, + model_info: { + max_input_tokens: null, + max_output_tokens: null, + }, + }, + ]) + + expect(getModelTokenLimits(Models.KIMI_LATEST)).toEqual({ + maxInputTokens: 128_000, + maxOutputTokens: undefined, + }) + expect(getEffectiveMaxOutputTokens(Models.KIMI_LATEST, 1_500)).toBe(1_500) + }) + + test("clamps requested output tokens to the upstream model maximum", () => { + __modelInfoInternals.setModelInfoCacheForTests([ + { + model_name: Models.Claude_Sonnet_4, + litellm_params: { model: "us.anthropic.claude-sonnet-4-20250514-v1:0" }, + model_info: { + max_input_tokens: 200_000, + max_output_tokens: 4_096, + }, + }, + ]) + + expect(getEffectiveMaxOutputTokens(Models.Claude_Sonnet_4, 8_000)).toBe(4_096) + expect(getEffectiveMaxOutputTokens(Models.Claude_Sonnet_4, 2_000)).toBe(2_000) + }) +}) diff --git a/server/tests/finalAnswerSynthesis.test.ts b/server/tests/finalAnswerSynthesis.test.ts new file mode 100644 index 000000000..4946c6f19 --- /dev/null +++ b/server/tests/finalAnswerSynthesis.test.ts @@ -0,0 +1,285 @@ +import { afterEach, beforeEach, describe, expect, test } from "bun:test" +import type { AgentRunContext } from "@/api/chat/agent-schemas" +import { + __finalAnswerSynthesisInternals, + buildFinalSynthesisPayload, +} from "@/api/chat/final-answer-synthesis" +import { __modelInfoInternals } from "@/ai/fetchModels" +import type { MinimalAgentFragment } from "@/api/chat/types" +import { Models } from "@/ai/types" +import { Apps } from "@xyne/vespa-ts/types" + +const baseFragment: MinimalAgentFragment = { + id: "doc-1", + content: "Quarterly ARR grew 12% and pipeline coverage improved.", + source: { + docId: "doc-1", + title: "ARR Summary", + url: "https://example.com/doc-1", + app: Apps.KnowledgeBase, + entity: "file" as any, + }, + confidence: 0.9, +} + +const createMockContext = (): AgentRunContext => ({ + user: { + email: "tester@example.com", + workspaceId: "workspace", + id: "user-1", + }, + chat: { + externalId: "chat-1", + metadata: {}, + }, + message: { + text: "How is ARR tracking?", + attachments: [], + timestamp: new Date().toISOString(), + }, + modelId: Models.Gpt_4o, + plan: null, + currentSubTask: null, + userContext: "", + agentPrompt: undefined, + dedicatedAgentSystemPrompt: undefined, + clarifications: [], + ambiguityResolved: true, + toolCallHistory: [], + seenDocuments: new Set(), + allFragments: [], + turnFragments: new Map(), + allImages: [], + imagesByTurn: new Map(), + recentImages: [], + currentTurnArtifacts: { + fragments: [], + expectations: [], + toolOutputs: [], + images: [], + }, + turnCount: 1, + totalLatency: 0, + totalCost: 0, + tokenUsage: { input: 0, output: 0 }, + availableAgents: [], + usedAgents: [], + enabledTools: new Set(), + delegationEnabled: true, + failedTools: new Map(), + retryCount: 0, + maxRetries: 3, + review: { + lastReviewTurn: null, + reviewFrequency: 5, + outstandingAnomalies: [], + clarificationQuestions: [], + lastReviewResult: null, + lockedByFinalSynthesis: false, + lockedAtTurn: null, + }, + decisions: [], + finalSynthesis: { + requested: false, + completed: false, + suppressAssistantStreaming: false, + streamedText: "", + ackReceived: false, + }, + stopRequested: false, +}) + +const createFragment = ( + id: string, + content: string, +): MinimalAgentFragment => ({ + ...baseFragment, + id, + content, + source: { + ...baseFragment.source, + docId: id, + title: `Title ${id}`, + url: `https://example.com/${id}`, + }, +}) + +describe("final-answer-synthesis", () => { + beforeEach(() => { + __modelInfoInternals.resetModelInfoCacheForTests() + }) + + afterEach(() => { + __modelInfoInternals.resetModelInfoCacheForTests() + }) + + test("builds deterministic fragment previews from raw fragment content", () => { + const preview = __finalAnswerSynthesisInternals.buildFragmentPreviewRecord( + { + ...baseFragment, + content: + " Quarterly ARR grew 12% and pipeline coverage improved.\n\nCustomers expanded seats. ", + source: { + ...baseFragment.source, + createdAt: "2026-03-08T09:00:00.000Z", + }, + }, + 7, + ) + + expect(preview.fragmentIndex).toBe(7) + expect(preview.docId).toBe("doc-1") + expect(preview.title).toBe("ARR Summary") + expect(preview.app).toBe(String(Apps.KnowledgeBase)) + expect(preview.previewText).toBe( + "Quarterly ARR grew 12% and pipeline coverage improved. Customers expanded seats.", + ) + expect(preview.timestamp).toBe("2026-03-08T09:00:00.000Z") + }) + + test("section payload keeps shared context and adds section-only instructions", () => { + const context = createMockContext() + context.dedicatedAgentSystemPrompt = + "You are an enterprise agent. Always use verified workspace evidence." + context.allFragments = [baseFragment] + + const payload = __finalAnswerSynthesisInternals.buildSectionAnswerPayload( + context, + [ + { + sectionId: 1, + title: "Summary", + objective: "Summarize the ARR status.", + }, + { + sectionId: 2, + title: "Evidence", + objective: "Provide supporting evidence.", + }, + ], + { + sectionId: 2, + title: "Evidence", + objective: "Provide supporting evidence.", + }, + [{ fragmentIndex: 3, fragment: baseFragment }], + ["3_doc-1_0"], + ) + + expect(payload.systemPrompt).toContain("Deliver only the assigned answer section") + expect(payload.userMessage).toContain( + "All Planned Sections (generated in parallel; a final ordered answer will be assembled later):", + ) + expect(payload.userMessage).toContain("Assigned Section:\n2. Evidence") + expect(payload.userMessage).toContain("Write only this section.") + expect(payload.userMessage).toContain("index 3 {file context begins here...}") + expect(payload.userMessage).toContain("Agent System Prompt Context:") + expect(payload.imageFileNames).toEqual(["3_doc-1_0"]) + }) + + test("switches to sectional mode when full final payload exceeds the model input budget", () => { + __modelInfoInternals.setModelInfoCacheForTests([ + { + model_name: Models.Gpt_4, + litellm_params: { + model: "gpt-4", + }, + model_info: { + max_input_tokens: 8_192, + max_output_tokens: 4_096, + }, + }, + ]) + + const context = createMockContext() + context.modelId = Models.Gpt_4 + context.allFragments = [ + { + ...baseFragment, + content: "A".repeat(40_000), + }, + ] + + const payload = buildFinalSynthesisPayload(context) + expect(payload.userMessage.length).toBeGreaterThan(0) + + const decision = __finalAnswerSynthesisInternals.decideSynthesisMode( + context, + Models.Gpt_4, + { + selected: [], + total: 0, + dropped: [], + userAttachmentCount: 0, + }, + ) + + expect(decision.mode).toBe("sectional") + expect(decision.estimatedInputTokens).toBeGreaterThan(decision.safeInputBudget) + }) + + test("skips an oversize first mapped fragment and keeps later entries that fit", () => { + const selection = __finalAnswerSynthesisInternals.selectMappedEntriesWithinBudget( + [ + { fragmentIndex: 1, fragment: createFragment("doc-large", "A".repeat(6_000)) }, + { fragmentIndex: 2, fragment: createFragment("doc-small-1", "brief evidence") }, + { fragmentIndex: 3, fragment: createFragment("doc-small-2", "more brief evidence") }, + ], + 100, + 500, + ) + + expect(selection.selected.map((entry) => entry.fragmentIndex)).toEqual([2, 3]) + expect(selection.trimmedCount).toBe(1) + expect(selection.skippedForBudgetCount).toBe(1) + }) + + test("skips an oversize middle mapped fragment without reordering selected entries", () => { + const selection = __finalAnswerSynthesisInternals.selectMappedEntriesWithinBudget( + [ + { fragmentIndex: 1, fragment: createFragment("doc-small-1", "brief evidence") }, + { fragmentIndex: 2, fragment: createFragment("doc-large", "B".repeat(6_000)) }, + { fragmentIndex: 3, fragment: createFragment("doc-small-2", "more brief evidence") }, + ], + 100, + 500, + ) + + expect(selection.selected.map((entry) => entry.fragmentIndex)).toEqual([1, 3]) + expect(selection.trimmedCount).toBe(1) + expect(selection.skippedForBudgetCount).toBe(1) + }) + + test("omits section synthesis when no mapped fragment fits within the safe budget", async () => { + const context = createMockContext() + context.allFragments = [ + createFragment("doc-large-1", "A".repeat(6_000)), + createFragment("doc-large-2", "B".repeat(6_000)), + ] + + const result = await __finalAnswerSynthesisInternals.synthesizeSection( + context, + [ + { + sectionId: 1, + title: "Answer", + objective: "Provide the best complete answer using the mapped evidence.", + }, + ], + { + sectionId: 1, + title: "Answer", + objective: "Provide the best complete answer using the mapped evidence.", + }, + [1, 2], + Models.Gpt_4, + 500, + ) + + expect(result).toEqual({ + result: null, + estimatedCostUsd: 0, + imageFileNames: [], + }) + }) +})