Skip to content

Commit a28bb41

Browse files
committed
Introduce MCP ClientMetadata and use it for filtering
- Add McpClientMetadata record which contains MCP client/server meta data that can be used for filtering the toolcallbacks - This provides a convenient approach to handling just the metadata from the client - Update the auto-configuration and tests Signed-off-by: Ilayaperumal Gopinathan <[email protected]>
1 parent e5be113 commit a28bb41

File tree

7 files changed

+51
-68
lines changed

7 files changed

+51
-68
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,14 @@
1717
package org.springframework.ai.mcp.client.common.autoconfigure;
1818

1919
import java.util.List;
20-
import java.util.function.BiPredicate;
2120

2221
import io.modelcontextprotocol.client.McpAsyncClient;
2322
import io.modelcontextprotocol.client.McpSyncClient;
24-
import io.modelcontextprotocol.spec.McpSchema;
2523

2624
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
27-
import org.springframework.ai.mcp.McpAsyncClientBiPredicate;
28-
import org.springframework.ai.mcp.McpSyncClientBiPredicate;
25+
import org.springframework.ai.mcp.McpClientBiPredicate;
2926
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
3027
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties;
31-
import org.springframework.ai.tool.annotation.Tool;
3228
import org.springframework.beans.factory.ObjectProvider;
3329
import org.springframework.boot.autoconfigure.AutoConfiguration;
3430
import org.springframework.boot.autoconfigure.condition.AllNestedConditions;
@@ -50,15 +46,15 @@ public class McpToolCallbackAutoConfiguration {
5046
* <p>
5147
* These callbacks enable integration with Spring AI's tool execution framework,
5248
* allowing MCP tools to be used as part of AI interactions.
53-
* @param syncClientsToolFilter list of {@link McpSyncClientBiPredicate}s for the sync
49+
* @param syncClientsToolFilter list of {@link McpClientBiPredicate}s for the sync
5450
* client to filter the discovered tools
5551
* @param syncMcpClients provider of MCP sync clients
5652
* @return list of tool callbacks for MCP integration
5753
*/
5854
@Bean
5955
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
6056
matchIfMissing = true)
61-
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpSyncClientBiPredicate> syncClientsToolFilter,
57+
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpClientBiPredicate> syncClientsToolFilter,
6258
ObjectProvider<List<McpSyncClient>> syncMcpClients) {
6359
List<McpSyncClient> mcpClients = syncMcpClients.stream().flatMap(List::stream).toList();
6460
return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)),
@@ -68,7 +64,7 @@ public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpSyncClient
6864
@Bean
6965
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
7066
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(
71-
ObjectProvider<McpAsyncClientBiPredicate> asyncClientsToolFilter,
67+
ObjectProvider<McpClientBiPredicate> asyncClientsToolFilter,
7268
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
7369
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
7470
return new AsyncMcpToolCallbackProvider(

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
import reactor.core.publisher.Mono;
2727

2828
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
29-
import org.springframework.ai.mcp.McpAsyncClientBiPredicate;
30-
import org.springframework.ai.mcp.McpSyncClientBiPredicate;
29+
import org.springframework.ai.mcp.McpClientBiPredicate;
30+
import org.springframework.ai.mcp.McpClientMetadata;
3131
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
3232
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition;
3333
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -89,14 +89,14 @@ void doesMatchWhenBothPropertiesAreMissing() {
8989
@Test
9090
void verifySyncToolCallbackFilterConfiguration() {
9191
this.contextRunner
92-
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpSyncClientFilterConfiguration.class)
92+
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
9393
.withPropertyValues("spring.ai.mcp.client.type=SYNC")
9494
.run(context -> {
95-
assertThat(context).hasBean("syncClientFilter");
95+
assertThat(context).hasBean("mcpClientFilter");
9696
SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class);
9797
Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
9898
field.setAccessible(true);
99-
McpSyncClientBiPredicate toolFilter = (McpSyncClientBiPredicate) field.get(toolCallbackProvider);
99+
McpClientBiPredicate toolFilter = (McpClientBiPredicate) field.get(toolCallbackProvider);
100100
McpSyncClient syncClient1 = mock(McpSyncClient.class);
101101
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
102102
when(syncClient1.getClientInfo()).thenReturn(clientInfo1);
@@ -107,22 +107,24 @@ void verifySyncToolCallbackFilterConfiguration() {
107107
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
108108
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
109109
when(syncClient1.listTools()).thenReturn(listToolsResult1);
110-
assertThat(toolFilter.test(syncClient1, tool1)).isFalse();
111-
assertThat(toolFilter.test(syncClient1, tool2)).isTrue();
110+
assertThat(toolFilter.test(new McpClientMetadata(null, syncClient1.getClientInfo(), null), tool1))
111+
.isFalse();
112+
assertThat(toolFilter.test(new McpClientMetadata(null, syncClient1.getClientInfo(), null), tool2))
113+
.isTrue();
112114
});
113115
}
114116

115117
@Test
116118
void verifyASyncToolCallbackFilterConfiguration() {
117119
this.contextRunner
118-
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpAsyncClientFilterConfiguration.class)
120+
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
119121
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
120122
.run(context -> {
121-
assertThat(context).hasBean("asyncClientFilter");
123+
assertThat(context).hasBean("mcpClientFilter");
122124
AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class);
123125
Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
124126
field.setAccessible(true);
125-
McpAsyncClientBiPredicate toolFilter = (McpAsyncClientBiPredicate) field.get(toolCallbackProvider);
127+
McpClientBiPredicate toolFilter = (McpClientBiPredicate) field.get(toolCallbackProvider);
126128
McpAsyncClient asyncClient1 = mock(McpAsyncClient.class);
127129
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
128130
when(asyncClient1.getClientInfo()).thenReturn(clientInfo1);
@@ -133,8 +135,10 @@ void verifyASyncToolCallbackFilterConfiguration() {
133135
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
134136
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
135137
when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1));
136-
assertThat(toolFilter.test(asyncClient1, tool1)).isFalse();
137-
assertThat(toolFilter.test(asyncClient1, tool2)).isTrue();
138+
assertThat(toolFilter.test(new McpClientMetadata(null, asyncClient1.getClientInfo(), null), tool1))
139+
.isFalse();
140+
assertThat(toolFilter.test(new McpClientMetadata(null, asyncClient1.getClientInfo(), null), tool2))
141+
.isTrue();
138142
});
139143
}
140144

