|
35 | 35 | import org.springframework.ai.chat.messages.MessageType; |
36 | 36 | import org.springframework.ai.chat.messages.SystemMessage; |
37 | 37 | import org.springframework.ai.chat.messages.UserMessage; |
38 | | -import org.springframework.ai.chat.model.MessageAggregator; |
39 | 38 | import org.springframework.ai.chat.prompt.PromptTemplate; |
40 | 39 |
|
41 | 40 | /** |
@@ -111,43 +110,43 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest |
111 | 110 | streamAdvisorChain, this::before); |
112 | 111 |
|
113 | 112 | // Ensure memory is updated after each streamed response |
114 | | - return chatClientResponses.doOnNext(this::after) |
115 | | - .transform(responses -> new MessageAggregator().aggregateChatClientResponse(responses, null)); |
| 113 | + return chatClientResponses.doOnNext(this::after); |
116 | 114 | } |
117 | 115 |
|
118 | 116 | @Override |
119 | 117 | protected ChatClientRequest before(ChatClientRequest chatClientRequest) { |
120 | 118 | String conversationId = this.doGetConversationId(chatClientRequest.context()); |
121 | 119 |
|
122 | | - // 1. Add all user messages from the current prompt to memory |
123 | | - List<UserMessage> userMessages = chatClientRequest.prompt().getUserMessages(); |
124 | | - for (UserMessage userMessage : userMessages) { |
125 | | - this.getChatMemoryStore().add(conversationId, userMessage); |
126 | | - logger.debug("[PromptChatMemoryAdvisor.before] Added USER message to memory for conversationId={}: {}", |
127 | | - conversationId, userMessage.getText()); |
128 | | - } |
129 | | - |
130 | | - // 2. Retrieve the chat memory for the current conversation. |
| 120 | + // 1. Retrieve the chat memory for the current conversation. |
131 | 121 | List<Message> memoryMessages = this.getChatMemoryStore().get(conversationId); |
132 | | - logger.debug("[PromptChatMemoryAdvisor.before] Memory after USER add for conversationId={}: {}", conversationId, |
133 | | - memoryMessages); |
| 122 | + logger.debug("[PromptChatMemoryAdvisor.before] Memory before processing for conversationId={}: {}", |
| 123 | + conversationId, memoryMessages); |
134 | 124 |
|
135 | | - // 3. Process memory messages as a string. |
| 125 | + // 2. Process memory messages as a string. |
136 | 126 | String memory = memoryMessages.stream() |
137 | 127 | .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) |
138 | 128 | .map(m -> m.getMessageType() + ":" + m.getText()) |
139 | 129 | .collect(Collectors.joining(System.lineSeparator())); |
140 | 130 |
|
141 | | - // 4. Augment the system message. |
| 131 | + // 3. Augment the system message. |
142 | 132 | SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage(); |
143 | 133 | String augmentedSystemText = this.systemPromptTemplate |
144 | 134 | .render(Map.of("instructions", systemMessage.getText(), "memory", memory)); |
145 | 135 |
|
146 | | - // 5. Create a new request with the augmented system message. |
| 136 | + // 4. Create a new request with the augmented system message. |
147 | 137 | ChatClientRequest processedChatClientRequest = chatClientRequest.mutate() |
148 | 138 | .prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText)) |
149 | 139 | .build(); |
150 | 140 |
|
| 141 | + // 5. Add all user messages from the current prompt to memory (after system |
| 142 | + // message is generated) |
| 143 | + List<UserMessage> userMessages = chatClientRequest.prompt().getUserMessages(); |
| 144 | + for (UserMessage userMessage : userMessages) { |
| 145 | + this.getChatMemoryStore().add(conversationId, userMessage); |
| 146 | + logger.debug("[PromptChatMemoryAdvisor.before] Added USER message to memory for conversationId={}: {}", |
| 147 | + conversationId, userMessage.getText()); |
| 148 | + } |
| 149 | + |
151 | 150 | return processedChatClientRequest; |
152 | 151 | } |
153 | 152 |
|
|
0 commit comments