diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 1b8bbea84e9..09afa1d74f3 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -32,6 +32,7 @@ import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.util.Assert; @@ -91,9 +92,10 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai .prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build()) .build(); - // 4. Add the new user message to the conversation memory. + // 4. Handle message updates and add the new user message to the conversation + // memory. UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); - this.chatMemory.add(conversationId, userMessage); + handleMessageUpdate(conversationId, userMessage); return processedChatClientRequest; } @@ -128,6 +130,74 @@ public Flux adviseStream(ChatClientRequest chatClientRequest response -> this.after(response, streamAdvisorChain))); } + /** + * Handle message updates by checking if a user message with the same ID already + * exists. If it does, remove the old message and its corresponding assistant response + * before adding the new message. + * @param conversationId the conversation ID + * @param userMessage the user message to add + */ + private void handleMessageUpdate(String conversationId, UserMessage userMessage) { + // Ensure the user message has a unique messageId for tracking + UserMessage messageWithId = ensureMessageId(userMessage); + + String messageId = (String) messageWithId.getMetadata().get("messageId"); + + // Check if this is an update (messageId already exists in memory) + if (this.chatMemory instanceof MessageWindowChatMemory windowMemory) { + // If we have an existing message with this ID, remove it and its response + if (hasExistingMessage(conversationId, messageId)) { + windowMemory.removeMessageAndResponse(conversationId, messageId); + } + } + + // Add the new/updated message to memory + this.chatMemory.add(conversationId, messageWithId); + } + + /** + * Ensure the user message has a unique messageId in its metadata. If no messageId + * exists, generate one based on content hash. + * @param userMessage the user message + * @return the user message with messageId in metadata + */ + private UserMessage ensureMessageId(UserMessage userMessage) { + String existingMessageId = (String) userMessage.getMetadata().get("messageId"); + if (existingMessageId != null) { + return userMessage; + } + + // Generate a messageId based on content hash for tracking updates + String messageId = generateMessageId(userMessage); + + // Merge with existing metadata + java.util.Map metadata = new java.util.HashMap<>(userMessage.getMetadata()); + metadata.put("messageId", messageId); + + return userMessage.mutate().metadata(metadata).build(); + } + + /** + * Generate a unique message ID based on the user message content. + * @param userMessage the user message + * @return a unique message ID + */ + private String generateMessageId(UserMessage userMessage) { + // Use content hash as a stable identifier for the same logical message + return String.valueOf(userMessage.getText().hashCode()); + } + + /** + * Check if a message with the given ID already exists in the conversation memory. + * @param conversationId the conversation ID + * @param messageId the message ID to check + * @return true if the message exists, false otherwise + */ + private boolean hasExistingMessage(String conversationId, String messageId) { + List memoryMessages = this.chatMemory.get(conversationId); + return memoryMessages.stream().anyMatch(message -> messageId.equals(message.getMetadata().get("messageId"))); + } + public static Builder builder(ChatMemory chatMemory) { return new Builder(chatMemory); } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java index 52ec1c00a98..ce816ee9c67 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java @@ -23,6 +23,11 @@ import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.messages.AssistantMessage; + +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -108,4 +113,79 @@ void testDefaultValues() { assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); } + @Test + void testMessageUpdateFunctionality() { + // Create a chat memory + MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build(); + String conversationId = "test-conversation"; + + // Test 1: Add original message with specific messageId + UserMessage originalMessage = UserMessage.builder() + .text("What is the capital of France?") + .metadata(Map.of("messageId", "msg-001")) + .build(); + + chatMemory.add(conversationId, originalMessage); + + // Simulate adding an assistant response + AssistantMessage assistantResponse = new AssistantMessage("The capital of France is Paris."); + chatMemory.add(conversationId, assistantResponse); + + // Verify initial state: should have 2 messages (user + assistant) + assertThat(chatMemory.get(conversationId)).hasSize(2); + assertThat(chatMemory.get(conversationId).get(0).getText()).isEqualTo("What is the capital of France?"); + assertThat(chatMemory.get(conversationId).get(1).getText()).isEqualTo("The capital of France is Paris."); + + // Test 2: Update the message with same messageId + UserMessage updatedMessage = UserMessage.builder() + .text("What is the capital of Italy?") + .metadata(Map.of("messageId", "msg-001")) // Same messageId + .build(); + + // Remove old message and response manually (testing the repository functionality) + chatMemory.removeMessageAndResponse(conversationId, "msg-001"); + chatMemory.add(conversationId, updatedMessage); + + // Verify the update: should have only 1 message (the updated user message) + // The old user message and assistant response should be removed + assertThat(chatMemory.get(conversationId)).hasSize(1); + assertThat(chatMemory.get(conversationId).get(0).getText()).isEqualTo("What is the capital of Italy?"); + assertThat(chatMemory.get(conversationId).get(0).getMetadata().get("messageId")).isEqualTo("msg-001"); + } + + @Test + void testMessageIdGeneration() { + // Create a chat memory + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build(); + + // Test that messages without messageId get one generated automatically + UserMessage messageWithoutId = new UserMessage("Hello world"); + + // This would happen inside handleMessageUpdate method when a message is processed + // We can't directly test the private method, but we can verify the behavior + // by checking that the same content generates the same hash-based ID + + String expectedId = String.valueOf("Hello world".hashCode()); + + // The generateMessageId method should produce consistent IDs for same content + assertThat(expectedId).isNotNull(); + + // Messages with same content should have same messageId + UserMessage message1 = new UserMessage("Same content"); + UserMessage message2 = new UserMessage("Same content"); + + String id1 = String.valueOf(message1.getText().hashCode()); + String id2 = String.valueOf(message2.getText().hashCode()); + + assertThat(id1).isEqualTo(id2); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java index 350b1ebbcf7..c062d30aad8 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java @@ -40,4 +40,52 @@ public interface ChatMemoryRepository { void deleteByConversationId(String conversationId); + /** + * Find a message by its unique identifier within a conversation. + * @param conversationId the conversation ID + * @param messageId the unique message ID + * @return the message if found, null otherwise + */ + default Message findByMessageId(String conversationId, String messageId) { + return findByConversationId(conversationId).stream() + .filter(message -> messageId.equals(message.getMetadata().get("messageId"))) + .findFirst() + .orElse(null); + } + + /** + * Delete a specific message and its subsequent assistant response if any. This is + * used when a user message is updated to remove the old message pair. + * @param conversationId the conversation ID + * @param messageId the unique message ID to delete + */ + default void deleteMessageAndResponse(String conversationId, String messageId) { + List messages = findByConversationId(conversationId); + List updatedMessages = new java.util.ArrayList<>(); + + boolean skipNext = false; + for (int i = 0; i < messages.size(); i++) { + Message message = messages.get(i); + String currentMessageId = (String) message.getMetadata().get("messageId"); + + if (skipNext) { + skipNext = false; + continue; + } + + if (messageId.equals(currentMessageId)) { + // Skip this message and potentially the next assistant response + if (i + 1 < messages.size() && messages.get(i + 1) + .getMessageType() == org.springframework.ai.chat.messages.MessageType.ASSISTANT) { + skipNext = true; + } + continue; + } + + updatedMessages.add(message); + } + + saveAll(conversationId, updatedMessages); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java index 9c56c9d50bf..6475bb18d49 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java @@ -77,6 +77,19 @@ public void clear(String conversationId) { this.chatMemoryRepository.deleteByConversationId(conversationId); } + /** + * Remove a specific message and its subsequent assistant response from the + * conversation memory. This is used when a user message is updated to clean up the + * old message pair. + * @param conversationId the conversation ID + * @param messageId the unique message ID to remove + */ + public void removeMessageAndResponse(String conversationId, String messageId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.hasText(messageId, "messageId cannot be null or empty"); + this.chatMemoryRepository.deleteMessageAndResponse(conversationId, messageId); + } + private List process(List memoryMessages, List newMessages) { List processedMessages = new ArrayList<>();