Skip to content

Commit 60356fb

Browse files
committed
Refactor Chat Memory Advisors to reduce code duplication and add extensive tests
Core Architecture Changes: 1. New Abstract Class: AbstractConversationHistoryAdvisor - Created a new abstract class that extends AbstractChatMemoryAdvisor<ChatMemory> - Provides common functionality for managing conversation history - Implements methods for retrieving messages and applying them to requests - Simplifies the implementation of concrete advisors 2. Refactored AbstractChatMemoryAdvisor - Removed the defaultChatMemoryRetrieveSize parameter from constructors - Enhanced the builder pattern with generic type parameters for better type safety - Added abstract before(ChatClientRequest, String) method for subclasses to implement - Improved logging for conversation ID handling 3. Refactored MessageChatMemoryAdvisor - Now extends AbstractConversationHistoryAdvisor instead of directly extending AbstractChatMemoryAdvisor - Simplified implementation by leveraging parent class methods - Fixed handling of multiple user messages in a single prompt - Updated the builder to properly extend AbstractChatMemoryAdvisor.AbstractBuilder<ChatMemory, Builder> 4. Refactored PromptChatMemoryAdvisor - Now extends AbstractConversationHistoryAdvisor instead of directly extending AbstractChatMemoryAdvisor - Fixed handling of multiple user messages by using getUserMessages() instead of getUserMessage() - Enhanced logging for better debugging - Updated the builder to properly extend AbstractChatMemoryAdvisor.AbstractBuilder<ChatMemory, Builder> Test Changes: 1. New Test Classes - Added MessageChatMemoryAdvisorIT for integration testing of MessageChatMemoryAdvisor - Added PromptChatMemoryAdvisorIT for integration testing of PromptChatMemoryAdvisor - Both extend from AbstractChatMemoryAdvisorIT to share common test logic 2. Test Coverage - Added tests for handling multiple user messages in a single prompt - Added tests for custom conversation IDs - Added tests for maintaining separate conversations - Added tests for reactive mode operation Key Improvements: 1. Code Duplication Reduction: Moved common functionality to the parent class 2. Bug Fix: Fixed a bug in PromptChatMemoryAdvisor where it was only storing the last user message 3. Enhanced Type Safety: Improved the builder pattern with proper generic type parameters 4. Better Logging: Added detailed logging for better debugging and traceability 5. Simplified API: Removed unnecessary parameters from constructors 6. Improved Test Coverage: Added comprehensive tests for various scenarios Signed-off-by: Mark Pollack <[email protected]>
1 parent 8f879aa commit 60356fb

File tree

14 files changed

+1491
-235
lines changed

14 files changed

