diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java index ce775b20cd5..bd60639c323 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java @@ -155,4 +155,272 @@ public String call(String toolInput) { assertThatNoException().isThrownBy(() -> managerWithCallback.executeToolCalls(prompt, chatResponse)); } + @Test + void shouldHandleMultipleToolCallsInSingleResponse() { + // Create mock tool callbacks + ToolCallback toolCallback1 = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("tool1") + .description("First tool") + .inputSchema("{\"type\": \"object\", \"properties\": {\"param\": {\"type\": \"string\"}}}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return "{\"result\": \"tool1_success\"}"; + } + }; + + ToolCallback toolCallback2 = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("tool2") + .description("Second tool") + .inputSchema("{\"type\": \"object\", \"properties\": {\"value\": {\"type\": \"number\"}}}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return "{\"result\": \"tool2_success\"}"; + } + }; + + // Create multiple ToolCalls + AssistantMessage.ToolCall toolCall1 = new AssistantMessage.ToolCall("1", "function", "tool1", + "{\"param\": \"test\"}"); + AssistantMessage.ToolCall toolCall2 = new AssistantMessage.ToolCall("2", "function", "tool2", + "{\"value\": 42}"); + + // Create ChatResponse with multiple tool calls + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall1, toolCall2)) + .build(); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test multiple tools"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> { + if ("tool1".equals(toolName)) { + return toolCallback1; + } + if ("tool2".equals(toolName)) { + return toolCallback2; + } + return null; + }) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + + @Test + void shouldHandleToolCallWithComplexJsonArguments() { + ToolCallback complexToolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("complexTool") + .description("A tool with complex JSON input") + .inputSchema("{\"type\": \"object\", \"properties\": {\"nested\": {\"type\": \"object\"}}}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + assertThat(toolInput).contains("nested"); + assertThat(toolInput).contains("array"); + return "{\"result\": \"processed\"}"; + } + }; + + String complexJson = "{\"nested\": {\"level1\": {\"level2\": \"value\"}}, \"array\": [1, 2, 3]}"; + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "complexTool", complexJson); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test complex json"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "complexTool".equals(toolName) ? complexToolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + + @Test + void shouldHandleToolCallWithMalformedJson() { + ToolCallback toolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("testTool") + .description("Test tool") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + // Should still receive some input even if malformed + assertThat(toolInput).isNotNull(); + return "{\"result\": \"handled\"}"; + } + }; + + // Malformed JSON as tool arguments + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "testTool", + "{invalid json}"); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test malformed json"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "testTool".equals(toolName) ? toolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + + @Test + void shouldHandleToolCallReturningNull() { + ToolCallback toolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("nullReturningTool") + .description("Tool that returns null") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return null; // Return null + } + }; + + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "nullReturningTool", "{}"); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test null return"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "nullReturningTool".equals(toolName) ? toolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + + @Test + void shouldHandleMultipleGenerationsWithToolCalls() { + ToolCallback toolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("multiGenTool") + .description("Tool for multiple generations") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return "{\"result\": \"success\"}"; + } + }; + + // Create multiple generations with tool calls + AssistantMessage.ToolCall toolCall1 = new AssistantMessage.ToolCall("1", "function", "multiGenTool", "{}"); + AssistantMessage.ToolCall toolCall2 = new AssistantMessage.ToolCall("2", "function", "multiGenTool", "{}"); + + AssistantMessage assistantMessage1 = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall1)) + .build(); + + AssistantMessage assistantMessage2 = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall2)) + .build(); + + Generation generation1 = new Generation(assistantMessage1); + Generation generation2 = new Generation(assistantMessage2); + + ChatResponse chatResponse = new ChatResponse(List.of(generation1, generation2)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test multiple generations"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "multiGenTool".equals(toolName) ? toolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + }