Skip to content

Commit 5f187bc

Browse files
committed
refactor: process method make return MessageChange
Signed-off-by: potato <[email protected]>
1 parent 16812f2 commit 5f187bc

File tree

1 file changed

+52
-27
lines changed

1 file changed

+52
-27
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.HashSet;
21+
import java.util.LinkedHashSet;
2122
import java.util.List;
2223
import java.util.Set;
2324

@@ -61,8 +62,10 @@ public void add(String conversationId, List<Message> messages) {
6162
Assert.noNullElements(messages, "messages cannot contain null elements");
6263

6364
List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);
64-
List<Message> processedMessages = process(memoryMessages, messages);
65-
this.chatMemoryRepository.saveAll(conversationId, processedMessages);
65+
MessageChanges changes = process(memoryMessages, messages);
66+
if (!changes.toDelete.isEmpty() || !changes.toAdd.isEmpty()) {
67+
this.chatMemoryRepository.refresh(conversationId, changes.toDelete, changes.toAdd);
68+
}
6669
}
6770

6871
@Override
@@ -77,38 +80,60 @@ public void clear(String conversationId) {
7780
this.chatMemoryRepository.deleteByConversationId(conversationId);
7881
}
7982

80-
private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
81-
List<Message> processedMessages = new ArrayList<>();
83+
private MessageChanges process(List<Message> memoryMessages, List<Message> newMessages) {
84+
Set<Message> originalMessageSet = new LinkedHashSet<>(memoryMessages);
85+
List<Message> uniqueNewMessages = newMessages.stream()
86+
.filter(msg -> !originalMessageSet.contains(msg))
87+
.toList();
88+
boolean hasNewSystemMessage = uniqueNewMessages.stream()
89+
.anyMatch(SystemMessage.class::isInstance);
90+
91+
List<Message> finalMessages = new ArrayList<>();
92+
if(hasNewSystemMessage) {
93+
memoryMessages.stream()
94+
.filter(msg -> !(msg instanceof SystemMessage))
95+
.forEach(finalMessages::add);
96+
finalMessages.addAll(uniqueNewMessages);
97+
} else {
98+
finalMessages.addAll(memoryMessages);
99+
finalMessages.addAll(uniqueNewMessages);
100+
}
82101

83-
Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);
84-
boolean hasNewSystemMessage = newMessages.stream()
85-
.filter(SystemMessage.class::isInstance)
86-
.anyMatch(message -> !memoryMessagesSet.contains(message));
102+
if (finalMessages.size() > this.maxMessages) {
103+
List<Message> trimmedMessages = new ArrayList<>();
104+
int messagesToRemove = finalMessages.size() - this.maxMessages;
105+
int removed = 0;
106+
for (Message message : finalMessages) {
107+
if (message instanceof SystemMessage || removed >= messagesToRemove) {
108+
trimmedMessages.add(message);
109+
} else {
110+
removed++;
111+
}
112+
}
113+
finalMessages = trimmedMessages;
114+
}
87115

88-
memoryMessages.stream()
89-
.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))
90-
.forEach(processedMessages::add);
116+
Set<Message> finalMessageSet = new LinkedHashSet<>(finalMessages);
91117

92-
processedMessages.addAll(newMessages);
118+
List<Message> toDelete = originalMessageSet.stream()
119+
.filter(m -> !finalMessageSet.contains(m))
120+
.toList();
93121

94-
if (processedMessages.size() <= this.maxMessages) {
95-
return processedMessages;
96-
}
122+
List<Message> toAdd = finalMessageSet.stream()
123+
.filter(m -> !originalMessageSet.contains(m))
124+
.toList();
97125

98-
int messagesToRemove = processedMessages.size() - this.maxMessages;
126+
return new MessageChanges(toDelete, toAdd);
127+
}
99128

100-
List<Message> trimmedMessages = new ArrayList<>();
101-
int removed = 0;
102-
for (Message message : processedMessages) {
103-
if (message instanceof SystemMessage || removed >= messagesToRemove) {
104-
trimmedMessages.add(message);
105-
}
106-
else {
107-
removed++;
108-
}
109-
}
129+
private static class MessageChanges {
130+
final List<Message> toDelete;
131+
final List<Message> toAdd;
110132

111-
return trimmedMessages;
133+
MessageChanges(List<Message> toDelete, List<Message> toAdd) {
134+
this.toDelete = toDelete;
135+
this.toAdd = toAdd;
136+
}
112137
}
113138

114139
public static Builder builder() {

0 commit comments

Comments
 (0)