Skip to content

Commit 867cc30

Browse files
committed
Fix VectorStoreChatMemoryAdvisor streaming bug
- Override adviseStream method in VectorStoreChatMemoryAdvisor to properly handle streaming responses - Add tests to verify the fix works with both normal and problematic streaming scenarios Fixes #3152 Signed-off-by: Mark Pollack <[email protected]>
1 parent 008a760 commit 867cc30

File tree

5 files changed

+308
-0
lines changed

5 files changed

+308
-0
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
import java.util.List;
2222
import java.util.Map;
2323

24+
import org.springframework.ai.chat.client.ChatClientMessageAggregator;
2425
import org.springframework.util.Assert;
26+
import reactor.core.publisher.Flux;
27+
import reactor.core.publisher.Mono;
2528
import reactor.core.scheduler.Scheduler;
2629

2730
import org.springframework.ai.chat.client.ChatClientRequest;
@@ -30,10 +33,12 @@
3033
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
3134
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
3235
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
36+
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
3337
import org.springframework.ai.chat.memory.ChatMemory;
3438
import org.springframework.ai.chat.messages.AssistantMessage;
3539
import org.springframework.ai.chat.messages.Message;
3640
import org.springframework.ai.chat.messages.MessageType;
41+
import org.springframework.ai.chat.messages.SystemMessage;
3742
import org.springframework.ai.chat.messages.UserMessage;
3843
import org.springframework.ai.chat.prompt.PromptTemplate;
3944
import org.springframework.ai.document.Document;
@@ -167,6 +172,20 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
167172
return chatClientResponse;
168173
}
169174

175+
@Override
176+
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
177+
StreamAdvisorChain streamAdvisorChain) {
178+
// Get the scheduler from BaseAdvisor
179+
Scheduler scheduler = this.getScheduler();
180+
// Process the request with the before method
181+
return Mono.just(chatClientRequest)
182+
.publishOn(scheduler)
183+
.map(request -> this.before(request, streamAdvisorChain))
184+
.flatMapMany(streamAdvisorChain::nextStream)
185+
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
186+
response -> this.after(response, streamAdvisorChain)));
187+
}
188+
170189
private List<Document> toDocuments(List<Message> messages, String conversationId) {
171190
List<Document> docs = messages.stream()
172191
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/AbstractChatMemoryAdvisorIT.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2323
import org.slf4j.Logger;
2424
import org.slf4j.LoggerFactory;
25+
import reactor.core.publisher.Flux;
2526

2627
import org.springframework.ai.chat.client.ChatClient;
2728
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
@@ -412,4 +413,74 @@ protected void testHandleMultipleMessagesInReactiveMode() {
412413
assertThat(memoryMessages.get(6).getText()).isEqualTo("What is my name and where do I live?");
413414
}
414415

416+
/**
417+
* Tests that the advisor correctly handles streaming responses and updates the
418+
* memory. This verifies that the adviseStream method in chat memory advisors is
419+
* working correctly.
420+
*/
421+
protected void testStreamingWithChatMemory() {
422+
// Arrange
423+
String conversationId = "streaming-conversation-" + System.currentTimeMillis();
424+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
425+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
426+
.build();
427+
428+
// Create advisor with the conversation ID
429+
var advisor = createAdvisor(chatMemory);
430+
431+
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
432+
433+
// Act - Send a message using streaming
434+
String initialQuestion = "My name is David and I live in Seattle.";
435+
436+
// Collect all streaming chunks
437+
List<String> streamingChunks = new ArrayList<>();
438+
Flux<String> responseStream = chatClient.prompt()
439+
.user(initialQuestion)
440+
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
441+
.stream()
442+
.content();
443+
444+
// Block and collect all streaming chunks
445+
responseStream.doOnNext(streamingChunks::add).blockLast();
446+
447+
// Join all chunks to get the complete response
448+
String completeResponse = String.join("", streamingChunks);
449+
450+
logger.info("Streaming response: {}", completeResponse);
451+
452+
// Verify memory contains the initial question and the response
453+
List<Message> memoryMessages = chatMemory.get(conversationId);
454+
assertThat(memoryMessages).hasSize(2); // 1 user message + 1 assistant response
455+
assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion);
456+
457+
// Send a follow-up question using streaming
458+
String followUpQuestion = "Where do I live?";
459+
460+
// Collect all streaming chunks for the follow-up
461+
List<String> followUpStreamingChunks = new ArrayList<>();
462+
Flux<String> followUpResponseStream = chatClient.prompt()
463+
.user(followUpQuestion)
464+
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
465+
.stream()
466+
.content();
467+
468+
// Block and collect all streaming chunks
469+
followUpResponseStream.doOnNext(followUpStreamingChunks::add).blockLast();
470+
471+
// Join all chunks to get the complete follow-up response
472+
String followUpCompleteResponse = String.join("", followUpStreamingChunks);
473+
474+
logger.info("Follow-up streaming response: {}", followUpCompleteResponse);
475+
476+
// Verify the follow-up response contains the location
477+
assertThat(followUpCompleteResponse).containsIgnoringCase("Seattle");
478+
479+
// Verify memory now contains all messages
480+
memoryMessages = chatMemory.get(conversationId);
481+
assertThat(memoryMessages).hasSize(4); // 2 user messages + 2 assistant responses
482+
assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion);
483+
assertThat(memoryMessages.get(2).getText()).isEqualTo(followUpQuestion);
484+
}
485+
415486
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,9 @@ void shouldStoreCompleteContentInStreamingMode() {
190190
logger.info("Assistant response stored in memory: {}", assistantMessage.getText());
191191
}
192192

