-
Notifications
You must be signed in to change notification settings - Fork 2k
Simplify chat memory advisor hierarchy and remove deprecated API #3121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4bcfbe0
1d335ed
cb8d68b
e57e62c
1266159
31ddb29
7ba87af
32b3905
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,24 +20,17 @@ | |
| import java.util.HashMap; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.stream.Collectors; | ||
|
|
||
| import reactor.core.publisher.Flux; | ||
|
|
||
| import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; | ||
| import org.springframework.ai.chat.client.ChatClientRequest; | ||
| import org.springframework.ai.chat.client.ChatClientResponse; | ||
| import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; | ||
| import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; | ||
| import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; | ||
| import org.springframework.ai.chat.client.advisor.api.AdvisorChain; | ||
| import org.springframework.ai.chat.messages.AssistantMessage; | ||
| import org.springframework.ai.chat.messages.Message; | ||
| import org.springframework.ai.chat.messages.MessageType; | ||
| import org.springframework.ai.chat.messages.SystemMessage; | ||
| import org.springframework.ai.chat.messages.UserMessage; | ||
| import org.springframework.ai.chat.model.MessageAggregator; | ||
| import org.springframework.ai.chat.prompt.PromptTemplate; | ||
| import org.springframework.ai.document.Document; | ||
| import org.springframework.ai.vectorstore.SearchRequest; | ||
| import org.springframework.ai.vectorstore.VectorStore; | ||
|
|
||
| /** | ||
|
|
@@ -48,14 +41,22 @@ | |
| * @author Christian Tzolov | ||
| * @author Thomas Vitale | ||
| * @author Oganes Bozoyan | ||
| * @author Mark Pollack | ||
| * @since 1.0.0 | ||
| */ | ||
| public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<VectorStore> { | ||
|
|
||
| public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size"; | ||
|
|
||
| private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId"; | ||
|
|
||
| private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType"; | ||
|
|
||
| /** | ||
| * The default chat memory retrieve size to use when no retrieve size is provided. | ||
| */ | ||
| public static final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 100; | ||
|
||
|
|
||
| private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate(""" | ||
| {instructions} | ||
|
|
||
|
|
@@ -69,71 +70,64 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect | |
|
|
||
| private final PromptTemplate systemPromptTemplate; | ||
|
|
||
| private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, | ||
| int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) { | ||
| super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order); | ||
| protected final int defaultChatMemoryRetrieveSize; | ||
|
|
||
| public VectorStoreChatMemoryAdvisor(VectorStore chatMemory, String defaultConversationId, | ||
| int defaultChatMemoryRetrieveSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, | ||
| int order) { | ||
| super(chatMemory, defaultConversationId, protectFromBlocking, order); | ||
| this.systemPromptTemplate = systemPromptTemplate; | ||
| this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize; | ||
| } | ||
|
|
||
| public static Builder builder(VectorStore chatMemory) { | ||
| return new Builder(chatMemory); | ||
| } | ||
|
|
||
| @Override | ||
| public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { | ||
| chatClientRequest = this.before(chatClientRequest); | ||
|
|
||
| ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest); | ||
|
|
||
| this.after(chatClientResponse); | ||
|
|
||
| return chatClientResponse; | ||
| } | ||
|
|
||
| @Override | ||
| public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, | ||
| StreamAdvisorChain streamAdvisorChain) { | ||
| Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest, | ||
| streamAdvisorChain, this::before); | ||
|
|
||
| return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after); | ||
| } | ||
|
|
||
| private ChatClientRequest before(ChatClientRequest chatClientRequest) { | ||
| String conversationId = this.doGetConversationId(chatClientRequest.context()); | ||
| int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context()); | ||
|
|
||
| // 1. Retrieve the chat memory for the current conversation. | ||
| var searchRequest = SearchRequest.builder() | ||
| .query(chatClientRequest.prompt().getUserMessage().getText()) | ||
| .topK(chatMemoryRetrieveSize) | ||
| .filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'") | ||
| public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) { | ||
| String conversationId = doGetConversationId(request.context()); | ||
| String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : ""; | ||
| int topK = doGetChatMemoryRetrieveSize(request.context()); | ||
| String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'"; | ||
| var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder() | ||
| .query(query) | ||
| .topK(topK) | ||
| .filterExpression(filter) | ||
| .build(); | ||
| java.util.List<org.springframework.ai.document.Document> documents = this.getChatMemoryStore() | ||
| .similaritySearch(searchRequest); | ||
|
|
||
| List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest); | ||
|
|
||
| // 2. Processed memory messages as a string. | ||
| String longTermMemory = documents == null ? "" | ||
| : documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator())); | ||
| : documents.stream() | ||
| .map(org.springframework.ai.document.Document::getText) | ||
| .collect(java.util.stream.Collectors.joining(System.lineSeparator())); | ||
|
|
||
| // 2. Augment the system message. | ||
| SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage(); | ||
| org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage(); | ||
| String augmentedSystemText = this.systemPromptTemplate | ||
| .render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory)); | ||
| .render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory)); | ||
|
|
||
| // 3. Create a new request with the augmented system message. | ||
| ChatClientRequest processedChatClientRequest = chatClientRequest.mutate() | ||
| .prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText)) | ||
| ChatClientRequest processedChatClientRequest = request.mutate() | ||
| .prompt(request.prompt().augmentSystemMessage(augmentedSystemText)) | ||
| .build(); | ||
|
|
||
| // 4. Add the new user message to the conversation memory. | ||
| UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); | ||
| this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId)); | ||
| org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt() | ||
| .getUserMessage(); | ||
| if (userMessage != null) { | ||
| this.getChatMemoryStore().write(toDocuments(java.util.List.of(userMessage), conversationId)); | ||
| } | ||
|
|
||
| return processedChatClientRequest; | ||
| } | ||
|
|
||
| private void after(ChatClientResponse chatClientResponse) { | ||
| protected int doGetChatMemoryRetrieveSize(Map<String, Object> context) { | ||
|
||
| return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY) | ||
| ? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString()) | ||
| : this.defaultChatMemoryRetrieveSize; | ||
| } | ||
|
|
||
| @Override | ||
| public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { | ||
| List<Message> assistantMessages = new ArrayList<>(); | ||
| if (chatClientResponse.chatResponse() != null) { | ||
| assistantMessages = chatClientResponse.chatResponse() | ||
|
|
@@ -144,6 +138,7 @@ private void after(ChatClientResponse chatClientResponse) { | |
| } | ||
| this.getChatMemoryStore() | ||
| .write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context()))); | ||
| return chatClientResponse; | ||
| } | ||
|
|
||
| private List<Document> toDocuments(List<Message> messages, String conversationId) { | ||
|
|
@@ -173,22 +168,56 @@ else if (message instanceof AssistantMessage assistantMessage) { | |
| return docs; | ||
| } | ||
|
|
||
| public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore> { | ||
| /** | ||
| * Builder for VectorStoreChatMemoryAdvisor. | ||
| */ | ||
| public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore, Builder> { | ||
|
|
||
| private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE; | ||
|
|
||
| protected Builder(VectorStore chatMemory) { | ||
| super(chatMemory); | ||
| private Integer chatMemoryRetrieveSize = DEFAULT_CHAT_MEMORY_RESPONSE_SIZE; | ||
|
|
||
| /** | ||
| * Creates a new builder instance. | ||
| * @param vectorStore the vector store to use | ||
| */ | ||
| protected Builder(VectorStore vectorStore) { | ||
| super(vectorStore); | ||
| } | ||
|
|
||
| public Builder systemTextAdvise(String systemTextAdvise) { | ||
| this.systemPromptTemplate = new PromptTemplate(systemTextAdvise); | ||
| @Override | ||
| protected Builder self() { | ||
| return this; | ||
| } | ||
|
|
||
| /** | ||
| * Set the system prompt template. | ||
| * @param systemPromptTemplate the system prompt template | ||
| * @return this builder | ||
| */ | ||
| public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) { | ||
| this.systemPromptTemplate = systemPromptTemplate; | ||
| return this; | ||
| return self(); | ||
| } | ||
|
|
||
| /** | ||
| * Set the system prompt template using a text template. | ||
| * @param systemTextAdvise the system prompt text template | ||
| * @return this builder | ||
| */ | ||
| public Builder systemTextAdvise(String systemTextAdvise) { | ||
|
||
| this.systemPromptTemplate = new PromptTemplate(systemTextAdvise); | ||
| return self(); | ||
| } | ||
|
|
||
| /** | ||
| * Set the chat memory retrieve size. | ||
| * @param chatMemoryRetrieveSize the chat memory retrieve size | ||
| * @return this builder | ||
| */ | ||
| public Builder chatMemoryRetrieveSize(int chatMemoryRetrieveSize) { | ||
|
||
| this.chatMemoryRetrieveSize = chatMemoryRetrieveSize; | ||
| return self(); | ||
| } | ||
|
|
||
| @Override | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could consider renaming this to something more explicit about what it's used for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, also maybe move it into the VectorStore interface so it is easier to discover. Let's consider tomorrow.