1818
1919import java .util .ArrayList ;
2020import java .util .HashSet ;
21+ import java .util .LinkedHashSet ;
2122import java .util .List ;
2223import java .util .Set ;
2324
@@ -61,8 +62,10 @@ public void add(String conversationId, List<Message> messages) {
6162 Assert .noNullElements (messages , "messages cannot contain null elements" );
6263
6364 List <Message > memoryMessages = this .chatMemoryRepository .findByConversationId (conversationId );
64- List <Message > processedMessages = process (memoryMessages , messages );
65- this .chatMemoryRepository .saveAll (conversationId , processedMessages );
65+ MessageChanges changes = process (memoryMessages , messages );
66+ if (!changes .toDelete .isEmpty () || !changes .toAdd .isEmpty ()) {
67+ this .chatMemoryRepository .refresh (conversationId , changes .toDelete , changes .toAdd );
68+ }
6669 }
6770
6871 @ Override
@@ -77,38 +80,60 @@ public void clear(String conversationId) {
7780 this .chatMemoryRepository .deleteByConversationId (conversationId );
7881 }
7982
80- private List <Message > process (List <Message > memoryMessages , List <Message > newMessages ) {
81- List <Message > processedMessages = new ArrayList <>();
83+ private MessageChanges process (List <Message > memoryMessages , List <Message > newMessages ) {
84+ Set <Message > originalMessageSet = new LinkedHashSet <>(memoryMessages );
85+ List <Message > uniqueNewMessages = newMessages .stream ()
86+ .filter (msg -> !originalMessageSet .contains (msg ))
87+ .toList ();
88+ boolean hasNewSystemMessage = uniqueNewMessages .stream ()
89+ .anyMatch (SystemMessage .class ::isInstance );
90+
91+ List <Message > finalMessages = new ArrayList <>();
92+ if (hasNewSystemMessage ) {
93+ memoryMessages .stream ()
94+ .filter (msg -> !(msg instanceof SystemMessage ))
95+ .forEach (finalMessages ::add );
96+ finalMessages .addAll (uniqueNewMessages );
97+ } else {
98+ finalMessages .addAll (memoryMessages );
99+ finalMessages .addAll (uniqueNewMessages );
100+ }
82101
83- Set <Message > memoryMessagesSet = new HashSet <>(memoryMessages );
84- boolean hasNewSystemMessage = newMessages .stream ()
85- .filter (SystemMessage .class ::isInstance )
86- .anyMatch (message -> !memoryMessagesSet .contains (message ));
102+ if (finalMessages .size () > this .maxMessages ) {
103+ List <Message > trimmedMessages = new ArrayList <>();
104+ int messagesToRemove = finalMessages .size () - this .maxMessages ;
105+ int removed = 0 ;
106+ for (Message message : finalMessages ) {
107+ if (message instanceof SystemMessage || removed >= messagesToRemove ) {
108+ trimmedMessages .add (message );
109+ } else {
110+ removed ++;
111+ }
112+ }
113+ finalMessages = trimmedMessages ;
114+ }
87115
88- memoryMessages .stream ()
89- .filter (message -> !(hasNewSystemMessage && message instanceof SystemMessage ))
90- .forEach (processedMessages ::add );
116+ Set <Message > finalMessageSet = new LinkedHashSet <>(finalMessages );
91117
92- processedMessages .addAll (newMessages );
118+ List <Message > toDelete = originalMessageSet .stream ()
119+ .filter (m -> !finalMessageSet .contains (m ))
120+ .toList ();
93121
94- if ( processedMessages . size () <= this . maxMessages ) {
95- return processedMessages ;
96- }
122+ List < Message > toAdd = finalMessageSet . stream ()
123+ . filter ( m -> ! originalMessageSet . contains ( m ))
124+ . toList ();
97125
98- int messagesToRemove = processedMessages .size () - this .maxMessages ;
126+ return new MessageChanges (toDelete , toAdd );
127+ }
99128
100- List <Message > trimmedMessages = new ArrayList <>();
101- int removed = 0 ;
102- for (Message message : processedMessages ) {
103- if (message instanceof SystemMessage || removed >= messagesToRemove ) {
104- trimmedMessages .add (message );
105- }
106- else {
107- removed ++;
108- }
109- }
129+ private static class MessageChanges {
130+ final List <Message > toDelete ;
131+ final List <Message > toAdd ;
110132
111- return trimmedMessages ;
133+ MessageChanges (List <Message > toDelete , List <Message > toAdd ) {
134+ this .toDelete = toDelete ;
135+ this .toAdd = toAdd ;
136+ }
112137 }
113138
114139 public static Builder builder () {
0 commit comments