Skip to content

Commit 5bad0c4

Browse files
committed
fix failing test
1 parent de75db8 commit 5bad0c4

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.springframework.ai.chat.messages.MessageType;
3636
import org.springframework.ai.chat.messages.SystemMessage;
3737
import org.springframework.ai.chat.messages.UserMessage;
38-
import org.springframework.ai.chat.model.MessageAggregator;
3938
import org.springframework.ai.chat.prompt.PromptTemplate;
4039

4140
/**
@@ -111,43 +110,43 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
111110
streamAdvisorChain, this::before);
112111

113112
// Ensure memory is updated after each streamed response
114-
return chatClientResponses.doOnNext(this::after)
115-
.transform(responses -> new MessageAggregator().aggregateChatClientResponse(responses, null));
113+
return chatClientResponses.doOnNext(this::after);
116114
}
117115

118116
@Override
119117
protected ChatClientRequest before(ChatClientRequest chatClientRequest) {
120118
String conversationId = this.doGetConversationId(chatClientRequest.context());
121119

122-
// 1. Add all user messages from the current prompt to memory
123-
List<UserMessage> userMessages = chatClientRequest.prompt().getUserMessages();
124-
for (UserMessage userMessage : userMessages) {
125-
this.getChatMemoryStore().add(conversationId, userMessage);
126-
logger.debug("[PromptChatMemoryAdvisor.before] Added USER message to memory for conversationId={}: {}",
127-
conversationId, userMessage.getText());
128-
}
129-
130-
// 2. Retrieve the chat memory for the current conversation.
120+
// 1. Retrieve the chat memory for the current conversation.
131121
List<Message> memoryMessages = this.getChatMemoryStore().get(conversationId);
132-
logger.debug("[PromptChatMemoryAdvisor.before] Memory after USER add for conversationId={}: {}", conversationId,
133-
memoryMessages);
122+
logger.debug("[PromptChatMemoryAdvisor.before] Memory before processing for conversationId={}: {}",
123+
conversationId, memoryMessages);
134124

135-
// 3. Process memory messages as a string.
125+
// 2. Process memory messages as a string.
136126
String memory = memoryMessages.stream()
137127
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
138128
.map(m -> m.getMessageType() + ":" + m.getText())
139129
.collect(Collectors.joining(System.lineSeparator()));
140130

141-
// 4. Augment the system message.
131+
// 3. Augment the system message.
142132
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
143133
String augmentedSystemText = this.systemPromptTemplate
144134
.render(Map.of("instructions", systemMessage.getText(), "memory", memory));
145135

146-
// 5. Create a new request with the augmented system message.
136+
// 4. Create a new request with the augmented system message.
147137
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
148138
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
149139
.build();
150140

141+
// 5. Add all user messages from the current prompt to memory (after system
142+
// message is generated)
143+
List<UserMessage> userMessages = chatClientRequest.prompt().getUserMessages();
144+
for (UserMessage userMessage : userMessages) {
145+
this.getChatMemoryStore().add(conversationId, userMessage);
146+
logger.debug("[PromptChatMemoryAdvisor.before] Added USER message to memory for conversationId={}: {}",
147+
conversationId, userMessage.getText());
148+
}
149+
151150
return processedChatClientRequest;
152151
}
153152

0 commit comments

Comments
 (0)