Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,13 @@
import java.util.Map;
import java.util.stream.Collectors;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponseStreamUtils;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
Expand All @@ -53,7 +49,7 @@
* @author Thomas Vitale
* @since 1.0.0
*/
public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
public class QuestionAnswerAdvisor implements BaseAdvisor {

public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";

Expand All @@ -80,198 +76,96 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv

private final SearchRequest searchRequest;

private final boolean protectFromBlocking;
private final Scheduler scheduler;

private final int order;

/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
*/
public QuestionAnswerAdvisor(VectorStore vectorStore) {
this(vectorStore, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER);
}

/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
*/
@Deprecated
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) {
this(vectorStore, searchRequest, DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER);
}

/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
*/
@Deprecated
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), true,
this(vectorStore, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, BaseAdvisor.DEFAULT_SCHEDULER,
DEFAULT_ORDER);
}

/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
* @param protectFromBlocking If true the advisor will protect the execution from
* blocking threads. If false the advisor will not protect the execution from blocking
* threads. This is useful when the advisor is used in a non-blocking environment. It
* is true by default.
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
*/
@Deprecated
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
boolean protectFromBlocking) {
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking,
DEFAULT_ORDER);
}

/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
* @param protectFromBlocking If true the advisor will protect the execution from
* blocking threads. If false the advisor will not protect the execution from blocking
* threads. This is useful when the advisor is used in a non-blocking environment. It
* is true by default.
* @param order The order of the advisor.
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
*/
@Deprecated
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
boolean protectFromBlocking, int order) {
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking,
order);
}

QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, @Nullable PromptTemplate promptTemplate,
boolean protectFromBlocking, int order) {
@Nullable Scheduler scheduler, int order) {
Assert.notNull(vectorStore, "vectorStore cannot be null");
Assert.notNull(searchRequest, "searchRequest cannot be null");

this.vectorStore = vectorStore;
this.searchRequest = searchRequest;
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.protectFromBlocking = protectFromBlocking;
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
this.order = order;
}

public static Builder builder(VectorStore vectorStore) {
return new Builder(vectorStore);
}

@Override
public String getName() {
return this.getClass().getSimpleName();
}

@Override
public int getOrder() {
return this.order;
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {

AdvisedRequest advisedRequest2 = before(advisedRequest);

AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest2);

return after(advisedResponse);
}

@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {

// This can be executed by both blocking and non-blocking Threads
// E.g. a command line or Tomcat blocking Thread implementation
// or by a WebFlux dispatch in a non-blocking manner.
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
// @formatter:off
Mono.just(advisedRequest)
.publishOn(Schedulers.boundedElastic())
.map(this::before)
.flatMapMany(request -> chain.nextAroundStream(request))
: chain.nextAroundStream(before(advisedRequest));
// @formatter:on

return advisedResponses.map(ar -> {
if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
});
}

private AdvisedRequest before(AdvisedRequest request) {

var context = new HashMap<>(request.adviseContext());

public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
// 1. Search for similar documents in the vector store.
var searchRequestToUse = SearchRequest.from(this.searchRequest)
.query(request.userText())
.filterExpression(doGetFilterExpression(context))
.query(chatClientRequest.prompt().getUserMessage().getText())
.filterExpression(doGetFilterExpression(chatClientRequest.context()))
.build();

List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);

// 2. Create the context from the documents.
Map<String, Object> context = new HashMap<>(chatClientRequest.context());
context.put(RETRIEVED_DOCUMENTS, documents);

String documentContext = documents.stream()
.map(Document::getText)
.collect(Collectors.joining(System.lineSeparator()));
String documentContext = documents == null ? ""
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));

// 3. Augment the user prompt with the document context.
String augmentedUserText = this.promptTemplate.mutate()
.template(request.userText() + System.lineSeparator() + this.promptTemplate.getTemplate())
.template(chatClientRequest.prompt().getUserMessage().getText() + System.lineSeparator()
+ this.promptTemplate.getTemplate())
.variables(Map.of("question_answer_context", documentContext))
.build()
.render();

AdvisedRequest advisedRequest = AdvisedRequest.from(request)
.userText(augmentedUserText)
.adviseContext(context)
// 4. Update ChatClientRequest with augmented prompt.
return chatClientRequest.mutate()
.prompt(chatClientRequest.prompt().augmentUserMessage(augmentedUserText))
.context(context)
.build();

return advisedRequest;
}

private AdvisedResponse after(AdvisedResponse advisedResponse) {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
chatResponseBuilder.metadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
@Override
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
ChatResponse.Builder chatResponseBuilder;
if (chatClientResponse.chatResponse() == null) {
chatResponseBuilder = ChatResponse.builder();
}
else {
chatResponseBuilder = ChatResponse.builder().from(chatClientResponse.chatResponse());
}
chatResponseBuilder.metadata(RETRIEVED_DOCUMENTS, chatClientResponse.context().get(RETRIEVED_DOCUMENTS));
return ChatClientResponse.builder()
.chatResponse(chatResponseBuilder.build())
.context(chatClientResponse.context())
.build();
}

@Nullable
protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {

if (!context.containsKey(FILTER_EXPRESSION)
|| !StringUtils.hasText(context.get(FILTER_EXPRESSION).toString())) {
return this.searchRequest.getFilterExpression();
}
return new FilterExpressionTextParser().parse(context.get(FILTER_EXPRESSION).toString());
}

@Override
public Scheduler getScheduler() {
return this.scheduler;
}

public static final class Builder {
Expand All @@ -282,7 +176,7 @@ public static final class Builder {

private PromptTemplate promptTemplate;

private boolean protectFromBlocking = true;
private Scheduler scheduler;

private int order = DEFAULT_ORDER;

Expand All @@ -303,18 +197,13 @@ public Builder searchRequest(SearchRequest searchRequest) {
return this;
}

/**
* @deprecated in favour of {@link #promptTemplate(PromptTemplate)}
*/
@Deprecated
public Builder userTextAdvise(String userTextAdvise) {
Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");
this.promptTemplate = PromptTemplate.builder().template(userTextAdvise).build();
public Builder protectFromBlocking(boolean protectFromBlocking) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is superseded by the new Scheduler input argument. I kept it for backward compatibility. And I haven't marked it as deprecated since we're trying to avoid that for RC1. But this is a candidate to be eventually removed in future releases.

this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();
return this;
}

public Builder protectFromBlocking(boolean protectFromBlocking) {
this.protectFromBlocking = protectFromBlocking;
public Builder scheduler(Scheduler scheduler) {
this.scheduler = scheduler;
return this;
}

Expand All @@ -324,8 +213,8 @@ public Builder order(int order) {
}

public QuestionAnswerAdvisor build() {
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.promptTemplate,
this.protectFromBlocking, this.order);
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.promptTemplate, this.scheduler,
this.order);
}

}
Expand Down
Loading