Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,56 @@ void shouldHandleNonExistentConversation() {
testHandleNonExistentConversation();
}

@Test
void shouldStoreCompleteContentInStreamingMode() {
// Arrange
String conversationId = "streaming-test-" + System.currentTimeMillis();
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

// Create MessageChatMemoryAdvisor with the conversation ID
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory)
.conversationId(conversationId)
.build();

ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();

// Act - Use streaming API
String userInput = "Tell me a short joke about programming";

// Collect the streaming responses
List<String> streamedResponses = new ArrayList<>();
chatClient.prompt()
.user(userInput)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
.content()
.collectList()
.block();

// Wait a moment to ensure all processing is complete
try {
Thread.sleep(500);
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
}

// Assert - Check that the memory contains the complete content
List<Message> memoryMessages = chatMemory.get(conversationId);

// Should have at least 2 messages (user + assistant)
assertThat(memoryMessages).hasSizeGreaterThanOrEqualTo(2);

// First message should be the user message
assertThat(memoryMessages.get(0).getText()).isEqualTo(userInput);

// Last message should be the assistant's response and should have content
Message assistantMessage = memoryMessages.get(memoryMessages.size() - 1);
assertThat(assistantMessage.getText()).isNotEmpty();

logger.info("Assistant response stored in memory: {}", assistantMessage.getText());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@
import java.util.List;

import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;

import org.springframework.ai.chat.client.ChatClientMessageAggregator;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
Expand Down Expand Up @@ -109,6 +113,21 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
return chatClientResponse;
}

@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
// Get the scheduler from BaseAdvisor
Scheduler scheduler = this.getScheduler();

// Process the request with the before method
return Mono.just(chatClientRequest)
.publishOn(scheduler)
.map(request -> this.before(request, streamAdvisorChain))
.flatMapMany(streamAdvisorChain::nextStream)
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
response -> this.after(response, streamAdvisorChain)));
}

public static Builder builder(ChatMemory chatMemory) {
return new Builder(chatMemory);
}
Expand Down