diff --git a/bot/admin/web/src/app/rag/rag-settings/models/engines-configurations.ts b/bot/admin/web/src/app/rag/rag-settings/models/engines-configurations.ts index 925980ad0e..d862251d20 100644 --- a/bot/admin/web/src/app/rag/rag-settings/models/engines-configurations.ts +++ b/bot/admin/web/src/app/rag/rag-settings/models/engines-configurations.ts @@ -27,35 +27,102 @@ import { PromptDefinitionFormatter } from '../../../shared/model/ai-settings'; -export const QuestionCondensingDefaultPrompt: string = `Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is.`; +export const QuestionCondensingDefaultPrompt: string = `You are a helpful assistant that reformulates questions. -export const QuestionAnsweringDefaultPrompt: string = `# TOCK (The Open Conversation Kit) chatbot - -## General context - -You are a chatbot designed to provide short conversational messages in response to user queries. - -## Guidelines - -Incorporate any relevant details from the provided context into your answers, ensuring they are directly related to the user's query. +You are given: +- The conversation history between the user and the assistant +- The most recent user question -## Style and format +Your task: +- Reformulate the user’s latest question into a clear, standalone query. +- Incorporate relevant context from the conversation history. +- Do NOT answer the question. +- If the history does not provide additional context, keep the question as is. -Your tone is empathetic, informative and polite. +Return only the reformulated question.`; -## Additional instructions - -Use the following pieces of retrieved context to answer the question. -If you dont know the answer, answer (exactly) with "{{no_answer}}". -Answer in {{locale}}. - -## Context - -{{context}} - -## Question +export const QuestionAnsweringDefaultPrompt: string = `# TOCK (The Open Conversation Kit) chatbot -{{question}} +## Instructions: +You must answer STRICTLY in valid JSON format (no extra text, no explanations). +Use only the following context and the rules below to answer the question. + +### Rules for JSON output: + +- If the answer is found in the context: + - "status": "found_in_context" + +- If the answer is NOT found in the context: + - "status": "not_found_in_context" + - "answer": + - The "answer" must not be a generic refusal. Instead, generate a helpful and intelligent response: + - If a similar or related element exists in the context (e.g., another product, service, or regulation with a close name, date, or wording), suggest it naturally in the answer. + - If no similar element exists, politely acknowledge the lack of information while encouraging clarification or rephrasing. + - Always ensure the response is phrased in a natural and user-friendly way, rather than a dry "not found in context". + +- If the question matches a special case defined below: + - "status": "" + +And for all cases (MANDATORY): + - "answer": "" + - "topic": "" + - "suggested_topics": [""] + +Exception: If the question is small talk (only to conversational rituals such as greetings (e.g., “hello”, “hi”) and farewells or leave-takings (e.g., “goodbye”, “see you”) ), you may ignore the context and generate a natural small-talk response in the "answer". In this case: + - "status": "small_talk" + - "topic": "" + - "suggested_topics": [] + - "context": [] + +### Context tracing requirements (MANDATORY): +- You MUST include **every** chunk from the input context in the "context" array, in the same order they appear. **No chunk may be omitted**. +- If explicit chunk identifiers are present in the context, use them; otherwise assign sequential numbers starting at 1. +- For each chunk object: + - "chunk": "" + - "sentences": [""] — leave empty \`[]\` if none. + - "reason": null if the chunk contributed; otherwise a concise explanation of why this chunk is not relevant to the question (e.g., "general background only", "different product", "no data for the asked period", etc.). +- If there are zero chunks in the context, return \`"context": []\`. + +### Predefined list of topics (use EXACT spelling, no variations): + +## Context: +{{ context }} + +## Conversation history +{{ chat_history }} + +## User question +{{ question }} + +## Output format (JSON only): +Return your response in the following format: + +{ + "status": "found_on_context" | "not_in_context" | "small_talk", + "answer": "TEXTUAL_ANSWER", + "topic": "EXACT_TOPIC_FROM_LIST_OR_UNKNOWN", + "suggested_topics": [ + "SUGGESTED_TOPIC_1", + "SUGGESTED_TOPIC_2" + ], + "context": [ + { + "chunk": "1", + "sentences": ["SENTENCE_1", "SENTENCE_2"], + "reason": null + }, + { + "chunk": "2", + "sentences": [], + "reason": "General description; no details related to the question." + }, + { + "chunk": "3", + "sentences": ["SENTENCE_X"], + "reason": null + } + ] +} `; export const QuestionCondensing_prompt: ProvidersConfigurationParam[] = [ diff --git a/bot/engine/src/main/kotlin/admin/bot/rag/BotRAGConfiguration.kt b/bot/engine/src/main/kotlin/admin/bot/rag/BotRAGConfiguration.kt index 0665f3a1a5..6680b11237 100644 --- a/bot/engine/src/main/kotlin/admin/bot/rag/BotRAGConfiguration.kt +++ b/bot/engine/src/main/kotlin/admin/bot/rag/BotRAGConfiguration.kt @@ -36,6 +36,7 @@ data class BotRAGConfiguration( val llmSetting: LLMSetting? = null, val emSetting: EMSetting, val indexSessionId: String? = null, + @Deprecated("Replaced by LLM answer status") val noAnswerSentence: String, val noAnswerStoryId: String? = null, val documentsRequired: Boolean = true, diff --git a/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt b/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt index c43698900d..c3f97cfb23 100644 --- a/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt +++ b/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt @@ -31,9 +31,9 @@ import ai.tock.bot.engine.action.SendSentenceWithFootnotes import ai.tock.bot.engine.dialog.Dialog import ai.tock.bot.engine.user.PlayerType import ai.tock.genai.orchestratorclient.requests.* +import ai.tock.genai.orchestratorclient.responses.LLMAnswer import ai.tock.genai.orchestratorclient.responses.ObservabilityInfo import ai.tock.genai.orchestratorclient.responses.RAGResponse -import ai.tock.genai.orchestratorclient.responses.TextWithFootnotes import ai.tock.genai.orchestratorclient.retrofit.GenAIOrchestratorBusinessError import ai.tock.genai.orchestratorclient.retrofit.GenAIOrchestratorValidationError import ai.tock.genai.orchestratorclient.services.RAGService @@ -60,7 +60,7 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler { BotRepository.saveMetric(createMetric(MetricType.STORY_HANDLED)) // Call RAG Api - Gen AI Orchestrator - val (answer, debug, noAnswerStory, observabilityInfo) = rag(this) + val (answer, footnotes, debug, noAnswerStory, observabilityInfo) = rag(this) // Add debug data if available and if debugging is enabled if (debug != null) { @@ -75,14 +75,18 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler { val modifiedObservabilityInfo = observabilityInfo?.let { updateObservabilityInfo(this, it) } send( - SendSentenceWithFootnotes( - botId, connectorId, userId, text = answer.text, footnotes = answer.footnotes.map { + action = SendSentenceWithFootnotes( + playerId = botId, + applicationId = connectorId, + recipientId = userId, + text = answer.answer, + footnotes = footnotes?.map { Footnote( it.identifier, it.title, it.url, if(action.metadata.sourceWithContent) it.content else null, it.score ) - }.toMutableList(), + }?.toMutableList() ?: mutableListOf(), // modifiedObservabilityInfo includes the public langfuse URL if filled. metadata = ActionMetadata(isGenAiRagAnswer = true, observabilityInfo = modifiedObservabilityInfo) ) @@ -116,13 +120,13 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler { private fun ragStoryRedirection(botBus: BotBus, response: RAGResponse?): StoryDefinition? { return with(botBus) { botDefinition.ragConfiguration?.let { ragConfig -> - if (response?.answer?.text.equals(ragConfig.noAnswerSentence, ignoreCase = true)) { + if (response?.answer?.status.equals("not_found_in_context", ignoreCase = true)) { // Save no answer metric saveRagMetric(IndicatorValues.NO_ANSWER) // Switch to no answer story if configured if (!ragConfig.noAnswerStoryId.isNullOrBlank()) { - logger.info { "The RAG response is equal to the configured no-answer sentence, so switch to the no-answer story." } + logger.info { "Switch to the no-answer RAG story." } getNoAnswerRAGStory(ragConfig) } else null } else { @@ -221,7 +225,7 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler { ) // Handle RAG response - return RAGResult(response?.answer, response?.debug, ragStoryRedirection(this, response), response?.observabilityInfo) + return RAGResult(response?.answer, response?.footnotes, response?.debug, ragStoryRedirection(this, response), response?.observabilityInfo) } catch (exc: Exception) { logger.error { exc } // Save failure metric @@ -232,7 +236,7 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler { RAGResult(noAnswerStory = getNoAnswerRAGStory(ragConfiguration)) } else RAGResult( - answer = TextWithFootnotes(text = technicalErrorMessage), + answer = LLMAnswer(status="error", answer = technicalErrorMessage), debug = when(exc) { is GenAIOrchestratorBusinessError -> RAGError(exc.message, exc.error) is GenAIOrchestratorValidationError -> RAGError(exc.message, exc.detail) @@ -282,7 +286,8 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler { * Aggregation of RAG answer, debug and the no answer Story. */ data class RAGResult( - val answer: TextWithFootnotes? = null, + val answer: LLMAnswer? = null, + val footnotes: List? = null, val debug: Any? = null, val noAnswerStory: StoryDefinition? = null, val observabilityInfo: ObservabilityInfo? = null, diff --git a/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/responses/RAGResponse.kt b/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/responses/RAGResponse.kt index a4dfde4c3e..fb699fd7f3 100644 --- a/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/responses/RAGResponse.kt +++ b/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/responses/RAGResponse.kt @@ -17,7 +17,8 @@ package ai.tock.genai.orchestratorclient.responses data class RAGResponse( - val answer: TextWithFootnotes, + val answer: LLMAnswer, + val footnotes: List = emptyList(), val debug: Any? = null, val observabilityInfo: ObservabilityInfo? = null, ) diff --git a/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/responses/models.kt b/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/responses/models.kt index 4131d334fd..fa0314a143 100644 --- a/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/responses/models.kt +++ b/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/responses/models.kt @@ -17,9 +17,18 @@ package ai.tock.genai.orchestratorclient.responses -data class TextWithFootnotes( - val text: String, - val footnotes: List = emptyList(), +data class ChunkSentences( + val chunk: String? = null, + val sentences: List? = emptyList(), + val reason: String? = null, +) + +data class LLMAnswer( + val status: String, + val answer: String, + val topic: String? = null, + val suggestedTopics: List? = null, + val context: List? = null, ) data class Footnote( diff --git a/gen-ai/orchestrator-core/src/main/kotlin/ai/tock/genai/orchestratorcore/models/llm/OllamaLLMSetting.kt b/gen-ai/orchestrator-core/src/main/kotlin/ai/tock/genai/orchestratorcore/models/llm/OllamaLLMSetting.kt index 0d72aa1de3..28077e3a3f 100644 --- a/gen-ai/orchestrator-core/src/main/kotlin/ai/tock/genai/orchestratorcore/models/llm/OllamaLLMSetting.kt +++ b/gen-ai/orchestrator-core/src/main/kotlin/ai/tock/genai/orchestratorcore/models/llm/OllamaLLMSetting.kt @@ -28,4 +28,3 @@ data class OllamaLLMSetting( } } -// TODO MASS : Check Compile + TU (car dernier commit) diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/rag/rag_models.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/rag/rag_models.py index a20a78e73c..c52d7a42df 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/rag/rag_models.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/rag/rag_models.py @@ -52,16 +52,55 @@ class Footnote(Source): identifier: str = Field(description='Footnote identifier', examples=['1']) +class ChunkInfos(BaseModel): + """A model representing information about a chunk used in the RAG context.""" -class TextWithFootnotes(BaseModel): - """Text with its footnotes. Used for RAG response""" - - text: str = Field( - description='Text with footnotes used to list outside sources', - examples=['This is page content [1], and this is more content [2]'], + chunk: Optional[str] = Field( + description='Unique identifier of the chunk.', + examples=['cd6d8221-ba9f-44da-86ee-0e25a3c9a5c7'], + default=None + ) + sentences: Optional[List[str]] = Field( + description='List of verbatim sentences from the chunk that were used by the LLM.', + default=None ) - footnotes: set[Footnote] = Field(description='Set of footnotes') + reason: Optional[str] = Field( + description='Reason why the chunk was not used (e.g., irrelevant, general background).', + default=None + ) + + +class LLMAnswer(BaseModel): + """ + A model representing the structured answer generated by the LLM + in response to a user query, based on the provided RAG context. + """ + status: Optional[str] = Field( + description="The status of the answer generation. " + "Possible values: 'found_in_context', 'not_found_in_context', 'small_talk', " + "or other case-specific codes.", + default=None + ) + answer: Optional[str] = Field( + description="The textual answer generated by the LLM, in the user's locale.", + default=None + ) + topic: Optional[str] = Field( + description="The main topic assigned to the answer. Must be one of the predefined list " + "of topics, or 'unknown' if no match is possible.", + default=None + ) + suggested_topics: Optional[List[str]] = Field( + description="A list of suggested alternative or related topics, " + "used when the main topic is 'unknown'.", + default=None + ) + context: Optional[List[ChunkInfos]] = Field( + description="The list of chunks from the context that contributed to or were considered " + "in the LLM's answer. Each entry contains identifiers, sentences, and reasons.", + default=None + ) @unique class ChatMessageType(str, Enum): @@ -154,4 +193,4 @@ class RAGDebugData(QADebugData): 'Question: Hello, how to plan a trip to Morocco ?. Answer in French.' ], ) - answer: str = Field(description='The RAG answer.') + answer: LLMAnswer = Field(description='The RAG answer.') diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/responses/responses.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/responses/responses.py index fe633d2f56..bbab4cf8bb 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/responses/responses.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/responses/responses.py @@ -25,9 +25,8 @@ ErrorInfo, ) from gen_ai_orchestrator.models.llm.llm_provider import LLMProvider -from gen_ai_orchestrator.models.rag.rag_models import Source, TextWithFootnotes +from gen_ai_orchestrator.models.rag.rag_models import Source, LLMAnswer, Footnote from gen_ai_orchestrator.models.observability.observability_provider import ObservabilityProvider -from gen_ai_orchestrator.models.rag.rag_models import TextWithFootnotes from gen_ai_orchestrator.models.vector_stores.vectore_store_provider import VectorStoreProvider @@ -122,9 +121,10 @@ class ObservabilityInfo(BaseModel): class RAGResponse(BaseModel): """The RAG response model""" - answer: TextWithFootnotes = Field( - description='The RAG answer, with outside sources.' + answer: Optional[LLMAnswer] = Field( + description='The RAG answer' ) + footnotes: set[Footnote] = Field(description='Set of footnotes') debug: Optional[Any] = Field( description='Debug data', examples=[{'action': 'retrieve', 'result': 'OK', 'errors': []}], diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py index ed7e666cdb..fe8d289118 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py @@ -17,31 +17,30 @@ It uses LangChain to perform a Conversational Retrieval Chain """ +import json import logging import time from functools import partial from logging import ERROR, WARNING -from typing import List, Optional +from operator import itemgetter +from typing import List, Optional, Tuple -from langchain.chains.conversational_retrieval.base import ( - ConversationalRetrievalChain, -) from langchain.retrievers.contextual_compression import ( ContextualCompressionRetriever, ) from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.callbacks import BaseCallbackHandler from langchain_core.documents import Document -from langchain_core.output_parsers import StrOutputParser +from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.output_parsers import StrOutputParser, JsonOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate from langchain_core.runnables import ( RunnableParallel, RunnablePassthrough, - RunnableSerializable, + RunnableSerializable, RunnableConfig, RunnableLambda, ) from langchain_core.vectorstores import VectorStoreRetriever -from langfuse.callback import CallbackHandler as LangfuseCallbackHandler from typing_extensions import Any from gen_ai_orchestrator.errors.exceptions.exceptions import ( @@ -68,11 +67,10 @@ RAGDebugData, RAGDocument, RAGDocumentMetadata, - TextWithFootnotes, + LLMAnswer, ) from gen_ai_orchestrator.routers.requests.requests import RAGRequest from gen_ai_orchestrator.routers.responses.responses import ( - ObservabilityInfo, RAGResponse, ) from gen_ai_orchestrator.services.langchain.callbacks.rag_callback_handler import ( @@ -109,7 +107,7 @@ async def execute_rag_chain( Args: request: The RAG request debug: True if RAG data debug should be returned with the response. - custom_observability_handler: Custom observability handler + custom_observability_handler: Custom observability handler (Used in the tooling run_experiment.py script) Returns: The RAG response (Answer and document sources) """ @@ -120,96 +118,105 @@ async def execute_rag_chain( conversational_retrieval_chain = create_rag_chain(request=request) message_history = ChatMessageHistory() - session_id = None - user_id = None - tags = [] if request.dialog: for msg in request.dialog.history: if ChatMessageType.HUMAN == msg.type: message_history.add_user_message(msg.text) else: message_history.add_ai_message(msg.text) - session_id = (request.dialog.dialog_id,) - user_id = (request.dialog.user_id,) - tags = (request.dialog.tags,) - logger.debug( - 'RAG chain - Use chat history: %s', - 'Yes' if len(message_history.messages) > 0 else 'No', - ) + logger.debug('RAG chain - Use chat history: %s', len(message_history.messages) > 0) + logger.debug('RAG chain - Use RAGCallbackHandler for debugging : %s', debug) + + records_handler, observability_handler = get_callback_handlers(request, debug) + + callbacks = [ + handler + for handler in (records_handler, observability_handler, custom_observability_handler) + if handler is not None + ] inputs = { **request.question_answering_prompt.inputs, 'chat_history': message_history.messages, } - logger.debug( - 'RAG chain - Use RAGCallbackHandler for debugging : %s', - debug, - ) - - callback_handlers = [] - records_callback_handler = RAGCallbackHandler() - observability_handler = None - if debug: - # Debug callback handler - callback_handlers.append(records_callback_handler) - if custom_observability_handler is not None: - callback_handlers.append(custom_observability_handler) - if request.observability_setting is not None: - # Langfuse callback handler - observability_handler = create_observability_callback_handler( - observability_setting=request.observability_setting, - trace_name=ObservabilityTrace.RAG.value, - session_id=session_id, - user_id=user_id, - tags=tags, - ) - callback_handlers.append(observability_handler) - response = await conversational_retrieval_chain.ainvoke( input=inputs, - config={'callbacks': callback_handlers}, + config=RunnableConfig(callbacks=callbacks) ) + llm_answer = LLMAnswer(**response['answer']) # RAG Guard - rag_guard(inputs, response, request.documents_required) + rag_guard(inputs, llm_answer, response, request.documents_required) # Guardrail if request.guardrail_setting: guardrail = get_guardrail_factory( setting=request.guardrail_setting ).get_parser() - guardrail_output = guardrail.parse(response['answer']) + guardrail_output = guardrail.parse(llm_answer.answer) check_guardrail_output(guardrail_output) # Calculation of RAG processing time rag_duration = '{:.2f}'.format(time.time() - start_time) logger.info('RAG chain - End of execution. (Duration : %s seconds)', rag_duration) + # Group contexts by chunk id + contexts_by_chunk = { + ctx.chunk: ctx + for ctx in (llm_answer.context or []) + if ctx.sentences + } + # Returning RAG response return RAGResponse( - answer=TextWithFootnotes( - text=response['answer'], - footnotes=set( - map( - lambda doc: Footnote( - identifier=doc.metadata['id'], - title=doc.metadata['title'], - url=doc.metadata['source'], - content=get_source_content(doc), - score=doc.metadata.get('retriever_score', None), - ), - response['documents'], - ) - ), - ), + answer=llm_answer, + footnotes={ + Footnote( + identifier=doc.metadata['id'], + title=doc.metadata['title'], + url=doc.metadata['source'], + content=get_source_content(doc), + score=doc.metadata.get('retriever_score', None), + ) + for doc in response["documents"] + if doc.metadata['id'] in contexts_by_chunk + }, observability_info=get_observability_info(observability_handler), - debug=get_rag_debug_data(request, records_callback_handler, rag_duration) + debug=get_rag_debug_data(request, records_handler, rag_duration) if debug else None, ) +def get_callback_handlers(request, debug) -> Tuple[ + Optional[RAGCallbackHandler], + Optional[object], +]: + records_handler = RAGCallbackHandler() if debug else None + observability_handler = None + + if request.observability_setting is not None: + if request.dialog: + session_id = request.dialog.dialog_id + user_id = request.dialog.user_id + tags = request.dialog.tags + else: + session_id = None + user_id = None + tags = None + observability_handler = create_observability_callback_handler( + observability_setting=request.observability_setting, + trace_name=ObservabilityTrace.RAG.value, + session_id=session_id, + user_id=user_id, + tags=tags, + ) + + return ( + records_handler, + observability_handler, + ) def get_source_content(doc: Document) -> str: """ @@ -279,31 +286,62 @@ def create_rag_chain( if question_condensing_llm_factory is not None: question_condensing_llm = question_condensing_llm_factory.get_language_model() question_answering_llm = question_answering_llm_factory.get_language_model() - rag_prompt = build_rag_prompt(request) - # Construct the RAG chain using the prompt and LLM, - # This chain will consume the documents retrieved by the retriever as input. - rag_chain = construct_rag_chain(question_answering_llm, rag_prompt) + # Fallback in case of missing condensing LLM setting using the answering LLM setting. + if question_condensing_llm is not None: + condensing_llm = question_condensing_llm + else : + condensing_llm = question_answering_llm # Build the chat chain for question contextualization - chat_chain = build_question_condensation_chain( - question_condensing_llm - if question_condensing_llm is not None - else question_answering_llm, - request.question_condensing_prompt, - ) + chat_chain = build_question_condensation_chain(condensing_llm, request.question_condensing_prompt) + rag_prompt = build_rag_prompt(request) # Function to contextualize the question based on chat history contextualize_question_fn = partial(contextualize_question, chat_chain=chat_chain) - # Final RAG chain with retriever and source documents - rag_chain_with_retriever = ( - contextualize_question_fn - | RunnableParallel({'documents': retriever, 'question': RunnablePassthrough()}) - | RunnablePassthrough.assign(answer=rag_chain) - ) - - return rag_chain_with_retriever + # Calculate the condensed question + with_condensed_question = RunnableParallel({ + "condensed_question": contextualize_question_fn, + "question": itemgetter("question"), + "chat_history": itemgetter("chat_history"), + }) + + def retrieve_with_variants(inputs): + variants = [ + # inputs["question"], Deactivated. It's an example to prove the multi retriever process + inputs["condensed_question"] + ] + docs = [] + for v in variants: + docs.extend(retriever.invoke(v)) + # Deduplicate docs + unique_docs = {d.metadata['id']: d for d in docs} + + # TODO [DERCBOT-1649] Apply the RRF Algo on unique_docs. + return list(unique_docs.values()) + + # Build the RAG inputs + rag_inputs = with_condensed_question | RunnableParallel({ + "question": itemgetter("condensed_question"), + "chat_history": itemgetter("chat_history"), + "documents": RunnableLambda(retrieve_with_variants), + }) + + return rag_inputs | RunnablePassthrough.assign(answer=( + { + "context": lambda x: json.dumps([ + { + "chunk_id": doc.metadata['id'], + "chunk_text": doc.page_content, + } + for doc in x["documents"] + ], ensure_ascii=False, indent=2), + "chat_history": format_chat_history, + } + | rag_prompt + | question_answering_llm + | JsonOutputParser(pydantic_object=LLMAnswer, name="rag_chain_output"))) def build_rag_prompt(request: RAGRequest) -> LangChainPromptTemplate: @@ -316,25 +354,14 @@ def build_rag_prompt(request: RAGRequest) -> LangChainPromptTemplate: partial_variables=request.question_answering_prompt.inputs, ) - -def construct_rag_chain(llm, rag_prompt): - """ - Construct the RAG chain from LLM and prompt. - """ - return ( - { - 'context': lambda inputs: '\n\n'.join( - doc.page_content for doc in inputs['documents'] - ), - 'question': lambda inputs: inputs[ - 'question' - ], # Override the user's original question with the condensed one - } - | rag_prompt - | llm - | StrOutputParser(name='rag_chain_output') - ) - +def format_chat_history(x): + messages = [] + for msg in x["chat_history"]: + if isinstance(msg, HumanMessage): + messages.append({"user": msg.content}) + elif isinstance(msg, AIMessage): + messages.append({"assistant": msg.content}) + return json.dumps(messages, ensure_ascii=False, indent=2) def build_question_condensation_chain( llm, prompt: Optional[PromptTemplate] @@ -342,14 +369,27 @@ def build_question_condensation_chain( """ Build the chat chain for contextualizing questions. """ + # TODO deprecated : All Gen configurations are supposed to have this prompt now. It is mandatory in the RAG configuration. if prompt is None: # Default prompt prompt = PromptTemplate( formatter=PromptFormatter.F_STRING, inputs={}, - template='Given a chat history and the latest user question which might reference context in \ -the chat history, formulate a standalone question which can be understood without the chat history. \ -Do NOT answer the question, just reformulate it if needed and otherwise return it as is.', + template=""" +You are a helpful assistant that reformulates questions. + +You are given: +- The conversation history between the user and the assistant +- The most recent user question + +Your task: +- Reformulate the user’s latest question into a clear, standalone query. +- Incorporate relevant context from the conversation history. +- Do NOT answer the question. +- If the history does not provide additional context, keep the question as is. + +Return only the reformulated question. +""" ) return ( @@ -373,51 +413,39 @@ def contextualize_question(inputs: dict, chat_chain) -> str: return chat_chain return inputs['question'] - -def rag_guard(inputs, response, documents_required): +def rag_guard(question, answer, response, documents_required): """ Validates the RAG system's response based on the presence or absence of source documents and the `documentsRequired` setting. Args: - inputs: question answering prompt inputs + question: user question + answer: the LLM answer response: the RAG response documents_required (bool): Specifies whether documents are mandatory for the response. """ - no_docs_retrieved = response['documents'] == [] - no_docs_but_required = no_docs_retrieved and documents_required - chain_can_give_no_answer_reply = 'no_answer' in inputs - chain_reply_no_answer = False - - if chain_can_give_no_answer_reply: - chain_reply_no_answer = response['answer'] == inputs['no_answer'] - - if no_docs_but_required: - if chain_can_give_no_answer_reply and chain_reply_no_answer: - # We expect the chain to use its non-response value, and it has done so, which is the expected behavior. - return - # Everything else isn't expected - message = 'The RAG system cannot provide an answer when no documents are found and documents are required' - rag_log(level=ERROR, message=message, inputs=inputs, response=response) + if documents_required and answer.status == "found_in_context" and len(response['documents']) == 0: + message = 'No documents were retrieved, yet an answer was attempted.' + rag_log(level=ERROR, message=message, question=question, answer=answer.answer, response=response) raise GenAIGuardCheckException(ErrorInfo(cause=message)) - if chain_reply_no_answer and not no_docs_retrieved: - # If the chain responds with its non-response value and the documents are retrieved, - # so we remove them from the RAG response. - message = 'The RAG gives no answer for user question, but some documents has been found!' - rag_log(level=WARNING, message=message, inputs=inputs, response=response) + if answer.status == "not_found_in_context" and len(response['documents']) > 0: + # If the answer is not found in context and some documents are retrieved, so we remove them from the RAG response. + message = 'No answer found in the retrieved context. The documents are therefore removed from the RAG response.' + rag_log(level=WARNING, message=message, question=question, answer=answer.answer, response=response) response['documents'] = [] -def rag_log(level, message, inputs, response): +def rag_log(level, message, question, answer, response): """ RAG logging Args: level: logging level message: message to log - inputs: question answering prompt inputs + question: question answering prompt inputs + answer: LLM answer response: the RAG response """ @@ -427,9 +455,9 @@ def rag_log(level, message, inputs, response): 'RAG chain - question="%(question)s", answer="%(answer)s", documents="%(documents)s"', { 'message': message, - 'question': inputs['question'], - 'answer': response['answer'], - 'documents': response['documents'], + 'question': question, + 'answer': answer, + 'documents': len(response['documents']), }, ) @@ -451,6 +479,8 @@ def get_rag_documents(handler: RAGCallbackHandler) -> List[RAGDocument]: for doc in handler.records['documents'] ] +def get_llm_answer(rag_chain_output) -> LLMAnswer: + return LLMAnswer(**json.loads(rag_chain_output.strip().removeprefix("```json").removesuffix("```").strip())) def get_rag_debug_data( request: RAGRequest, records_callback_handler: RAGCallbackHandler, rag_duration @@ -470,7 +500,7 @@ def get_rag_debug_data( documents=get_rag_documents(records_callback_handler), document_index_name=request.document_index_name, document_search_params=request.document_search_params, - answer=records_callback_handler.records['rag_chain_output'], + answer=get_llm_answer(records_callback_handler.records['rag_chain_output']), duration=rag_duration, ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py index 7a1b5edb1b..8d88a9629c 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py @@ -28,9 +28,9 @@ from gen_ai_orchestrator.models.guardrail.bloomz.bloomz_guardrail_setting import ( BloomzGuardrailSetting, ) +from gen_ai_orchestrator.models.rag.rag_models import LLMAnswer from gen_ai_orchestrator.routers.requests.requests import RAGRequest from gen_ai_orchestrator.services.langchain import rag_chain - from gen_ai_orchestrator.services.langchain.factories.langchain_factory import ( get_guardrail_factory, ) @@ -48,19 +48,19 @@ @patch('gen_ai_orchestrator.services.langchain.rag_chain.RAGCallbackHandler') @patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_guard') @patch('gen_ai_orchestrator.services.langchain.rag_chain.RAGResponse') -@patch('gen_ai_orchestrator.services.langchain.rag_chain.TextWithFootnotes') @patch('gen_ai_orchestrator.services.langchain.rag_chain.RAGDebugData') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.get_llm_answer') @pytest.mark.asyncio async def test_rag_chain( - mocked_rag_debug_data, - mocked_text_with_footnotes, - mocked_rag_response, - mocked_rag_guard, - mocked_callback_init, - mocked_create_rag_chain, - mocked_get_callback_handler_factory, - mocked_get_document_compressor_factory, - mocked_guardrail_parse, + mocked_get_llm_answer, + mocked_rag_debug_data, + mocked_rag_response, + mocked_rag_guard, + mocked_callback_init, + mocked_create_rag_chain, + mocked_get_callback_handler_factory, + mocked_get_document_compressor_factory, + mocked_guardrail_parse, ): """Test the full execute_qa_chain method by mocking all external calls.""" # Build a test RAGRequest @@ -90,7 +90,7 @@ async def test_rag_chain( {question} Answer in {locale}:""", - 'inputs' : { + 'inputs': { 'question': 'How to get started playing guitar ?', 'no_answer': 'Sorry, I don t know.', 'locale': 'French', @@ -154,9 +154,19 @@ async def test_rag_chain( } docs = [Document( page_content='some page content', - metadata={'id':'123-abc', 'title':'my-title', 'source': None}, + metadata={'id': '123-abc', 'title': 'my-title', 'source': None}, )] - response = {'answer': 'an answer from llm', 'documents': docs} + response = { + 'answer': { + 'status': '', + 'answer': 'an answer from llm', + 'topic': None, + 'suggested_topics': None, + 'context': [] + }, + 'documents': docs + } + llm_answer = LLMAnswer(**response['answer']) # Setup mock factories/init return value observability_factory_instance = mocked_get_callback_handler_factory.return_value @@ -186,10 +196,8 @@ async def test_rag_chain( ) # Assert the response is build using the expected settings mocked_rag_response.assert_called_once_with( - # TextWithFootnotes must be mocked or mapping the footnotes will fail - answer=mocked_text_with_footnotes( - text=mocked_rag_answer['answer'], footnotes=[] - ), + answer=llm_answer, + footnotes=set(), debug=mocked_rag_debug_data(request, mocked_rag_answer, mocked_callback, 1), observability_info=None ) @@ -199,23 +207,32 @@ async def test_rag_chain( # Assert the rag guardrail is called mocked_guardrail_parse.assert_called_once_with( os.path.join(request.guardrail_setting.api_base, 'guardrail'), - json={'text': [mocked_rag_answer['answer']]}, + json={'text': [mocked_rag_answer['answer']['answer']]}, ) # Assert the rag guard is called mocked_rag_guard.assert_called_once_with( - inputs, response, request.documents_required + inputs, llm_answer, response, request.documents_required ) + @patch('gen_ai_orchestrator.services.langchain.impls.guardrail.bloomz_guardrail.requests.post') def test_guardrail_parse_succeed_with_toxicities_encountered( - mocked_guardrail_response, + mocked_guardrail_response, ): guardrail = get_guardrail_factory( BloomzGuardrailSetting( provider='BloomzGuardrail', max_score=0.5, api_base='http://test-guard.com' ) ).get_parser() - rag_response = {'answer': 'This is a sample text.'} + rag_response = { + 'answer': { + 'status': '', + 'answer': 'This is a sample text.', + 'topic': None, + 'suggested_topics': None, + 'context': [] + } + } mocked_response = MagicMock() mocked_response.status_code = 200 @@ -231,11 +248,11 @@ def test_guardrail_parse_succeed_with_toxicities_encountered( } mocked_guardrail_response.return_value = mocked_response - guardrail_output = guardrail.parse(rag_response['answer']) + guardrail_output = guardrail.parse(rag_response['answer']['answer']) mocked_guardrail_response.assert_called_once_with( os.path.join(guardrail.endpoint, 'guardrail'), - json={'text': [rag_response['answer']]}, + json={'text': [rag_response['answer']['answer']]}, ) assert guardrail_output == { 'content': 'This is a sample text.', @@ -251,21 +268,29 @@ def test_guardrail_parse_fail(mocked_guardrail_response): provider='BloomzGuardrail', max_score=0.5, api_base='http://test-guard.com' ) ).get_parser() - rag_response = {'answer': 'This is a sample text.'} + rag_response = { + 'answer': { + 'status': '', + 'answer': 'This is a sample text.', + 'topic': None, + 'suggested_topics': None, + 'context': [] + } + } mocked_response = MagicMock() mocked_response.status_code = 500 mocked_guardrail_response.return_value = mocked_response with pytest.raises( - HTTPError, - match=f"Error {mocked_response.status_code}. Bloomz guardrail didn't respond as expected.", + HTTPError, + match=f"Error {mocked_response.status_code}. Bloomz guardrail didn't respond as expected.", ): - guardrail.parse(rag_response['answer']) + guardrail.parse(rag_response['answer']['answer']) mocked_guardrail_response.assert_called_once_with( os.path.join(guardrail.endpoint, 'guardrail'), - json={'text': [rag_response['answer']]}, + json={'text': [rag_response['answer']['answer']]}, ) @@ -409,59 +434,83 @@ def test_check_guardrail_output_is_ok(): @patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_fails_if_no_docs_in_valid_answer(mocked_log): - inputs = {'no_answer': "Sorry, I don't know."} + question = 'Hi!' response = { - 'answer': 'a valid answer', + 'answer': { + 'status': 'found_in_context', + 'answer': 'a valid answer' + }, 'documents': [], } try: - rag_chain.rag_guard(inputs, response,documents_required=True) + rag_chain.rag_guard(question, LLMAnswer(**response['answer']), response, documents_required=True) except Exception as e: assert isinstance(e, GenAIGuardCheckException) @patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_accepts_no_answer_even_with_docs(mocked_log): - inputs = {'no_answer': "Sorry, I don't know."} + question = 'Hi!' response = { - 'answer': "Sorry, I don't know.", + 'answer': { + 'status': 'not_found_in_context', + 'answer': 'Sorry, I don t know.', + 'context': [ + { + 'chunk': 1, + 'sentences': ["str1"], + } + ] + }, 'documents': ['a doc as a string'], } - rag_chain.rag_guard(inputs, response, documents_required=True) - assert response['documents'] == ['a doc as a string'] + rag_chain.rag_guard(question, LLMAnswer(**response['answer']), response, documents_required=True) + # No answer found in the retrieved context. The documents are therefore removed from the RAG response. + assert response['documents'] == [] @patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_valid_answer_with_docs(mocked_log): - inputs = {'no_answer': "Sorry, I don't know."} + question = 'Hi!' response = { - 'answer': 'a valid answer', + 'answer': { + 'status': 'found_in_context', + 'answer': 'a valid answer', + }, 'documents': ['doc1', 'doc2'], } - rag_chain.rag_guard(inputs, response, documents_required=True) + rag_chain.rag_guard(question, LLMAnswer(**response['answer']), response, documents_required=True) assert response['documents'] == ['doc1', 'doc2'] + @patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_no_answer_with_no_docs(mocked_log): - inputs = {'no_answer': "Sorry, I don't know."} + question = 'Hi!' response = { - 'answer': "Sorry, I don't know.", + 'answer': { + 'status': 'not_found_in_context', + 'answer': 'Sorry, I don t know.' + }, 'documents': [], } - rag_chain.rag_guard(inputs, response, documents_required=True) + rag_chain.rag_guard(question, LLMAnswer(**response['answer']), response, documents_required=True) assert response['documents'] == [] + @patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_without_no_answer_input(mocked_log): """Test that __rag_guard handles missing no_answer input correctly.""" - inputs = {} # No 'no_answer' key + question = 'Hi!' response = { - 'answer': 'some answer', + 'answer': { + 'status': 'found_in_context', + 'answer': 'a valid answer', + }, 'documents': [], } with pytest.raises(GenAIGuardCheckException) as exc: - rag_chain.rag_guard(inputs, response, documents_required=True) + rag_chain.rag_guard(question, LLMAnswer(**response['answer']), response, documents_required=True) mocked_log.assert_called_once() - assert isinstance(exc.value, GenAIGuardCheckException) \ No newline at end of file + assert isinstance(exc.value, GenAIGuardCheckException)