Skip to content

Commit 4bcfbe0

Browse files
committed
refactor: Simplify chat memory advisor hierarchy and remove deprecated API
Remove deprecated method in ChatMemory and refactor chat memory advisor classes - Refactored AbstractChatMemoryAdvisor to inherit from BaseAdvisor, leveraging its default implementations for adviseCall and adviseStream methods - Updated MessageChatMemoryAdvisor, PromptChatMemoryAdvisor, and VectorStoreChatMemoryAdvisor to directly implement their own before/after methods - Removed deprecated ChatMemory.get(String conversationId, int lastN) method in favor of using MessageWindowChatMemory - Fixed a bug in PromptChatMemoryAdvisor where it was only storing the last user message from a prompt with multiple messages - Enhanced logging in memory advisors to aid in debugging - Added comprehensive tests for advisor implementations: - Unit tests for MessageChatMemoryAdvisor and PromptChatMemoryAdvisor to check builder behavior - Integration tests for VectorStoreChatMemoryAdvisor with semantic memory retrieval Signed-off-by: Mark Pollack <[email protected]>
1 parent 2d517ee commit 4bcfbe0

File tree

13 files changed

+1483
-270
lines changed

13 files changed

+1483
-270
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 89 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,17 @@
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.stream.Collectors;
2423

25-
import reactor.core.publisher.Flux;
26-
27-
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
2824
import org.springframework.ai.chat.client.ChatClientRequest;
2925
import org.springframework.ai.chat.client.ChatClientResponse;
30-
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
31-
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
26+
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
27+
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
3228
import org.springframework.ai.chat.messages.AssistantMessage;
3329
import org.springframework.ai.chat.messages.Message;
3430
import org.springframework.ai.chat.messages.MessageType;
35-
import org.springframework.ai.chat.messages.SystemMessage;
3631
import org.springframework.ai.chat.messages.UserMessage;
37-
import org.springframework.ai.chat.model.MessageAggregator;
3832
import org.springframework.ai.chat.prompt.PromptTemplate;
3933
import org.springframework.ai.document.Document;
40-
import org.springframework.ai.vectorstore.SearchRequest;
4134
import org.springframework.ai.vectorstore.VectorStore;
4235

4336
/**
@@ -48,14 +41,22 @@
4841
* @author Christian Tzolov
4942
* @author Thomas Vitale
5043
* @author Oganes Bozoyan
44+
* @author Mark Pollack
5145
* @since 1.0.0
5246
*/
5347
public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<VectorStore> {
5448

49+
public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";
50+
5551
private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
5652

5753
private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";
5854

55+
/**
56+
* The default chat memory retrieve size to use when no retrieve size is provided.
57+
*/
58+
public static final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 100;
59+
5960
private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
6061
{instructions}
6162
@@ -69,71 +70,64 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
6970

7071
private final PromptTemplate systemPromptTemplate;
7172

72-
private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
73-
int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) {
74-
super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order);
73+
protected final int defaultChatMemoryRetrieveSize;
74+
75+
public VectorStoreChatMemoryAdvisor(VectorStore chatMemory, String defaultConversationId,
76+
int defaultChatMemoryRetrieveSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate,
77+
int order) {
78+
super(chatMemory, defaultConversationId, protectFromBlocking, order);
7579
this.systemPromptTemplate = systemPromptTemplate;
80+
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
7681
}
7782

7883
public static Builder builder(VectorStore chatMemory) {
7984
return new Builder(chatMemory);
8085
}
8186

8287
@Override
83-
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
84-
chatClientRequest = this.before(chatClientRequest);
85-
86-
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
87-
88-
this.after(chatClientResponse);
89-
90-
return chatClientResponse;
91-
}
92-
93-
@Override
94-
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
95-
StreamAdvisorChain streamAdvisorChain) {
96-
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
97-
streamAdvisorChain, this::before);
98-
99-
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
100-
}
101-
102-
private ChatClientRequest before(ChatClientRequest chatClientRequest) {
103-
String conversationId = this.doGetConversationId(chatClientRequest.context());
104-
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());
105-
106-
// 1. Retrieve the chat memory for the current conversation.
107-
var searchRequest = SearchRequest.builder()
108-
.query(chatClientRequest.prompt().getUserMessage().getText())
109-
.topK(chatMemoryRetrieveSize)
110-
.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")
88+
public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
89+
String conversationId = doGetConversationId(request.context());
90+
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
91+
int topK = doGetChatMemoryRetrieveSize(request.context());
92+
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
93+
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
94+
.query(query)
95+
.topK(topK)
96+
.filterExpression(filter)
11197
.build();
98+
java.util.List<org.springframework.ai.document.Document> documents = this.getChatMemoryStore()
99+
.similaritySearch(searchRequest);
112100

