diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index 3525b9593e3..7de29fdb111 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Set; import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpAsyncClient; @@ -90,7 +91,33 @@ public AsyncMcpToolCallbackProvider(BiPredicate toolFilter this.mcpClients = mcpClients; this.toolFilter = toolFilter; } - + /** + * Creates a new {@code AsyncMcpToolCallbackProvider} instance that includes only clients + * from the specified allowed servers. + *

+ * This constructor: + *

    + *
  1. Filters the provided MCP clients to only those matching allowed server names
  2. + *
  3. Retains all tools from the selected clients (no additional tool filtering)
  4. + *
  5. Ensures no null parameters are passed
  6. + *
  7. Maintains full asynchronous operation capability
  8. + *
+ * + * @param mcpClients complete list of available MCP async clients + * @param allowedServerNames set of server names to include (case-sensitive) + * @throws IllegalArgumentException if parameters are null or empty + * @since 1.1.0 + */ + public AsyncMcpToolCallbackProvider(List mcpClients, Set allowedServerNames) { + Assert.notNull(mcpClients, "MCP clients list must not be null"); + Assert.notNull(allowedServerNames, "Allowed server names set must not be null"); + Assert.notEmpty(allowedServerNames, "Allowed server names set must not be empty"); + + this.mcpClients = mcpClients.stream() + .filter(client -> allowedServerNames.contains(client.getServerInfo().name())) + .collect(Collectors.toList()); + this.toolFilter = (client, tool) -> true; + } /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP * clients. diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java index 7d0aa4276a1..4e916eb023c 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java @@ -17,6 +17,7 @@ package org.springframework.ai.mcp; import java.util.List; +import java.util.Set; import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpSyncClient; @@ -86,7 +87,32 @@ public SyncMcpToolCallbackProvider(BiPredicate toolFilter, this.mcpClients = mcpClients; this.toolFilter = toolFilter; } - + /** + * Creates a new {@code SyncMcpToolCallbackProvider} instance that includes only clients + * from the specified allowed servers. + *

+ * This constructor: + *

    + *
  1. Filters the provided MCP clients to only those matching allowed server names
  2. + *
  3. Retains all tools from the selected clients (no additional tool filtering)
  4. + *
  5. Ensures no null parameters are passed
  6. + *
+ * + * @param mcpClients complete list of available MCP clients + * @param allowedServerNames set of server names to include (case-sensitive) + * @throws IllegalArgumentException if parameters are null or empty + * @since 1.1.0 + */ + public SyncMcpToolCallbackProvider(List mcpClients, Set allowedServerNames) { + Assert.notNull(mcpClients, "MCP clients list must not be null"); + Assert.notNull(allowedServerNames, "Allowed server names set must not be null"); + Assert.notEmpty(allowedServerNames, "Allowed server names set must not be empty"); + + this.mcpClients = mcpClients.stream() + .filter(client -> allowedServerNames.contains(client.getServerInfo().name())) + .toList(); + this.toolFilter = (client, tool) -> true; // No additional filtering + } /** * Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP * clients.