Skip to content

Commit 2a61b37

Browse files
committed
Issue: 3930
Signed-off-by: Mattia Pasetto [email protected]
1 parent c122fe1 commit 2a61b37

File tree

4 files changed

+213
-2
lines changed

4 files changed

+213
-2
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
3333
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
3434
import org.springframework.ai.chat.memory.ChatMemory;
35+
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
3536
import org.springframework.ai.chat.messages.Message;
3637
import org.springframework.ai.chat.messages.UserMessage;
3738
import org.springframework.util.Assert;
@@ -91,9 +92,10 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai
9192
.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
9293
.build();
9394

94-
// 4. Add the new user message to the conversation memory.
95+
// 4. Handle message updates and add the new user message to the conversation
96+
// memory.
9597
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
96-
this.chatMemory.add(conversationId, userMessage);
98+
handleMessageUpdate(conversationId, userMessage);
9799

98100
return processedChatClientRequest;
99101
}
@@ -128,6 +130,74 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
128130
response -> this.after(response, streamAdvisorChain)));
129131
}
130132

133+
/**
134+
* Handle message updates by checking if a user message with the same ID already
135+
* exists. If it does, remove the old message and its corresponding assistant response
136+
* before adding the new message.
137+
* @param conversationId the conversation ID
138+
* @param userMessage the user message to add
139+
*/
140+
private void handleMessageUpdate(String conversationId, UserMessage userMessage) {
141+
// Ensure the user message has a unique messageId for tracking
142+
UserMessage messageWithId = ensureMessageId(userMessage);
143+
144+
String messageId = (String) messageWithId.getMetadata().get("messageId");
145+
146+
// Check if this is an update (messageId already exists in memory)
147+
if (this.chatMemory instanceof MessageWindowChatMemory windowMemory) {
148+
// If we have an existing message with this ID, remove it and its response
149+
if (hasExistingMessage(conversationId, messageId)) {
150+
windowMemory.removeMessageAndResponse(conversationId, messageId);
151+
}
152+
}
153+
154+
// Add the new/updated message to memory
155+
this.chatMemory.add(conversationId, messageWithId);
156+
}
157+
158+
/**
159+
* Ensure the user message has a unique messageId in its metadata. If no messageId
160+
* exists, generate one based on content hash.
161+
* @param userMessage the user message
162+
* @return the user message with messageId in metadata
163+
*/
164+
private UserMessage ensureMessageId(UserMessage userMessage) {
165+
String existingMessageId = (String) userMessage.getMetadata().get("messageId");
166+
if (existingMessageId != null) {
167+
return userMessage;
168+
}
169+
170+
// Generate a messageId based on content hash for tracking updates
171+
String messageId = generateMessageId(userMessage);
172+
173+
// Merge with existing metadata
174+
java.util.Map<String, Object> metadata = new java.util.HashMap<>(userMessage.getMetadata());
175+
metadata.put("messageId", messageId);
176+
177+
return userMessage.mutate().metadata(metadata).build();
178+
}
179+
180+
/**
181+
* Generate a unique message ID based on the user message content.
182+
* @param userMessage the user message
183+
* @return a unique message ID
184+
*/
185+
private String generateMessageId(UserMessage userMessage) {
186+
// Use content hash as a stable identifier for the same logical message
187+
return String.valueOf(userMessage.getText().hashCode());
188+
}
189+
190+
/**
191+
* Check if a message with the given ID already exists in the conversation memory.
192+
* @param conversationId the conversation ID
193+
* @param messageId the message ID to check
194+
* @return true if the message exists, false otherwise
195+
*/
196+
private boolean hasExistingMessage(String conversationId, String messageId) {
197+
List<Message> memoryMessages = this.chatMemory.get(conversationId);
198+
return memoryMessages.stream().anyMatch(message -> messageId.equals(message.getMetadata().get("messageId")));
199+
}
200+
131201
public static Builder builder(ChatMemory chatMemory) {
132202
return new Builder(chatMemory);
133203
}

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
import org.springframework.ai.chat.memory.ChatMemory;
2424
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
2525
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
26+
import org.springframework.ai.chat.messages.Message;
27+
import org.springframework.ai.chat.messages.UserMessage;
28+
import org.springframework.ai.chat.messages.AssistantMessage;
29+
30+
import java.util.Map;
2631

2732
import static org.assertj.core.api.Assertions.assertThat;
2833
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -108,4 +113,79 @@ void testDefaultValues() {
108113
assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
109114
}
110115

