diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java index 350b1ebbcf7..4132bcff83c 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java @@ -40,4 +40,15 @@ public interface ChatMemoryRepository { void deleteByConversationId(String conversationId); + /** + * Atomically removes the messages in {@code deletes} and adds the messages in {@code adds} + * for the given conversation ID. This provides a more efficient way to update + * the memory than reading the entire history and overwriting it. + * + * @param conversationId The ID of the conversation to update. + * @param deletes A list of messages to be removed from the memory. + * @param adds A list of new messages to be added to the memory. + */ + void refresh(String conversationId, List deletes, List adds); + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java index 0843af8ffa2..756e19b34c5 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java @@ -60,4 +60,17 @@ public void deleteByConversationId(String conversationId) { this.chatMemoryStore.remove(conversationId); } + @Override + public void refresh(String conversationId, List deletes, List adds) { + this.chatMemoryStore.compute(conversationId, (key, currentMessages) -> { + if (currentMessages == null) { + return new ArrayList<>(adds); + } + List updatedMessages = new ArrayList<>(currentMessages); + updatedMessages.removeAll(deletes); + updatedMessages.addAll(adds); + return updatedMessages; + }); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java index 9c56c9d50bf..7538ec93035 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Set; @@ -61,8 +62,10 @@ public void add(String conversationId, List messages) { Assert.noNullElements(messages, "messages cannot contain null elements"); List memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId); - List processedMessages = process(memoryMessages, messages); - this.chatMemoryRepository.saveAll(conversationId, processedMessages); + MessageChanges changes = process(memoryMessages, messages); + if (!changes.toDelete.isEmpty() || !changes.toAdd.isEmpty()) { + this.chatMemoryRepository.refresh(conversationId, changes.toDelete, changes.toAdd); + } } @Override @@ -77,38 +80,55 @@ public void clear(String conversationId) { this.chatMemoryRepository.deleteByConversationId(conversationId); } - private List process(List memoryMessages, List newMessages) { - List processedMessages = new ArrayList<>(); + private MessageChanges process(List memoryMessages, List newMessages) { + Set originalMessageSet = new LinkedHashSet<>(memoryMessages); + List uniqueNewMessages = newMessages.stream() + .filter(msg -> !originalMessageSet.contains(msg)) + .toList(); + boolean hasNewSystemMessage = uniqueNewMessages.stream().anyMatch(SystemMessage.class::isInstance); + + List finalMessages = new ArrayList<>(); + if (hasNewSystemMessage) { + memoryMessages.stream().filter(msg -> !(msg instanceof SystemMessage)).forEach(finalMessages::add); + } else { + finalMessages.addAll(memoryMessages); + } + finalMessages.addAll(uniqueNewMessages); + + if (finalMessages.size() > this.maxMessages) { + List trimmedMessages = new ArrayList<>(); + int messagesToRemove = finalMessages.size() - this.maxMessages; + int removed = 0; + for (Message message : finalMessages) { + if (message instanceof SystemMessage || removed >= messagesToRemove) { + trimmedMessages.add(message); + } else { + removed++; + } + } + finalMessages = trimmedMessages; + } - Set memoryMessagesSet = new HashSet<>(memoryMessages); - boolean hasNewSystemMessage = newMessages.stream() - .filter(SystemMessage.class::isInstance) - .anyMatch(message -> !memoryMessagesSet.contains(message)); + Set finalMessageSet = new LinkedHashSet<>(finalMessages); - memoryMessages.stream() - .filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage)) - .forEach(processedMessages::add); + List toDelete = originalMessageSet.stream().filter(m -> !finalMessageSet.contains(m)).toList(); - processedMessages.addAll(newMessages); + List toAdd = finalMessageSet.stream().filter(m -> !originalMessageSet.contains(m)).toList(); - if (processedMessages.size() <= this.maxMessages) { - return processedMessages; - } + return new MessageChanges(toDelete, toAdd); + } - int messagesToRemove = processedMessages.size() - this.maxMessages; + private static class MessageChanges { - List trimmedMessages = new ArrayList<>(); - int removed = 0; - for (Message message : processedMessages) { - if (message instanceof SystemMessage || removed >= messagesToRemove) { - trimmedMessages.add(message); - } - else { - removed++; - } + final List toDelete; + + final List toAdd; + + MessageChanges(List toDelete, List toAdd) { + this.toDelete = toDelete; + this.toAdd = toAdd; } - return trimmedMessages; } public static Builder builder() { diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java index 90283e3e9c2..e591bcba180 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java @@ -152,5 +152,62 @@ void messagesWithNullElementsNotAllowed() { .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("messages cannot contain null elements"); } + // refresh test code + @Test + void refreshAddsMessagesToNewConversation() { + String conversationId = UUID.randomUUID().toString(); + List messagesToAdd = List.of(new UserMessage("Hello")); + this.chatMemoryRepository.refresh(conversationId, List.of(), messagesToAdd); + + assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).isEqualTo(messagesToAdd); + } + + @Test + void refreshAddsMessagesToExistingConversation() { + String conversationId = UUID.randomUUID().toString(); + Message initialMessage = new UserMessage("Initial"); + this.chatMemoryRepository.saveAll(conversationId, List.of(initialMessage)); + + Message newMessage = new AssistantMessage("New"); + this.chatMemoryRepository.refresh(conversationId, List.of(), List.of(newMessage)); + + assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactly(initialMessage, newMessage); + } + + @Test + void refreshDeletesMessagesFromExistingConversation() { + String conversationId = UUID.randomUUID().toString(); + Message messageToKeep = new UserMessage("Keep"); + Message messageToDelete = new AssistantMessage("Delete"); + this.chatMemoryRepository.saveAll(conversationId, List.of(messageToKeep, messageToDelete)); + + this.chatMemoryRepository.refresh(conversationId, List.of(messageToDelete), List.of()); + + assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactly(messageToKeep); + } + + @Test + void refreshAddsAndDeletesMessages() { + String conversationId = UUID.randomUUID().toString(); + Message messageToKeep = new UserMessage("Keep"); + Message messageToDelete = new AssistantMessage("Delete"); + this.chatMemoryRepository.saveAll(conversationId, List.of(messageToKeep, messageToDelete)); + + Message messageToAdd = new UserMessage("Add"); + this.chatMemoryRepository.refresh(conversationId, List.of(messageToDelete), List.of(messageToAdd)); + + assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactlyInAnyOrder(messageToKeep, messageToAdd); + } + + @Test + void refreshWithEmptyChangesDoesNothing() { + String conversationId = UUID.randomUUID().toString(); + List initialMessages = List.of(new UserMessage("Hello")); + this.chatMemoryRepository.saveAll(conversationId, initialMessages); + + this.chatMemoryRepository.refresh(conversationId, List.of(), List.of()); + + assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).isEqualTo(initialMessages); + } }