diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java index ae85e47b70d..6e9b91d7947 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java @@ -36,6 +36,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; import org.springframework.util.CollectionUtils; /** @@ -108,6 +109,7 @@ @EnableConfigurationProperties(McpClientCommonProperties.class) @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) +@Import(McpCompositeClientProperties.class) public class McpClientAutoConfiguration { /** @@ -146,7 +148,7 @@ private String connectedClientName(String clientName, String serverConnectionNam @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientConfigurer, - McpClientCommonProperties commonProperties, + McpClientCommonProperties commonProperties, McpCompositeClientProperties mcpCompositeClientProperties, ObjectProvider> transportsProvider) { List mcpSyncClients = new ArrayList<>(); @@ -165,7 +167,11 @@ public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC .requestTimeout(commonProperties.getRequestTimeout()); spec = mcpSyncClientConfigurer.configure(namedTransport.name(), spec); - + spec.toolAnnotationsHandler(name -> { + // set returnDirect in client level + boolean returnDirect = mcpCompositeClientProperties.getReturnDirect(namedTransport.name()); + return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect); + }); var client = spec.build(); if (commonProperties.isInitialized()) { @@ -213,7 +219,7 @@ McpSyncClientConfigurer mcpSyncClientConfigurer(ObjectProvider mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncClientConfigurer, - McpClientCommonProperties commonProperties, + McpClientCommonProperties commonProperties, McpCompositeClientProperties mcpCompositeClientProperties, ObjectProvider> transportsProvider) { List mcpAsyncClients = new ArrayList<>(); @@ -232,7 +238,11 @@ public List mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncCli .requestTimeout(commonProperties.getRequestTimeout()); spec = mcpAsyncClientConfigurer.configure(namedTransport.name(), spec); - + spec.toolAnnotationsHandler(name -> { + // set returnDirect in client level + boolean returnDirect = mcpCompositeClientProperties.getReturnDirect(namedTransport.name()); + return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect); + }); var client = spec.build(); if (commonProperties.isInitialized()) { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpCompositeClientProperties.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpCompositeClientProperties.java new file mode 100644 index 00000000000..a41af9d66b1 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpCompositeClientProperties.java @@ -0,0 +1,44 @@ +package org.springframework.ai.mcp.client.common.autoconfigure; + +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpSseClientProperties; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStdioClientProperties; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class McpCompositeClientProperties { + + private final ObjectProvider sseClientPropertiesObjectProvider; + + private final ObjectProvider stdioClientPropertiesObjectProvider; + + private final ObjectProvider streamableHttpClientPropertiesObjectProvider; + + public McpCompositeClientProperties(ObjectProvider sseClientPropertiesObjectProvider, + ObjectProvider stdioClientPropertiesObjectProvider, + ObjectProvider streamableHttpClientPropertiesObjectProvider) { + this.sseClientPropertiesObjectProvider = sseClientPropertiesObjectProvider; + this.stdioClientPropertiesObjectProvider = stdioClientPropertiesObjectProvider; + this.streamableHttpClientPropertiesObjectProvider = streamableHttpClientPropertiesObjectProvider; + } + + public boolean getReturnDirect(String connectionName) { + McpSseClientProperties sseClientProperties = sseClientPropertiesObjectProvider.getIfAvailable(); + if (sseClientProperties != null && sseClientProperties.getConnections().containsKey(connectionName)) { + return sseClientProperties.getConnections().get(connectionName).returnDirect(); + } + McpStdioClientProperties stdioClientProperties = stdioClientPropertiesObjectProvider.getIfAvailable(); + if (stdioClientProperties != null && stdioClientProperties.getConnections().containsKey(connectionName)) { + return stdioClientProperties.getConnections().get(connectionName).returnDirect(); + } + McpStreamableHttpClientProperties streamableHttpClientProperties = streamableHttpClientPropertiesObjectProvider + .getIfAvailable(); + if (streamableHttpClientProperties != null + && streamableHttpClientProperties.getConnections().containsKey(connectionName)) { + return streamableHttpClientProperties.getConnections().get(connectionName).returnDirect(); + } + return false; + } + +} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java index f23029ddd96..8c05c85b9ee 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java @@ -69,7 +69,8 @@ public Map getConnections() { * @param url the URL endpoint for SSE communication with the MCP server * @param sseEndpoint the SSE endpoint for the MCP server */ - public record SseParameters(String url, String sseEndpoint) { + public record SseParameters(String url, String sseEndpoint, boolean returnDirect) { + } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStdioClientProperties.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStdioClientProperties.java index 7517f45e858..f86405fbd06 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStdioClientProperties.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStdioClientProperties.java @@ -124,7 +124,9 @@ public record Parameters( /** * Map of environment variables for the server process. */ - @JsonProperty("env") Map env) { + @JsonProperty("env") Map env, + + @JsonProperty("returnDirect") boolean returnDirect) { public ServerParameters toServerParameters() { return ServerParameters.builder(this.command()).args(this.args()).env(this.env()).build(); diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java index 312c5af4e2f..c90dde9434c 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java @@ -68,7 +68,8 @@ public Map getConnections() { * @param url the URL endpoint for Streamable Http communication with the MCP server * @param endpoint the endpoint for the MCP server */ - public record ConnectionParameters(String url, String endpoint) { + public record ConnectionParameters(String url, String endpoint, boolean returnDirect) { + } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java index b3c72aa08b3..b33fff61ad4 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java @@ -105,7 +105,7 @@ void connectionWithNullUrl() { void sseParametersRecord() { String url = "http://test-server:8080/events"; String sseUrl = "/sse"; - McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl); + McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl, false); assertThat(params.url()).isEqualTo(url); assertThat(params.sseEndpoint()).isEqualTo(sseUrl); @@ -114,7 +114,7 @@ void sseParametersRecord() { @Test void sseParametersRecordWithNullSseEndpoint() { String url = "http://test-server:8080/events"; - McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null); + McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null, false); assertThat(params.url()).isEqualTo(url); assertThat(params.sseEndpoint()).isNull(); @@ -150,21 +150,21 @@ void connectionMapManipulation() { // Add a connection connections.put("server1", - new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse")); + new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse", false)); assertThat(properties.getConnections()).hasSize(1); assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080/events"); assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/sse"); // Add another connection connections.put("server2", - new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null)); + new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null, false)); assertThat(properties.getConnections()).hasSize(2); assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081/events"); assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull(); // Replace a connection connections.put("server1", - new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events")); + new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events", false)); assertThat(properties.getConnections()).hasSize(2); assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://newserver:8082/events"); assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events"); @@ -209,13 +209,15 @@ void specialCharactersInConnectionName() { void connectionWithSseEndpoint() { this.contextRunner .withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080", - "spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events") + "spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events", + "spring.ai.mcp.client.sse.connections.server1.return-direct=true") .run(context -> { McpSseClientProperties properties = context.getBean(McpSseClientProperties.class); assertThat(properties.getConnections()).hasSize(1); assertThat(properties.getConnections()).containsKey("server1"); assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080"); assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events"); + assertThat(properties.getConnections().get("server1").returnDirect()).isEqualTo(true); }); } @@ -224,16 +226,20 @@ void multipleConnectionsWithSseEndpoint() { this.contextRunner .withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080", "spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events", + "spring.ai.mcp.client.sse.connections.server1.return-direct=true", "spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081", - "spring.ai.mcp.client.sse.connections.server2.sse-endpoint=/sse") + "spring.ai.mcp.client.sse.connections.server2.sse-endpoint=/sse", + "spring.ai.mcp.client.sse.connections.server2.return-direct=false") .run(context -> { McpSseClientProperties properties = context.getBean(McpSseClientProperties.class); assertThat(properties.getConnections()).hasSize(2); assertThat(properties.getConnections()).containsKeys("server1", "server2"); assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080"); assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events"); + assertThat(properties.getConnections().get("server1").returnDirect()).isEqualTo(true); assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081"); assertThat(properties.getConnections().get("server2").sseEndpoint()).isEqualTo("/sse"); + assertThat(properties.getConnections().get("server2").returnDirect()).isEqualTo(false); }); } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index c4627dcabf5..dec2150f141 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -30,7 +30,6 @@ import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; import io.modelcontextprotocol.server.McpStatelessServerFeatures; import io.modelcontextprotocol.server.McpSyncServerExchange; -import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.Role; @@ -203,6 +202,7 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal .name(toolCallback.getToolDefinition().name()) .description(toolCallback.getToolDefinition().description()) .inputSchema(toolCallback.getToolDefinition().inputSchema()) + .annotations(toToolAnnotations(toolCallback)) .build(); return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> { @@ -222,6 +222,11 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal }); } + private static McpSchema.ToolAnnotations toToolAnnotations(ToolCallback toolCallback) { + Boolean returnDirect = toolCallback.getToolMetadata().returnDirect(); + return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect); + } + /** * Retrieves the MCP exchange object from the provided tool context if it exists. * @param toolContext the tool context from which to retrieve the MCP exchange diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index fc61d801df1..caa0f54ddbb 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -17,10 +17,13 @@ package org.springframework.ai.mcp; import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Tool; import java.util.Map; +import java.util.Optional; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,6 +33,8 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.metadata.DefaultToolMetadata; +import org.springframework.ai.tool.metadata.ToolMetadata; /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool @@ -80,6 +85,24 @@ public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) { } + /** + * Returns the tool metadata for the MCP tool. + *

+ * The tool metadata includes: + *

    + *
  • The tool's return direct flag from the MCP definition
  • + *
+ * @return the tool metadata + */ + @Override + public ToolMetadata getToolMetadata() { + Boolean returnDirect = Optional.ofNullable(tool.annotations()) + .map(McpSchema.ToolAnnotations::returnDirect) + .orElse(false); + + return DefaultToolMetadata.builder().returnDirect(returnDirect).build(); + } + /** * Returns a Spring AI tool definition adapted from the MCP tool. *

diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java index b83db6387b6..7b11baf0db4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java @@ -68,7 +68,7 @@ static Stream openAiCompatibleApis() { .openAiApi(OpenAiApi.builder() .baseUrl("https://api.groq.com/openai") .apiKey(System.getenv("GROQ_API_KEY")) - .build()) + .build()) .defaultOptions(forModelName("llama3-8b-8192")) .build()); }