Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@

import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.McpToolFilter;
import org.springframework.ai.mcp.McpToolNamePrefixGenerator;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.AllNestedConditions;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
Expand All @@ -40,6 +42,23 @@
@Conditional(McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition.class)
public class McpToolCallbackAutoConfiguration {

/**
* Provides a default {@link McpToolNamePrefixGenerator} bean if none is already
* defined.
* <p>
* This generator is used to create uniquely prefixed tool names based on the MCP
* connection information, helping to avoid name collisions when integrating tools
* from multiple MCP servers.
*
* Register the McpToolNamePrefixGenerator.noPrefix() bean to disable the prefixing.
* @return the default McpToolNamePrefixGenerator
*/
@Bean
@ConditionalOnMissingBean
public McpToolNamePrefixGenerator mcpToolNamePrefixGenerator() {
return McpToolNamePrefixGenerator.defaultGenerator();
}

/**
* Creates tool callbacks for all configured MCP clients.
*
Expand All @@ -49,25 +68,28 @@ public class McpToolCallbackAutoConfiguration {
* @param syncClientsToolFilter list of {@link McpToolFilter}s for the sync client to
* filter the discovered tools
* @param syncMcpClients provider of MCP sync clients
* @param mcpToolNamePrefixGenerator the tool name prefix generator
* @return list of tool callbacks for MCP integration
*/
@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
matchIfMissing = true)
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpToolFilter> syncClientsToolFilter,
ObjectProvider<List<McpSyncClient>> syncMcpClients) {
ObjectProvider<List<McpSyncClient>> syncMcpClients, McpToolNamePrefixGenerator mcpToolNamePrefixGenerator) {
List<McpSyncClient> mcpClients = syncMcpClients.stream().flatMap(List::stream).toList();
return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)),
mcpClients);
mcpToolNamePrefixGenerator, mcpClients);
}

@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<McpToolFilter> asyncClientsToolFilter,
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider,
McpToolNamePrefixGenerator toolNamePrefixGenerator) {
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
return new AsyncMcpToolCallbackProvider(
asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true), mcpClients);
asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true), toolNamePrefixGenerator,
mcpClients);
}

public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import reactor.core.publisher.Mono;

import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.McpMetadata;
import org.springframework.ai.mcp.McpConnectionInfo;
import org.springframework.ai.mcp.McpToolFilter;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition;
Expand Down Expand Up @@ -107,8 +107,10 @@ void verifySyncToolCallbackFilterConfiguration() {
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
when(syncClient1.listTools()).thenReturn(listToolsResult1);
assertThat(toolFilter.test(new McpMetadata(null, syncClient1.getClientInfo(), null), tool1)).isFalse();
assertThat(toolFilter.test(new McpMetadata(null, syncClient1.getClientInfo(), null), tool2)).isTrue();
assertThat(toolFilter.test(new McpConnectionInfo(null, syncClient1.getClientInfo(), null), tool1))
.isFalse();
assertThat(toolFilter.test(new McpConnectionInfo(null, syncClient1.getClientInfo(), null), tool2))
.isTrue();
});
}

Expand All @@ -133,8 +135,10 @@ void verifyASyncToolCallbackFilterConfiguration() {
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1));
assertThat(toolFilter.test(new McpMetadata(null, asyncClient1.getClientInfo(), null), tool1)).isFalse();
assertThat(toolFilter.test(new McpMetadata(null, asyncClient1.getClientInfo(), null), tool2)).isTrue();
assertThat(toolFilter.test(new McpConnectionInfo(null, asyncClient1.getClientInfo(), null), tool1))
.isFalse();
assertThat(toolFilter.test(new McpConnectionInfo(null, asyncClient1.getClientInfo(), null), tool2))
.isTrue();
});
}

