diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 5f8da416109..bfb999b503e 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -19,8 +19,8 @@ import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; -import java.util.Map; - +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; @@ -29,6 +29,8 @@ import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import java.util.Map; + /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool * interface with asynchronous execution support. @@ -61,6 +63,8 @@ */ public class AsyncMcpToolCallback implements ToolCallback { + private static final Logger logger = LoggerFactory.getLogger(AsyncMcpToolCallback.class); + private final McpAsyncClient asyncMcpClient; private final Tool tool; @@ -109,6 +113,13 @@ public ToolDefinition getToolDefinition() { */ @Override public String call(String functionInput) { + // Handle the possible null parameter situation in streaming mode. + if (functionInput == null || functionInput.trim().isEmpty()) { + logger.debug("Tool call arguments are null or empty for MCP tool: {}. Using empty JSON object as default.", + this.tool.name()); + functionInput = "{}"; + } + Map arguments = ModelOptionsUtils.jsonToMap(functionInput); // Note that we use the original tool name here, not the adapted one from // getToolDefinition diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index fc61d801df1..df6cfa23cd5 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -20,10 +20,8 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Tool; -import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.tool.ToolCallback; @@ -31,6 +29,8 @@ import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import java.util.Map; + /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool * interface. @@ -114,6 +114,13 @@ public ToolDefinition getToolDefinition() { */ @Override public String call(String functionInput) { + // Handle the possible null parameter situation in streaming mode. + if (functionInput == null || functionInput.trim().isEmpty()) { + logger.debug("Tool call arguments are null or empty for MCP tool: {}. Using empty JSON object as default.", + this.tool.name()); + functionInput = "{}"; + } + Map arguments = ModelOptionsUtils.jsonToMap(functionInput); CallToolResult response; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index 5149a98a85c..53bad12512a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -16,16 +16,9 @@ package org.springframework.ai.model.tool; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -47,6 +40,12 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + /** * Default implementation of {@link ToolCallingManager}. * @@ -189,6 +188,17 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess String toolName = toolCall.name(); String toolInputArguments = toolCall.arguments(); + // Handle the possible null parameter situation in streaming mode. + final String finalToolInputArguments; + if (toolInputArguments == null || toolInputArguments.trim().isEmpty()) { + logger.debug("Tool call arguments are null or empty for tool: {}. Using empty JSON object as default.", + toolName); + finalToolInputArguments = "{}"; + } + else { + finalToolInputArguments = toolInputArguments; + } + ToolCallback toolCallback = toolCallbacks.stream() .filter(tool -> toolName.equals(tool.getToolDefinition().name())) .findFirst() @@ -208,7 +218,7 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder() .toolDefinition(toolCallback.getToolDefinition()) .toolMetadata(toolCallback.getToolMetadata()) - .toolCallArguments(toolInputArguments) + .toolCallArguments(finalToolInputArguments) .build(); String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL @@ -217,7 +227,7 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess .observe(() -> { String toolResult; try { - toolResult = toolCallback.call(toolInputArguments, toolContext); + toolResult = toolCallback.call(finalToolInputArguments, toolContext); } catch (ToolExecutionException ex) { toolResult = this.toolExecutionExceptionProcessor.process(ex); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java new file mode 100644 index 00000000000..ceb84a0a941 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java @@ -0,0 +1,164 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; + +/** + * Tests for {@link DefaultToolCallingManager} with empty/null arguments handling. + * + * @author Spring AI Team + */ +class DefaultToolCallingManagerTest { + + @Test + void shouldHandleNullArgumentsInStreamMode() { + // Create a mock tool callback + ToolCallback mockToolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("testTool") + .description("A test tool") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + // Verify the input is not null or empty + assertThat(toolInput).isNotNull(); + assertThat(toolInput).isNotEmpty(); + return "{\"result\": \"success\"}"; + } + }; + + // Create DefaultToolCallingManager with tool callback + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .build(); + + // Create a ToolCall with empty parameters + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "testTool", null); + + // Create a ChatResponse + AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), List.of(toolCall)); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + // Create a Prompt with tool callbacks + Prompt prompt = new Prompt(List.of(new UserMessage("test"))); + + // Mock the tool callbacks resolution by creating a custom ToolCallbackResolver + DefaultToolCallingManager managerWithCallback = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> { + if ("testTool".equals(toolName)) { + return mockToolCallback; + } + return null; + }) + .build(); + + // Verify that no exception is thrown + assertThatNoException().isThrownBy(() -> { + managerWithCallback.executeToolCalls(prompt, chatResponse); + }); + } + + @Test + void shouldHandleEmptyArgumentsInStreamMode() { + // Create a mock tool callback + ToolCallback mockToolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("testTool") + .description("A test tool") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + // Verify the input is not null or empty + assertThat(toolInput).isNotNull(); + assertThat(toolInput).isNotEmpty(); + return "{\"result\": \"success\"}"; + } + }; + + // Create DefaultToolCallingManager with tool callback + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .build(); + + // Create a ToolCall with empty parameters + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "testTool", ""); + + // Create a ChatResponse + AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), List.of(toolCall)); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + // Create a Prompt with tool callbacks + Prompt prompt = new Prompt(List.of(new UserMessage("test"))); + + // Mock the tool callbacks resolution by creating a custom ToolCallbackResolver + DefaultToolCallingManager managerWithCallback = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> { + if ("testTool".equals(toolName)) { + return mockToolCallback; + } + return null; + }) + .build(); + + // Verify that no exception is thrown + assertThatNoException().isThrownBy(() -> { + managerWithCallback.executeToolCalls(prompt, chatResponse); + }); + } + +}