Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import org.slf4j.Logger;
Expand Down Expand Up @@ -141,18 +142,20 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai
@Override
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
List<Message> assistantMessages = new ArrayList<>();
// Handle streaming case where we have a single result
if (chatClientResponse.chatResponse() != null && chatClientResponse.chatResponse().getResult() != null
&& chatClientResponse.chatResponse().getResult().getOutput() != null) {
assistantMessages = List.of((Message) chatClientResponse.chatResponse().getResult().getOutput());
}
else if (chatClientResponse.chatResponse() != null) {
assistantMessages = chatClientResponse.chatResponse()
.getResults()
// Extract assistant messages from chat client response.
// Processes all results from getResults() which automatically handles both single
// and multiple
// result scenarios (since getResult() == getResults().get(0)). Uses Optional
// chaining for
// null safety and returns empty list if no results are available.
assistantMessages = Optional.ofNullable(chatClientResponse)
.map(ChatClientResponse::chatResponse)
.filter(response -> response.getResults() != null && !response.getResults().isEmpty())
.map(response -> response.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();
}
.collect(Collectors.toList()))
.orElse(List.of());

if (!assistantMessages.isEmpty()) {
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,34 @@

package org.springframework.ai.chat.client.advisor;

import java.util.List;

import org.junit.jupiter.api.Test;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.PromptTemplate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Unit tests for {@link PromptChatMemoryAdvisor}.
*
* @author Mark Pollack
* @author Thomas Vitale
* @author Soby Chacko
*/
public class PromptChatMemoryAdvisorTests {

Expand Down Expand Up @@ -138,4 +149,120 @@ void testDefaultValues() {
assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}

@Test
void testAfterMethodHandlesSingleGeneration() {
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
.conversationId("test-conversation")
.build();

ChatClientResponse mockResponse = mock(ChatClientResponse.class);
ChatResponse mockChatResponse = mock(ChatResponse.class);
Generation mockGeneration = mock(Generation.class);
AdvisorChain mockChain = mock(AdvisorChain.class);

when(mockResponse.chatResponse()).thenReturn(mockChatResponse);
when(mockChatResponse.getResults()).thenReturn(List.of(mockGeneration)); // Single
// result
when(mockGeneration.getOutput()).thenReturn(new AssistantMessage("Single response"));

ChatClientResponse result = advisor.after(mockResponse, mockChain);

assertThat(result).isEqualTo(mockResponse); // Should return the same response

// Verify single message stored in memory
List<Message> messages = chatMemory.get("test-conversation");
assertThat(messages).hasSize(1);
assertThat(messages.get(0).getText()).isEqualTo("Single response");
}

@Test
void testAfterMethodHandlesMultipleGenerations() {
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
.conversationId("test-conversation")
.build();

ChatClientResponse mockResponse = mock(ChatClientResponse.class);
ChatResponse mockChatResponse = mock(ChatResponse.class);
Generation mockGen1 = mock(Generation.class);
Generation mockGen2 = mock(Generation.class);
Generation mockGen3 = mock(Generation.class);
AdvisorChain mockChain = mock(AdvisorChain.class);

when(mockResponse.chatResponse()).thenReturn(mockChatResponse);
when(mockChatResponse.getResults()).thenReturn(List.of(mockGen1, mockGen2, mockGen3)); // Multiple
// results
when(mockGen1.getOutput()).thenReturn(new AssistantMessage("Response 1"));
when(mockGen2.getOutput()).thenReturn(new AssistantMessage("Response 2"));
when(mockGen3.getOutput()).thenReturn(new AssistantMessage("Response 3"));

ChatClientResponse result = advisor.after(mockResponse, mockChain);

assertThat(result).isEqualTo(mockResponse); // Should return the same response

// Verify all messages were stored in memory
List<Message> messages = chatMemory.get("test-conversation");
assertThat(messages).hasSize(3);
assertThat(messages.get(0).getText()).isEqualTo("Response 1");
assertThat(messages.get(1).getText()).isEqualTo("Response 2");
assertThat(messages.get(2).getText()).isEqualTo("Response 3");
}

@Test
void testAfterMethodHandlesEmptyResults() {
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
.conversationId("test-conversation")
.build();

ChatClientResponse mockResponse = mock(ChatClientResponse.class);
ChatResponse mockChatResponse = mock(ChatResponse.class);
AdvisorChain mockChain = mock(AdvisorChain.class);

when(mockResponse.chatResponse()).thenReturn(mockChatResponse);
when(mockChatResponse.getResults()).thenReturn(List.of());

ChatClientResponse result = advisor.after(mockResponse, mockChain);

assertThat(result).isEqualTo(mockResponse);

// Verify no messages were stored in memory
List<Message> messages = chatMemory.get("test-conversation");
assertThat(messages).isEmpty();
}

@Test
void testAfterMethodHandlesNullChatResponse() {
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
.conversationId("test-conversation")
.build();

ChatClientResponse mockResponse = mock(ChatClientResponse.class);
AdvisorChain mockChain = mock(AdvisorChain.class);

when(mockResponse.chatResponse()).thenReturn(null);

ChatClientResponse result = advisor.after(mockResponse, mockChain);

assertThat(result).isEqualTo(mockResponse);

// Verify no messages were stored in memory
List<Message> messages = chatMemory.get("test-conversation");
assertThat(messages).isEmpty();
}

}