Expand All @@ -156,7 +160,7 @@ static class McpClientFilterConfiguration {
McpToolFilter mcpClientFilter() {
return new McpToolFilter() {
@Override
public boolean test(McpMetadata metadata, McpSchema.Tool tool) {
public boolean test(McpConnectionInfo metadata, McpSchema.Tool tool) {
if (metadata.clientInfo().name().equals("client1") && tool.name().contains("tool1")) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@

package org.springframework.ai.mcp.client.common.autoconfigure;

import io.modelcontextprotocol.spec.McpSchema.Tool;
import org.junit.jupiter.api.Test;

import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.McpConnectionInfo;
import org.springframework.ai.mcp.McpToolNamePrefixGenerator;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -85,4 +92,94 @@ void enabledMcpToolCallbackAutoConfiguration() {
});
}

@Test
void defaultMcpToolNamePrefixGeneratorIsCreated() {
// Test with SYNC mode (default)
this.applicationContext.run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
assertThat(generator).isNotNull();
});

// Test with ASYNC mode
this.applicationContext.withPropertyValues("spring.ai.mcp.client.type=ASYNC").run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
assertThat(generator).isNotNull();
});
}

@Test
void customMcpToolNamePrefixGeneratorOverridesDefault() {
// Test with SYNC mode
this.applicationContext.withUserConfiguration(CustomPrefixGeneratorConfig.class).run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
assertThat(generator).isInstanceOf(CustomPrefixGenerator.class);
assertThat(context).hasBean("mcpToolCallbacks");
// Verify the custom generator is injected into the provider
SyncMcpToolCallbackProvider provider = context.getBean(SyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});

// Test with ASYNC mode
this.applicationContext.withUserConfiguration(CustomPrefixGeneratorConfig.class)
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
.run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
assertThat(generator).isInstanceOf(CustomPrefixGenerator.class);
assertThat(context).hasBean("mcpAsyncToolCallbacks");
// Verify the custom generator is injected into the provider
AsyncMcpToolCallbackProvider provider = context.getBean(AsyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});
}

@Test
void mcpToolNamePrefixGeneratorIsInjectedIntoProviders() {
// Test SYNC provider receives the generator
this.applicationContext.run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
assertThat(context).hasBean("mcpToolCallbacks");

McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
SyncMcpToolCallbackProvider provider = context.getBean(SyncMcpToolCallbackProvider.class);

assertThat(generator).isNotNull();
assertThat(provider).isNotNull();
});

// Test ASYNC provider receives the generator
this.applicationContext.withPropertyValues("spring.ai.mcp.client.type=ASYNC").run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
assertThat(context).hasBean("mcpAsyncToolCallbacks");

McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
AsyncMcpToolCallbackProvider provider = context.getBean(AsyncMcpToolCallbackProvider.class);

assertThat(generator).isNotNull();
assertThat(provider).isNotNull();
});
}

@Configuration
static class CustomPrefixGeneratorConfig {

@Bean
public McpToolNamePrefixGenerator mcpToolNamePrefixGenerator() {
return new CustomPrefixGenerator();
}

}

