Skip to content

Commit 9289ef4

Browse files
committed
refactor(ai-client-chat): optimize conversation memory handling in PromptChatMemoryAdvisor
- When the memory message is empty, the `ChatClientRequest` is returned intact. - Add new user message to conversation memory at the beginning of the process - Reorder and optimize the steps for processing memory messages Signed-off-by: Ahoo Wang <[email protected]>
1 parent 49df56c commit 9289ef4

File tree

2 files changed

+13
-30
lines changed

2 files changed

+13
-30
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,29 +113,28 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai
113113
logger.debug("[PromptChatMemoryAdvisor.before] Memory before processing for conversationId={}: {}",
114114
conversationId, memoryMessages);
115115

116-
// 2. Process memory messages as a string.
116+
// 2. Add the new user message to the conversation memory.
117+
UserMessage userMessage = chatClientRequest.prompt().getUserMessage();
118+
this.chatMemory.add(conversationId, userMessage);
119+
// 3. Check if memory is empty and return the request as is.
120+
if (memoryMessages.isEmpty()) {
121+
return chatClientRequest;
122+
}
123+
// 4. Process memory messages as a string.
117124
String memory = memoryMessages.stream()
118125
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
119126
.map(m -> m.getMessageType() + ":" + m.getText())
120127
.collect(Collectors.joining(System.lineSeparator()));
121128

122-
// 3. Augment the system message.
129+
// 5. Augment the system message.
123130
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
124131
String augmentedSystemText = this.systemPromptTemplate
125132
.render(Map.of("instructions", systemMessage.getText(), "memory", memory));
126133

127-
// 4. Create a new request with the augmented system message.
128-
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
134+
// 6. Create a new request with the augmented system message.
135+
return chatClientRequest.mutate()
129136
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
130137
.build();
131-
132-
// 5. Add all user messages from the current prompt to memory (after system
133-
// message is generated)
134-
// 4. Add the new user message to the conversation memory.
135-
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
136-
this.chatMemory.add(conversationId, userMessage);
137-
138-
return processedChatClientRequest;
139138
}
140139

141140
@Override

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,7 @@ public void promptChatMemory() {
9797

9898
// Capture and verify the system message instructions
9999
Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
100-
assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("""
101-
Default system text.
102-
103-
Use the conversation memory from the MEMORY section to provide accurate answers.
104-
105-
---------------------
106-
MEMORY:
107-
---------------------
108-
""");
100+
assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("Default system text.");
109101
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
110102

111103
// Capture and verify the user message instructions
@@ -175,15 +167,7 @@ public void streamingPromptChatMemory() {
175167

176168
// Capture and verify the system message instructions
177169
Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
178-
assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("""
179-
Default system text.
180-
181-
Use the conversation memory from the MEMORY section to provide accurate answers.
182-
183-
---------------------
184-
MEMORY:
185-
---------------------
186-
""");
170+
assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("Default system text.");
187171
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
188172

189173
// Capture and verify the user message instructions

0 commit comments

Comments
 (0)