Skip to content

refactor(memory): Improve Chat Memory logic and efficiency #4065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,15 @@ public interface ChatMemoryRepository {

void deleteByConversationId(String conversationId);

/**
* Atomically removes the messages in {@code deletes} and adds the messages in {@code adds}
* for the given conversation ID. This provides a more efficient way to update
* the memory than reading the entire history and overwriting it.
*
* @param conversationId The ID of the conversation to update.
* @param deletes A list of messages to be removed from the memory.
* @param adds A list of new messages to be added to the memory.
*/
void refresh(String conversationId, List<Message> deletes, List<Message> adds);

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,17 @@ public void deleteByConversationId(String conversationId) {
this.chatMemoryStore.remove(conversationId);
}

@Override
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
this.chatMemoryStore.compute(conversationId, (key, currentMessages) -> {
if (currentMessages == null) {
return new ArrayList<>(adds);
}
List<Message> updatedMessages = new ArrayList<>(currentMessages);
updatedMessages.removeAll(deletes);
updatedMessages.addAll(adds);
return updatedMessages;
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -61,8 +62,10 @@ public void add(String conversationId, List<Message> messages) {
Assert.noNullElements(messages, "messages cannot contain null elements");

List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);
List<Message> processedMessages = process(memoryMessages, messages);
this.chatMemoryRepository.saveAll(conversationId, processedMessages);
MessageChanges changes = process(memoryMessages, messages);
if (!changes.toDelete.isEmpty() || !changes.toAdd.isEmpty()) {
this.chatMemoryRepository.refresh(conversationId, changes.toDelete, changes.toAdd);
}
}

@Override
Expand All @@ -77,38 +80,55 @@ public void clear(String conversationId) {
this.chatMemoryRepository.deleteByConversationId(conversationId);
}

private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
List<Message> processedMessages = new ArrayList<>();
private MessageChanges process(List<Message> memoryMessages, List<Message> newMessages) {
Set<Message> originalMessageSet = new LinkedHashSet<>(memoryMessages);
List<Message> uniqueNewMessages = newMessages.stream()
.filter(msg -> !originalMessageSet.contains(msg))
.toList();
boolean hasNewSystemMessage = uniqueNewMessages.stream().anyMatch(SystemMessage.class::isInstance);

List<Message> finalMessages = new ArrayList<>();
if (hasNewSystemMessage) {
memoryMessages.stream().filter(msg -> !(msg instanceof SystemMessage)).forEach(finalMessages::add);
} else {
finalMessages.addAll(memoryMessages);
}
finalMessages.addAll(uniqueNewMessages);

if (finalMessages.size() > this.maxMessages) {
List<Message> trimmedMessages = new ArrayList<>();
int messagesToRemove = finalMessages.size() - this.maxMessages;
int removed = 0;
for (Message message : finalMessages) {
if (message instanceof SystemMessage || removed >= messagesToRemove) {
trimmedMessages.add(message);
} else {
removed++;
}
}
finalMessages = trimmedMessages;
}

Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);
boolean hasNewSystemMessage = newMessages.stream()
.filter(SystemMessage.class::isInstance)
.anyMatch(message -> !memoryMessagesSet.contains(message));
Set<Message> finalMessageSet = new LinkedHashSet<>(finalMessages);

memoryMessages.stream()
.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))
.forEach(processedMessages::add);
List<Message> toDelete = originalMessageSet.stream().filter(m -> !finalMessageSet.contains(m)).toList();

processedMessages.addAll(newMessages);
List<Message> toAdd = finalMessageSet.stream().filter(m -> !originalMessageSet.contains(m)).toList();

if (processedMessages.size() <= this.maxMessages) {
return processedMessages;
}
return new MessageChanges(toDelete, toAdd);
}

int messagesToRemove = processedMessages.size() - this.maxMessages;
private static class MessageChanges {

List<Message> trimmedMessages = new ArrayList<>();
int removed = 0;
for (Message message : processedMessages) {
if (message instanceof SystemMessage || removed >= messagesToRemove) {
trimmedMessages.add(message);
}
else {
removed++;
}
final List<Message> toDelete;

final List<Message> toAdd;

MessageChanges(List<Message> toDelete, List<Message> toAdd) {
this.toDelete = toDelete;
this.toAdd = toAdd;
}

return trimmedMessages;
}

public static Builder builder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,62 @@ void messagesWithNullElementsNotAllowed() {
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("messages cannot contain null elements");
}
// refresh test code
@Test
void refreshAddsMessagesToNewConversation() {
String conversationId = UUID.randomUUID().toString();
List<Message> messagesToAdd = List.of(new UserMessage("Hello"));
this.chatMemoryRepository.refresh(conversationId, List.of(), messagesToAdd);

assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).isEqualTo(messagesToAdd);
}

@Test
void refreshAddsMessagesToExistingConversation() {
String conversationId = UUID.randomUUID().toString();
Message initialMessage = new UserMessage("Initial");
this.chatMemoryRepository.saveAll(conversationId, List.of(initialMessage));

Message newMessage = new AssistantMessage("New");
this.chatMemoryRepository.refresh(conversationId, List.of(), List.of(newMessage));

assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactly(initialMessage, newMessage);
}

@Test
void refreshDeletesMessagesFromExistingConversation() {
String conversationId = UUID.randomUUID().toString();
Message messageToKeep = new UserMessage("Keep");
Message messageToDelete = new AssistantMessage("Delete");
this.chatMemoryRepository.saveAll(conversationId, List.of(messageToKeep, messageToDelete));

this.chatMemoryRepository.refresh(conversationId, List.of(messageToDelete), List.of());

assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactly(messageToKeep);
}

@Test
void refreshAddsAndDeletesMessages() {
String conversationId = UUID.randomUUID().toString();
Message messageToKeep = new UserMessage("Keep");
Message messageToDelete = new AssistantMessage("Delete");
this.chatMemoryRepository.saveAll(conversationId, List.of(messageToKeep, messageToDelete));

Message messageToAdd = new UserMessage("Add");
this.chatMemoryRepository.refresh(conversationId, List.of(messageToDelete), List.of(messageToAdd));

assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).containsExactlyInAnyOrder(messageToKeep, messageToAdd);
}

@Test
void refreshWithEmptyChangesDoesNothing() {
String conversationId = UUID.randomUUID().toString();
List<Message> initialMessages = List.of(new UserMessage("Hello"));
this.chatMemoryRepository.saveAll(conversationId, initialMessages);

this.chatMemoryRepository.refresh(conversationId, List.of(), List.of());

assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).isEqualTo(initialMessages);
}

}