@@ -150,32 +154,14 @@ String testBean() {
150154
}
151155

152156
@Configuration
153-
static class McpSyncClientFilterConfiguration {
157+
static class McpClientFilterConfiguration {
154158

155159
@Bean
156-
McpSyncClientBiPredicate syncClientFilter() {
157-
return new McpSyncClientBiPredicate() {
160+
McpClientBiPredicate mcpClientFilter() {
161+
return new McpClientBiPredicate() {
158162
@Override
159-
public boolean test(McpSyncClient mcpSyncClient, McpSchema.Tool tool) {
160-
if (mcpSyncClient.getClientInfo().name().equals("client1") && tool.name().contains("tool1")) {
161-
return false;
162-
}
163-
return true;
164-
}
165-
};
166-
}
167-
168-
}
169-
170-
@Configuration
171-
static class McpAsyncClientFilterConfiguration {
172-
173-
@Bean
174-
McpAsyncClientBiPredicate asyncClientFilter() {
175-
return new McpAsyncClientBiPredicate() {
176-
@Override
177-
public boolean test(McpAsyncClient mcpAsyncClient, McpSchema.Tool tool) {
178-
if (mcpAsyncClient.getClientInfo().name().equals("client1") && tool.name().contains("tool1")) {
163+
public boolean test(McpClientMetadata clientMetadata, McpSchema.Tool tool) {
164+
if (clientMetadata.clientInfo().name().equals("client1") && tool.name().contains("tool1")) {
179165
return false;
180166
}
181167
return true;

mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121

2222
import io.modelcontextprotocol.client.McpAsyncClient;
23+
import io.modelcontextprotocol.spec.McpSchema;
2324
import io.modelcontextprotocol.util.Assert;
2425
import reactor.core.publisher.Flux;
2526

@@ -72,21 +73,21 @@
7273
*/
7374
public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {
7475

75-
private final List<McpAsyncClient> mcpClients;
76+
private final McpClientBiPredicate toolFilter;
7677

77-
private final McpAsyncClientBiPredicate toolFilter;
78+
private final List<McpAsyncClient> mcpClients;
7879

7980
/**
8081
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
8182
* clients.
8283
* @param toolFilter a filter to apply to each discovered tool
8384
* @param mcpClients the list of MCP clients to use for discovering tools
8485
*/
85-
public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, List<McpAsyncClient> mcpClients) {
86+
public AsyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, List<McpAsyncClient> mcpClients) {
8687
Assert.notNull(mcpClients, "MCP clients must not be null");
8788
Assert.notNull(toolFilter, "Tool filter must not be null");
88-
this.mcpClients = mcpClients;
8989
this.toolFilter = toolFilter;
90+
this.mcpClients = mcpClients;
9091
}
9192

9293
/**
@@ -107,7 +108,7 @@ public AsyncMcpToolCallbackProvider(List<McpAsyncClient> mcpClients) {
107108
* @param toolFilter a filter to apply to each discovered tool
108109
* @param mcpClients the MCP clients to use for discovering tools
109110
*/
110-
public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, McpAsyncClient... mcpClients) {
111+
public AsyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, McpAsyncClient... mcpClients) {
111112
this(toolFilter, List.of(mcpClients));
112113
}
113114

@@ -145,7 +146,8 @@ public ToolCallback[] getToolCallbacks() {
145146
ToolCallback[] toolCallbacks = mcpClient.listTools()
146147
.map(response -> response.tools()
147148
.stream()
148-
.filter(tool -> this.toolFilter.test(mcpClient, tool))
149+
.filter(tool -> this.toolFilter.test(new McpClientMetadata(mcpClient.getClientCapabilities(),
150+
mcpClient.getClientInfo(), mcpClient.initialize().block()), tool))
149151
.map(tool -> new AsyncMcpToolCallback(mcpClient, tool))
150152
.toArray(ToolCallback[]::new))
151153
.block();

mcp/common/src/main/java/org/springframework/ai/mcp/McpSyncClientBiPredicate.java renamed to mcp/common/src/main/java/org/springframework/ai/mcp/McpClientBiPredicate.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
import io.modelcontextprotocol.spec.McpSchema;
2323

2424
/**
25-
* A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} to filter the discovered
26-
* tool for the given {@link McpSyncClient}.
25+
* A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} and the
26+
* {@link AsyncMcpToolCallbackProvider} to filter the discovered tool for the given
27+
* {@link McpClientMetadata}.
2728
*
2829
* @author Ilayaperumal Gopinathan
2930
*/
30-
public interface McpSyncClientBiPredicate extends BiPredicate<McpSyncClient, McpSchema.Tool> {
31+
public interface McpClientBiPredicate extends BiPredicate<McpClientMetadata, McpSchema.Tool> {
3132

3233
}
Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@
1616

1717
package org.springframework.ai.mcp;
1818

19-
import java.util.function.BiPredicate;
20-
21-
import io.modelcontextprotocol.client.McpAsyncClient;
2219
import io.modelcontextprotocol.spec.McpSchema;
2320

2421
/**
25-
* A {@link BiPredicate} for {@link AsyncMcpToolCallbackProvider} to filter the discovered
26-
* tool for the given {@link McpAsyncClient}.
22+
* MCP client metadata record containing the client/server specific data.
2723
*
2824
* @author Ilayaperumal Gopinathan
2925
*/
30-
public interface McpAsyncClientBiPredicate extends BiPredicate<McpAsyncClient, McpSchema.Tool> {
31-
26+
public record McpClientMetadata(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo,
27+
McpSchema.InitializeResult initializeResult) {
3228
}

mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider {
7070

7171
private final List<McpSyncClient> mcpClients;
7272

73-
private final McpSyncClientBiPredicate toolFilter;
73+
private final McpClientBiPredicate toolFilter;
7474

7575
/**
7676
* Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP
7777
* clients.
7878
* @param mcpClients the list of MCP clients to use for discovering tools
7979
* @param toolFilter a filter to apply to each discovered tool
8080
*/
81-
public SyncMcpToolCallbackProvider(McpSyncClientBiPredicate toolFilter, List<McpSyncClient> mcpClients) {
81+
public SyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, List<McpSyncClient> mcpClients) {
8282
Assert.notNull(mcpClients, "MCP clients must not be null");
8383
Assert.notNull(toolFilter, "Tool filter must not be null");
8484
this.mcpClients = mcpClients;
@@ -100,7 +100,7 @@ public SyncMcpToolCallbackProvider(List<McpSyncClient> mcpClients) {
100100
* @param mcpClients the MCP clients to use for discovering tools
101101
* @param toolFilter a filter to apply to each discovered tool
102102
*/
103-
public SyncMcpToolCallbackProvider(McpSyncClientBiPredicate toolFilter, McpSyncClient... mcpClients) {
103+
public SyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, McpSyncClient... mcpClients) {
104104
this(toolFilter, List.of(mcpClients));
105105
}
106106

@@ -131,7 +131,8 @@ public ToolCallback[] getToolCallbacks() {
131131
.flatMap(mcpClient -> mcpClient.listTools()
132132
.tools()
133133
.stream()
134-
.filter(tool -> this.toolFilter.test(mcpClient, tool))
134+
.filter(tool -> this.toolFilter.test(new McpClientMetadata(mcpClient.getClientCapabilities(),
135+
mcpClient.getClientInfo(), mcpClient.initialize()), tool))
135136
.map(tool -> new SyncMcpToolCallback(mcpClient, tool)))
136137
.toArray(ToolCallback[]::new);
137138
validateToolCallbacks(array);

mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ void toolFilterShouldRejectAllToolsWhenConfigured() {
163163
when(this.mcpClient.listTools()).thenReturn(listToolsResult);
164164

165165
// Create a filter that rejects all tools
166-
McpSyncClientBiPredicate rejectAllFilter = (client, tool) -> false;
166+
McpClientBiPredicate rejectAllFilter = (client, tool) -> false;
167167

168168
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(rejectAllFilter, this.mcpClient);
169169

@@ -191,7 +191,7 @@ void toolFilterShouldFilterToolsByNameWhenConfigured() {
191191
when(this.mcpClient.listTools()).thenReturn(listToolsResult);
192192

193193
// Create a filter that only accepts tools with names containing "2" or "3"
194-
McpSyncClientBiPredicate nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3");
194+
McpClientBiPredicate nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3");
195195

196196
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(nameFilter, this.mcpClient);
197197

@@ -226,7 +226,8 @@ void toolFilterShouldFilterToolsByClientWhenConfigured() {
226226
when(mcpClient2.getClientInfo()).thenReturn(clientInfo2);
227227

228228
// Create a filter that only accepts tools from client1
229-
McpSyncClientBiPredicate clientFilter = (client, tool) -> client.getClientInfo().name().equals("testClient1");
229+
McpClientBiPredicate clientFilter = (clientMetadata,
230+
tool) -> clientMetadata.clientInfo().name().equals("testClient1");
230231

231232
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(clientFilter, mcpClient1, mcpClient2);
232233

@@ -253,8 +254,8 @@ void toolFilterShouldCombineClientAndToolCriteriaWhenConfigured() {
253254
when(weatherClient.getClientInfo()).thenReturn(weatherClientInfo);
254255

255256
// Create a filter that only accepts weather tools from the weather service
256-
McpSyncClientBiPredicate complexFilter = (client,
257-
tool) -> client.getClientInfo().name().equals("weather-service") && tool.name().equals("weather");
257+
McpClientBiPredicate complexFilter = (client, tool) -> client.clientInfo().name().equals("weather-service")
258+
&& tool.name().equals("weather");
258259

259260
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(complexFilter, weatherClient);
260261

0 commit comments

Comments
 (0)