113-
List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);
114-
115-
// 2. Processed memory messages as a string.
116101
String longTermMemory = documents == null ? ""
117-
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
102+
: documents.stream()
103+
.map(org.springframework.ai.document.Document::getText)
104+
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));
118105

119-
// 2. Augment the system message.
120-
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
106+
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
121107
String augmentedSystemText = this.systemPromptTemplate
122-
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
108+
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
123109

124-
// 3. Create a new request with the augmented system message.
125-
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
126-
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
110+
ChatClientRequest processedChatClientRequest = request.mutate()
111+
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
127112
.build();
128113

129-
// 4. Add the new user message to the conversation memory.
130-
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
131-
this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));
114+
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
115+
.getUserMessage();
116+
if (userMessage != null) {
117+
this.getChatMemoryStore().write(toDocuments(java.util.List.of(userMessage), conversationId));
118+
}
132119

133120
return processedChatClientRequest;
134121
}
135122

136-
private void after(ChatClientResponse chatClientResponse) {
123+
protected int doGetChatMemoryRetrieveSize(Map<String, Object> context) {
124+
return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)
125+
? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())
126+
: this.defaultChatMemoryRetrieveSize;
127+
}
128+
129+
@Override
130+
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
137131
List<Message> assistantMessages = new ArrayList<>();
138132
if (chatClientResponse.chatResponse() != null) {
139133
assistantMessages = chatClientResponse.chatResponse()
@@ -144,6 +138,7 @@ private void after(ChatClientResponse chatClientResponse) {
144138
}
145139
this.getChatMemoryStore()
146140
.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));
141+
return chatClientResponse;
147142
}
148143

149144
private List<Document> toDocuments(List<Message> messages, String conversationId) {
@@ -173,22 +168,56 @@ else if (message instanceof AssistantMessage assistantMessage) {
173168
return docs;
174169
}
175170

176-
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore> {
171+
/**
172+
* Builder for VectorStoreChatMemoryAdvisor.
173+
*/
174+
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore, Builder> {
177175

178176
private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;
179177

180-
protected Builder(VectorStore chatMemory) {
181-
super(chatMemory);
178+
private Integer chatMemoryRetrieveSize = DEFAULT_CHAT_MEMORY_RESPONSE_SIZE;
179+
180+
/**
181+
* Creates a new builder instance.
182+
* @param vectorStore the vector store to use
183+
*/
184+
protected Builder(VectorStore vectorStore) {
185+
super(vectorStore);
182186
}
183187

184-
public Builder systemTextAdvise(String systemTextAdvise) {
185-
this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);
188+
@Override
189+
protected Builder self() {
186190
return this;
187191
}
188192

193+
/**
194+
* Set the system prompt template.
195+
* @param systemPromptTemplate the system prompt template
196+
* @return this builder
197+
*/
189198
public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
190199
this.systemPromptTemplate = systemPromptTemplate;
191-
return this;
200+
return self();
201+
}
202+
203+
/**
204+
* Set the system prompt template using a text template.
205+
* @param systemTextAdvise the system prompt text template
206+
* @return this builder
207+
*/
208+
public Builder systemTextAdvise(String systemTextAdvise) {
209+
this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);
210+
return self();
211+
}
212+
213+
/**
214+
* Set the chat memory retrieve size.
215+
* @param chatMemoryRetrieveSize the chat memory retrieve size
216+
* @return this builder
217+
*/
218+
public Builder chatMemoryRetrieveSize(int chatMemoryRetrieveSize) {
219+
this.chatMemoryRetrieveSize = chatMemoryRetrieveSize;
220+
return self();
192221
}
193222

194223
@Override

0 commit comments

Comments
 (0)