Skip to content

Commit b1eec85

Browse files
committed
Add MCP Client tool predicate for filtering the MCP tools
- Introduce MCP Sync/Async client BiPredicate interface as a tool filter for the MCP Sync/Async ToolCallbackProvider to use when filtering the MCP tools - Update MCP ToolCallbackAutoConfiguration to use these BiPredicate beans when defined (default is to allow all) - Add test verifying the tool filter configuration on both sync and async toolcallback provider auto-configuration - Update the unit tests for the MCP toolcallback provider Signed-off-by: Ilayaperumal Gopinathan <[email protected]>
1 parent 128c45a commit b1eec85

File tree

7 files changed

+189
-21
lines changed

7 files changed

+189
-21
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,18 @@
1717
package org.springframework.ai.mcp.client.common.autoconfigure;
1818

1919
import java.util.List;
20+
import java.util.function.BiPredicate;
2021

2122
import io.modelcontextprotocol.client.McpAsyncClient;
2223
import io.modelcontextprotocol.client.McpSyncClient;
24+
import io.modelcontextprotocol.spec.McpSchema;
2325

2426
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
27+
import org.springframework.ai.mcp.McpAsyncClientBiPredicate;
28+
import org.springframework.ai.mcp.McpSyncClientBiPredicate;
2529
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
2630
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties;
31+
import org.springframework.ai.tool.annotation.Tool;
2732
import org.springframework.beans.factory.ObjectProvider;
2833
import org.springframework.boot.autoconfigure.AutoConfiguration;
2934
import org.springframework.boot.autoconfigure.condition.AllNestedConditions;
@@ -51,16 +56,21 @@ public class McpToolCallbackAutoConfiguration {
5156
@Bean
5257
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
5358
matchIfMissing = true)
54-
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<List<McpSyncClient>> syncMcpClients) {
59+
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpSyncClientBiPredicate> syncClientsToolFilter,
60+
ObjectProvider<List<McpSyncClient>> syncMcpClients) {
5561
List<McpSyncClient> mcpClients = syncMcpClients.stream().flatMap(List::stream).toList();
56-
return new SyncMcpToolCallbackProvider(mcpClients);
62+
return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)),
63+
mcpClients);
5764
}
5865

5966
@Bean
6067
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
61-
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
68+
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(
69+
ObjectProvider<McpAsyncClientBiPredicate> asyncClientsToolFilter,
70+
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
6271
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
63-
return new AsyncMcpToolCallbackProvider(mcpClients);
72+
return new AsyncMcpToolCallbackProvider(
73+
asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true), mcpClients);
6474
}
6575