193+
@Test
194+
void shouldHandleStreamingWithChatMemory() {
195+
testStreamingWithChatMemory();
196+
}
197+
193198
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/PromptChatMemoryAdvisorIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,9 @@ void shouldHandleMultipleUserMessagesInPrompt() {
135135
testMultipleUserMessagesInPrompt();
136136
}
137137

138+
@Test
139+
void shouldHandleStreamingWithChatMemory() {
140+
testStreamingWithChatMemory();
141+
}
142+
138143
}

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.testcontainers.containers.PostgreSQLContainer;
3131
import org.testcontainers.junit.jupiter.Container;
3232
import org.testcontainers.junit.jupiter.Testcontainers;
33+
import reactor.core.publisher.Flux;
3334

3435
import org.springframework.ai.chat.client.ChatClient;
3536
import org.springframework.ai.chat.client.advisor.vectorstore.VectorStoreChatMemoryAdvisor;
@@ -42,9 +43,11 @@
4243
import org.springframework.ai.chat.prompt.Prompt;
4344
import org.springframework.ai.document.Document;
4445
import org.springframework.ai.embedding.EmbeddingModel;
46+
import org.springframework.ai.vectorstore.SearchRequest;
4547
import org.springframework.jdbc.core.JdbcTemplate;
4648

4749
import static org.assertj.core.api.Assertions.assertThat;
50+
import static org.assertj.core.api.Assertions.fail;
4851
import static org.mockito.ArgumentMatchers.any;
4952
import static org.mockito.BDDMockito.given;
5053
import static org.mockito.Mockito.mock;
@@ -117,6 +120,78 @@ private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatM
117120
""");
118121
}
119122

123+
/**
124+
* Create a mock ChatModel that supports streaming responses for testing.
125+
* @return A mock ChatModel that returns a predefined streaming response
126+
*/
127+
private static @NotNull ChatModel chatModelWithStreamingSupport() {
128+
ChatModel chatModel = mock(ChatModel.class);
129+
130+
// Mock the regular call method
131+
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
132+
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
133+
Why don't scientists trust atoms?
134+
Because they make up everything!
135+
"""))));
136+
given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse);
137+
138+
// Mock the streaming method
139+
ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class);
140+
Flux<ChatResponse> streamingResponse = Flux.just(
141+
new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))),
142+
new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))),
143+
new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))),
144+
new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))),
145+
new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))),
146+
new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))),
147+
new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))),
148+
new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))),
149+
new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))),
150+
new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))));
151+
given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse);
152+
153+
return chatModel;
154+
}
155+
156+
/**
157+
* Create a mock ChatModel that simulates the problematic streaming behavior. This
158+
* mock includes a final empty message that triggers the bug in
159+
* VectorStoreChatMemoryAdvisor.
160+
* @return A mock ChatModel that returns a problematic streaming response
161+
*/
162+
private static @NotNull ChatModel chatModelWithProblematicStreamingBehavior() {
163+
ChatModel chatModel = mock(ChatModel.class);
164+
165+
// Mock the regular call method
166+
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
167+
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
168+
Why don't scientists trust atoms?
169+
Because they make up everything!
170+
"""))));
171+
given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse);
172+
173+
// Mock the streaming method with a problematic final message (empty content)
174+
// This simulates the real-world condition that triggers the bug
175+
ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class);
176+
Flux<ChatResponse> streamingResponse = Flux.just(
177+
new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))),
178+
new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))),
179+
new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))),
180+
new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))),
181+
new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))),
182+
new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))),
183+
new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))),
184+
new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))),
185+
new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))),
186+
new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))),
187+
// This final empty message triggers the bug in
188+
// VectorStoreChatMemoryAdvisor
189+
new ChatResponse(List.of(new Generation(new AssistantMessage("")))));
190+
given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse);
191+
192+
return chatModel;
193+
}
194+
120195
/**
121196
* Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
122197
* messages from the (gp)vector store.
@@ -182,6 +257,139 @@ void advisedChatShouldHaveSimilarMessagesFromVectorStoreWhenSystemMessageProvide
182257
""");
183258
}
184259

