Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.util.CollectionUtils;

/**
Expand Down Expand Up @@ -108,6 +109,7 @@
@EnableConfigurationProperties(McpClientCommonProperties.class)
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
@Import(McpCompositeClientProperties.class)
public class McpClientAutoConfiguration {

/**
Expand Down Expand Up @@ -146,7 +148,7 @@ private String connectedClientName(String clientName, String serverConnectionNam
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
matchIfMissing = true)
public List<McpSyncClient> mcpSyncClients(McpSyncClientConfigurer mcpSyncClientConfigurer,
McpClientCommonProperties commonProperties,
McpClientCommonProperties commonProperties, McpCompositeClientProperties mcpCompositeClientProperties,
ObjectProvider<List<NamedClientMcpTransport>> transportsProvider) {

List<McpSyncClient> mcpSyncClients = new ArrayList<>();
Expand All @@ -165,7 +167,11 @@ public List<McpSyncClient> mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC
.requestTimeout(commonProperties.getRequestTimeout());

spec = mcpSyncClientConfigurer.configure(namedTransport.name(), spec);

spec.toolAnnotationsHandler(name -> {
// set returnDirect in client level
boolean returnDirect = mcpCompositeClientProperties.getReturnDirect(namedTransport.name());
return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect);
});
var client = spec.build();

if (commonProperties.isInitialized()) {
Expand Down Expand Up @@ -213,7 +219,7 @@ McpSyncClientConfigurer mcpSyncClientConfigurer(ObjectProvider<McpSyncClientCust
@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
public List<McpAsyncClient> mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncClientConfigurer,
McpClientCommonProperties commonProperties,
McpClientCommonProperties commonProperties, McpCompositeClientProperties mcpCompositeClientProperties,
ObjectProvider<List<NamedClientMcpTransport>> transportsProvider) {

List<McpAsyncClient> mcpAsyncClients = new ArrayList<>();
Expand All @@ -232,7 +238,11 @@ public List<McpAsyncClient> mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncCli
.requestTimeout(commonProperties.getRequestTimeout());

spec = mcpAsyncClientConfigurer.configure(namedTransport.name(), spec);

spec.toolAnnotationsHandler(name -> {
// set returnDirect in client level
boolean returnDirect = mcpCompositeClientProperties.getReturnDirect(namedTransport.name());
return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect);
});
var client = spec.build();

if (commonProperties.isInitialized()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.springframework.ai.mcp.client.common.autoconfigure;

import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpSseClientProperties;
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStdioClientProperties;
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.context.annotation.Configuration;

@Configuration
public class McpCompositeClientProperties {

private final ObjectProvider<McpSseClientProperties> sseClientPropertiesObjectProvider;

private final ObjectProvider<McpStdioClientProperties> stdioClientPropertiesObjectProvider;

private final ObjectProvider<McpStreamableHttpClientProperties> streamableHttpClientPropertiesObjectProvider;

public McpCompositeClientProperties(ObjectProvider<McpSseClientProperties> sseClientPropertiesObjectProvider,
ObjectProvider<McpStdioClientProperties> stdioClientPropertiesObjectProvider,
ObjectProvider<McpStreamableHttpClientProperties> streamableHttpClientPropertiesObjectProvider) {
this.sseClientPropertiesObjectProvider = sseClientPropertiesObjectProvider;
this.stdioClientPropertiesObjectProvider = stdioClientPropertiesObjectProvider;
this.streamableHttpClientPropertiesObjectProvider = streamableHttpClientPropertiesObjectProvider;
}

public boolean getReturnDirect(String connectionName) {
McpSseClientProperties sseClientProperties = sseClientPropertiesObjectProvider.getIfAvailable();
if (sseClientProperties != null && sseClientProperties.getConnections().containsKey(connectionName)) {
return sseClientProperties.getConnections().get(connectionName).returnDirect();
}
McpStdioClientProperties stdioClientProperties = stdioClientPropertiesObjectProvider.getIfAvailable();
if (stdioClientProperties != null && stdioClientProperties.getConnections().containsKey(connectionName)) {
return stdioClientProperties.getConnections().get(connectionName).returnDirect();
}
McpStreamableHttpClientProperties streamableHttpClientProperties = streamableHttpClientPropertiesObjectProvider
.getIfAvailable();
if (streamableHttpClientProperties != null
&& streamableHttpClientProperties.getConnections().containsKey(connectionName)) {
return streamableHttpClientProperties.getConnections().get(connectionName).returnDirect();
}
return false;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public Map<String, SseParameters> getConnections() {
* @param url the URL endpoint for SSE communication with the MCP server
* @param sseEndpoint the SSE endpoint for the MCP server
*/
public record SseParameters(String url, String sseEndpoint) {
public record SseParameters(String url, String sseEndpoint, boolean returnDirect) {

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ public record Parameters(
/**
* Map of environment variables for the server process.
*/
@JsonProperty("env") Map<String, String> env) {
@JsonProperty("env") Map<String, String> env,

@JsonProperty("returnDirect") boolean returnDirect) {

public ServerParameters toServerParameters() {
return ServerParameters.builder(this.command()).args(this.args()).env(this.env()).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public Map<String, ConnectionParameters> getConnections() {
* @param url the URL endpoint for Streamable Http communication with the MCP server
* @param endpoint the endpoint for the MCP server
*/
public record ConnectionParameters(String url, String endpoint) {
public record ConnectionParameters(String url, String endpoint, boolean returnDirect) {

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ void connectionWithNullUrl() {
void sseParametersRecord() {
String url = "http://test-server:8080/events";
String sseUrl = "/sse";
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl);
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl, false);

assertThat(params.url()).isEqualTo(url);
assertThat(params.sseEndpoint()).isEqualTo(sseUrl);
Expand All @@ -114,7 +114,7 @@ void sseParametersRecord() {
@Test
void sseParametersRecordWithNullSseEndpoint() {
String url = "http://test-server:8080/events";
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null);
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null, false);

assertThat(params.url()).isEqualTo(url);
assertThat(params.sseEndpoint()).isNull();
Expand Down Expand Up @@ -150,21 +150,21 @@ void connectionMapManipulation() {

// Add a connection
connections.put("server1",
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse"));
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse", false));
assertThat(properties.getConnections()).hasSize(1);
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080/events");
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/sse");

// Add another connection
connections.put("server2",
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null));
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null, false));
assertThat(properties.getConnections()).hasSize(2);
assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081/events");
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();

// Replace a connection
connections.put("server1",
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events"));
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events", false));
assertThat(properties.getConnections()).hasSize(2);
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://newserver:8082/events");
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");
Expand Down Expand Up @@ -209,13 +209,15 @@ void specialCharactersInConnectionName() {
void connectionWithSseEndpoint() {
this.contextRunner
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events")
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events",
"spring.ai.mcp.client.sse.connections.server1.return-direct=true")
.run(context -> {
McpSseClientProperties properties = context.getBean(McpSseClientProperties.class);
assertThat(properties.getConnections()).hasSize(1);
assertThat(properties.getConnections()).containsKey("server1");
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080");
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");
assertThat(properties.getConnections().get("server1").returnDirect()).isEqualTo(true);
});
}

Expand All @@ -224,16 +226,20 @@ void multipleConnectionsWithSseEndpoint() {
this.contextRunner
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events",
"spring.ai.mcp.client.sse.connections.server1.return-direct=true",
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081",
"spring.ai.mcp.client.sse.connections.server2.sse-endpoint=/sse")
"spring.ai.mcp.client.sse.connections.server2.sse-endpoint=/sse",
"spring.ai.mcp.client.sse.connections.server2.return-direct=false")
.run(context -> {
McpSseClientProperties properties = context.getBean(McpSseClientProperties.class);
assertThat(properties.getConnections()).hasSize(2);
assertThat(properties.getConnections()).containsKeys("server1", "server2");
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080");
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");
assertThat(properties.getConnections().get("server1").returnDirect()).isEqualTo(true);
assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081");
assertThat(properties.getConnections().get("server2").sseEndpoint()).isEqualTo("/sse");
assertThat(properties.getConnections().get("server2").returnDirect()).isEqualTo(false);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification;
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.server.McpTransportContext;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.Role;
Expand Down Expand Up @@ -203,6 +202,7 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal
.name(toolCallback.getToolDefinition().name())
.description(toolCallback.getToolDefinition().description())
.inputSchema(toolCallback.getToolDefinition().inputSchema())
.annotations(toToolAnnotations(toolCallback))
.build();

return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> {
Expand All @@ -222,6 +222,11 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal
});
}

private static McpSchema.ToolAnnotations toToolAnnotations(ToolCallback toolCallback) {
Boolean returnDirect = toolCallback.getToolMetadata().returnDirect();
return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect);
}

/**
* Retrieves the MCP exchange object from the provided tool context if it exists.
* @param toolContext the tool context from which to retrieve the MCP exchange
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
package org.springframework.ai.mcp;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import java.util.Map;
import java.util.Optional;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -30,6 +33,8 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.DefaultToolMetadata;
import org.springframework.ai.tool.metadata.ToolMetadata;

/**
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
Expand Down Expand Up @@ -80,6 +85,24 @@ public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) {

}

/**
* Returns the tool metadata for the MCP tool.
* <p>
* The tool metadata includes:
* <ul>
* <li>The tool's return direct flag from the MCP definition</li>
* </ul>
* @return the tool metadata
*/
@Override
public ToolMetadata getToolMetadata() {
Boolean returnDirect = Optional.ofNullable(tool.annotations())
.map(McpSchema.ToolAnnotations::returnDirect)
.orElse(false);

return DefaultToolMetadata.builder().returnDirect(returnDirect).build();
}

/**
* Returns a Spring AI tool definition adapted from the MCP tool.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ static Stream<ChatModel> openAiCompatibleApis() {
.openAiApi(OpenAiApi.builder()
.baseUrl("https://api.groq.com/openai")
.apiKey(System.getenv("GROQ_API_KEY"))
.build())
.build())
.defaultOptions(forModelName("llama3-8b-8192"))
.build());
}
Expand Down