diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java index c21d566f69e..45f63ac714b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java @@ -138,4 +138,56 @@ void shouldHandleNonExistentConversation() { testHandleNonExistentConversation(); } + @Test + void shouldStoreCompleteContentInStreamingMode() { + // Arrange + String conversationId = "streaming-test-" + System.currentTimeMillis(); + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + // Create MessageChatMemoryAdvisor with the conversation ID + MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory) + .conversationId(conversationId) + .build(); + + ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build(); + + // Act - Use streaming API + String userInput = "Tell me a short joke about programming"; + + // Collect the streaming responses + List streamedResponses = new ArrayList<>(); + chatClient.prompt() + .user(userInput) + .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)) + .stream() + .content() + .collectList() + .block(); + + // Wait a moment to ensure all processing is complete + try { + Thread.sleep(500); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Assert - Check that the memory contains the complete content + List memoryMessages = chatMemory.get(conversationId); + + // Should have at least 2 messages (user + assistant) + assertThat(memoryMessages).hasSizeGreaterThanOrEqualTo(2); + + // First message should be the user message + assertThat(memoryMessages.get(0).getText()).isEqualTo(userInput); + + // Last message should be the assistant's response and should have content + Message assistantMessage = memoryMessages.get(memoryMessages.size() - 1); + assertThat(assistantMessage.getText()).isNotEmpty(); + + logger.info("Assistant response stored in memory: {}", assistantMessage.getText()); + } + } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index b2b0278f2e5..835beb8aab1 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -20,14 +20,18 @@ import java.util.List; import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; +import org.springframework.ai.chat.client.ChatClientMessageAggregator; import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.AdvisorChain; import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -109,6 +113,21 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh return chatClientResponse; } + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + // Get the scheduler from BaseAdvisor + Scheduler scheduler = this.getScheduler(); + + // Process the request with the before method + return Mono.just(chatClientRequest) + .publishOn(scheduler) + .map(request -> this.before(request, streamAdvisorChain)) + .flatMapMany(streamAdvisorChain::nextStream) + .transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux, + response -> this.after(response, streamAdvisorChain))); + } + public static Builder builder(ChatMemory chatMemory) { return new Builder(chatMemory); }