260+
/**
261+
* Test that streaming chats with {@link VectorStoreChatMemoryAdvisor} get advised
262+
* with similar messages from the vector store and properly handle streaming
263+
* responses.
264+
*
265+
* This test verifies that the fix for the bug reported in
266+
* https://github.com/spring-projects/spring-ai/issues/3152 works correctly. The
267+
* VectorStoreChatMemoryAdvisor now properly handles streaming responses and saves the
268+
* assistant's messages to the vector store.
269+
*/
270+
@Test
271+
void advisedStreamingChatShouldHaveSimilarMessagesFromVectorStore() throws Exception {
272+
// Create a ChatModel with streaming support
273+
ChatModel chatModel = chatModelWithStreamingSupport();
274+
275+
// Create the embedding model
276+
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
277+
278+
// Create and initialize the vector store
279+
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
280+
String conversationId = UUID.randomUUID().toString();
281+
initStore(store, conversationId);
282+
283+
// Create a chat client with the VectorStoreChatMemoryAdvisor
284+
ChatClient chatClient = ChatClient.builder(chatModel).build();
285+
286+
// Execute a streaming chat request
287+
Flux<String> responseStream = chatClient.prompt()
288+
.user("joke")
289+
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
290+
.param(ChatMemory.CONVERSATION_ID, conversationId))
291+
.stream()
292+
.content();
293+
294+
// Collect all streaming chunks
295+
List<String> streamingChunks = responseStream.collectList().block();
296+
297+
// Verify the streaming response
298+
assertThat(streamingChunks).isNotNull();
299+
String completeResponse = String.join("", streamingChunks);
300+
assertThat(completeResponse).contains("scientists", "atoms", "everything");
301+
302+
// Verify the request was properly advised with vector store content
303+
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
304+
verify(chatModel).stream(promptCaptor.capture());
305+
Prompt capturedPrompt = promptCaptor.getValue();
306+
assertThat(capturedPrompt.getInstructions().get(0)).isInstanceOf(SystemMessage.class);
307+
assertThat(capturedPrompt.getInstructions().get(0).getText()).isEqualToIgnoringWhitespace("""
308+
309+
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
310+
311+
---------------------
312+
LONG_TERM_MEMORY:
313+
Tell me a good joke
314+
Tell me a bad joke
315+
---------------------
316+
""");
317+
318+
// Verify that the assistant's response was properly added to the vector store
319+
// after
320+
// streaming completed
321+
// This verifies that the fix for the adviseStream implementation works correctly
322+
String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'";
323+
var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build();
324+
325+
List<Document> assistantDocuments = store.similaritySearch(searchRequest);
326+
327+
// With our fix, the assistant's response should be saved to the vector store
328+
assertThat(assistantDocuments).isNotEmpty();
329+
assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything");
330+
}
331+
332+
/**
333+
* Test that verifies the fix for the bug reported in
334+
* https://github.com/spring-projects/spring-ai/issues/3152. The
335+
* VectorStoreChatMemoryAdvisor now properly handles streaming responses with empty
336+
* messages by using ChatClientMessageAggregator to aggregate messages before calling
337+
* the after method.
338+
*/
339+
@Test
340+
void vectorStoreChatMemoryAdvisorShouldHandleEmptyMessagesInStream() throws Exception {
341+
// Create a ChatModel with problematic streaming behavior
342+
ChatModel chatModel = chatModelWithProblematicStreamingBehavior();
343+
344+
// Create the embedding model
345+
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
346+
347+
// Create and initialize the vector store
348+
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
349+
String conversationId = UUID.randomUUID().toString();
350+
initStore(store, conversationId);
351+
352+
// Create a chat client with the VectorStoreChatMemoryAdvisor
353+
ChatClient chatClient = ChatClient.builder(chatModel).build();
354+
355+
// Execute a streaming chat request
356+
// This should now succeed with our fix
357+
Flux<String> responseStream = chatClient.prompt()
358+
.user("joke")
359+
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
360+
.param(ChatMemory.CONVERSATION_ID, conversationId))
361+
.stream()
362+
.content();
363+
364+
// Collect all streaming chunks - this should no longer throw an exception
365+
List<String> streamingChunks = responseStream.collectList().block();
366+
367+
// Verify the streaming response
368+
assertThat(streamingChunks).isNotNull();
369+
String completeResponse = String.join("", streamingChunks);
370+
assertThat(completeResponse).contains("scientists", "atoms", "everything");
371+
372+
// Verify that the assistant's response was properly added to the vector store
373+
// This verifies that our fix works correctly
374+
String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'";
375+
var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build();
376+
377+
List<Document> assistantDocuments = store.similaritySearch(searchRequest);
378+
assertThat(assistantDocuments).isNotEmpty();
379+
assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything");
380+
}
381+
382+
/**
383+
* Helper method to get the root cause of an exception
384+
*/
385+
private Throwable getRootCause(Throwable throwable) {
386+
Throwable cause = throwable;
387+
while (cause.getCause() != null && cause.getCause() != cause) {
388+
cause = cause.getCause();
389+
}
390+
return cause;
391+
}
392+
185393
@SuppressWarnings("unchecked")
186394
private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
187395
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);

0 commit comments

Comments
 (0)