diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index de88715e896..8a12f2af250 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -40,6 +41,8 @@ import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.util.Assert; @@ -142,17 +145,11 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { List assistantMessages = new ArrayList<>(); // Handle streaming case where we have a single result - if (chatClientResponse.chatResponse() != null && chatClientResponse.chatResponse().getResult() != null - && chatClientResponse.chatResponse().getResult().getOutput() != null) { - assistantMessages = List.of((Message) chatClientResponse.chatResponse().getResult().getOutput()); - } - else if (chatClientResponse.chatResponse() != null) { - assistantMessages = chatClientResponse.chatResponse() - .getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - } + Optional.of(chatClientResponse) + .map(ChatClientResponse::chatResponse) + .map(ChatResponse::getResult) + .map(Generation::getOutput) + .ifPresent(assistantMessages::add); if (!assistantMessages.isEmpty()) { this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),