diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index 972aca72ef5..03065622cbe 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -32,6 +32,7 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; @@ -47,6 +48,7 @@ * user text. * * @author Christian Tzolov + * @author Timo Salm * @since 1.0.0 */ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { @@ -106,7 +108,7 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques * @param vectorStore The vector store to use * @param searchRequest The search request defined using the portable filter * expression syntax - * @param userTextAdvise the user text to append to the existing user prompt. The text + * @param userTextAdvise The user text to append to the existing user prompt. The text * should contain a placeholder named "question_answer_context". */ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) { @@ -119,9 +121,9 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques * @param vectorStore The vector store to use * @param searchRequest The search request defined using the portable filter * expression syntax - * @param userTextAdvise the user text to append to the existing user prompt. The text + * @param userTextAdvise The user text to append to the existing user prompt. The text * should contain a placeholder named "question_answer_context". - * @param protectFromBlocking if true the advisor will protect the execution from + * @param protectFromBlocking If true the advisor will protect the execution from * blocking threads. If false the advisor will not protect the execution from blocking * threads. This is useful when the advisor is used in a non-blocking environment. It * is true by default. @@ -137,13 +139,13 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques * @param vectorStore The vector store to use * @param searchRequest The search request defined using the portable filter * expression syntax - * @param userTextAdvise the user text to append to the existing user prompt. The text + * @param userTextAdvise The user text to append to the existing user prompt. The text * should contain a placeholder named "question_answer_context". - * @param protectFromBlocking if true the advisor will protect the execution from + * @param protectFromBlocking If true the advisor will protect the execution from * blocking threads. If false the advisor will not protect the execution from blocking * threads. This is useful when the advisor is used in a non-blocking environment. It * is true by default. - * @param order the order of the advisor. + * @param order The order of the advisor. */ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise, boolean protectFromBlocking, int order) { @@ -213,16 +215,17 @@ private AdvisedRequest before(AdvisedRequest request) { // 1. Advise the system text. String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise; + // 2. Search for similar documents in the vector store. + String query = new PromptTemplate(request.userText(), request.userParams()).render(); var searchRequestToUse = SearchRequest.from(this.searchRequest) - .withQuery(request.userText()) + .withQuery(query) .withFilterExpression(doGetFilterExpression(context)); - // 2. Search for similar documents in the vector store. List documents = this.vectorStore.similaritySearch(searchRequestToUse); + // 3. Create the context from the documents. context.put(RETRIEVED_DOCUMENTS, documents); - // 3. Create the context from the documents. String documentContext = documents.stream() .map(Content::getContent) .collect(Collectors.joining(System.lineSeparator())); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java index c63fa9ec206..bf03d06bf38 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java @@ -38,6 +38,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; @@ -48,6 +49,7 @@ /** * @author Christian Tzolov + * @author Timo Salm */ @ExtendWith(MockitoExtension.class) public class QuestionAnswerAdvisorTests { @@ -178,7 +180,63 @@ public Duration getTokensReset() { assertThat(this.vectorSearchCaptor.getValue().getFilterExpression()).isEqualTo(new FilterExpressionBuilder().eq("type", "Spring").build()); assertThat(this.vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d); assertThat(this.vectorSearchCaptor.getValue().getTopK()).isEqualTo(6); + } + + @Test + public void qaAdvisorTakesUserTextParametersIntoAccountForSimilaritySearch() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))), + ChatResponseMetadata.builder().build())); + + given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture())) + .willReturn(List.of(new Document("doc1"), new Document("doc2"))); + var chatClient = ChatClient.builder(this.chatModel).build(); + var qaAdvisor = new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()); + var userTextTemplate = "Please answer my question {question}"; + // @formatter:off + chatClient.prompt() + .user(u -> u.text(userTextTemplate).param("question", "XYZ")) + .advisors(qaAdvisor) + .call() + .chatResponse(); + //formatter:on + + var expectedQuery = "Please answer my question XYZ"; + var userPrompt = this.promptCaptor.getValue().getInstructions().get(0).getContent(); + assertThat(userPrompt).doesNotContain(userTextTemplate); + assertThat(userPrompt).contains(expectedQuery); + assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(expectedQuery); } + + @Test + public void qaAdvisorTakesUserParameterizedUserMessagesIntoAccountForSimilaritySearch() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))), + ChatResponseMetadata.builder().build())); + + given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture())) + .willReturn(List.of(new Document("doc1"), new Document("doc2"))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var qaAdvisor = new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()); + + var userTextTemplate = "Please answer my question {question}"; + var userPromptTemplate = new PromptTemplate(userTextTemplate, Map.of("question", "XYZ")); + var userMessage = userPromptTemplate.createMessage(); + // @formatter:off + chatClient.prompt(new Prompt(userMessage)) + .advisors(qaAdvisor) + .call() + .chatResponse(); + //formatter:on + + var expectedQuery = "Please answer my question XYZ"; + var userPrompt = this.promptCaptor.getValue().getInstructions().get(0).getContent(); + assertThat(userPrompt).doesNotContain(userTextTemplate); + assertThat(userPrompt).contains(expectedQuery); + assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(expectedQuery); + } + }