static class CustomPrefixGenerator implements McpToolNamePrefixGenerator {

@Override
public String prefixedToolName(McpConnectionInfo mcpConnInfo, Tool tool) {
return "custom_" + tool.name();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ List<ToolCallback> testTool() {
Mockito.when(mockClient.callTool(Mockito.any(McpSchema.CallToolRequest.class))).thenReturn(mockResult);
when(mockClient.getClientInfo()).thenReturn(new McpSchema.Implementation("testClient", "1.0.0"));

return List.of(new SyncMcpToolCallback(mockClient, mockTool));
return List.of(new SyncMcpToolCallback(mockClient, mockTool, mockTool.name()));
}

}
Expand All @@ -413,7 +413,7 @@ ToolCallbackProvider testToolCallbackProvider() {
Mockito.when(mockTool.description()).thenReturn("Provider Tool");
when(mockClient.getClientInfo()).thenReturn(new McpSchema.Implementation("testClient", "1.0.0"));

return new ToolCallback[] { new SyncMcpToolCallback(mockClient, mockTool) };
return new ToolCallback[] { new SyncMcpToolCallback(mockClient, mockTool, mockTool.name()) };
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
Expand Down Expand Up @@ -71,14 +72,33 @@ public class AsyncMcpToolCallback implements ToolCallback {

private final Tool tool;

private final String prefixedToolName;

/**
* Creates a new {@code AsyncMcpToolCallback} instance.
* @param mcpClient the MCP client to use for tool execution
* @param tool the MCP tool definition to adapt
* @deprecated use {@link #AsyncMcpToolCallback(McpAsyncClient, Tool, String)}
*/
@Deprecated
public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool) {
this(mcpClient, tool, McpToolUtils.prefixedToolName(mcpClient.getClientInfo().name(), tool.name()));
}

/**
* Creates a new {@code AsyncMcpToolCallback} instance.
* @param mcpClient the MCP client to use for tool execution
* @param tool the MCP tool definition to adapt
* @param prefixedToolName the prefixed tool name to use in the tool definition.
*/
public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool, String prefixedToolName) {
Assert.notNull(mcpClient, "MCP client must not be null");
Assert.notNull(tool, "MCP tool must not be null");
Assert.hasText(prefixedToolName, "Prefixed tool name must not be empty");

this.asyncMcpClient = mcpClient;
this.tool = tool;
this.prefixedToolName = prefixedToolName;
}

/**
Expand All @@ -95,7 +115,7 @@ public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool) {
@Override
public ToolDefinition getToolDefinition() {
return DefaultToolDefinition.builder()
.name(McpToolUtils.prefixedToolName(this.asyncMcpClient.getClientInfo().name(), this.tool.name()))
.name(this.prefixedToolName)
.description(this.tool.description())
.inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema()))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,37 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {

private final List<McpAsyncClient> mcpClients;

private final McpToolNamePrefixGenerator toolNamePrefixGenerator;

/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
* clients.
* @param toolFilter a filter to apply to each discovered tool
* @param mcpClients the list of MCP clients to use for discovering tools
* @deprecated use
* {@link #AsyncMcpToolCallbackProvider(McpToolFilter, McpToolNamePrefixGenerator, List)}
*/
@Deprecated
public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, List<McpAsyncClient> mcpClients) {
this(toolFilter, McpToolNamePrefixGenerator.defaultGenerator(), mcpClients);
}

/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
* clients.
* @param toolFilter a filter to apply to each discovered tool
* @param toolNamePrefixGenerator the tool name prefix generator to use when creating
* tool callbacks.
* @param mcpClients the list of MCP clients to use for discovering tools
*/
public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpToolNamePrefixGenerator toolNamePrefixGenerator,
List<McpAsyncClient> mcpClients) {
Assert.notNull(mcpClients, "MCP clients must not be null");
Assert.notNull(toolFilter, "Tool filter must not be null");
Assert.notNull(toolNamePrefixGenerator, "Tool name prefix generator must not be null");
this.toolFilter = toolFilter;
this.mcpClients = mcpClients;
this.toolNamePrefixGenerator = toolNamePrefixGenerator;
}

/**
Expand Down Expand Up @@ -145,9 +165,20 @@ public ToolCallback[] getToolCallbacks() {
ToolCallback[] toolCallbacks = mcpClient.listTools()
.map(response -> response.tools()
.stream()
.filter(tool -> this.toolFilter.test(new McpMetadata(mcpClient.getClientCapabilities(),
mcpClient.getClientInfo(), mcpClient.getCurrentInitializationResult()), tool))
.map(tool -> new AsyncMcpToolCallback(mcpClient, tool))
.filter(tool -> this.toolFilter.test(McpConnectionInfo.builder()
.clientCapabilities(mcpClient.getClientCapabilities())
.clientInfo(mcpClient.getClientInfo())
.initializeResult(mcpClient.getCurrentInitializationResult())
.build(), tool))
.map(tool -> {
McpConnectionInfo connectionInfo = McpConnectionInfo.builder()
.clientCapabilities(mcpClient.getClientCapabilities())
.clientInfo(mcpClient.getClientInfo())
.initializeResult(mcpClient.getCurrentInitializationResult())
.build();
return new AsyncMcpToolCallback(mcpClient, tool,
this.toolNamePrefixGenerator.prefixedToolName(connectionInfo, tool));
})
.toArray(ToolCallback[]::new))
.block();

Expand Down
Loading