Skip to content

Commit 523315b

Browse files
committed
Add MCP metadata which would contain MCP client and server metadata
- This provides a convenience for the filter to operate against metadata Add McpToolFilter which is of type BiPredicate<McpMetadata, McpSchema.Tool> - The filter configuration would look like this: @configuration static class McpClientFilterConfiguration { @bean McpToolFilter mcpClientFilter() { return new McpToolFilter() { @OverRide public boolean test(McpMetadata metadata, McpSchema.Tool tool) { if (metadata.mcpClientMetadata().clientInfo().name().equals("client1") && tool.name().contains("tool1")) { return false; } return true; } }; } } Signed-off-by: Ilayaperumal Gopinathan <[email protected]>
1 parent a28bb41 commit 523315b

File tree

9 files changed

+98
-40
lines changed

9 files changed

+98
-40
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import io.modelcontextprotocol.client.McpSyncClient;
2323

2424
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
25-
import org.springframework.ai.mcp.McpClientBiPredicate;
25+
import org.springframework.ai.mcp.McpToolFilter;
2626
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
2727
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties;
2828
import org.springframework.beans.factory.ObjectProvider;
@@ -46,15 +46,15 @@ public class McpToolCallbackAutoConfiguration {
4646
* <p>
4747
* These callbacks enable integration with Spring AI's tool execution framework,
4848
* allowing MCP tools to be used as part of AI interactions.
49-
* @param syncClientsToolFilter list of {@link McpClientBiPredicate}s for the sync
50-
* client to filter the discovered tools
49+
* @param syncClientsToolFilter list of {@link McpToolFilter}s for the sync client to
50+
* filter the discovered tools
5151
* @param syncMcpClients provider of MCP sync clients
5252
* @return list of tool callbacks for MCP integration
5353
*/
5454
@Bean
5555
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
5656
matchIfMissing = true)
57-
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpClientBiPredicate> syncClientsToolFilter,
57+
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpToolFilter> syncClientsToolFilter,
5858
ObjectProvider<List<McpSyncClient>> syncMcpClients) {
5959
List<McpSyncClient> mcpClients = syncMcpClients.stream().flatMap(List::stream).toList();
6060
return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)),
@@ -63,8 +63,7 @@ public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpClientBiPr
6363

