diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionResultTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionResultTests.java index c6c06c1e75e..947483cf374 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionResultTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionResultTests.java @@ -17,6 +17,7 @@ package org.springframework.ai.model.tool; import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.Test; @@ -190,4 +191,34 @@ void whenEqualsAndHashCodeAreConsistent() { assertThat(result1.hashCode()).isEqualTo(result2.hashCode()); } + @Test + void whenConversationHistoryIsImmutableList() { + List conversationHistory = List.of(new org.springframework.ai.chat.messages.UserMessage("Hello"), + new org.springframework.ai.chat.messages.UserMessage("Hi!")); + + var result = DefaultToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .returnDirect(false) + .build(); + + assertThat(result.conversationHistory()).hasSize(2); + assertThat(result.conversationHistory()).isEqualTo(conversationHistory); + } + + @Test + void whenReturnDirectIsChangedMultipleTimes() { + var conversationHistory = new ArrayList(); + conversationHistory.add(new org.springframework.ai.chat.messages.UserMessage("Test")); + + var builder = DefaultToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .returnDirect(true) + .returnDirect(false) + .returnDirect(true); + + var result = builder.build(); + + assertThat(result.returnDirect()).isTrue(); + } + } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java index d347f9190f1..5cef425f942 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java @@ -16,6 +16,7 @@ package org.springframework.ai.model.tool; +import java.util.Collections; import java.util.List; import org.junit.jupiter.api.Test; @@ -75,6 +76,32 @@ void whenTestMethodCalledDirectly() { assertThat(result).isTrue(); } + @Test + void whenChatResponseHasEmptyGenerations() { + ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate(); + ChatOptions promptOptions = ChatOptions.builder().build(); + ChatResponse emptyResponse = new ChatResponse(Collections.emptyList()); + + boolean result = predicate.isToolExecutionRequired(promptOptions, emptyResponse); + assertThat(result).isTrue(); + } + + @Test + void whenChatOptionsHasModel() { + ModelCheckingPredicate predicate = new ModelCheckingPredicate(); + + ChatOptions optionsWithModel = ChatOptions.builder().model("gpt-4").build(); + + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test")))); + + boolean result = predicate.isToolExecutionRequired(optionsWithModel, chatResponse); + assertThat(result).isTrue(); + + ChatOptions optionsWithoutModel = ChatOptions.builder().build(); + result = predicate.isToolExecutionRequired(optionsWithoutModel, chatResponse); + assertThat(result).isFalse(); + } + /** * Test implementation of {@link ToolExecutionEligibilityPredicate} that always * returns true. @@ -88,4 +115,13 @@ public boolean test(ChatOptions promptOptions, ChatResponse chatResponse) { } + private static class ModelCheckingPredicate implements ToolExecutionEligibilityPredicate { + + @Override + public boolean test(ChatOptions promptOptions, ChatResponse chatResponse) { + return promptOptions.getModel() != null && !promptOptions.getModel().isEmpty(); + } + + } + } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java index acb3bfca0c5..d7695b39ce7 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java @@ -25,6 +25,7 @@ import org.springframework.ai.chat.messages.UserMessage; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /** * Unit tests for {@link ToolExecutionResult}. @@ -80,4 +81,63 @@ void whenMultipleToolCallsThenMultipleGenerations() { assertThat(generations.get(1).getMetadata().getFinishReason()).isEqualTo(ToolExecutionResult.FINISH_REASON); } + @Test + void whenEmptyConversationHistoryThenThrowsException() { + var toolExecutionResult = ToolExecutionResult.builder().conversationHistory(List.of()).build(); + + assertThatThrownBy(() -> ToolExecutionResult.buildGenerations(toolExecutionResult)) + .isInstanceOf(ArrayIndexOutOfBoundsException.class); + } + + @Test + void whenToolResponseWithEmptyResponseListThenEmptyGenerations() { + var toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory( + List.of(new AssistantMessage("Processing request"), new ToolResponseMessage(List.of()))) + .build(); + + var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); + + assertThat(generations).isEmpty(); + } + + @Test + void whenToolResponseWithNullContentThenGenerationWithNullText() { + var toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory( + List.of(new ToolResponseMessage(List.of(new ToolResponseMessage.ToolResponse("1", "tool", null))))) + .build(); + + var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); + + assertThat(generations).hasSize(1); + assertThat(generations.get(0).getOutput().getText()).isNull(); + } + + @Test + void whenToolResponseWithEmptyStringContentThenGenerationWithEmptyText() { + var toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory( + List.of(new ToolResponseMessage(List.of(new ToolResponseMessage.ToolResponse("1", "tool", ""))))) + .build(); + + var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); + + assertThat(generations).hasSize(1); + assertThat(generations.get(0).getOutput().getText()).isEmpty(); + assertThat((String) generations.get(0).getMetadata().get(ToolExecutionResult.METADATA_TOOL_NAME)) + .isEqualTo("tool"); + } + + @Test + void whenBuilderCalledWithoutConversationHistoryThenThrowsException() { + var toolExecutionResult = ToolExecutionResult.builder().build(); + + assertThatThrownBy(() -> ToolExecutionResult.buildGenerations(toolExecutionResult)) + .isInstanceOf(ArrayIndexOutOfBoundsException.class); + + assertThat(toolExecutionResult.conversationHistory()).isNotNull(); + assertThat(toolExecutionResult.conversationHistory()).isEmpty(); + } + }