6676
public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions {

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,28 @@
1616

1717
package org.springframework.ai.mcp.client.common.autoconfigure;
1818

19+
import java.lang.reflect.Field;
20+
import java.util.List;
21+
22+
import io.modelcontextprotocol.client.McpAsyncClient;
23+
import io.modelcontextprotocol.client.McpSyncClient;
24+
import io.modelcontextprotocol.spec.McpSchema;
1925
import org.junit.jupiter.api.Test;
26+
import reactor.core.publisher.Mono;
2027

28+
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
29+
import org.springframework.ai.mcp.McpAsyncClientBiPredicate;
30+
import org.springframework.ai.mcp.McpSyncClientBiPredicate;
31+
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
2132
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition;
2233
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
2334
import org.springframework.context.annotation.Bean;
2435
import org.springframework.context.annotation.Conditional;
2536
import org.springframework.context.annotation.Configuration;
2637

2738
import static org.assertj.core.api.Assertions.assertThat;
39+
import static org.mockito.Mockito.mock;
40+
import static org.mockito.Mockito.when;
2841

2942
/**
3043
* Tests for {@link McpToolCallbackAutoConfigurationCondition}.
@@ -73,6 +86,58 @@ void doesMatchWhenBothPropertiesAreMissing() {
7386
this.contextRunner.run(context -> assertThat(context).hasBean("testBean"));
7487
}
7588

89+
@Test
90+
void verifySyncToolCallbackFilterConfiguration() {
91+
this.contextRunner
92+
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpSyncClientFilterConfiguration.class)
93+
.withPropertyValues("spring.ai.mcp.client.type=SYNC")
94+
.run(context -> {
95+
assertThat(context).hasBean("syncClientFilter");
96+
SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class);
97+
Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
98+
field.setAccessible(true);
99+
McpSyncClientBiPredicate toolFilter = (McpSyncClientBiPredicate) field.get(toolCallbackProvider);
100+
McpSyncClient syncClient1 = mock(McpSyncClient.class);
101+
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
102+
when(syncClient1.getClientInfo()).thenReturn(clientInfo1);
103+
McpSchema.Tool tool1 = mock(McpSchema.Tool.class);
104+
when(tool1.name()).thenReturn("tool1");
105+
McpSchema.Tool tool2 = mock(McpSchema.Tool.class);
106+
when(tool2.name()).thenReturn("tool2");
107+
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
108+
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
109+
when(syncClient1.listTools()).thenReturn(listToolsResult1);
110+
assertThat(toolFilter.test(syncClient1, tool1)).isFalse();
111+
assertThat(toolFilter.test(syncClient1, tool2)).isTrue();
112+
});
113+
}
114+
115+
@Test
116+
void verifyASyncToolCallbackFilterConfiguration() {
117+
this.contextRunner
118+
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpAsyncClientFilterConfiguration.class)
119+
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
120+
.run(context -> {
121+
assertThat(context).hasBean("asyncClientFilter");
122+
AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class);
123+
Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
124+
field.setAccessible(true);
125+
McpAsyncClientBiPredicate toolFilter = (McpAsyncClientBiPredicate) field.get(toolCallbackProvider);
126+
McpAsyncClient asyncClient1 = mock(McpAsyncClient.class);
127+
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
128+
when(asyncClient1.getClientInfo()).thenReturn(clientInfo1);
129+
McpSchema.Tool tool1 = mock(McpSchema.Tool.class);
130+
when(tool1.name()).thenReturn("tool1");
131+
McpSchema.Tool tool2 = mock(McpSchema.Tool.class);
132+
when(tool2.name()).thenReturn("tool2");
133+
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
134+
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
135+
when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1));
136+
assertThat(toolFilter.test(asyncClient1, tool1)).isFalse();
137+
assertThat(toolFilter.test(asyncClient1, tool2)).isTrue();
138+
});
139+
}
140+
76141
@Configuration
77142
@Conditional(McpToolCallbackAutoConfigurationCondition.class)
78143
static class TestConfiguration {
@@ -84,4 +149,40 @@ String testBean() {
84149

85150
}
86151

152+
@Configuration
153+
static class McpSyncClientFilterConfiguration {
154+
155+
@Bean
156+
McpSyncClientBiPredicate syncClientFilter() {
157+
return new McpSyncClientBiPredicate() {
158+
@Override
159+
public boolean test(McpSyncClient mcpSyncClient, McpSchema.Tool tool) {
160+
if (mcpSyncClient.getClientInfo().name().equals("client1") && tool.name().contains("tool1")) {
161+
return false;
162+
}
163+
return true;
164+
}
165+
};
166+
}
167+
168+
}
169+
170+
@Configuration
171+
static class McpAsyncClientFilterConfiguration {
172+
173+
@Bean
174+
McpAsyncClientBiPredicate asyncClientFilter() {
175+
return new McpAsyncClientBiPredicate() {
176+
@Override
177+
public boolean test(McpAsyncClient mcpAsyncClient, McpSchema.Tool tool) {
178+
if (mcpAsyncClient.getClientInfo().name().equals("client1") && tool.name().contains("tool1")) {
179+
return false;
180+
}
181+
return true;
182+
}
183+
};
184+
}
185+
186+
}
187+
87188
}

mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21-
import java.util.function.BiPredicate;
2221

2322
import io.modelcontextprotocol.client.McpAsyncClient;
24-
import io.modelcontextprotocol.spec.McpSchema.Tool;
2523
import io.modelcontextprotocol.util.Assert;
2624
import reactor.core.publisher.Flux;
2725

@@ -76,15 +74,15 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {
7674

7775
private final List<McpAsyncClient> mcpClients;
7876

79-
private final BiPredicate<McpAsyncClient, Tool> toolFilter;
77+
private final McpAsyncClientBiPredicate toolFilter;
8078

8179
/**
8280
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
8381
* clients.
8482
* @param mcpClients the list of MCP clients to use for discovering tools
8583
* @param toolFilter a filter to apply to each discovered tool
8684
*/
87-
public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter, List<McpAsyncClient> mcpClients) {
85+
public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, List<McpAsyncClient> mcpClients) {
8886
Assert.notNull(mcpClients, "MCP clients must not be null");
8987
Assert.notNull(toolFilter, "Tool filter must not be null");
9088
this.mcpClients = mcpClients;
@@ -109,7 +107,7 @@ public AsyncMcpToolCallbackProvider(List<McpAsyncClient> mcpClients) {
109107
* @param mcpClients the MCP clients to use for discovering tools
110108
* @param toolFilter a filter to apply to each discovered tool
111109
*/
112-
public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter, McpAsyncClient... mcpClients) {
110+
public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, McpAsyncClient... mcpClients) {
113111
this(toolFilter, List.of(mcpClients));
114112
}
115113

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp;
18+
19+
import java.util.function.BiPredicate;
20+
21+
import io.modelcontextprotocol.client.McpAsyncClient;
22+
import io.modelcontextprotocol.spec.McpSchema;
23+
24+
/**
25+
* A {@link BiPredicate} for {@link AsyncMcpToolCallbackProvider} to filter the discovered
26+
* tool for the given {@link McpAsyncClient}.
27+
*
28+
* @author Ilayaperumal Gopinathan
29+
*/
30+
public interface McpAsyncClientBiPredicate extends BiPredicate<McpAsyncClient, McpSchema.Tool> {
31+
32+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp;
18+
19+
import java.util.function.BiPredicate;
20+
21+
import io.modelcontextprotocol.client.McpSyncClient;
22+
import io.modelcontextprotocol.spec.McpSchema;
23+
24+
/**
25+
* A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} to filter the discovered
26+
* tool for the given {@link McpSyncClient}.
27+
*
28+
* @author Ilayaperumal Gopinathan
29+
*/
30+
public interface McpSyncClientBiPredicate extends BiPredicate<McpSyncClient, McpSchema.Tool> {
31+
32+
}

mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
package org.springframework.ai.mcp;
1818

1919
import java.util.List;
20-
import java.util.function.BiPredicate;
2120

2221
import io.modelcontextprotocol.client.McpSyncClient;
23-
import io.modelcontextprotocol.spec.McpSchema.Tool;
2422

2523
import org.springframework.ai.tool.ToolCallback;
2624
import org.springframework.ai.tool.ToolCallbackProvider;
@@ -72,15 +70,15 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider {
7270

7371
private final List<McpSyncClient> mcpClients;
7472

75-
private final BiPredicate<McpSyncClient, Tool> toolFilter;
73+
private final McpSyncClientBiPredicate toolFilter;
7674

7775
/**
7876
* Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP
7977
* clients.
8078
* @param mcpClients the list of MCP clients to use for discovering tools
8179
* @param toolFilter a filter to apply to each discovered tool
8280
*/
83-
public SyncMcpToolCallbackProvider(BiPredicate<McpSyncClient, Tool> toolFilter, List<McpSyncClient> mcpClients) {
81+
public SyncMcpToolCallbackProvider(McpSyncClientBiPredicate toolFilter, List<McpSyncClient> mcpClients) {
8482
Assert.notNull(mcpClients, "MCP clients must not be null");
8583
Assert.notNull(toolFilter, "Tool filter must not be null");
8684
this.mcpClients = mcpClients;
@@ -102,7 +100,7 @@ public SyncMcpToolCallbackProvider(List<McpSyncClient> mcpClients) {
102100
* @param mcpClients the MCP clients to use for discovering tools
103101
* @param toolFilter a filter to apply to each discovered tool
104102
*/
105-
public SyncMcpToolCallbackProvider(BiPredicate<McpSyncClient, Tool> toolFilter, McpSyncClient... mcpClients) {
103+
public SyncMcpToolCallbackProvider(McpSyncClientBiPredicate toolFilter, McpSyncClient... mcpClients) {
106104
this(toolFilter, List.of(mcpClients));
107105
}
108106

mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.springframework.ai.mcp;
1818

1919
import java.util.List;
20-
import java.util.function.BiPredicate;
2120

2221
import io.modelcontextprotocol.client.McpSyncClient;
2322
import io.modelcontextprotocol.spec.McpSchema.Implementation;
@@ -164,7 +163,7 @@ void toolFilterShouldRejectAllToolsWhenConfigured() {
164163
when(this.mcpClient.listTools()).thenReturn(listToolsResult);
165164

166165
// Create a filter that rejects all tools
167-
BiPredicate<McpSyncClient, Tool> rejectAllFilter = (client, tool) -> false;
166+
McpSyncClientBiPredicate rejectAllFilter = (client, tool) -> false;
168167

169168
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(rejectAllFilter, this.mcpClient);
170169

@@ -192,8 +191,7 @@ void toolFilterShouldFilterToolsByNameWhenConfigured() {
192191
when(this.mcpClient.listTools()).thenReturn(listToolsResult);
193192

194193
// Create a filter that only accepts tools with names containing "2" or "3"
195-
BiPredicate<McpSyncClient, Tool> nameFilter = (client, tool) -> tool.name().contains("2")
196-
|| tool.name().contains("3");
194+
McpSyncClientBiPredicate nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3");
197195

198196
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(nameFilter, this.mcpClient);
199197

@@ -228,8 +226,7 @@ void toolFilterShouldFilterToolsByClientWhenConfigured() {
228226
when(mcpClient2.getClientInfo()).thenReturn(clientInfo2);
229227

230228
// Create a filter that only accepts tools from client1
231-
BiPredicate<McpSyncClient, Tool> clientFilter = (client,
232-
tool) -> client.getClientInfo().name().equals("testClient1");
229+
McpSyncClientBiPredicate clientFilter = (client, tool) -> client.getClientInfo().name().equals("testClient1");
233230

234231
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(clientFilter, mcpClient1, mcpClient2);
235232

@@ -256,7 +253,7 @@ void toolFilterShouldCombineClientAndToolCriteriaWhenConfigured() {
256253
when(weatherClient.getClientInfo()).thenReturn(weatherClientInfo);
257254

258255
// Create a filter that only accepts weather tools from the weather service
259-
BiPredicate<McpSyncClient, Tool> complexFilter = (client,
256+
McpSyncClientBiPredicate complexFilter = (client,
260257
tool) -> client.getClientInfo().name().equals("weather-service") && tool.name().equals("weather");
261258

262259
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(complexFilter, weatherClient);

0 commit comments

Comments
 (0)