diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/ChatModelTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/ChatModelTests.java index 3de278296a1..27878a5ca52 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/ChatModelTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/ChatModelTests.java @@ -26,6 +26,7 @@ import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -139,4 +140,139 @@ void generateWithWhitespaceOnlyStringHandlesCorrectly() { verify(mockClient, times(1)).call(eq(userMessage)); } + @Test + void generateWhenPromptCallThrowsExceptionPropagatesCorrectly() { + String userMessage = "Test message"; + RuntimeException expectedException = new RuntimeException("API call failed"); + + ChatModel mockClient = Mockito.mock(ChatModel.class); + + doCallRealMethod().when(mockClient).call(anyString()); + given(mockClient.call(any(Prompt.class))).willThrow(expectedException); + + assertThatThrownBy(() -> mockClient.call(userMessage)).isEqualTo(expectedException); + + verify(mockClient, times(1)).call(eq(userMessage)); + verify(mockClient, times(1)).call(isA(Prompt.class)); + } + + @Test + void generateWhenResponseIsNullHandlesGracefully() { + String userMessage = "Test message"; + + ChatModel mockClient = Mockito.mock(ChatModel.class); + + doCallRealMethod().when(mockClient).call(anyString()); + given(mockClient.call(any(Prompt.class))).willReturn(null); + + assertThatThrownBy(() -> mockClient.call(userMessage)).isInstanceOf(NullPointerException.class); + + verify(mockClient, times(1)).call(eq(userMessage)); + verify(mockClient, times(1)).call(isA(Prompt.class)); + } + + @Test + void generateWhenAssistantMessageIsNullHandlesGracefully() { + String userMessage = "Test message"; + + ChatModel mockClient = Mockito.mock(ChatModel.class); + + Generation generation = Mockito.mock(Generation.class); + given(generation.getOutput()).willReturn(null); + + ChatResponse response = Mockito.mock(ChatResponse.class); + given(response.getResult()).willReturn(generation); + + doCallRealMethod().when(mockClient).call(anyString()); + given(mockClient.call(any(Prompt.class))).willReturn(response); + + assertThatThrownBy(() -> mockClient.call(userMessage)).isInstanceOf(NullPointerException.class); + + verify(mockClient, times(1)).call(eq(userMessage)); + verify(generation, times(1)).getOutput(); + } + + @Test + void generateWhenAssistantMessageTextIsNullReturnsNull() { + String userMessage = "Test message"; + + ChatModel mockClient = Mockito.mock(ChatModel.class); + + AssistantMessage mockAssistantMessage = Mockito.mock(AssistantMessage.class); + given(mockAssistantMessage.getText()).willReturn(null); + + Generation generation = Mockito.mock(Generation.class); + given(generation.getOutput()).willReturn(mockAssistantMessage); + + ChatResponse response = Mockito.mock(ChatResponse.class); + given(response.getResult()).willReturn(generation); + + doCallRealMethod().when(mockClient).call(anyString()); + given(mockClient.call(any(Prompt.class))).willReturn(response); + + String result = mockClient.call(userMessage); + + assertThat(result).isNull(); + verify(mockClient, times(1)).call(eq(userMessage)); + verify(mockAssistantMessage, times(1)).getText(); + } + + @Test + void generateWithMultilineStringHandlesCorrectly() { + String userMessage = "Line 1\nLine 2\r\nLine 3\rLine 4"; + String responseMessage = "Multiline input processed"; + + ChatModel mockClient = Mockito.mock(ChatModel.class); + + AssistantMessage mockAssistantMessage = Mockito.mock(AssistantMessage.class); + given(mockAssistantMessage.getText()).willReturn(responseMessage); + + Generation generation = Mockito.mock(Generation.class); + given(generation.getOutput()).willReturn(mockAssistantMessage); + + ChatResponse response = Mockito.mock(ChatResponse.class); + given(response.getResult()).willReturn(generation); + + doCallRealMethod().when(mockClient).call(anyString()); + given(mockClient.call(any(Prompt.class))).willReturn(response); + + String result = mockClient.call(userMessage); + + assertThat(result).isEqualTo(responseMessage); + verify(mockClient, times(1)).call(eq(userMessage)); + } + + @Test + void generateMultipleTimesWithSameClientMaintainsState() { + ChatModel mockClient = Mockito.mock(ChatModel.class); + + doCallRealMethod().when(mockClient).call(anyString()); + + // First call + setupMockResponse(mockClient, "Response 1"); + String result1 = mockClient.call("Message 1"); + assertThat(result1).isEqualTo("Response 1"); + + // Second call + setupMockResponse(mockClient, "Response 2"); + String result2 = mockClient.call("Message 2"); + assertThat(result2).isEqualTo("Response 2"); + + verify(mockClient, times(2)).call(anyString()); + verify(mockClient, times(2)).call(any(Prompt.class)); + } + + private void setupMockResponse(ChatModel mockClient, String responseText) { + AssistantMessage mockAssistantMessage = Mockito.mock(AssistantMessage.class); + given(mockAssistantMessage.getText()).willReturn(responseText); + + Generation generation = Mockito.mock(Generation.class); + given(generation.getOutput()).willReturn(mockAssistantMessage); + + ChatResponse response = Mockito.mock(ChatResponse.class); + given(response.getResult()).willReturn(generation); + + given(mockClient.call(any(Prompt.class))).willReturn(response); + } + } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java b/spring-ai-model/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java index 76b7c16dfef..05ec5ee1e2f 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java @@ -178,4 +178,43 @@ void csvWithOnlyWhitespace() { assertThat(list.get(0)).isBlank(); } + @Test + void csvWithCommasInQuotedValues() { + String csvAsString = "\"value, with, commas\", normal, \"another, comma\""; + List list = this.listOutputConverter.convert(csvAsString); + assertThat(list).isNotEmpty(); + assertThat(list).doesNotContainNull(); + } + + @Test + void csvWithNewlinesInQuotedValues() { + String csvAsString = "\"line1\nline2\", normal, \"another\nline\""; + List list = this.listOutputConverter.convert(csvAsString); + assertThat(list).isNotEmpty(); + assertThat(list).doesNotContainNull(); + } + + @Test + void csvWithMixedQuotingStyles() { + String csvAsString = "'single quoted', \"double quoted\", `backtick quoted`, unquoted"; + List list = this.listOutputConverter.convert(csvAsString); + assertThat(list).hasSize(4); + assertThat(list).doesNotContainNull(); + } + + @Test + void csvWithOnlyCommasAndSpaces() { + String csvAsString = " , , , "; + List list = this.listOutputConverter.convert(csvAsString); + assertThat(list).hasSize(4); + assertThat(list).allMatch(String::isEmpty); + } + + @Test + void csvWithMalformedQuoting() { + String csvAsString = "\"unclosed quote, normal, \"properly closed\""; + List list = this.listOutputConverter.convert(csvAsString); + assertThat(list).isNotEmpty(); + } + } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java index d7695b39ce7..ca1862f985f 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java @@ -140,4 +140,51 @@ void whenBuilderCalledWithoutConversationHistoryThenThrowsException() { assertThat(toolExecutionResult.conversationHistory()).isEmpty(); } + @Test + void whenMultipleToolResponseMessagesOnlyLastOneIsProcessed() { + var toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(List.of(new AssistantMessage("First response"), + new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("1", "old_tool", "Old response"))), + new AssistantMessage("Second response"), + new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("2", "new_tool", "New response"))))) + .build(); + + var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); + + assertThat(generations).hasSize(1); + assertThat(generations.get(0).getOutput().getText()).isEqualTo("New response"); + assertThat((String) generations.get(0).getMetadata().get(ToolExecutionResult.METADATA_TOOL_NAME)) + .isEqualTo("new_tool"); + } + + @Test + void whenToolResponseWithEmptyToolNameThenMetadataContainsEmptyString() { + var toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(List.of(new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("1", "", "Response content"))))) + .build(); + + var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); + + assertThat(generations).hasSize(1); + assertThat((String) generations.get(0).getMetadata().get(ToolExecutionResult.METADATA_TOOL_NAME)).isEmpty(); + } + + @Test + void whenToolResponseWithNullToolIdThenGenerationStillCreated() { + var toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(List.of(new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse(null, "tool", "Response content"))))) + .build(); + + var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); + + assertThat(generations).hasSize(1); + assertThat(generations.get(0).getOutput().getText()).isEqualTo("Response content"); + assertThat((String) generations.get(0).getMetadata().get(ToolExecutionResult.METADATA_TOOL_NAME)) + .isEqualTo("tool"); + } + }