Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Content;
Expand All @@ -47,6 +48,7 @@
* user text.
*
* @author Christian Tzolov
* @author Timo Salm
* @since 1.0.0
*/
public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
Expand Down Expand Up @@ -106,7 +108,7 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise the user text to append to the existing user prompt. The text
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
*/
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
Expand All @@ -119,9 +121,9 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise the user text to append to the existing user prompt. The text
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
* @param protectFromBlocking if true the advisor will protect the execution from
* @param protectFromBlocking If true the advisor will protect the execution from
* blocking threads. If false the advisor will not protect the execution from blocking
* threads. This is useful when the advisor is used in a non-blocking environment. It
* is true by default.
Expand All @@ -137,13 +139,13 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise the user text to append to the existing user prompt. The text
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
* @param protectFromBlocking if true the advisor will protect the execution from
* @param protectFromBlocking If true the advisor will protect the execution from
* blocking threads. If false the advisor will not protect the execution from blocking
* threads. This is useful when the advisor is used in a non-blocking environment. It
* is true by default.
* @param order the order of the advisor.
* @param order The order of the advisor.
*/
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
boolean protectFromBlocking, int order) {
Expand Down Expand Up @@ -213,16 +215,17 @@ private AdvisedRequest before(AdvisedRequest request) {
// 1. Advise the system text.
String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;

// 2. Search for similar documents in the vector store.
String query = new PromptTemplate(request.userText(), request.userParams()).render();
var searchRequestToUse = SearchRequest.from(this.searchRequest)
.withQuery(request.userText())
.withQuery(query)
.withFilterExpression(doGetFilterExpression(context));

// 2. Search for similar documents in the vector store.
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);

// 3. Create the context from the documents.
context.put(RETRIEVED_DOCUMENTS, documents);

// 3. Create the context from the documents.
String documentContext = documents.stream()
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
Expand All @@ -48,6 +49,7 @@

/**
* @author Christian Tzolov
* @author Timo Salm
*/
@ExtendWith(MockitoExtension.class)
public class QuestionAnswerAdvisorTests {
Expand Down Expand Up @@ -178,7 +180,63 @@ public Duration getTokensReset() {
assertThat(this.vectorSearchCaptor.getValue().getFilterExpression()).isEqualTo(new FilterExpressionBuilder().eq("type", "Spring").build());
assertThat(this.vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d);
assertThat(this.vectorSearchCaptor.getValue().getTopK()).isEqualTo(6);
}

@Test
public void qaAdvisorTakesUserTextParametersIntoAccountForSimilaritySearch() {
given(this.chatModel.call(this.promptCaptor.capture()))
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))),
ChatResponseMetadata.builder().build()));

given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture()))
.willReturn(List.of(new Document("doc1"), new Document("doc2")));

var chatClient = ChatClient.builder(this.chatModel).build();
var qaAdvisor = new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults());

var userTextTemplate = "Please answer my question {question}";
// @formatter:off
chatClient.prompt()
.user(u -> u.text(userTextTemplate).param("question", "XYZ"))
.advisors(qaAdvisor)
.call()
.chatResponse();
//formatter:on

var expectedQuery = "Please answer my question XYZ";
var userPrompt = this.promptCaptor.getValue().getInstructions().get(0).getContent();
assertThat(userPrompt).doesNotContain(userTextTemplate);
assertThat(userPrompt).contains(expectedQuery);
assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(expectedQuery);
}

@Test
public void qaAdvisorTakesUserParameterizedUserMessagesIntoAccountForSimilaritySearch() {
given(this.chatModel.call(this.promptCaptor.capture()))
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))),
ChatResponseMetadata.builder().build()));

given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture()))
.willReturn(List.of(new Document("doc1"), new Document("doc2")));

var chatClient = ChatClient.builder(this.chatModel).build();
var qaAdvisor = new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults());

var userTextTemplate = "Please answer my question {question}";
var userPromptTemplate = new PromptTemplate(userTextTemplate, Map.of("question", "XYZ"));
var userMessage = userPromptTemplate.createMessage();
// @formatter:off
chatClient.prompt(new Prompt(userMessage))
.advisors(qaAdvisor)
.call()
.chatResponse();
//formatter:on

var expectedQuery = "Please answer my question XYZ";
var userPrompt = this.promptCaptor.getValue().getInstructions().get(0).getContent();
assertThat(userPrompt).doesNotContain(userTextTemplate);
assertThat(userPrompt).contains(expectedQuery);
assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(expectedQuery);
}

}