Skip to content

Commit cc6f1c4

Browse files
Benviiassouktimscezen
authored
[DERCBOT-1425] RAGAS Evaluation + Contextualization + Improvements (#1900)
* [DERCBOT-1425] RAGAS Evaluation + Contextualization + Improvements [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip [DERCBOT-1425] wip test wip ok Improvement : Optional NoRagStat metric to Excel export (#376) * add optional NoRagStat metric to Excel export * fix Feature : Modular settings export (#377) * partial settings export * Update models.py Co-authored-by: Benjamin BERNARD <[email protected]> * Exclude keys example settings * Rename export_settings_exclude_keys.example.jsonc to export_settings_exclude_keys.example.json Typo in file name was named ".jsonc" no ".json". --------- Co-authored-by: Benjamin BERNARD <[email protected]> :wrench: [Gen AI Tooling] Make the script executable outside of Pycharm. * Fix Smarttribune Script + .env * Fix * :fix: Bug na present in question input dataframes * :fix: Small fixes --------- Co-authored-by: Mohamed ASSOUKTI <[email protected]> Co-authored-by: scezen <[email protected]>
1 parent 1c66b83 commit cc6f1c4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2868
-1502
lines changed

gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/factories/langchain_factory.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,7 @@
160160
from gen_ai_orchestrator.services.langchain.factories.vector_stores.vector_store_factory import (
161161
LangChainVectorStoreFactory,
162162
)
163-
from gen_ai_orchestrator.utils.secret_manager.secret_manager_service import (
164-
vector_store_credentials,
165-
)
163+
from gen_ai_orchestrator.utils.secret_manager.secret_manager_service import fetch_default_vector_store_credentials
166164

167165
logger = logging.getLogger(__name__)
168166

@@ -237,14 +235,15 @@ def get_vector_store_factory(
237235
The LangChain Vector Store Factory, or raise an exception otherwise
238236
"""
239237
logger.info('Get Vector Store Factory for the given setting')
238+
vector_store_credentials = fetch_default_vector_store_credentials()
240239

241240
# Helper function to create OpenSearchFactory
242241
def create_opensearch_factory(
243-
vs_setting: Optional[OpenSearchVectorStoreSetting],
242+
vs_setting: Optional[OpenSearchVectorStoreSetting],
244243
) -> OpenSearchFactory:
245244
return OpenSearchFactory(
246245
setting=vs_setting
247-
or OpenSearchVectorStoreSetting(
246+
or OpenSearchVectorStoreSetting(
248247
host=application_settings.vector_store_host,
249248
port=application_settings.vector_store_port,
250249
username=vector_store_credentials.username,
@@ -256,11 +255,12 @@ def create_opensearch_factory(
256255

257256
# Helper function to create PGVectorFactory
258257
def create_pgvector_factory(
259-
vs_setting: Optional[PGVectorStoreSetting],
258+
vs_setting: Optional[PGVectorStoreSetting],
260259
) -> PGVectorFactory:
260+
vector_store_credentials = fetch_default_vector_store_credentials()
261261
return PGVectorFactory(
262262
setting=vs_setting
263-
or PGVectorStoreSetting(
263+
or PGVectorStoreSetting(
264264
host=application_settings.vector_store_host,
265265
port=application_settings.vector_store_port,
266266
username=vector_store_credentials.username,
@@ -277,8 +277,8 @@ def create_pgvector_factory(
277277

278278
# Validate required default settings
279279
if (
280-
application_settings.vector_store_provider is None
281-
or vector_store_credentials is None
280+
application_settings.vector_store_provider is None
281+
or vector_store_credentials is None
282282
):
283283
logger.error('No default Vector Store defined!')
284284
raise GenAIUnknownVectorStoreProviderSettingException()

gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py

Lines changed: 106 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2023-2024 Credit Mutuel Arkea
1+
# Copyright (C) 2023-2025 Credit Mutuel Arkea
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -23,18 +23,23 @@
2323
from logging import ERROR, WARNING
2424
from typing import List, Optional
2525

26-
from gen_ai_orchestrator.services.observability.observabilty_service import get_observability_info
2726
from langchain.chains.conversational_retrieval.base import (
2827
ConversationalRetrievalChain,
2928
)
3029
from langchain.retrievers.contextual_compression import (
3130
ContextualCompressionRetriever,
3231
)
3332
from langchain_community.chat_message_histories import ChatMessageHistory
33+
from langchain_core.callbacks import BaseCallbackHandler
3434
from langchain_core.documents import Document
3535
from langchain_core.output_parsers import StrOutputParser
36-
from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate, ChatPromptTemplate, MessagesPlaceholder
37-
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableSerializable
36+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
37+
from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate
38+
from langchain_core.runnables import (
39+
RunnableParallel,
40+
RunnablePassthrough,
41+
RunnableSerializable,
42+
)
3843
from langchain_core.vectorstores import VectorStoreRetriever
3944
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
4045
from typing_extensions import Any
@@ -48,7 +53,9 @@
4853
from gen_ai_orchestrator.errors.handlers.opensearch.opensearch_exception_handler import (
4954
opensearch_exception_handler,
5055
)
51-
from gen_ai_orchestrator.models.document_compressor.document_compressor_setting import BaseDocumentCompressorSetting
56+
from gen_ai_orchestrator.models.document_compressor.document_compressor_setting import (
57+
BaseDocumentCompressorSetting,
58+
)
5259
from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo
5360
from gen_ai_orchestrator.models.observability.observability_trace import (
5461
ObservabilityTrace,
@@ -64,10 +71,13 @@
6471
TextWithFootnotes,
6572
)
6673
from gen_ai_orchestrator.routers.requests.requests import RAGRequest
74+
from gen_ai_orchestrator.routers.responses.responses import (
75+
ObservabilityInfo,
76+
RAGResponse,
77+
)
6778
from gen_ai_orchestrator.services.langchain.callbacks.rag_callback_handler import (
6879
RAGCallbackHandler,
6980
)
70-
from gen_ai_orchestrator.routers.responses.responses import RAGResponse, ObservabilityInfo
7181
from gen_ai_orchestrator.services.langchain.factories.langchain_factory import (
7282
create_observability_callback_handler,
7383
get_compressor_factory,
@@ -76,20 +86,30 @@
7686
get_llm_factory,
7787
get_vector_store_factory,
7888
)
79-
from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template
89+
from gen_ai_orchestrator.services.observability.observabilty_service import (
90+
get_observability_info,
91+
)
92+
from gen_ai_orchestrator.services.utils.prompt_utility import (
93+
validate_prompt_template,
94+
)
8095

8196
logger = logging.getLogger(__name__)
8297

8398

8499
@opensearch_exception_handler
85100
@openai_exception_handler(provider='OpenAI or AzureOpenAIService')
86-
async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
101+
async def execute_rag_chain(
102+
request: RAGRequest,
103+
debug: bool,
104+
custom_observability_handler: Optional[BaseCallbackHandler] = None,
105+
) -> RAGResponse:
87106
"""
88107
RAG chain execution, using the LLM and Embedding settings specified in the request
89108
90109
Args:
91110
request: The RAG request
92111
debug: True if RAG data debug should be returned with the response.
112+
custom_observability_handler: Custom observability handler
93113
Returns:
94114
The RAG response (Answer and document sources)
95115
"""
@@ -109,17 +129,18 @@ async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
109129
message_history.add_user_message(msg.text)
110130
else:
111131
message_history.add_ai_message(msg.text)
112-
session_id = request.dialog.dialog_id,
113-
user_id = request.dialog.user_id,
114-
tags = request.dialog.tags,
132+
session_id = (request.dialog.dialog_id,)
133+
user_id = (request.dialog.user_id,)
134+
tags = (request.dialog.tags,)
115135

116136
logger.debug(
117-
'RAG chain - Use chat history: %s', 'Yes' if len(message_history.messages) > 0 else 'No'
137+
'RAG chain - Use chat history: %s',
138+
'Yes' if len(message_history.messages) > 0 else 'No',
118139
)
119140

120141
inputs = {
121142
**request.question_answering_prompt.inputs,
122-
'chat_history': message_history.messages
143+
'chat_history': message_history.messages,
123144
}
124145

125146
logger.debug(
@@ -133,6 +154,8 @@ async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
133154
if debug:
134155
# Debug callback handler
135156
callback_handlers.append(records_callback_handler)
157+
if custom_observability_handler is not None:
158+
callback_handlers.append(custom_observability_handler)
136159
if request.observability_setting is not None:
137160
# Langfuse callback handler
138161
observability_handler = create_observability_callback_handler(
@@ -154,7 +177,9 @@ async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
154177

155178
# Guardrail
156179
if request.guardrail_setting:
157-
guardrail = get_guardrail_factory(setting=request.guardrail_setting).get_parser()
180+
guardrail = get_guardrail_factory(
181+
setting=request.guardrail_setting
182+
).get_parser()
158183
guardrail_output = guardrail.parse(response['answer'])
159184
check_guardrail_output(guardrail_output)
160185

@@ -173,20 +198,19 @@ async def execute_rag_chain(request: RAGRequest, debug: bool) -> RAGResponse:
173198
title=doc.metadata['title'],
174199
url=doc.metadata['source'],
175200
content=get_source_content(doc),
176-
score=doc.metadata.get('retriever_score', None)
201+
score=doc.metadata.get('retriever_score', None),
177202
),
178203
response['documents'],
179204
)
180205
),
181206
),
182207
observability_info=get_observability_info(observability_handler),
183-
debug=get_rag_debug_data(
184-
request, records_callback_handler, rag_duration
185-
)
208+
debug=get_rag_debug_data(request, records_callback_handler, rag_duration)
186209
if debug
187210
else None,
188211
)
189212

213+
190214
def get_source_content(doc: Document) -> str:
191215
"""
192216
Find and delete the title followed by two line breaks
@@ -203,8 +227,9 @@ def get_source_content(doc: Document) -> str:
203227
return doc.page_content
204228

205229

206-
def create_rag_chain(request: RAGRequest, vector_db_async_mode: Optional[bool] = True) -> RunnableSerializable[
207-
Any, dict[str, Any]]:
230+
def create_rag_chain(
231+
request: RAGRequest, vector_db_async_mode: Optional[bool] = True
232+
) -> RunnableSerializable[Any, dict[str, Any]]:
208233
"""
209234
Create the RAG chain from RAGRequest, using the LLM and Embedding settings specified in the request.
210235
@@ -217,14 +242,22 @@ def create_rag_chain(request: RAGRequest, vector_db_async_mode: Optional[bool] =
217242

218243
# Log progress and validate prompt template
219244
logger.info('RAG chain - Validating LLM prompt template')
220-
validate_prompt_template(request.question_answering_prompt, 'Question answering prompt')
245+
validate_prompt_template(
246+
request.question_answering_prompt, 'Question answering prompt'
247+
)
221248
if request.question_condensing_prompt is not None:
222-
validate_prompt_template(request.question_condensing_prompt, 'Question condensing prompt')
249+
validate_prompt_template(
250+
request.question_condensing_prompt, 'Question condensing prompt'
251+
)
223252

224253
question_condensing_llm_factory = None
225254
if request.question_condensing_llm_setting is not None:
226-
question_condensing_llm_factory = get_llm_factory(setting=request.question_condensing_llm_setting)
227-
question_answering_llm_factory = get_llm_factory(setting=request.question_answering_llm_setting)
255+
question_condensing_llm_factory = get_llm_factory(
256+
setting=request.question_condensing_llm_setting
257+
)
258+
question_answering_llm_factory = get_llm_factory(
259+
setting=request.question_answering_llm_setting
260+
)
228261
em_factory = get_em_factory(setting=request.embedding_question_em_setting)
229262
vector_store_factory = get_vector_store_factory(
230263
setting=request.vector_store_setting,
@@ -234,7 +267,7 @@ def create_rag_chain(request: RAGRequest, vector_db_async_mode: Optional[bool] =
234267

235268
retriever = vector_store_factory.get_vector_store_retriever(
236269
search_kwargs=request.document_search_params.to_dict(),
237-
async_mode=vector_db_async_mode
270+
async_mode=vector_db_async_mode,
238271
)
239272
if request.compressor_setting:
240273
retriever = add_document_compressor(retriever, request.compressor_setting)
@@ -254,17 +287,20 @@ def create_rag_chain(request: RAGRequest, vector_db_async_mode: Optional[bool] =
254287

255288
# Build the chat chain for question contextualization
256289
chat_chain = build_question_condensation_chain(
257-
question_condensing_llm if question_condensing_llm is not None else question_answering_llm,
258-
request.question_condensing_prompt)
290+
question_condensing_llm
291+
if question_condensing_llm is not None
292+
else question_answering_llm,
293+
request.question_condensing_prompt,
294+
)
259295

260296
# Function to contextualize the question based on chat history
261297
contextualize_question_fn = partial(contextualize_question, chat_chain=chat_chain)
262298

263299
# Final RAG chain with retriever and source documents
264300
rag_chain_with_retriever = (
265-
contextualize_question_fn |
266-
RunnableParallel( {"documents": retriever, "question": RunnablePassthrough()} ) |
267-
RunnablePassthrough.assign(answer=rag_chain)
301+
contextualize_question_fn
302+
| RunnableParallel({'documents': retriever, 'question': RunnablePassthrough()})
303+
| RunnablePassthrough.assign(answer=rag_chain)
268304
)
269305

270306
return rag_chain_with_retriever
@@ -277,44 +313,66 @@ def build_rag_prompt(request: RAGRequest) -> LangChainPromptTemplate:
277313
return LangChainPromptTemplate.from_template(
278314
template=request.question_answering_prompt.template,
279315
template_format=request.question_answering_prompt.formatter.value,
280-
partial_variables=request.question_answering_prompt.inputs
316+
partial_variables=request.question_answering_prompt.inputs,
281317
)
282318

319+
283320
def construct_rag_chain(llm, rag_prompt):
284321
"""
285322
Construct the RAG chain from LLM and prompt.
286323
"""
287-
return {
288-
"context": lambda inputs: "\n\n".join(doc.page_content for doc in inputs["documents"]),
289-
"question": lambda inputs: inputs["question"] # Override the user's original question with the condensed one
290-
} | rag_prompt | llm | StrOutputParser(name="rag_chain_output")
324+
return (
325+
{
326+
'context': lambda inputs: '\n\n'.join(
327+
doc.page_content for doc in inputs['documents']
328+
),
329+
'question': lambda inputs: inputs[
330+
'question'
331+
], # Override the user's original question with the condensed one
332+
}
333+
| rag_prompt
334+
| llm
335+
| StrOutputParser(name='rag_chain_output')
336+
)
337+
291338

292-
def build_question_condensation_chain(llm, prompt: Optional[PromptTemplate]) -> ChatPromptTemplate:
339+
def build_question_condensation_chain(
340+
llm, prompt: Optional[PromptTemplate]
341+
) -> ChatPromptTemplate:
293342
"""
294343
Build the chat chain for contextualizing questions.
295344
"""
296345
if prompt is None:
297346
# Default prompt
298347
prompt = PromptTemplate(
299-
formatter = PromptFormatter.F_STRING, inputs = {},
300-
template = "Given a chat history and the latest user question which might reference context in \
348+
formatter=PromptFormatter.F_STRING,
349+
inputs={},
350+
template='Given a chat history and the latest user question which might reference context in \
301351
the chat history, formulate a standalone question which can be understood without the chat history. \
302-
Do NOT answer the question, just reformulate it if needed and otherwise return it as is.",
352+
Do NOT answer the question, just reformulate it if needed and otherwise return it as is.',
303353
)
304354

305-
return ChatPromptTemplate.from_messages([
306-
("system", prompt.template),
307-
MessagesPlaceholder(variable_name="chat_history"),
308-
("human", "{question}"),
309-
]).partial(**prompt.inputs) | llm | StrOutputParser(name="chat_chain_output")
355+
return (
356+
ChatPromptTemplate.from_messages(
357+
[
358+
('system', prompt.template),
359+
MessagesPlaceholder(variable_name='chat_history'),
360+
('human', '{question}'),
361+
]
362+
).partial(**prompt.inputs)
363+
| llm
364+
| StrOutputParser(name='chat_chain_output')
365+
)
366+
310367

311368
def contextualize_question(inputs: dict, chat_chain) -> str:
312369
"""
313370
Contextualize the question based on the chat history.
314371
"""
315-
if inputs.get("chat_history") and len(inputs["chat_history"]) > 0:
372+
if inputs.get('chat_history') and len(inputs['chat_history']) > 0:
316373
return chat_chain
317-
return inputs["question"]
374+
return inputs['question']
375+
318376

319377
def rag_guard(inputs, response, documents_required):
320378
"""
@@ -387,15 +445,15 @@ def get_rag_documents(handler: RAGCallbackHandler) -> List[RAGDocument]:
387445
return [
388446
# Get first 100 char of content
389447
RAGDocument(
390-
content=doc.page_content[0:len(doc.metadata['title'])+100] + '...',
448+
content=doc.page_content[0 : len(doc.metadata['title']) + 100] + '...',
391449
metadata=RAGDocumentMetadata(**doc.metadata),
392450
)
393451
for doc in handler.records['documents']
394452
]
395453

396454

397455
def get_rag_debug_data(
398-
request: RAGRequest, records_callback_handler: RAGCallbackHandler, rag_duration
456+
request: RAGRequest, records_callback_handler: RAGCallbackHandler, rag_duration
399457
) -> RAGDebugData:
400458
"""RAG debug data assembly"""
401459

0 commit comments

Comments
 (0)