18
18
19
19
import java .util .ArrayList ;
20
20
import java .util .HashSet ;
21
+ import java .util .LinkedHashSet ;
21
22
import java .util .List ;
22
23
import java .util .Set ;
23
24
@@ -61,8 +62,10 @@ public void add(String conversationId, List<Message> messages) {
61
62
Assert .noNullElements (messages , "messages cannot contain null elements" );
62
63
63
64
List <Message > memoryMessages = this .chatMemoryRepository .findByConversationId (conversationId );
64
- List <Message > processedMessages = process (memoryMessages , messages );
65
- this .chatMemoryRepository .saveAll (conversationId , processedMessages );
65
+ MessageChanges changes = process (memoryMessages , messages );
66
+ if (!changes .toDelete .isEmpty () || !changes .toAdd .isEmpty ()) {
67
+ this .chatMemoryRepository .refresh (conversationId , changes .toDelete , changes .toAdd );
68
+ }
66
69
}
67
70
68
71
@ Override
@@ -77,38 +80,55 @@ public void clear(String conversationId) {
77
80
this .chatMemoryRepository .deleteByConversationId (conversationId );
78
81
}
79
82
80
- private List <Message > process (List <Message > memoryMessages , List <Message > newMessages ) {
81
- List <Message > processedMessages = new ArrayList <>();
83
+ private MessageChanges process (List <Message > memoryMessages , List <Message > newMessages ) {
84
+ Set <Message > originalMessageSet = new LinkedHashSet <>(memoryMessages );
85
+ List <Message > uniqueNewMessages = newMessages .stream ()
86
+ .filter (msg -> !originalMessageSet .contains (msg ))
87
+ .toList ();
88
+ boolean hasNewSystemMessage = uniqueNewMessages .stream ().anyMatch (SystemMessage .class ::isInstance );
89
+
90
+ List <Message > finalMessages = new ArrayList <>();
91
+ if (hasNewSystemMessage ) {
92
+ memoryMessages .stream ().filter (msg -> !(msg instanceof SystemMessage )).forEach (finalMessages ::add );
93
+ } else {
94
+ finalMessages .addAll (memoryMessages );
95
+ }
96
+ finalMessages .addAll (uniqueNewMessages );
97
+
98
+ if (finalMessages .size () > this .maxMessages ) {
99
+ List <Message > trimmedMessages = new ArrayList <>();
100
+ int messagesToRemove = finalMessages .size () - this .maxMessages ;
101
+ int removed = 0 ;
102
+ for (Message message : finalMessages ) {
103
+ if (message instanceof SystemMessage || removed >= messagesToRemove ) {
104
+ trimmedMessages .add (message );
105
+ } else {
106
+ removed ++;
107
+ }
108
+ }
109
+ finalMessages = trimmedMessages ;
110
+ }
82
111
83
- Set <Message > memoryMessagesSet = new HashSet <>(memoryMessages );
84
- boolean hasNewSystemMessage = newMessages .stream ()
85
- .filter (SystemMessage .class ::isInstance )
86
- .anyMatch (message -> !memoryMessagesSet .contains (message ));
112
+ Set <Message > finalMessageSet = new LinkedHashSet <>(finalMessages );
87
113
88
- memoryMessages .stream ()
89
- .filter (message -> !(hasNewSystemMessage && message instanceof SystemMessage ))
90
- .forEach (processedMessages ::add );
114
+ List <Message > toDelete = originalMessageSet .stream ().filter (m -> !finalMessageSet .contains (m )).toList ();
91
115
92
- processedMessages . addAll ( newMessages );
116
+ List < Message > toAdd = finalMessageSet . stream (). filter ( m -> ! originalMessageSet . contains ( m )). toList ( );
93
117
94
- if (processedMessages .size () <= this .maxMessages ) {
95
- return processedMessages ;
96
- }
118
+ return new MessageChanges (toDelete , toAdd );
119
+ }
97
120
98
- int messagesToRemove = processedMessages . size () - this . maxMessages ;
121
+ private static class MessageChanges {
99
122
100
- List <Message > trimmedMessages = new ArrayList <>();
101
- int removed = 0 ;
102
- for (Message message : processedMessages ) {
103
- if (message instanceof SystemMessage || removed >= messagesToRemove ) {
104
- trimmedMessages .add (message );
105
- }
106
- else {
107
- removed ++;
108
- }
123
+ final List <Message > toDelete ;
124
+
125
+ final List <Message > toAdd ;
126
+
127
+ MessageChanges (List <Message > toDelete , List <Message > toAdd ) {
128
+ this .toDelete = toDelete ;
129
+ this .toAdd = toAdd ;
109
130
}
110
131
111
- return trimmedMessages ;
112
132
}
113
133
114
134
public static Builder builder () {
0 commit comments