1- # Copyright (C) 2023-2024 Credit Mutuel Arkea
1+ # Copyright (C) 2023-2025 Credit Mutuel Arkea
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
2323from logging import ERROR , WARNING
2424from typing import List , Optional
2525
26- from gen_ai_orchestrator .services .observability .observabilty_service import get_observability_info
2726from langchain .chains .conversational_retrieval .base import (
2827 ConversationalRetrievalChain ,
2928)
3029from langchain .retrievers .contextual_compression import (
3130 ContextualCompressionRetriever ,
3231)
3332from langchain_community .chat_message_histories import ChatMessageHistory
33+ from langchain_core .callbacks import BaseCallbackHandler
3434from langchain_core .documents import Document
3535from langchain_core .output_parsers import StrOutputParser
36- from langchain_core .prompts import PromptTemplate as LangChainPromptTemplate , ChatPromptTemplate , MessagesPlaceholder
37- from langchain_core .runnables import RunnablePassthrough , RunnableParallel , RunnableSerializable
36+ from langchain_core .prompts import ChatPromptTemplate , MessagesPlaceholder
37+ from langchain_core .prompts import PromptTemplate as LangChainPromptTemplate
38+ from langchain_core .runnables import (
39+ RunnableParallel ,
40+ RunnablePassthrough ,
41+ RunnableSerializable ,
42+ )
3843from langchain_core .vectorstores import VectorStoreRetriever
3944from langfuse .callback import CallbackHandler as LangfuseCallbackHandler
4045from typing_extensions import Any
4853from gen_ai_orchestrator .errors .handlers .opensearch .opensearch_exception_handler import (
4954 opensearch_exception_handler ,
5055)
51- from gen_ai_orchestrator .models .document_compressor .document_compressor_setting import BaseDocumentCompressorSetting
56+ from gen_ai_orchestrator .models .document_compressor .document_compressor_setting import (
57+ BaseDocumentCompressorSetting ,
58+ )
5259from gen_ai_orchestrator .models .errors .errors_models import ErrorInfo
5360from gen_ai_orchestrator .models .observability .observability_trace import (
5461 ObservabilityTrace ,
6471 TextWithFootnotes ,
6572)
6673from gen_ai_orchestrator .routers .requests .requests import RAGRequest
74+ from gen_ai_orchestrator .routers .responses .responses import (
75+ ObservabilityInfo ,
76+ RAGResponse ,
77+ )
6778from gen_ai_orchestrator .services .langchain .callbacks .rag_callback_handler import (
6879 RAGCallbackHandler ,
6980)
70- from gen_ai_orchestrator .routers .responses .responses import RAGResponse , ObservabilityInfo
7181from gen_ai_orchestrator .services .langchain .factories .langchain_factory import (
7282 create_observability_callback_handler ,
7383 get_compressor_factory ,
7686 get_llm_factory ,
7787 get_vector_store_factory ,
7888)
79- from gen_ai_orchestrator .services .utils .prompt_utility import validate_prompt_template
89+ from gen_ai_orchestrator .services .observability .observabilty_service import (
90+ get_observability_info ,
91+ )
92+ from gen_ai_orchestrator .services .utils .prompt_utility import (
93+ validate_prompt_template ,
94+ )
8095
8196logger = logging .getLogger (__name__ )
8297
8398
8499@opensearch_exception_handler
85100@openai_exception_handler (provider = 'OpenAI or AzureOpenAIService' )
86- async def execute_rag_chain (request : RAGRequest , debug : bool ) -> RAGResponse :
101+ async def execute_rag_chain (
102+ request : RAGRequest ,
103+ debug : bool ,
104+ custom_observability_handler : Optional [BaseCallbackHandler ] = None ,
105+ ) -> RAGResponse :
87106 """
88107 RAG chain execution, using the LLM and Embedding settings specified in the request
89108
90109 Args:
91110 request: The RAG request
92111 debug: True if RAG data debug should be returned with the response.
112+ custom_observability_handler: Custom observability handler
93113 Returns:
94114 The RAG response (Answer and document sources)
95115 """
@@ -109,17 +129,18 @@ async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
109129 message_history .add_user_message (msg .text )
110130 else :
111131 message_history .add_ai_message (msg .text )
112- session_id = request .dialog .dialog_id ,
113- user_id = request .dialog .user_id ,
114- tags = request .dialog .tags ,
132+ session_id = ( request .dialog .dialog_id ,)
133+ user_id = ( request .dialog .user_id ,)
134+ tags = ( request .dialog .tags ,)
115135
116136 logger .debug (
117- 'RAG chain - Use chat history: %s' , 'Yes' if len (message_history .messages ) > 0 else 'No'
137+ 'RAG chain - Use chat history: %s' ,
138+ 'Yes' if len (message_history .messages ) > 0 else 'No' ,
118139 )
119140
120141 inputs = {
121142 ** request .question_answering_prompt .inputs ,
122- 'chat_history' : message_history .messages
143+ 'chat_history' : message_history .messages ,
123144 }
124145
125146 logger .debug (
@@ -133,6 +154,8 @@ async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
133154 if debug :
134155 # Debug callback handler
135156 callback_handlers .append (records_callback_handler )
157+ if custom_observability_handler is not None :
158+ callback_handlers .append (custom_observability_handler )
136159 if request .observability_setting is not None :
137160 # Langfuse callback handler
138161 observability_handler = create_observability_callback_handler (
@@ -154,7 +177,9 @@ async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
154177
155178 # Guardrail
156179 if request .guardrail_setting :
157- guardrail = get_guardrail_factory (setting = request .guardrail_setting ).get_parser ()
180+ guardrail = get_guardrail_factory (
181+ setting = request .guardrail_setting
182+ ).get_parser ()
158183 guardrail_output = guardrail .parse (response ['answer' ])
159184 check_guardrail_output (guardrail_output )
160185
@@ -173,20 +198,19 @@ async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
173198 title = doc .metadata ['title' ],
174199 url = doc .metadata ['source' ],
175200 content = get_source_content (doc ),
176- score = doc .metadata .get ('retriever_score' , None )
201+ score = doc .metadata .get ('retriever_score' , None ),
177202 ),
178203 response ['documents' ],
179204 )
180205 ),
181206 ),
182207 observability_info = get_observability_info (observability_handler ),
183- debug = get_rag_debug_data (
184- request , records_callback_handler , rag_duration
185- )
208+ debug = get_rag_debug_data (request , records_callback_handler , rag_duration )
186209 if debug
187210 else None ,
188211 )
189212
213+
190214def get_source_content (doc : Document ) -> str :
191215 """
192216 Find and delete the title followed by two line breaks
@@ -203,8 +227,9 @@ def get_source_content(doc: Document) -> str:
203227 return doc .page_content
204228
205229
206- def create_rag_chain (request : RAGRequest , vector_db_async_mode : Optional [bool ] = True ) -> RunnableSerializable [
207- Any , dict [str , Any ]]:
230+ def create_rag_chain (
231+ request : RAGRequest , vector_db_async_mode : Optional [bool ] = True
232+ ) -> RunnableSerializable [Any , dict [str , Any ]]:
208233 """
209234 Create the RAG chain from RAGRequest, using the LLM and Embedding settings specified in the request.
210235
@@ -217,14 +242,22 @@ def create_rag_chain(request: RAGRequest, vector_db_async_mode: Optional[bool] =
217242
218243 # Log progress and validate prompt template
219244 logger .info ('RAG chain - Validating LLM prompt template' )
220- validate_prompt_template (request .question_answering_prompt , 'Question answering prompt' )
245+ validate_prompt_template (
246+ request .question_answering_prompt , 'Question answering prompt'
247+ )
221248 if request .question_condensing_prompt is not None :
222- validate_prompt_template (request .question_condensing_prompt , 'Question condensing prompt' )
249+ validate_prompt_template (
250+ request .question_condensing_prompt , 'Question condensing prompt'
251+ )
223252
224253 question_condensing_llm_factory = None
225254 if request .question_condensing_llm_setting is not None :
226- question_condensing_llm_factory = get_llm_factory (setting = request .question_condensing_llm_setting )
227- question_answering_llm_factory = get_llm_factory (setting = request .question_answering_llm_setting )
255+ question_condensing_llm_factory = get_llm_factory (
256+ setting = request .question_condensing_llm_setting
257+ )
258+ question_answering_llm_factory = get_llm_factory (
259+ setting = request .question_answering_llm_setting
260+ )
228261 em_factory = get_em_factory (setting = request .embedding_question_em_setting )
229262 vector_store_factory = get_vector_store_factory (
230263 setting = request .vector_store_setting ,
@@ -234,7 +267,7 @@ def create_rag_chain(request: RAGRequest, vector_db_async_mode: Optional[bool] =
234267
235268 retriever = vector_store_factory .get_vector_store_retriever (
236269 search_kwargs = request .document_search_params .to_dict (),
237- async_mode = vector_db_async_mode
270+ async_mode = vector_db_async_mode ,
238271 )
239272 if request .compressor_setting :
240273 retriever = add_document_compressor (retriever , request .compressor_setting )
@@ -254,17 +287,20 @@ def create_rag_chain(request: RAGRequest, vector_db_async_mode: Optional[bool] =
254287
255288 # Build the chat chain for question contextualization
256289 chat_chain = build_question_condensation_chain (
257- question_condensing_llm if question_condensing_llm is not None else question_answering_llm ,
258- request .question_condensing_prompt )
290+ question_condensing_llm
291+ if question_condensing_llm is not None
292+ else question_answering_llm ,
293+ request .question_condensing_prompt ,
294+ )
259295
260296 # Function to contextualize the question based on chat history
261297 contextualize_question_fn = partial (contextualize_question , chat_chain = chat_chain )
262298
263299 # Final RAG chain with retriever and source documents
264300 rag_chain_with_retriever = (
265- contextualize_question_fn |
266- RunnableParallel ( { " documents" : retriever , " question" : RunnablePassthrough ()} ) |
267- RunnablePassthrough .assign (answer = rag_chain )
301+ contextualize_question_fn
302+ | RunnableParallel ({ ' documents' : retriever , ' question' : RunnablePassthrough ()})
303+ | RunnablePassthrough .assign (answer = rag_chain )
268304 )
269305
270306 return rag_chain_with_retriever
@@ -277,44 +313,66 @@ def build_rag_prompt(request: RAGRequest) -> LangChainPromptTemplate:
277313 return LangChainPromptTemplate .from_template (
278314 template = request .question_answering_prompt .template ,
279315 template_format = request .question_answering_prompt .formatter .value ,
280- partial_variables = request .question_answering_prompt .inputs
316+ partial_variables = request .question_answering_prompt .inputs ,
281317 )
282318
319+
283320def construct_rag_chain (llm , rag_prompt ):
284321 """
285322 Construct the RAG chain from LLM and prompt.
286323 """
287- return {
288- "context" : lambda inputs : "\n \n " .join (doc .page_content for doc in inputs ["documents" ]),
289- "question" : lambda inputs : inputs ["question" ] # Override the user's original question with the condensed one
290- } | rag_prompt | llm | StrOutputParser (name = "rag_chain_output" )
324+ return (
325+ {
326+ 'context' : lambda inputs : '\n \n ' .join (
327+ doc .page_content for doc in inputs ['documents' ]
328+ ),
329+ 'question' : lambda inputs : inputs [
330+ 'question'
331+ ], # Override the user's original question with the condensed one
332+ }
333+ | rag_prompt
334+ | llm
335+ | StrOutputParser (name = 'rag_chain_output' )
336+ )
337+
291338
292- def build_question_condensation_chain (llm , prompt : Optional [PromptTemplate ]) -> ChatPromptTemplate :
339+ def build_question_condensation_chain (
340+ llm , prompt : Optional [PromptTemplate ]
341+ ) -> ChatPromptTemplate :
293342 """
294343 Build the chat chain for contextualizing questions.
295344 """
296345 if prompt is None :
297346 # Default prompt
298347 prompt = PromptTemplate (
299- formatter = PromptFormatter .F_STRING , inputs = {},
300- template = "Given a chat history and the latest user question which might reference context in \
348+ formatter = PromptFormatter .F_STRING ,
349+ inputs = {},
350+ template = 'Given a chat history and the latest user question which might reference context in \
301351 the chat history, formulate a standalone question which can be understood without the chat history. \
302- Do NOT answer the question, just reformulate it if needed and otherwise return it as is." ,
352+ Do NOT answer the question, just reformulate it if needed and otherwise return it as is.' ,
303353 )
304354
305- return ChatPromptTemplate .from_messages ([
306- ("system" , prompt .template ),
307- MessagesPlaceholder (variable_name = "chat_history" ),
308- ("human" , "{question}" ),
309- ]).partial (** prompt .inputs ) | llm | StrOutputParser (name = "chat_chain_output" )
355+ return (
356+ ChatPromptTemplate .from_messages (
357+ [
358+ ('system' , prompt .template ),
359+ MessagesPlaceholder (variable_name = 'chat_history' ),
360+ ('human' , '{question}' ),
361+ ]
362+ ).partial (** prompt .inputs )
363+ | llm
364+ | StrOutputParser (name = 'chat_chain_output' )
365+ )
366+
310367
311368def contextualize_question (inputs : dict , chat_chain ) -> str :
312369 """
313370 Contextualize the question based on the chat history.
314371 """
315- if inputs .get (" chat_history" ) and len (inputs [" chat_history" ]) > 0 :
372+ if inputs .get (' chat_history' ) and len (inputs [' chat_history' ]) > 0 :
316373 return chat_chain
317- return inputs ["question" ]
374+ return inputs ['question' ]
375+
318376
319377def rag_guard (inputs , response , documents_required ):
320378 """
@@ -387,15 +445,15 @@ def get_rag_documents(handler: RAGCallbackHandler) -> List[RAGDocument]:
387445 return [
388446 # Get first 100 char of content
389447 RAGDocument (
390- content = doc .page_content [0 : len (doc .metadata ['title' ])+ 100 ] + '...' ,
448+ content = doc .page_content [0 : len (doc .metadata ['title' ]) + 100 ] + '...' ,
391449 metadata = RAGDocumentMetadata (** doc .metadata ),
392450 )
393451 for doc in handler .records ['documents' ]
394452 ]
395453
396454
397455def get_rag_debug_data (
398- request : RAGRequest , records_callback_handler : RAGCallbackHandler , rag_duration
456+ request : RAGRequest , records_callback_handler : RAGCallbackHandler , rag_duration
399457) -> RAGDebugData :
400458 """RAG debug data assembly"""
401459
0 commit comments