From fcb63d860a04735cfb040b4a17794dd2a79e72ae Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Thu, 18 Sep 2025 15:15:20 -0400 Subject: [PATCH] refactor: simplify assistant message extraction using Optional chaining Replace nested null checks and redundant branching with streamlined Optional-based approach. Since getResult() == getResults().get(0), processing all results handles both single and multiple result cases. Adding tests to verify. Fixes #4292 Signed-off-by: Soby Chacko --- .../advisor/PromptChatMemoryAdvisor.java | 23 ++-- .../advisor/PromptChatMemoryAdvisorTests.java | 127 ++++++++++++++++++ 2 files changed, 140 insertions(+), 10 deletions(-) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index de88715e896..3bf56e8d401 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -141,18 +142,20 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai @Override public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { List assistantMessages = new ArrayList<>(); - // Handle streaming case where we have a single result - if (chatClientResponse.chatResponse() != null && chatClientResponse.chatResponse().getResult() != null - && chatClientResponse.chatResponse().getResult().getOutput() != null) { - assistantMessages = List.of((Message) chatClientResponse.chatResponse().getResult().getOutput()); - } - else if (chatClientResponse.chatResponse() != null) { - assistantMessages = chatClientResponse.chatResponse() - .getResults() + // Extract assistant messages from chat client response. + // Processes all results from getResults() which automatically handles both single + // and multiple + // result scenarios (since getResult() == getResults().get(0)). Uses Optional + // chaining for + // null safety and returns empty list if no results are available. + assistantMessages = Optional.ofNullable(chatClientResponse) + .map(ChatClientResponse::chatResponse) + .filter(response -> response.getResults() != null && !response.getResults().isEmpty()) + .map(response -> response.getResults() .stream() .map(g -> (Message) g.getOutput()) - .toList(); - } + .collect(Collectors.toList())) + .orElse(List.of()); if (!assistantMessages.isEmpty()) { this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisorTests.java index f875a7bf803..bc1f0ce7db3 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisorTests.java @@ -16,23 +16,34 @@ package org.springframework.ai.chat.client.advisor; +import java.util.List; + import org.junit.jupiter.api.Test; import reactor.core.scheduler.Schedulers; +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.AdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.PromptTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Unit tests for {@link PromptChatMemoryAdvisor}. * * @author Mark Pollack * @author Thomas Vitale + * @author Soby Chacko */ public class PromptChatMemoryAdvisorTests { @@ -138,4 +149,120 @@ void testDefaultValues() { assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); } + @Test + void testAfterMethodHandlesSingleGeneration() { + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) + .conversationId("test-conversation") + .build(); + + ChatClientResponse mockResponse = mock(ChatClientResponse.class); + ChatResponse mockChatResponse = mock(ChatResponse.class); + Generation mockGeneration = mock(Generation.class); + AdvisorChain mockChain = mock(AdvisorChain.class); + + when(mockResponse.chatResponse()).thenReturn(mockChatResponse); + when(mockChatResponse.getResults()).thenReturn(List.of(mockGeneration)); // Single + // result + when(mockGeneration.getOutput()).thenReturn(new AssistantMessage("Single response")); + + ChatClientResponse result = advisor.after(mockResponse, mockChain); + + assertThat(result).isEqualTo(mockResponse); // Should return the same response + + // Verify single message stored in memory + List messages = chatMemory.get("test-conversation"); + assertThat(messages).hasSize(1); + assertThat(messages.get(0).getText()).isEqualTo("Single response"); + } + + @Test + void testAfterMethodHandlesMultipleGenerations() { + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) + .conversationId("test-conversation") + .build(); + + ChatClientResponse mockResponse = mock(ChatClientResponse.class); + ChatResponse mockChatResponse = mock(ChatResponse.class); + Generation mockGen1 = mock(Generation.class); + Generation mockGen2 = mock(Generation.class); + Generation mockGen3 = mock(Generation.class); + AdvisorChain mockChain = mock(AdvisorChain.class); + + when(mockResponse.chatResponse()).thenReturn(mockChatResponse); + when(mockChatResponse.getResults()).thenReturn(List.of(mockGen1, mockGen2, mockGen3)); // Multiple + // results + when(mockGen1.getOutput()).thenReturn(new AssistantMessage("Response 1")); + when(mockGen2.getOutput()).thenReturn(new AssistantMessage("Response 2")); + when(mockGen3.getOutput()).thenReturn(new AssistantMessage("Response 3")); + + ChatClientResponse result = advisor.after(mockResponse, mockChain); + + assertThat(result).isEqualTo(mockResponse); // Should return the same response + + // Verify all messages were stored in memory + List messages = chatMemory.get("test-conversation"); + assertThat(messages).hasSize(3); + assertThat(messages.get(0).getText()).isEqualTo("Response 1"); + assertThat(messages.get(1).getText()).isEqualTo("Response 2"); + assertThat(messages.get(2).getText()).isEqualTo("Response 3"); + } + + @Test + void testAfterMethodHandlesEmptyResults() { + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) + .conversationId("test-conversation") + .build(); + + ChatClientResponse mockResponse = mock(ChatClientResponse.class); + ChatResponse mockChatResponse = mock(ChatResponse.class); + AdvisorChain mockChain = mock(AdvisorChain.class); + + when(mockResponse.chatResponse()).thenReturn(mockChatResponse); + when(mockChatResponse.getResults()).thenReturn(List.of()); + + ChatClientResponse result = advisor.after(mockResponse, mockChain); + + assertThat(result).isEqualTo(mockResponse); + + // Verify no messages were stored in memory + List messages = chatMemory.get("test-conversation"); + assertThat(messages).isEmpty(); + } + + @Test + void testAfterMethodHandlesNullChatResponse() { + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) + .conversationId("test-conversation") + .build(); + + ChatClientResponse mockResponse = mock(ChatClientResponse.class); + AdvisorChain mockChain = mock(AdvisorChain.class); + + when(mockResponse.chatResponse()).thenReturn(null); + + ChatClientResponse result = advisor.after(mockResponse, mockChain); + + assertThat(result).isEqualTo(mockResponse); + + // Verify no messages were stored in memory + List messages = chatMemory.get("test-conversation"); + assertThat(messages).isEmpty(); + } + }