+1491
-235
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package org.springframework.ai.chat.client.advisor.vectorstore;
2+
3+
import java.util.List;
4+
5+
import org.springframework.ai.chat.memory.ChatMemory;
6+
import org.springframework.ai.chat.messages.Message;
7+
8+
/**
9+
* Interface for chat memories that support parameterized retrieval. Implementations can
10+
* define their own parameter types.
11+
*
12+
* @param <P> The type of parameters used for retrieval
13+
*/
14+
public interface ParameterizedChatMemory<P> extends ChatMemory {
15+
16+
/**
17+
* Retrieve messages based on the provided parameters.
18+
* @param conversationId The conversation identifier
19+
* @param parameters The retrieval parameters
20+
* @return List of retrieved messages
21+
*/
22+
List<Message> retrieve(String conversationId, P parameters);
23+
24+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package org.springframework.ai.chat.client.advisor.vectorstore;
2+
3+
import org.springframework.ai.vectorstore.SearchRequest;
4+
5+
/**
6+
* Parameters for vector store similarity search with default values.
7+
*/
8+
public record VectorSearchParameters(int topK, String filter) {
9+
10+
public static final int DEFAULT_TOP_K = 100;
11+
12+
public VectorSearchParameters() {
13+
this(VectorStoreChatMemoryAdvisor.DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, null);
14+
}
15+
16+
public static VectorSearchParameters of(int topK) {
17+
return new VectorSearchParameters(topK, null);
18+
}
19+
20+
public static VectorSearchParameters forConversation(String conversationId) {
21+
return new VectorSearchParameters(DEFAULT_TOP_K, "conversationId=='" + conversationId + "'");
22+
}
23+
24+
public VectorSearchParameters withTopK(int topK) {
25+
return new VectorSearchParameters(topK, this.filter);
26+
}
27+
28+
/**
29+
* Create a new instance with a different filter.
30+
* @param filter the new filter expression
31+
* @return a new VectorSearchParameters instance with the updated filter
32+
*/
33+
public VectorSearchParameters withFilter(String filter) {
34+
return new VectorSearchParameters(this.topK, filter);
35+
}
36+
37+
}

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 110 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
import reactor.core.publisher.Flux;
2626

27-
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
2827
import org.springframework.ai.chat.client.ChatClientRequest;
2928
import org.springframework.ai.chat.client.ChatClientResponse;
29+
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
3030
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
3131
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
3232
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -52,10 +52,17 @@
5252
*/
5353
public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<VectorStore> {
5454

55+
public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";
56+
5557
private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
5658

5759
private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";
5860

61+
/**
62+
* The default chat memory retrieve size to use when no retrieve size is provided.
63+
*/
64+
public static final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 100;
65+
5966
private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
6067
{instructions}
6168
@@ -69,71 +76,62 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
6976

7077
private final PromptTemplate systemPromptTemplate;
7178

72-
private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
73-
int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) {
74-
super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order);
79+
protected final int defaultChatMemoryRetrieveSize;
80+
81+
public VectorStoreChatMemoryAdvisor(VectorStore chatMemory, String defaultConversationId,
82+
int defaultChatMemoryRetrieveSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate,
83+
int order) {
84+
super(chatMemory, defaultConversationId, protectFromBlocking, order);
7585
this.systemPromptTemplate = systemPromptTemplate;
86+
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
7687
}
7788

7889
public static Builder builder(VectorStore chatMemory) {
7990
return new Builder(chatMemory);
8091
}
8192

82-
@Override
83-
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
84-
chatClientRequest = this.before(chatClientRequest);
85-
86-
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
87-
88-
this.after(chatClientResponse);
89-
90-
return chatClientResponse;
93+
protected int doGetChatMemoryRetrieveSize(Map<String, Object> context) {
94+
return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)
95+
? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())
96+
: this.defaultChatMemoryRetrieveSize;
9197
}
9298

9399
@Override
94-
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
95-
StreamAdvisorChain streamAdvisorChain) {
96-
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
97-
streamAdvisorChain, this::before);
98-
99-
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
100-
}
101-
102-
private ChatClientRequest before(ChatClientRequest chatClientRequest) {
103-
String conversationId = this.doGetConversationId(chatClientRequest.context());
104-
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());
105-
106-
// 1. Retrieve the chat memory for the current conversation.
107-
var searchRequest = SearchRequest.builder()
108-
.query(chatClientRequest.prompt().getUserMessage().getText())
109-
.topK(chatMemoryRetrieveSize)
110-
.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")
100+
protected ChatClientRequest before(ChatClientRequest request, String conversationId) {
101+
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
102+
int topK = doGetChatMemoryRetrieveSize(request.context());
103+
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
104+
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
105+
.query(query)
106+
.topK(topK)
107+
.filterExpression(filter)
111108
.build();
109+
java.util.List<org.springframework.ai.document.Document> documents = this.getChatMemoryStore()
110+
.similaritySearch(searchRequest);
112111

113-
List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);
114-
115-
// 2. Processed memory messages as a string.
116112
String longTermMemory = documents == null ? ""
117-
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
113+
: documents.stream()
114+
.map(org.springframework.ai.document.Document::getText)
115+
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));
118116

119-
// 2. Augment the system message.
120-
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
117+
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
121118
String augmentedSystemText = this.systemPromptTemplate
122-
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
119+
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
123120

124-
// 3. Create a new request with the augmented system message.
125-
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
126-
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
121+
ChatClientRequest processedChatClientRequest = request.mutate()
122+
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
127123
.build();
128124

