|
28 | 28 | import org.springframework.ai.chat.client.ChatClientResponse; |
29 | 29 | import org.springframework.ai.chat.client.advisor.api.Advisor; |
30 | 30 | import org.springframework.ai.chat.client.advisor.api.AdvisorChain; |
| 31 | +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; |
31 | 32 | import org.springframework.ai.chat.memory.ChatMemory; |
32 | 33 | import org.springframework.ai.chat.messages.Message; |
33 | 34 | import org.springframework.ai.chat.messages.MessageType; |
34 | 35 | import org.springframework.ai.chat.messages.SystemMessage; |
35 | 36 | import org.springframework.ai.chat.messages.UserMessage; |
| 37 | +import org.springframework.ai.chat.model.MessageAggregator; |
36 | 38 | import org.springframework.ai.chat.prompt.PromptTemplate; |
37 | 39 |
|
| 40 | +import reactor.core.publisher.Flux; |
| 41 | +import reactor.core.publisher.Mono; |
| 42 | +import reactor.core.scheduler.Scheduler; |
| 43 | + |
38 | 44 | /** |
39 | 45 | * Memory is retrieved added into the prompt's system text. |
40 | 46 | * |
@@ -137,16 +143,39 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh |
137 | 143 | .map(g -> (Message) g.getOutput()) |
138 | 144 | .toList(); |
139 | 145 | } |
140 | | - this.getChatMemoryStore().add(this.doGetConversationId(chatClientResponse.context()), assistantMessages); |
141 | | - logger.debug("[PromptChatMemoryAdvisor.after] Added ASSISTANT messages to memory for conversationId={}: {}", |
142 | | - this.doGetConversationId(chatClientResponse.context()), assistantMessages); |
143 | | - List<Message> memoryMessages = this.getChatMemoryStore() |
144 | | - .get(this.doGetConversationId(chatClientResponse.context())); |
145 | | - logger.debug("[PromptChatMemoryAdvisor.after] Memory after ASSISTANT add for conversationId={}: {}", |
146 | | - this.doGetConversationId(chatClientResponse.context()), memoryMessages); |
| 146 | + // Handle streaming case where we have a single result |
| 147 | + else if (chatClientResponse.chatResponse() != null && chatClientResponse.chatResponse().getResult() != null |
| 148 | + && chatClientResponse.chatResponse().getResult().getOutput() != null) { |
| 149 | + assistantMessages = List.of((Message) chatClientResponse.chatResponse().getResult().getOutput()); |
| 150 | + } |
| 151 | + |
| 152 | + if (!assistantMessages.isEmpty()) { |
| 153 | + this.getChatMemoryStore().add(this.doGetConversationId(chatClientResponse.context()), assistantMessages); |
| 154 | + logger.debug("[PromptChatMemoryAdvisor.after] Added ASSISTANT messages to memory for conversationId={}: {}", |
| 155 | + this.doGetConversationId(chatClientResponse.context()), assistantMessages); |
| 156 | + List<Message> memoryMessages = this.getChatMemoryStore() |
| 157 | + .get(this.doGetConversationId(chatClientResponse.context())); |
| 158 | + logger.debug("[PromptChatMemoryAdvisor.after] Memory after ASSISTANT add for conversationId={}: {}", |
| 159 | + this.doGetConversationId(chatClientResponse.context()), memoryMessages); |
| 160 | + } |
147 | 161 | return chatClientResponse; |
148 | 162 | } |
149 | 163 |
|
| 164 | + @Override |
| 165 | + public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, |
| 166 | + StreamAdvisorChain streamAdvisorChain) { |
| 167 | + // Get the scheduler from BaseAdvisor |
| 168 | + Scheduler scheduler = this.getScheduler(); |
| 169 | + |
| 170 | + // Process the request with the before method |
| 171 | + return Mono.just(chatClientRequest) |
| 172 | + .publishOn(scheduler) |
| 173 | + .map(request -> this.before(request, streamAdvisorChain)) |
| 174 | + .flatMapMany(streamAdvisorChain::nextStream) |
| 175 | + .transform(flux -> new MessageAggregator().aggregateChatClientResponse(flux, |
| 176 | + response -> this.after(response, streamAdvisorChain))); |
| 177 | + } |
| 178 | + |
150 | 179 | public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<ChatMemory, Builder> { |
151 | 180 |
|
152 | 181 | private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE; |
|
0 commit comments