6464
@Bean
6565
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
66-
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(
67-
ObjectProvider<McpClientBiPredicate> asyncClientsToolFilter,
66+
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<McpToolFilter> asyncClientsToolFilter,
6867
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
6968
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
7069
return new AsyncMcpToolCallbackProvider(

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
import reactor.core.publisher.Mono;
2727

2828
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
29-
import org.springframework.ai.mcp.McpClientBiPredicate;
29+
import org.springframework.ai.mcp.McpToolFilter;
3030
import org.springframework.ai.mcp.McpClientMetadata;
31+
import org.springframework.ai.mcp.McpMetadata;
32+
import org.springframework.ai.mcp.McpServerMetadata;
3133
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
3234
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition;
3335
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -96,7 +98,7 @@ void verifySyncToolCallbackFilterConfiguration() {
9698
SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class);
9799
Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
98100
field.setAccessible(true);
99-
McpClientBiPredicate toolFilter = (McpClientBiPredicate) field.get(toolCallbackProvider);
101+
McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider);
100102
McpSyncClient syncClient1 = mock(McpSyncClient.class);
101103
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
102104
when(syncClient1.getClientInfo()).thenReturn(clientInfo1);
@@ -107,9 +109,11 @@ void verifySyncToolCallbackFilterConfiguration() {
107109
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
108110
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
109111
when(syncClient1.listTools()).thenReturn(listToolsResult1);
110-
assertThat(toolFilter.test(new McpClientMetadata(null, syncClient1.getClientInfo(), null), tool1))
112+
assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, syncClient1.getClientInfo()),
113+
new McpServerMetadata(null)), tool1))
111114
.isFalse();
112-
assertThat(toolFilter.test(new McpClientMetadata(null, syncClient1.getClientInfo(), null), tool2))
115+
assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, syncClient1.getClientInfo()),
116+
new McpServerMetadata(null)), tool2))
113117
.isTrue();
114118
});
115119
}
@@ -124,7 +128,7 @@ void verifyASyncToolCallbackFilterConfiguration() {
124128
AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class);
125129
Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
126130
field.setAccessible(true);
127-
McpClientBiPredicate toolFilter = (McpClientBiPredicate) field.get(toolCallbackProvider);
131+
McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider);
128132
McpAsyncClient asyncClient1 = mock(McpAsyncClient.class);
129133
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
130134
when(asyncClient1.getClientInfo()).thenReturn(clientInfo1);
@@ -135,9 +139,11 @@ void verifyASyncToolCallbackFilterConfiguration() {
135139
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
136140
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
137141
when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1));
138-
assertThat(toolFilter.test(new McpClientMetadata(null, asyncClient1.getClientInfo(), null), tool1))
142+
assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, asyncClient1.getClientInfo()),
143+
new McpServerMetadata(null)), tool1))
139144
.isFalse();
140-
assertThat(toolFilter.test(new McpClientMetadata(null, asyncClient1.getClientInfo(), null), tool2))
145+
assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, asyncClient1.getClientInfo()),
146+
new McpServerMetadata(null)), tool2))
141147
.isTrue();
142148
});
143149
}
@@ -157,11 +163,12 @@ String testBean() {
157163
static class McpClientFilterConfiguration {
158164

159165
@Bean
160-
McpClientBiPredicate mcpClientFilter() {
161-
return new McpClientBiPredicate() {
166+
McpToolFilter mcpClientFilter() {
167+
return new McpToolFilter() {
162168
@Override
163-
public boolean test(McpClientMetadata clientMetadata, McpSchema.Tool tool) {
164-
if (clientMetadata.clientInfo().name().equals("client1") && tool.name().contains("tool1")) {
169+
public boolean test(McpMetadata metadata, McpSchema.Tool tool) {
170+
if (metadata.mcpClientMetadata().clientInfo().name().equals("client1")
171+
&& tool.name().contains("tool1")) {
165172
return false;
166173
}
167174
return true;

mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.List;
2121

2222
import io.modelcontextprotocol.client.McpAsyncClient;
23-
import io.modelcontextprotocol.spec.McpSchema;
2423
import io.modelcontextprotocol.util.Assert;
2524
import reactor.core.publisher.Flux;
2625

@@ -73,7 +72,7 @@
7372
*/
7473
public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {
7574

76-
private final McpClientBiPredicate toolFilter;
75+
private final McpToolFilter toolFilter;
7776

7877
private final List<McpAsyncClient> mcpClients;
7978

@@ -83,7 +82,7 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {
8382
* @param toolFilter a filter to apply to each discovered tool
8483
* @param mcpClients the list of MCP clients to use for discovering tools
8584
*/
86-
public AsyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, List<McpAsyncClient> mcpClients) {
85+
public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, List<McpAsyncClient> mcpClients) {
8786
Assert.notNull(mcpClients, "MCP clients must not be null");
8887
Assert.notNull(toolFilter, "Tool filter must not be null");
8988
this.toolFilter = toolFilter;
@@ -108,7 +107,7 @@ public AsyncMcpToolCallbackProvider(List<McpAsyncClient> mcpClients) {
108107
* @param toolFilter a filter to apply to each discovered tool
109108
* @param mcpClients the MCP clients to use for discovering tools
110109
*/
111-
public AsyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, McpAsyncClient... mcpClients) {
110+
public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpAsyncClient... mcpClients) {
112111
this(toolFilter, List.of(mcpClients));
113112
}
114113

@@ -146,8 +145,9 @@ public ToolCallback[] getToolCallbacks() {
146145
ToolCallback[] toolCallbacks = mcpClient.listTools()
147146
.map(response -> response.tools()
148147
.stream()
149-
.filter(tool -> this.toolFilter.test(new McpClientMetadata(mcpClient.getClientCapabilities(),
150-
mcpClient.getClientInfo(), mcpClient.initialize().block()), tool))
148+
.filter(tool -> this.toolFilter.test(new McpMetadata(
149+
new McpClientMetadata(mcpClient.getClientCapabilities(), mcpClient.getClientInfo()),
150+
new McpServerMetadata(mcpClient.initialize().block())), tool))
151151
.map(tool -> new AsyncMcpToolCallback(mcpClient, tool))
152152
.toArray(ToolCallback[]::new))
153153
.block();

mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
import io.modelcontextprotocol.spec.McpSchema;
2020

2121
/**
22-
* MCP client metadata record containing the client/server specific data.
22+
* MCP client metadata record.
2323
*
2424
* @author Ilayaperumal Gopinathan
2525
*/
26-
public record McpClientMetadata(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo,
27-
McpSchema.InitializeResult initializeResult) {
26+
public record McpClientMetadata(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) {
2827
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp;
18+
19+
/**
20+
* MCP metadata record containing the client/server specific meta data.
21+
*
22+
* @author Ilayaperumal Gopinathan
23+
*/
24+
public record McpMetadata(McpClientMetadata mcpClientMetadata, McpServerMetadata mcpServermetadata) {
25+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp;
18+
19+
import io.modelcontextprotocol.spec.McpSchema;
20+
21+
/**
22+
* MCP server metadata record.
23+
*
24+
* @author Ilayaperumal Gopinathan
25+
*/
26+
public record McpServerMetadata(McpSchema.InitializeResult initializeResult) {
27+
}

mcp/common/src/main/java/org/springframework/ai/mcp/McpClientBiPredicate.java renamed to mcp/common/src/main/java/org/springframework/ai/mcp/McpToolFilter.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@
1818

1919
import java.util.function.BiPredicate;
2020

21-
import io.modelcontextprotocol.client.McpSyncClient;
2221
import io.modelcontextprotocol.spec.McpSchema;
2322

2423
/**
2524
* A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} and the
2625
* {@link AsyncMcpToolCallbackProvider} to filter the discovered tool for the given
27-
* {@link McpClientMetadata}.
26+
* {@link McpMetadata}.
2827
*
2928
* @author Ilayaperumal Gopinathan
3029
*/
31-
public interface McpClientBiPredicate extends BiPredicate<McpClientMetadata, McpSchema.Tool> {
30+
public interface McpToolFilter extends BiPredicate<McpMetadata, McpSchema.Tool> {
3231

3332
}

mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider {
7070

7171
private final List<McpSyncClient> mcpClients;
7272

73-
private final McpClientBiPredicate toolFilter;
73+
private final McpToolFilter toolFilter;
7474

7575
/**
7676
* Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP
7777
* clients.
7878
* @param mcpClients the list of MCP clients to use for discovering tools
7979
* @param toolFilter a filter to apply to each discovered tool
8080
*/
81-
public SyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, List<McpSyncClient> mcpClients) {
81+
public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, List<McpSyncClient> mcpClients) {
8282
Assert.notNull(mcpClients, "MCP clients must not be null");
8383
Assert.notNull(toolFilter, "Tool filter must not be null");
8484
this.mcpClients = mcpClients;
@@ -100,7 +100,7 @@ public SyncMcpToolCallbackProvider(List<McpSyncClient> mcpClients) {
100100
* @param mcpClients the MCP clients to use for discovering tools
101101
* @param toolFilter a filter to apply to each discovered tool
102102
*/
103-
public SyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, McpSyncClient... mcpClients) {
103+
public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpSyncClient... mcpClients) {
104104
this(toolFilter, List.of(mcpClients));
105105
}
106106

@@ -131,8 +131,9 @@ public ToolCallback[] getToolCallbacks() {
131131
.flatMap(mcpClient -> mcpClient.listTools()
132132
.tools()
133133
.stream()
134-
.filter(tool -> this.toolFilter.test(new McpClientMetadata(mcpClient.getClientCapabilities(),
135-
mcpClient.getClientInfo(), mcpClient.initialize()), tool))
134+
.filter(tool -> this.toolFilter.test(new McpMetadata(
135+
new McpClientMetadata(mcpClient.getClientCapabilities(), mcpClient.getClientInfo()),
136+
new McpServerMetadata(mcpClient.initialize())), tool))
136137
.map(tool -> new SyncMcpToolCallback(mcpClient, tool)))
137138
.toArray(ToolCallback[]::new);
138139
validateToolCallbacks(array);

mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ void toolFilterShouldRejectAllToolsWhenConfigured() {
163163
when(this.mcpClient.listTools()).thenReturn(listToolsResult);
164164

165165
// Create a filter that rejects all tools
166-
McpClientBiPredicate rejectAllFilter = (client, tool) -> false;
166+
McpToolFilter rejectAllFilter = (client, tool) -> false;
167167

168168
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(rejectAllFilter, this.mcpClient);
169169

@@ -191,7 +191,7 @@ void toolFilterShouldFilterToolsByNameWhenConfigured() {
191191
when(this.mcpClient.listTools()).thenReturn(listToolsResult);
192192

193193
// Create a filter that only accepts tools with names containing "2" or "3"
194-
McpClientBiPredicate nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3");
194+
McpToolFilter nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3");
195195

196196
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(nameFilter, this.mcpClient);
197197

@@ -226,8 +226,8 @@ void toolFilterShouldFilterToolsByClientWhenConfigured() {
226226
when(mcpClient2.getClientInfo()).thenReturn(clientInfo2);
227227

228228
// Create a filter that only accepts tools from client1
229-
McpClientBiPredicate clientFilter = (clientMetadata,
230-
tool) -> clientMetadata.clientInfo().name().equals("testClient1");
229+
McpToolFilter clientFilter = (mcpMetadata,
230+
tool) -> mcpMetadata.mcpClientMetadata().clientInfo().name().equals("testClient1");
231231

232232
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(clientFilter, mcpClient1, mcpClient2);
233233

@@ -254,8 +254,9 @@ void toolFilterShouldCombineClientAndToolCriteriaWhenConfigured() {
254254
when(weatherClient.getClientInfo()).thenReturn(weatherClientInfo);
255255

256256
// Create a filter that only accepts weather tools from the weather service
257-
McpClientBiPredicate complexFilter = (client, tool) -> client.clientInfo().name().equals("weather-service")
258-
&& tool.name().equals("weather");
257+
McpToolFilter complexFilter = (mcpMetadata,
258+
tool) -> mcpMetadata.mcpClientMetadata().clientInfo().name().equals("weather-service")
259+
&& tool.name().equals("weather");
259260

260261
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(complexFilter, weatherClient);
261262

0 commit comments

Comments
 (0)