129-
// 4. Add the new user message to the conversation memory.
130-
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
131-
this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));
125+
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
126+
.getUserMessage();
127+
if (userMessage != null) {
128+
this.getChatMemoryStore().write(toDocuments(java.util.List.of(userMessage), conversationId));
129+
}
132130

133131
return processedChatClientRequest;
134132
}
135133

136-
private void after(ChatClientResponse chatClientResponse) {
134+
protected void after(ChatClientResponse chatClientResponse) {
137135
List<Message> assistantMessages = new ArrayList<>();
138136
if (chatClientResponse.chatResponse() != null) {
139137
assistantMessages = chatClientResponse.chatResponse()
@@ -146,6 +144,24 @@ private void after(ChatClientResponse chatClientResponse) {
146144
.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));
147145
}
148146

147+
protected ChatClientRequest applyMessagesToRequest(ChatClientRequest request, List<Message> memoryMessages) {
148+
if (memoryMessages == null || memoryMessages.isEmpty()) {
149+
return request;
150+
}
151+
// Convert memory messages to a string for the system prompt
152+
String longTermMemory = memoryMessages.stream()
153+
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
154+
.map(m -> m.getMessageType() + ":" + m.getText())
155+
.collect(Collectors.joining(System.lineSeparator()));
156+
157+
SystemMessage systemMessage = request.prompt().getSystemMessage();
158+
String augmentedSystemText = this.systemPromptTemplate
159+
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
160+
161+
// Create a new request with the augmented system message
162+
return request.mutate().prompt(request.prompt().augmentSystemMessage(augmentedSystemText)).build();
163+
}
164+
149165
private List<Document> toDocuments(List<Message> messages, String conversationId) {
150166
List<Document> docs = messages.stream()
151167
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
@@ -173,28 +189,71 @@ else if (message instanceof AssistantMessage assistantMessage) {
173189
return docs;
174190
}
175191

176-
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore> {
192+
/**
193+
* Builder for VectorStoreChatMemoryAdvisor.
194+
*/
195+
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore, Builder> {
177196

178197
private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;
179198

180-
protected Builder(VectorStore chatMemory) {
181-
super(chatMemory);
199+
private Integer defaultChatMemoryRetrieveSize = null;
200+
201+
/**
202+
* Creates a new builder instance.
203+
* @param vectorStore the vector store to use
204+
*/
205+
protected Builder(VectorStore vectorStore) {
206+
super(vectorStore);
207+
}
208+
209+
/**
210+
* Set the system prompt template.
211+
* @param systemPromptTemplate the system prompt template
212+
* @return this builder
213+
*/
214+
public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
215+
this.systemPromptTemplate = systemPromptTemplate;
216+
return this;
182217
}
183218

219+
/**
220+
* Set the system prompt template using a text template.
221+
* @param systemTextAdvise the system prompt text template
222+
* @return this builder
223+
*/
184224
public Builder systemTextAdvise(String systemTextAdvise) {
185225
this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);
186226
return this;
187227
}
188228

189-
public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
190-
this.systemPromptTemplate = systemPromptTemplate;
229+
/**
230+
* Set the default chat memory retrieve size.
231+
* @param defaultChatMemoryRetrieveSize the default chat memory retrieve size
232+
* @return this builder
233+
*/
234+
public Builder defaultChatMemoryRetrieveSize(int defaultChatMemoryRetrieveSize) {
235+
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
236+
return this;
237+
}
238+
239+
@Override
240+
protected Builder self() {
191241
return this;
192242
}
193243

194244
@Override
195245
public VectorStoreChatMemoryAdvisor build() {
196-
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
197-
this.protectFromBlocking, this.systemPromptTemplate, this.order);
246+
if (defaultChatMemoryRetrieveSize == null) {
247+
// Default to legacy mode for backward compatibility
248+
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId,
249+
DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, this.protectFromBlocking, this.systemPromptTemplate,
250+
this.order);
251+
}
252+
else {
253+
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId,
254+
this.defaultChatMemoryRetrieveSize, this.protectFromBlocking, this.systemPromptTemplate,
255+
this.order);
256+
}
198257
}
199258

200259
}

0 commit comments

Comments
 (0)