Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import org.springframework.ai.mcp.McpToolFilter;
import org.springframework.ai.mcp.McpToolNamePrefixGenerator;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.ToolContextToMcpMetaConverter;
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.AllNestedConditions;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
Expand All @@ -42,23 +42,6 @@
@Conditional(McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition.class)
public class McpToolCallbackAutoConfiguration {

/**
* Provides a default {@link McpToolNamePrefixGenerator} bean if none is already
* defined.
* <p>
* This generator is used to create uniquely prefixed tool names based on the MCP
* connection information, helping to avoid name collisions when integrating tools
* from multiple MCP servers.
*
* Register the McpToolNamePrefixGenerator.noPrefix() bean to disable the prefixing.
* @return the default McpToolNamePrefixGenerator
*/
@Bean
@ConditionalOnMissingBean
public McpToolNamePrefixGenerator mcpToolNamePrefixGenerator() {
return McpToolNamePrefixGenerator.defaultGenerator();
}

/**
* Creates tool callbacks for all configured MCP clients.
*
Expand All @@ -75,21 +58,35 @@ public McpToolNamePrefixGenerator mcpToolNamePrefixGenerator() {
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
matchIfMissing = true)
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpToolFilter> syncClientsToolFilter,
ObjectProvider<List<McpSyncClient>> syncMcpClients, McpToolNamePrefixGenerator mcpToolNamePrefixGenerator) {
ObjectProvider<List<McpSyncClient>> syncMcpClients,
ObjectProvider<McpToolNamePrefixGenerator> mcpToolNamePrefixGenerator,
ObjectProvider<ToolContextToMcpMetaConverter> toolContextToMcpMetaConverter) {
List<McpSyncClient> mcpClients = syncMcpClients.stream().flatMap(List::stream).toList();
return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)),
mcpToolNamePrefixGenerator, mcpClients);
return SyncMcpToolCallbackProvider.builder()
.mcpClients(mcpClients)
.toolFilter(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)))
.toolNamePrefixGenerator(
mcpToolNamePrefixGenerator.getIfUnique(() -> McpToolNamePrefixGenerator.defaultGenerator()))
.toolContextToMcpMetaConverter(
toolContextToMcpMetaConverter.getIfUnique(() -> ToolContextToMcpMetaConverter.defaultConverter()))
.build();
}

@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<McpToolFilter> asyncClientsToolFilter,
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider,
McpToolNamePrefixGenerator toolNamePrefixGenerator) {
ObjectProvider<McpToolNamePrefixGenerator> toolNamePrefixGenerator,
ObjectProvider<ToolContextToMcpMetaConverter> toolContextToMcpMetaConverter) { // TODO
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
return new AsyncMcpToolCallbackProvider(
asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true), toolNamePrefixGenerator,
mcpClients);
return AsyncMcpToolCallbackProvider.builder()
.toolFilter(asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true))
.toolNamePrefixGenerator(
toolNamePrefixGenerator.getIfUnique(() -> McpToolNamePrefixGenerator.defaultGenerator()))
.toolContextToMcpMetaConverter(
toolContextToMcpMetaConverter.getIfUnique(() -> ToolContextToMcpMetaConverter.defaultConverter()))
.mcpClients(mcpClients)
.build();
}

public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void verifySyncToolCallbackFilterConfiguration() {
}

@Test
void verifyASyncToolCallbackFilterConfiguration() {
void verifyAsyncToolCallbackFilterConfiguration() {
this.contextRunner
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,29 @@

package org.springframework.ai.mcp.client.common.autoconfigure;

import java.util.List;
import java.util.Map;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.McpConnectionInfo;
import org.springframework.ai.mcp.McpToolFilter;
import org.springframework.ai.mcp.McpToolNamePrefixGenerator;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.ToolContextToMcpMetaConverter;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

public class McpToolCallbackAutoConfigurationTests {

Expand Down Expand Up @@ -93,20 +103,26 @@ void enabledMcpToolCallbackAutoConfiguration() {
}

@Test
void defaultMcpToolNamePrefixGeneratorIsCreated() {
// Test with SYNC mode (default)
this.applicationContext.run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
assertThat(generator).isNotNull();
void disabledMcpToolCallbackAutoConfiguration() {
// Test when MCP client is disabled
this.applicationContext.withPropertyValues("spring.ai.mcp.client.enabled=false").run(context -> {
assertThat(context).doesNotHaveBean("mcpToolCallbacks");
assertThat(context).doesNotHaveBean("mcpAsyncToolCallbacks");
});

// Test with ASYNC mode
this.applicationContext.withPropertyValues("spring.ai.mcp.client.type=ASYNC").run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
assertThat(generator).isNotNull();
// Test when toolcallback is disabled
this.applicationContext.withPropertyValues("spring.ai.mcp.client.toolcallback.enabled=false").run(context -> {
assertThat(context).doesNotHaveBean("mcpToolCallbacks");
assertThat(context).doesNotHaveBean("mcpAsyncToolCallbacks");
});

// Test when both are disabled
this.applicationContext
.withPropertyValues("spring.ai.mcp.client.enabled=false", "spring.ai.mcp.client.toolcallback.enabled=false")
.run(context -> {
assertThat(context).doesNotHaveBean("mcpToolCallbacks");
assertThat(context).doesNotHaveBean("mcpAsyncToolCallbacks");
});
}

@Test
Expand Down Expand Up @@ -137,28 +153,97 @@ void customMcpToolNamePrefixGeneratorOverridesDefault() {
}

@Test
void mcpToolNamePrefixGeneratorIsInjectedIntoProviders() {
// Test SYNC provider receives the generator
this.applicationContext.run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
void customMcpToolFilterOverridesDefault() {
// Test with SYNC mode
this.applicationContext.withUserConfiguration(CustomToolFilterConfig.class).run(context -> {
assertThat(context).hasBean("customToolFilter");
McpToolFilter filter = context.getBean("customToolFilter", McpToolFilter.class);
assertThat(filter).isInstanceOf(CustomToolFilter.class);
assertThat(context).hasBean("mcpToolCallbacks");
// Verify the custom filter is injected into the provider
SyncMcpToolCallbackProvider provider = context.getBean(SyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});

McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
// Test with ASYNC mode
this.applicationContext.withUserConfiguration(CustomToolFilterConfig.class)
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
.run(context -> {
assertThat(context).hasBean("customToolFilter");
McpToolFilter filter = context.getBean("customToolFilter", McpToolFilter.class);
assertThat(filter).isInstanceOf(CustomToolFilter.class);
assertThat(context).hasBean("mcpAsyncToolCallbacks");
// Verify the custom filter is injected into the provider
AsyncMcpToolCallbackProvider provider = context.getBean(AsyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});
}

@Test
void customToolContextToMcpMetaConverterOverridesDefault() {
// Test with SYNC mode
this.applicationContext.withUserConfiguration(CustomConverterConfig.class).run(context -> {
assertThat(context).hasBean("customConverter");
ToolContextToMcpMetaConverter converter = context.getBean("customConverter",
ToolContextToMcpMetaConverter.class);
assertThat(converter).isInstanceOf(CustomToolContextToMcpMetaConverter.class);
assertThat(context).hasBean("mcpToolCallbacks");
// Verify the custom converter is injected into the provider
SyncMcpToolCallbackProvider provider = context.getBean(SyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});

// Test with ASYNC mode
this.applicationContext.withUserConfiguration(CustomConverterConfig.class)
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
.run(context -> {
assertThat(context).hasBean("customConverter");
ToolContextToMcpMetaConverter converter = context.getBean("customConverter",
ToolContextToMcpMetaConverter.class);
assertThat(converter).isInstanceOf(CustomToolContextToMcpMetaConverter.class);
assertThat(context).hasBean("mcpAsyncToolCallbacks");
// Verify the custom converter is injected into the provider
AsyncMcpToolCallbackProvider provider = context.getBean(AsyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});
}

@Test
void providersCreatedWithMcpClients() {
// Test SYNC mode with MCP clients
this.applicationContext.withUserConfiguration(McpSyncClientConfig.class).run(context -> {
assertThat(context).hasBean("mcpToolCallbacks");
assertThat(context).hasBean("mcpSyncClient1");
assertThat(context).hasBean("mcpSyncClient2");
SyncMcpToolCallbackProvider provider = context.getBean(SyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});

// Test ASYNC mode with MCP clients
this.applicationContext.withUserConfiguration(McpAsyncClientConfig.class)
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
.run(context -> {
assertThat(context).hasBean("mcpAsyncToolCallbacks");
assertThat(context).hasBean("mcpAsyncClient1");
assertThat(context).hasBean("mcpAsyncClient2");
AsyncMcpToolCallbackProvider provider = context.getBean(AsyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});
}

assertThat(generator).isNotNull();
@Test
void providersCreatedWithoutMcpClients() {
// Test SYNC mode without MCP clients
this.applicationContext.run(context -> {
assertThat(context).hasBean("mcpToolCallbacks");
SyncMcpToolCallbackProvider provider = context.getBean(SyncMcpToolCallbackProvider.class);
assertThat(provider).isNotNull();
});

// Test ASYNC provider receives the generator
// Test ASYNC mode without MCP clients
this.applicationContext.withPropertyValues("spring.ai.mcp.client.type=ASYNC").run(context -> {
assertThat(context).hasBean("mcpToolNamePrefixGenerator");
assertThat(context).hasBean("mcpAsyncToolCallbacks");

McpToolNamePrefixGenerator generator = context.getBean(McpToolNamePrefixGenerator.class);
AsyncMcpToolCallbackProvider provider = context.getBean(AsyncMcpToolCallbackProvider.class);

assertThat(generator).isNotNull();
assertThat(provider).isNotNull();
});
}
Expand All @@ -182,4 +267,84 @@ public String prefixedToolName(McpConnectionInfo mcpConnInfo, Tool tool) {

}

@Configuration
static class CustomToolFilterConfig {

@Bean
public McpToolFilter customToolFilter() {
return new CustomToolFilter();
}

}

static class CustomToolFilter implements McpToolFilter {

@Override
public boolean test(McpConnectionInfo metadata, McpSchema.Tool tool) {
// Custom filter logic
return !tool.name().startsWith("excluded_");
}

}

@Configuration
static class CustomConverterConfig {

@Bean
public ToolContextToMcpMetaConverter customConverter() {
return new CustomToolContextToMcpMetaConverter();
}

}

static class CustomToolContextToMcpMetaConverter implements ToolContextToMcpMetaConverter {

@Override
public Map<String, Object> convert(ToolContext toolContext) {
// Custom conversion logic
return Map.of("custom", "metadata");
}

}

@Configuration
static class McpSyncClientConfig {

@Bean
public List<McpSyncClient> mcpSyncClients() {
return List.of(mcpSyncClient1(), mcpSyncClient2());
}

@Bean
public McpSyncClient mcpSyncClient1() {
return mock(McpSyncClient.class);
}

@Bean
public McpSyncClient mcpSyncClient2() {
return mock(McpSyncClient.class);
}

}

@Configuration
static class McpAsyncClientConfig {

@Bean
public List<McpAsyncClient> mcpAsyncClients() {
return List.of(mcpAsyncClient1(), mcpAsyncClient2());
}

@Bean
public McpAsyncClient mcpAsyncClient1() {
return mock(McpAsyncClient.class);
}

@Bean
public McpAsyncClient mcpAsyncClient2() {
return mock(McpAsyncClient.class);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,11 @@ List<ToolCallback> testTool() {
Mockito.when(mockClient.callTool(Mockito.any(McpSchema.CallToolRequest.class))).thenReturn(mockResult);
when(mockClient.getClientInfo()).thenReturn(new McpSchema.Implementation("testClient", "1.0.0"));

return List.of(new SyncMcpToolCallback(mockClient, mockTool, mockTool.name()));
return List.of(SyncMcpToolCallback.builder()
.mcpClient(mockClient)
.tool(mockTool)
.prefixedToolName(mockTool.name())
.build());
}

}
Expand All @@ -413,7 +417,11 @@ ToolCallbackProvider testToolCallbackProvider() {
Mockito.when(mockTool.description()).thenReturn("Provider Tool");
when(mockClient.getClientInfo()).thenReturn(new McpSchema.Implementation("testClient", "1.0.0"));

return new ToolCallback[] { new SyncMcpToolCallback(mockClient, mockTool, mockTool.name()) };
return new ToolCallback[] { SyncMcpToolCallback.builder()
.mcpClient(mockClient)
.tool(mockTool)
.prefixedToolName(mockTool.name())
.build() };
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ List<ToolCallback> testTool() {
Mockito.when(mockClient.callTool(Mockito.any(McpSchema.CallToolRequest.class))).thenReturn(mockResult);
when(mockClient.getClientInfo()).thenReturn(new McpSchema.Implementation("testClient", "1.0.0"));

return List.of(new SyncMcpToolCallback(mockClient, mockTool));
return List.of(SyncMcpToolCallback.builder().mcpClient(mockClient).tool(mockTool).build());
}

}
Expand All @@ -363,7 +363,8 @@ ToolCallbackProvider testToolCallbackProvider() {
Mockito.when(mockTool.description()).thenReturn("Provider Tool");
when(mockClient.getClientInfo()).thenReturn(new McpSchema.Implementation("testClient", "1.0.0"));

return new ToolCallback[] { new SyncMcpToolCallback(mockClient, mockTool) };
return new ToolCallback[] {
SyncMcpToolCallback.builder().mcpClient(mockClient).tool(mockTool).build() };
};
}

Expand Down
Loading