116+
@Test
117+
void testMessageUpdateFunctionality() {
118+
// Create a chat memory
119+
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
120+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
121+
.build();
122+
123+
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
124+
String conversationId = "test-conversation";
125+
126+
// Test 1: Add original message with specific messageId
127+
UserMessage originalMessage = UserMessage.builder()
128+
.text("What is the capital of France?")
129+
.metadata(Map.of("messageId", "msg-001"))
130+
.build();
131+
132+
chatMemory.add(conversationId, originalMessage);
133+
134+
// Simulate adding an assistant response
135+
AssistantMessage assistantResponse = new AssistantMessage("The capital of France is Paris.");
136+
chatMemory.add(conversationId, assistantResponse);
137+
138+
// Verify initial state: should have 2 messages (user + assistant)
139+
assertThat(chatMemory.get(conversationId)).hasSize(2);
140+
assertThat(chatMemory.get(conversationId).get(0).getText()).isEqualTo("What is the capital of France?");
141+
assertThat(chatMemory.get(conversationId).get(1).getText()).isEqualTo("The capital of France is Paris.");
142+
143+
// Test 2: Update the message with same messageId
144+
UserMessage updatedMessage = UserMessage.builder()
145+
.text("What is the capital of Italy?")
146+
.metadata(Map.of("messageId", "msg-001")) // Same messageId
147+
.build();
148+
149+
// Remove old message and response manually (testing the repository functionality)
150+
chatMemory.removeMessageAndResponse(conversationId, "msg-001");
151+
chatMemory.add(conversationId, updatedMessage);
152+
153+
// Verify the update: should have only 1 message (the updated user message)
154+
// The old user message and assistant response should be removed
155+
assertThat(chatMemory.get(conversationId)).hasSize(1);
156+
assertThat(chatMemory.get(conversationId).get(0).getText()).isEqualTo("What is the capital of Italy?");
157+
assertThat(chatMemory.get(conversationId).get(0).getMetadata().get("messageId")).isEqualTo("msg-001");
158+
}
159+
160+
@Test
161+
void testMessageIdGeneration() {
162+
// Create a chat memory
163+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
164+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
165+
.build();
166+
167+
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
168+
169+
// Test that messages without messageId get one generated automatically
170+
UserMessage messageWithoutId = new UserMessage("Hello world");
171+
172+
// This would happen inside handleMessageUpdate method when a message is processed
173+
// We can't directly test the private method, but we can verify the behavior
174+
// by checking that the same content generates the same hash-based ID
175+
176+
String expectedId = String.valueOf("Hello world".hashCode());
177+
178+
// The generateMessageId method should produce consistent IDs for same content
179+
assertThat(expectedId).isNotNull();
180+
181+
// Messages with same content should have same messageId
182+
UserMessage message1 = new UserMessage("Same content");
183+
UserMessage message2 = new UserMessage("Same content");
184+
185+
String id1 = String.valueOf(message1.getText().hashCode());
186+
String id2 = String.valueOf(message2.getText().hashCode());
187+
188+
assertThat(id1).isEqualTo(id2);
189+
}
190+
111191
}

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,52 @@ public interface ChatMemoryRepository {
4040

4141
void deleteByConversationId(String conversationId);
4242

43+
/**
44+
* Find a message by its unique identifier within a conversation.
45+
* @param conversationId the conversation ID
46+
* @param messageId the unique message ID
47+
* @return the message if found, null otherwise
48+
*/
49+
default Message findByMessageId(String conversationId, String messageId) {
50+
return findByConversationId(conversationId).stream()
51+
.filter(message -> messageId.equals(message.getMetadata().get("messageId")))
52+
.findFirst()
53+
.orElse(null);
54+
}
55+
56+
/**
57+
* Delete a specific message and its subsequent assistant response if any. This is
58+
* used when a user message is updated to remove the old message pair.
59+
* @param conversationId the conversation ID
60+
* @param messageId the unique message ID to delete
61+
*/
62+
default void deleteMessageAndResponse(String conversationId, String messageId) {
63+
List<Message> messages = findByConversationId(conversationId);
64+
List<Message> updatedMessages = new java.util.ArrayList<>();
65+
66+
boolean skipNext = false;
67+
for (int i = 0; i < messages.size(); i++) {
68+
Message message = messages.get(i);
69+
String currentMessageId = (String) message.getMetadata().get("messageId");
70+
71+
if (skipNext) {
72+
skipNext = false;
73+
continue;
74+
}
75+
76+
if (messageId.equals(currentMessageId)) {
77+
// Skip this message and potentially the next assistant response
78+
if (i + 1 < messages.size() && messages.get(i + 1)
79+
.getMessageType() == org.springframework.ai.chat.messages.MessageType.ASSISTANT) {
80+
skipNext = true;
81+
}
82+
continue;
83+
}
84+
85+
updatedMessages.add(message);
86+
}
87+
88+
saveAll(conversationId, updatedMessages);
89+
}
90+
4391
}

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ public void clear(String conversationId) {
7777
this.chatMemoryRepository.deleteByConversationId(conversationId);
7878
}
7979

80+
/**
81+
* Remove a specific message and its subsequent assistant response from the
82+
* conversation memory. This is used when a user message is updated to clean up the
83+
* old message pair.
84+
* @param conversationId the conversation ID
85+
* @param messageId the unique message ID to remove
86+
*/
87+
public void removeMessageAndResponse(String conversationId, String messageId) {
88+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
89+
Assert.hasText(messageId, "messageId cannot be null or empty");
90+
this.chatMemoryRepository.deleteMessageAndResponse(conversationId, messageId);
91+
}
92+
8093
private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
8194
List<Message> processedMessages = new ArrayList<>();
8295

0 commit comments

Comments
 (0)