Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,16 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;

/**
Expand All @@ -52,10 +44,17 @@
*/
public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<VectorStore> {

public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";

private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";

private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";

/**
* The default chat memory retrieve size to use when no retrieve size is provided.
*/
public static final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 100;

private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
{instructions}

Expand All @@ -69,71 +68,62 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect

private final PromptTemplate systemPromptTemplate;

private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) {
super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order);
protected final int defaultChatMemoryRetrieveSize;

public VectorStoreChatMemoryAdvisor(VectorStore chatMemory, String defaultConversationId,
int defaultChatMemoryRetrieveSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate,
int order) {
super(chatMemory, defaultConversationId, protectFromBlocking, order);
this.systemPromptTemplate = systemPromptTemplate;
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
}

public static Builder builder(VectorStore chatMemory) {
return new Builder(chatMemory);
}

@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
chatClientRequest = this.before(chatClientRequest);

ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);

this.after(chatClientResponse);

return chatClientResponse;
protected int doGetChatMemoryRetrieveSize(Map<String, Object> context) {
return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)
? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())
: this.defaultChatMemoryRetrieveSize;
}

@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
streamAdvisorChain, this::before);

return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
}

private ChatClientRequest before(ChatClientRequest chatClientRequest) {
String conversationId = this.doGetConversationId(chatClientRequest.context());
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());

// 1. Retrieve the chat memory for the current conversation.
var searchRequest = SearchRequest.builder()
.query(chatClientRequest.prompt().getUserMessage().getText())
.topK(chatMemoryRetrieveSize)
.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")
protected ChatClientRequest before(ChatClientRequest request, String conversationId) {
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
int topK = doGetChatMemoryRetrieveSize(request.context());
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
.query(query)
.topK(topK)
.filterExpression(filter)
.build();
java.util.List<org.springframework.ai.document.Document> documents = this.getChatMemoryStore()
.similaritySearch(searchRequest);

List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);

// 2. Processed memory messages as a string.
String longTermMemory = documents == null ? ""
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
: documents.stream()
.map(org.springframework.ai.document.Document::getText)
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));

// 2. Augment the system message.
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
String augmentedSystemText = this.systemPromptTemplate
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));

// 3. Create a new request with the augmented system message.
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
ChatClientRequest processedChatClientRequest = request.mutate()
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
.build();

// 4. Add the new user message to the conversation memory.
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
.getUserMessage();
if (userMessage != null) {
this.getChatMemoryStore().write(toDocuments(java.util.List.of(userMessage), conversationId));
}

return processedChatClientRequest;
}

private void after(ChatClientResponse chatClientResponse) {
protected void after(ChatClientResponse chatClientResponse) {
List<Message> assistantMessages = new ArrayList<>();
if (chatClientResponse.chatResponse() != null) {
assistantMessages = chatClientResponse.chatResponse()
Expand Down Expand Up @@ -173,28 +163,71 @@ else if (message instanceof AssistantMessage assistantMessage) {
return docs;
}

public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore> {
/**
* Builder for VectorStoreChatMemoryAdvisor.
*/
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore, Builder> {

private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;

protected Builder(VectorStore chatMemory) {
super(chatMemory);
private Integer defaultChatMemoryRetrieveSize = null;

/**
* Creates a new builder instance.
* @param vectorStore the vector store to use
*/
protected Builder(VectorStore vectorStore) {
super(vectorStore);
}

/**
* Set the system prompt template.
* @param systemPromptTemplate the system prompt template
* @return this builder
*/
public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
this.systemPromptTemplate = systemPromptTemplate;
return this;
}

/**
* Set the system prompt template using a text template.
* @param systemTextAdvise the system prompt text template
* @return this builder
*/
public Builder systemTextAdvise(String systemTextAdvise) {
this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);
return this;
}

public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
this.systemPromptTemplate = systemPromptTemplate;
/**
* Set the default chat memory retrieve size.
* @param defaultChatMemoryRetrieveSize the default chat memory retrieve size
* @return this builder
*/
public Builder defaultChatMemoryRetrieveSize(int defaultChatMemoryRetrieveSize) {
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
return this;
}

@Override
protected Builder self() {
return this;
}

@Override
public VectorStoreChatMemoryAdvisor build() {
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
this.protectFromBlocking, this.systemPromptTemplate, this.order);
if (defaultChatMemoryRetrieveSize == null) {
// Default to legacy mode for backward compatibility
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId,
DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, this.protectFromBlocking, this.systemPromptTemplate,
this.order);
}
else {
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId,
this.defaultChatMemoryRetrieveSize, this.protectFromBlocking, this.systemPromptTemplate,
this.order);
}
}

}
Expand Down
Loading