Skip to content

Commit 6bf4f28

Browse files
committed
refactor(memory): Improve Chat Memory logic and efficiency
Refactors `MessageWindowChatMemory` to improve performance and extensibility by abstracting the storage layer. Introduces a `ChatMemoryRepository` interface to decouple memory logic from the underlying storage. This allows for future persistent implementations (e.g., Redis, JDBC). The default `InMemoryChatMemoryRepository` now uses an efficient `refresh` method to apply only deltas, reducing overhead in long conversations. Signed-off-by: Juntar Park <[email protected]> Signed-off-by: potato <[email protected]>
1 parent e0ccc13 commit 6bf4f28

File tree

4 files changed

+127
-26
lines changed

4 files changed

+127
-26
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,15 @@ public interface ChatMemoryRepository {
4040

4141
void deleteByConversationId(String conversationId);
4242

43+
/**
44+
* Atomically removes the messages in {@code deletes} and adds the messages in {@code adds}
45+
* for the given conversation ID. This provides a more efficient way to update
46+
* the memory than reading the entire history and overwriting it.
47+
*
48+
* @param conversationId The ID of the conversation to update.
49+
* @param deletes A list of messages to be removed from the memory.
50+
* @param adds A list of new messages to be added to the memory.
51+
*/
52+
void refresh(String conversationId, List<Message> deletes, List<Message> adds);
53+
4354
}

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,17 @@ public void deleteByConversationId(String conversationId) {
6060
this.chatMemoryStore.remove(conversationId);
6161
}
6262

63+
@Override
64+
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
65+
this.chatMemoryStore.compute(conversationId, (key, currentMessages) -> {
66+
if (currentMessages == null) {
67+
return new ArrayList<>(adds);
68+
}
69+
List<Message> updatedMessages = new ArrayList<>(currentMessages);
70+
updatedMessages.removeAll(deletes);
71+
updatedMessages.addAll(adds);
72+
return updatedMessages;
73+
});
74+
}
75+
6376
}

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.HashSet;
21+
import java.util.LinkedHashSet;
2122
import java.util.List;
2223
import java.util.Set;
2324

@@ -61,8 +62,10 @@ public void add(String conversationId, List<Message> messages) {
6162
Assert.noNullElements(messages, "messages cannot contain null elements");
6263

6364
List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);
64-
List<Message> processedMessages = process(memoryMessages, messages);
65-
this.chatMemoryRepository.saveAll(conversationId, processedMessages);
65+
MessageChanges changes = process(memoryMessages, messages);
66+
if (!changes.toDelete.isEmpty() || !changes.toAdd.isEmpty()) {
67+
this.chatMemoryRepository.refresh(conversationId, changes.toDelete, changes.toAdd);
68+
}
6669
}
6770

6871
@Override
@@ -77,38 +80,55 @@ public void clear(String conversationId) {
7780
this.chatMemoryRepository.deleteByConversationId(conversationId);
7881
}
7982

80-
private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
81-
List<Message> processedMessages = new ArrayList<>();
83+
private MessageChanges process(List<Message> memoryMessages, List<Message> newMessages) {
84+
Set<Message> originalMessageSet = new LinkedHashSet<>(memoryMessages);
85+
List<Message> uniqueNewMessages = newMessages.stream()
86+
.filter(msg -> !originalMessageSet.contains(msg))
87+
.toList();
88+
boolean hasNewSystemMessage = uniqueNewMessages.stream().anyMatch(SystemMessage.class::isInstance);
89+
90+
List<Message> finalMessages = new ArrayList<>();
91+
if (hasNewSystemMessage) {
92+
memoryMessages.stream().filter(msg -> !(msg instanceof SystemMessage)).forEach(finalMessages::add);
93+
} else {
94+
finalMessages.addAll(memoryMessages);
95+
}
96+
finalMessages.addAll(uniqueNewMessages);
97+
98+
if (finalMessages.size() > this.maxMessages) {
99+
List<Message> trimmedMessages = new ArrayList<>();
100+
int messagesToRemove = finalMessages.size() - this.maxMessages;
101+
int removed = 0;
102+
for (Message message : finalMessages) {
103+
if (message instanceof SystemMessage || removed >= messagesToRemove) {
104+
trimmedMessages.add(message);
105+
} else {
106+
removed++;
107+
}
108+
}
109+
finalMessages = trimmedMessages;
110+
}
82111

83-
Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);
84-
boolean hasNewSystemMessage = newMessages.stream()
85-
.filter(SystemMessage.class::isInstance)
86-
.anyMatch(message -> !memoryMessagesSet.contains(message));
112+
Set<Message> finalMessageSet = new LinkedHashSet<>(finalMessages);
87113

