diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index b7c73fb9246..77df628c992 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -19,6 +19,7 @@ import java.util.Map; import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Tool; @@ -33,6 +34,7 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; @@ -45,15 +47,20 @@ * * @author Christian Tzolov * @author YunKui Lu + * @author Sun Yuhan */ public class AsyncMcpToolCallback implements ToolCallback { private static final Logger logger = LoggerFactory.getLogger(AsyncMcpToolCallback.class); + private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); + private final McpAsyncClient mcpClient; private final Tool tool; + private final ToolMetadata toolMetadata; + private final String prefixedToolName; private final ToolContextToMcpMetaConverter toolContextToMcpMetaConverter; @@ -88,6 +95,14 @@ private AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool, String prefixe this.tool = tool; this.prefixedToolName = prefixedToolName; this.toolContextToMcpMetaConverter = toolContextToMcpMetaConverter; + McpSchema.ToolAnnotations annotations = tool.annotations(); + Boolean returnDirect = (annotations != null) ? annotations.returnDirect() : null; + if (returnDirect != null) { + this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build(); + } + else { + this.toolMetadata = DEFAULT_TOOL_METADATA; + } } @Override @@ -149,6 +164,11 @@ public String call(String toolCallInput, @Nullable ToolContext toolContext) { return ModelOptionsUtils.toJsonString(response.content()); } + @Override + public ToolMetadata getToolMetadata() { + return this.toolMetadata; + } + /** * Creates a builder for constructing AsyncMcpToolCallback instances. * @return a new builder diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index 858c8be575a..74939f217d8 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -40,6 +40,7 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; @@ -63,6 +64,7 @@ * * * @author Christian Tzolov + * @author Sun Yuhan */ public final class McpToolUtils { @@ -228,12 +230,18 @@ public static McpStatelessServerFeatures.SyncToolSpecification toStatelessSyncTo private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCallback toolCallback, MimeType mimeType) { + boolean returnDirect = Optional.ofNullable(toolCallback.getToolMetadata()) + .map(ToolMetadata::returnDirect) + .orElse(false); + McpSchema.ToolAnnotations toolAnnotations = new McpSchema.ToolAnnotations(null, null, null, null, null, + returnDirect); var tool = McpSchema.Tool.builder() .name(toolCallback.getToolDefinition().name()) .description(toolCallback.getToolDefinition().description()) .inputSchema(ModelOptionsUtils.jsonToObject(toolCallback.getToolDefinition().inputSchema(), McpSchema.JsonSchema.class)) + .annotations(toolAnnotations) .build(); return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> { diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index d7f980d08d9..6b560e23a13 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -19,6 +19,7 @@ import java.util.Map; import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Tool; @@ -32,6 +33,7 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; @@ -41,16 +43,21 @@ * * @author Christian Tzolov * @author YunKui Lu + * @author Sun Yuhan * @since 1.0.0 */ public class SyncMcpToolCallback implements ToolCallback { private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolCallback.class); + private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); + private final McpSyncClient mcpClient; private final Tool tool; + private final ToolMetadata toolMetadata; + private final String prefixedToolName; private final ToolContextToMcpMetaConverter toolContextToMcpMetaConverter; @@ -85,6 +92,14 @@ private SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool, String prefixedT this.tool = tool; this.prefixedToolName = prefixedToolName; this.toolContextToMcpMetaConverter = toolContextToMcpMetaConverter; + McpSchema.ToolAnnotations annotations = tool.annotations(); + Boolean returnDirect = (annotations != null) ? annotations.returnDirect() : null; + if (returnDirect != null) { + this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build(); + } + else { + this.toolMetadata = DEFAULT_TOOL_METADATA; + } } @Override @@ -149,6 +164,11 @@ public String call(String toolCallInput, @Nullable ToolContext toolContext) { return ModelOptionsUtils.toJsonString(response.content()); } + @Override + public ToolMetadata getToolMetadata() { + return this.toolMetadata; + } + /** * Creates a builder for constructing {@code SyncMcpToolCallback} instances. * @return a new builder diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java index f29743e0573..0a57648aa9d 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java @@ -204,6 +204,8 @@ void getToolDefinitionShouldReturnCorrectDefinition() { when(this.tool.description()).thenReturn("Test tool description"); var jsonSchema = mock(McpSchema.JsonSchema.class); when(this.tool.inputSchema()).thenReturn(jsonSchema); + var toolAnnotations = new McpSchema.ToolAnnotations(null, false, false, false, false, true); + when(this.tool.annotations()).thenReturn(toolAnnotations); // Act var callback = AsyncMcpToolCallback.builder() @@ -213,11 +215,13 @@ void getToolDefinitionShouldReturnCorrectDefinition() { .build(); ToolDefinition definition = callback.getToolDefinition(); + var toolMetadata = callback.getToolMetadata(); // Assert assertThat(definition.name()).isEqualTo("prefix_testTool"); assertThat(definition.description()).isEqualTo("Test tool description"); assertThat(definition.inputSchema()).isNotNull(); + assertThat(toolMetadata.returnDirect()).isEqualTo(true); } @Test diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java index 3dae5d0eba1..5f86a07a505 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java @@ -51,6 +51,7 @@ class SyncMcpToolCallbackTests { @Test void getToolDefinitionShouldReturnCorrectDefinition() { var clientInfo = new Implementation("testClient", "1.0.0"); + when(this.tool.name()).thenReturn("testTool"); when(this.tool.description()).thenReturn("Test tool description"); diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java index 12c1f6023e7..f18ca0032c2 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java @@ -37,6 +37,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -314,8 +315,10 @@ private ToolCallback createMockToolCallback(String name, String result) { .description("Test tool") .inputSchema("{}") .build(); + ToolMetadata metadata = ToolMetadata.builder().build(); when(callback.getToolDefinition()).thenReturn(definition); when(callback.call(anyString(), any())).thenReturn(result); + when(callback.getToolMetadata()).thenReturn(metadata); return callback; } @@ -326,8 +329,10 @@ private ToolCallback createMockToolCallback(String name, RuntimeException error) .description("Test tool") .inputSchema("{}") .build(); + ToolMetadata metadata = ToolMetadata.builder().build(); when(callback.getToolDefinition()).thenReturn(definition); when(callback.call(anyString(), any())).thenThrow(error); + when(callback.getToolMetadata()).thenReturn(metadata); return callback; }