Skip to content

Commit 75f7c08

Browse files
committed
[DERCBOT-1037] Use of PromptTemplate + Rewrite the RAG chain using LCEL
1 parent 0d6289d commit 75f7c08

File tree

15 files changed

+348
-639
lines changed

15 files changed

+348
-639
lines changed

bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@ import ai.tock.bot.engine.action.SendSentence
3030
import ai.tock.bot.engine.action.SendSentenceWithFootnotes
3131
import ai.tock.bot.engine.dialog.Dialog
3232
import ai.tock.bot.engine.user.PlayerType
33-
import ai.tock.genai.orchestratorclient.requests.ChatMessage
34-
import ai.tock.genai.orchestratorclient.requests.ChatMessageType
35-
import ai.tock.genai.orchestratorclient.requests.DialogDetails
36-
import ai.tock.genai.orchestratorclient.requests.RAGQuery
33+
import ai.tock.genai.orchestratorclient.requests.*
3734
import ai.tock.genai.orchestratorclient.responses.ObservabilityInfo
3835
import ai.tock.genai.orchestratorclient.responses.RAGResponse
3936
import ai.tock.genai.orchestratorclient.responses.TextWithFootnotes
@@ -189,10 +186,14 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler {
189186
)
190187
),
191188
questionAnsweringLlmSetting = ragConfiguration.llmSetting,
192-
questionAnsweringPromptInputs = mapOf(
193-
"question" to action.toString(),
194-
"locale" to userPreferences.locale.displayLanguage,
195-
"no_answer" to ragConfiguration.noAnswerSentence
189+
questionAnsweringPrompt = PromptTemplate(
190+
formatter = Formatter.F_STRING.id,
191+
template = ragConfiguration.llmSetting.prompt,
192+
inputs = mapOf(
193+
"question" to action.toString(),
194+
"locale" to userPreferences.locale.displayLanguage,
195+
"no_answer" to ragConfiguration.noAnswerSentence
196+
)
196197
),
197198
embeddingQuestionEmSetting = ragConfiguration.emSetting,
198199
documentIndexName = indexName,

gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ import ai.tock.genai.orchestratorcore.models.vectorstore.VectorStoreSetting
2424

2525
data class RAGQuery(
2626
// val condenseQuestionLlmSetting: LLMSetting,
27-
// val condenseQuestionPromptInputs: Map<String, String>,
27+
// val condenseQuestionPrompt: PromptTemplate,
2828
val dialog: DialogDetails?,
2929
val questionAnsweringLlmSetting: LLMSetting,
30-
val questionAnsweringPromptInputs: Map<String, String>,
30+
val questionAnsweringPrompt: PromptTemplate,
3131
val embeddingQuestionEmSetting: EMSetting,
3232
val documentIndexName: String,
3333
val documentSearchParams: DocumentSearchParamsBase,

gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,3 @@ class BaseLLMSetting(BaseModel):
3939
ge=0,
4040
le=2,
4141
)
42-
prompt: str = Field(
43-
description='The prompt to generate completions for.',
44-
examples=['How to learn to ride a bike without wheels!'],
45-
min_length=1,
46-
)

gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ class BaseQuery(BaseModel):
8585
observability_setting: Optional[ObservabilitySetting] = Field(
8686
description='The observability settings.', default=None
8787
)
88+
compressor_setting: Optional[DocumentCompressorSetting] = Field(
89+
description='Compressor settings, to rerank relevant documents returned by retriever.',
90+
default=None,
91+
)
8892

8993

9094
class QAQuery(BaseQuery):
@@ -159,43 +163,20 @@ class RagQuery(BaseQuery):
159163
"""The RAG query model"""
160164

161165
dialog: Optional[DialogDetails] = Field(description='The user dialog details.')
162-
question_answering_prompt_inputs: Any = Field(
163-
description='Key-value inputs for the llm prompt when used as a template. Please note that the '
164-
'chat_history field must not be specified here, it will be override by the dialog.history field',
165-
)
166166
# condense_question_llm_setting: LLMSetting =
167167
# Field(description="LLM setting, used to condense the user's question.")
168-
# condense_question_prompt_inputs: Any = (
169-
# Field(
170-
# description='Key-value inputs for the condense question llm prompt, when used as a template.',
171-
# ),
168+
# condense_question_prompt: PromptTemplate = Field(
169+
# description='Prompt template, used to create a prompt with inputs for jinja and fstring format'
172170
# )
173171
question_answering_llm_setting: LLMSetting = Field(
174172
description='LLM setting, used to perform a QA Prompt.'
175173
)
176-
question_answering_prompt_inputs: Any = Field(
177-
description='Key-value inputs for the llm prompt when used as a template. Please note that the '
178-
'chat_history field must not be specified here, it will be override by the dialog.history field',
179-
)
180-
embedding_question_em_setting: EMSetting = Field(
181-
description="Embedding model setting, used to calculate the user's question vector."
182-
)
183-
document_index_name: str = Field(
184-
description='Index name corresponding to a document collection in the vector database.',
185-
)
186-
document_search_params: DocumentSearchParams = Field(
187-
description='The document search parameters. Ex: number of documents, metadata filter',
188-
)
189-
observability_setting: Optional[ObservabilitySetting] = Field(
190-
description='The observability settings.', default=None
174+
question_answering_prompt : PromptTemplate = Field(
175+
description='Prompt template, used to create a prompt with inputs for jinja and fstring format'
191176
)
192177
guardrail_setting: Optional[GuardrailSetting] = Field(
193178
description='Guardrail settings, to classify LLM output toxicity.', default=None
194179
)
195-
compressor_setting: Optional[DocumentCompressorSetting] = Field(
196-
description='Compressor settings, to rerank relevant documents returned by retriever.',
197-
default=None,
198-
)
199180
documents_required: Optional[bool] = Field(
200181
description='Specifies whether the presence of documents is mandatory for generating answers. '
201182
'If set to True, the system will only provide answers when relevant documents are found. '
@@ -223,7 +204,11 @@ class RagQuery(BaseQuery):
223204
'secret': 'ab7***************************A1IV4B',
224205
},
225206
'temperature': 1.2,
226-
'prompt': """Use the following context to answer the question at the end.
207+
'model': 'gpt-3.5-turbo',
208+
},
209+
'question_answering_prompt': {
210+
'formatter': 'f-string',
211+
'template': """Use the following context to answer the question at the end.
227212
If you don't know the answer, just say {no_answer}.
228213
229214
Context:
@@ -233,12 +218,11 @@ class RagQuery(BaseQuery):
233218
{question}
234219
235220
Answer in {locale}:""",
236-
'model': 'gpt-3.5-turbo',
237-
},
238-
'question_answering_prompt_inputs': {
239-
'question': 'How to get started playing guitar ?',
240-
'no_answer': "Sorry, I don't know.",
241-
'locale': 'French',
221+
'inputs': {
222+
'question': 'How to get started playing guitar ?',
223+
'no_answer': 'Sorry, I don t know.',
224+
'locale': 'French',
225+
}
242226
},
243227
'embedding_question_em_setting': {
244228
'provider': 'OpenAI',

gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,14 @@
1616

1717
import logging
1818
import time
19-
from typing import Optional
2019

21-
from jinja2 import Template, TemplateError
2220
from langchain_core.output_parsers import NumberedListOutputParser
2321
from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate
24-
from langchain_core.runnables import RunnableConfig
2522

26-
from gen_ai_orchestrator.errors.exceptions.exceptions import (
27-
GenAIPromptTemplateException,
28-
)
2923
from gen_ai_orchestrator.errors.handlers.openai.openai_exception_handler import (
3024
openai_exception_handler,
3125
)
32-
from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo
3326
from gen_ai_orchestrator.models.observability.observability_trace import ObservabilityTrace
34-
from gen_ai_orchestrator.models.prompt.prompt_formatter import PromptFormatter
35-
from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate
3627
from gen_ai_orchestrator.routers.requests.requests import (
3728
SentenceGenerationQuery,
3829
)
@@ -42,6 +33,7 @@
4233
from gen_ai_orchestrator.services.langchain.factories.langchain_factory import (
4334
get_llm_factory, create_observability_callback_handler,
4435
)
36+
from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template
4537

4638
logger = logging.getLogger(__name__)
4739

@@ -90,29 +82,3 @@ async def generate_and_split_sentences(
9082
)
9183

9284
return SentenceGenerationResponse(sentences=sentences)
93-
94-
95-
def validate_prompt_template(prompt: PromptTemplate):
96-
"""
97-
Prompt template validation
98-
99-
Args:
100-
prompt: The prompt template
101-
102-
Returns:
103-
Nothing.
104-
Raises:
105-
GenAIPromptTemplateException: if template is incorrect
106-
"""
107-
if PromptFormatter.JINJA2 == prompt.formatter:
108-
try:
109-
Template(prompt.template).render(prompt.inputs)
110-
except TemplateError as exc:
111-
logger.error('Prompt completion - template validation failed!')
112-
logger.error(exc)
113-
raise GenAIPromptTemplateException(
114-
ErrorInfo(
115-
error=exc.__class__.__name__,
116-
cause=str(exc),
117-
)
118-
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (C) 2023-2024 Credit Mutuel Arkea
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Retriever callback handler for LangChain."""
16+
17+
import logging
18+
from typing import Any, Dict, Optional
19+
20+
from langchain.callbacks.base import BaseCallbackHandler
21+
from langchain_core.messages import SystemMessage, AIMessage
22+
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class RAGCallbackHandler(BaseCallbackHandler):
28+
"""Customized RAG callback handler that retrieves data from the chain execution."""
29+
30+
records: Dict[str, Any] = {
31+
'chat_prompt': None,
32+
'chat_chain_output': None,
33+
'rag_prompt': None,
34+
'rag_chain_output': None,
35+
'documents': None,
36+
}
37+
38+
def on_chain_start(
39+
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
40+
) -> None:
41+
"""Print out that we are entering a chain."""
42+
43+
if kwargs['name'] == 'chat_chain_output' and isinstance(inputs, AIMessage):
44+
self.records['chat_chain_output'] = inputs.content
45+
46+
if kwargs['name'] == 'rag_chain_output' and isinstance(inputs, AIMessage):
47+
self.records['rag_chain_output'] = inputs.content
48+
49+
if kwargs['name'] == 'RunnableAssign<answer>' and 'documents' in inputs:
50+
self.records['documents'] = inputs['documents']
51+
52+
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
53+
"""Print out that we finished a chain.""" # if outputs is instance of StringPromptValue
54+
55+
if isinstance(outputs, ChatPromptValue):
56+
self.records['chat_prompt'] = next(
57+
(msg.content for msg in outputs.messages if isinstance(msg, SystemMessage)), None
58+
)
59+
60+
if isinstance(outputs, StringPromptValue):
61+
self.records['rag_prompt'] = outputs.text

0 commit comments

Comments
 (0)