88-
memoryMessages.stream()
89-
.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))
90-
.forEach(processedMessages::add);
114+
List<Message> toDelete = originalMessageSet.stream().filter(m -> !finalMessageSet.contains(m)).toList();
91115

92-
processedMessages.addAll(newMessages);
116+
List<Message> toAdd = finalMessageSet.stream().filter(m -> !originalMessageSet.contains(m)).toList();
93117

94-
if (processedMessages.size() <= this.maxMessages) {
95-
return processedMessages;
96-
}
118+
return new MessageChanges(toDelete, toAdd);
119+
}
97120

98-
int messagesToRemove = processedMessages.size() - this.maxMessages;
121+
private static class MessageChanges {
99122

100-
List<Message> trimmedMessages = new ArrayList<>();
101-
int removed = 0;
102-
for (Message message : processedMessages) {
103-
if (message instanceof SystemMessage || removed >= messagesToRemove) {
104-
trimmedMessages.add(message);
105-
}
106-
else {
107-
removed++;
108-
}
123+
final List<Message> toDelete;
124+
125+
final List<Message> toAdd;
126+
127+
MessageChanges(List<Message> toDelete, List<Message> toAdd) {
128+
this.toDelete = toDelete;
129+
this.toAdd = toAdd;
109130
}
110131

111-
return trimmedMessages;
112132
}
113133

114134
public static Builder builder() {

spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,5 +152,62 @@ void messagesWithNullElementsNotAllowed() {
152152
.isInstanceOf(IllegalArgumentException.class)
153153
.hasMessageContaining("messages cannot contain null elements");
154154
}
155+
// refresh test code
156+
@Test
157+
void refreshAddsMessagesToNewConversation() {
158+
String conversationId = UUID.randomUUID().toString();
159+
List<Message> messagesToAdd = List.of(new UserMessage("Hello"));
160+
this.chatMemoryRepository.refresh(conversationId, List.of(), messagesToAdd);
161+
162+
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).isEqualTo(messagesToAdd);
163+
}
164+
165+
@Test
166+
void refreshAddsMessagesToExistingConversation() {
167+
String conversationId = UUID.randomUUID().toString();
168+
Message initialMessage = new UserMessage("Initial");
169+
this.chatMemoryRepository.saveAll(conversationId, List.of(initialMessage));
170+
171+
Message newMessage = new AssistantMessage("New");
172+
this.chatMemoryRepository.refresh(conversationId, List.of(), List.of(newMessage));
173+
174+
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactly(initialMessage, newMessage);
175+
}
176+
177+
@Test
178+
void refreshDeletesMessagesFromExistingConversation() {
179+
String conversationId = UUID.randomUUID().toString();
180+
Message messageToKeep = new UserMessage("Keep");
181+
Message messageToDelete = new AssistantMessage("Delete");
182+
this.chatMemoryRepository.saveAll(conversationId, List.of(messageToKeep, messageToDelete));
183+
184+
this.chatMemoryRepository.refresh(conversationId, List.of(messageToDelete), List.of());
185+
186+
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactly(messageToKeep);
187+
}
188+
189+
@Test
190+
void refreshAddsAndDeletesMessages() {
191+
String conversationId = UUID.randomUUID().toString();
192+
Message messageToKeep = new UserMessage("Keep");
193+
Message messageToDelete = new AssistantMessage("Delete");
194+
this.chatMemoryRepository.saveAll(conversationId, List.of(messageToKeep, messageToDelete));
195+
196+
Message messageToAdd = new UserMessage("Add");
197+
this.chatMemoryRepository.refresh(conversationId, List.of(messageToDelete), List.of(messageToAdd));
198+
199+
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactlyInAnyOrder(messageToKeep, messageToAdd);
200+
}
201+
202+
@Test
203+
void refreshWithEmptyChangesDoesNothing() {
204+
String conversationId = UUID.randomUUID().toString();
205+
List<Message> initialMessages = List.of(new UserMessage("Hello"));
206+
this.chatMemoryRepository.saveAll(conversationId, initialMessages);
207+
208+
this.chatMemoryRepository.refresh(conversationId, List.of(), List.of());
209+
210+
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).isEqualTo(initialMessages);
211+
}
155212

156213
}

0 commit comments

Comments
 (0)