Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import org.springframework.ai.autoconfigure.mcp.client.configurer.McpAsyncClientConfigurer;
import org.springframework.ai.autoconfigure.mcp.client.configurer.McpSyncClientConfigurer;
import org.springframework.ai.autoconfigure.mcp.client.properties.McpClientCommonProperties;
import org.springframework.ai.mcp.McpToolUtils;
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.customizer.McpAsyncClientCustomizer;
import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
Expand Down Expand Up @@ -176,9 +178,22 @@ public List<McpSyncClient> mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC
@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
matchIfMissing = true)
public List<ToolCallback> toolCallbacks(ObjectProvider<List<McpSyncClient>> mcpClientsProvider) {
public ToolCallbackProvider toolCallbacks(ObjectProvider<List<McpSyncClient>> mcpClientsProvider) {
List<McpSyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
return McpToolUtils.getToolCallbacksFromSyncClients(mcpClients);
return new SyncMcpToolCallbackProvider(mcpClients);
}

/**
* @deprecated replaced by {@link #toolCallbacks(ObjectProvider)} that returns a
* {@link ToolCallbackProvider} instead of a list of {@link ToolCallback}
*/
@Deprecated
@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
matchIfMissing = true)
public List<ToolCallback> toolCallbacksDeprecated(ObjectProvider<List<McpSyncClient>> mcpClientsProvider) {
List<McpSyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
return List.of(new SyncMcpToolCallbackProvider(mcpClients).getToolCallbacks());
}

/**
Expand All @@ -189,7 +204,7 @@ public List<ToolCallback> toolCallbacks(ObjectProvider<List<McpSyncClient>> mcpC
* This class is responsible for closing all MCP sync clients when the application
* context is closed, preventing resource leaks.
*/
public record ClosebleMcpSyncClients(List<McpSyncClient> clients) implements AutoCloseable {
public record CloseableMcpSyncClients(List<McpSyncClient> clients) implements AutoCloseable {

@Override
public void close() {
Expand All @@ -205,8 +220,8 @@ public void close() {
@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
matchIfMissing = true)
public ClosebleMcpSyncClients makeSyncClientsClosable(List<McpSyncClient> clients) {
return new ClosebleMcpSyncClients(clients);
public CloseableMcpSyncClients makeSyncClientsClosable(List<McpSyncClient> clients) {
return new CloseableMcpSyncClients(clients);
}

/**
Expand Down Expand Up @@ -263,14 +278,26 @@ public List<McpAsyncClient> mcpAsyncClients(McpAsyncClientConfigurer mcpSyncClie
return mcpSyncClients;
}

/**
* @deprecated replaced by {@link #asyncToolCallbacks(ObjectProvider)} that returns a
* {@link ToolCallbackProvider} instead of a list of {@link ToolCallback}
*/
@Deprecated
@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
public List<ToolCallback> asyncToolCallbacksDeprecated(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
return List.of(new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks());
}

@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
public List<ToolCallback> asyncToolCallbacks(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
public ToolCallbackProvider asyncToolCallbacks(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
return McpToolUtils.getToolCallbacksFromAsyncClinents(mcpClients);
return new AsyncMcpToolCallbackProvider(mcpClients);
}

public record ClosebleMcpAsyncClients(List<McpAsyncClient> clients) implements AutoCloseable {
public record CloseableMcpAsyncClients(List<McpAsyncClient> clients) implements AutoCloseable {
@Override
public void close() {
this.clients.forEach(McpAsyncClient::close);
Expand All @@ -279,8 +306,8 @@ public void close() {

@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
public ClosebleMcpAsyncClients makeAsynClientsClosable(List<McpAsyncClient> clients) {
return new ClosebleMcpAsyncClients(clients);
public CloseableMcpAsyncClients makeAsynClientsClosable(List<McpAsyncClient> clients) {
return new CloseableMcpAsyncClients(clients);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void toolCallbacksCreation() {
@Test
void closeableWrappersCreation() {
this.contextRunner.withUserConfiguration(TestTransportConfiguration.class).run(context -> {
assertThat(context).hasSingleBean(McpClientAutoConfiguration.ClosebleMcpSyncClients.class);
assertThat(context).hasSingleBean(McpClientAutoConfiguration.CloseableMcpSyncClients.class);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package org.springframework.ai.autoconfigure.mcp.server;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;

import io.modelcontextprotocol.server.McpAsyncServer;
import io.modelcontextprotocol.server.McpServer;
Expand All @@ -39,7 +41,9 @@
import reactor.core.publisher.Mono;

import org.springframework.ai.mcp.McpToolUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
Expand Down Expand Up @@ -135,15 +139,23 @@ public McpSyncServer mcpSyncServer(ServerMcpTransport transport,
McpSchema.ServerCapabilities.Builder capabilitiesBuilder, McpServerProperties serverProperties,
ObjectProvider<List<SyncToolRegistration>> tools, ObjectProvider<List<SyncResourceRegistration>> resources,
ObjectProvider<List<SyncPromptRegistration>> prompts,
ObjectProvider<Consumer<List<McpSchema.Root>>> rootsChangeConsumers) {
ObjectProvider<Consumer<List<McpSchema.Root>>> rootsChangeConsumers,
List<ToolCallbackProvider> toolCallbackProvider) {

McpSchema.Implementation serverInfo = new Implementation(serverProperties.getName(),
serverProperties.getVersion());

// Create the server with both tool and resource capabilities
SyncSpec serverBuilder = McpServer.sync(transport).serverInfo(serverInfo);

List<SyncToolRegistration> toolResgistrations = tools.stream().flatMap(List::stream).toList();
List<SyncToolRegistration> toolResgistrations = new ArrayList<>(tools.stream().flatMap(List::stream).toList());
List<ToolCallback> providerToolCallbacks = toolCallbackProvider.stream()
.map(pr -> List.of(pr.getToolCallbacks()))
.flatMap(List::stream)
.filter(fc -> fc instanceof ToolCallback)
.map(fc -> (ToolCallback) fc)
.toList();
toolResgistrations.addAll(McpToolUtils.toSyncToolRegistration(providerToolCallbacks));
if (!CollectionUtils.isEmpty(toolResgistrations)) {
serverBuilder.tools(toolResgistrations);
capabilitiesBuilder.tools(serverProperties.isToolChangeNotification());
Expand Down Expand Up @@ -191,15 +203,23 @@ public McpAsyncServer mcpAsyncServer(ServerMcpTransport transport,
ObjectProvider<List<AsyncToolRegistration>> tools,
ObjectProvider<List<AsyncResourceRegistration>> resources,
ObjectProvider<List<AsyncPromptRegistration>> prompts,
ObjectProvider<Consumer<List<McpSchema.Root>>> rootsChangeConsumer) {
ObjectProvider<Consumer<List<McpSchema.Root>>> rootsChangeConsumer,
List<ToolCallbackProvider> toolCallbackProvider) {

McpSchema.Implementation serverInfo = new Implementation(serverProperties.getName(),
serverProperties.getVersion());

// Create the server with both tool and resource capabilities
AsyncSpec serverBilder = McpServer.async(transport).serverInfo(serverInfo);

List<AsyncToolRegistration> toolResgistrations = tools.stream().flatMap(List::stream).toList();
List<AsyncToolRegistration> toolResgistrations = new ArrayList<>(tools.stream().flatMap(List::stream).toList());
List<ToolCallback> providerToolCallbacks = toolCallbackProvider.stream()
.map(pr -> List.of(pr.getToolCallbacks()))
.flatMap(List::stream)
.filter(fc -> fc instanceof ToolCallback)
.map(fc -> (ToolCallback) fc)
.toList();
toolResgistrations.addAll(McpToolUtils.toAsyncToolRegistration(providerToolCallbacks));
if (!CollectionUtils.isEmpty(toolResgistrations)) {
serverBilder.tools(toolResgistrations);
capabilitiesBuilder.tools(serverProperties.isToolChangeNotification());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
*/
package org.springframework.ai.mcp;

import java.util.ArrayList;
import java.util.List;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
Expand All @@ -28,25 +29,40 @@

/**
* Implementation of {@link ToolCallbackProvider} that discovers and provides MCP tools
* asynchronously.
* asynchronously from one or more MCP servers.
* <p>
* This class acts as a tool provider for Spring AI, automatically discovering tools from
* an MCP server and making them available as Spring AI tools. It:
* multiple MCP servers and making them available as Spring AI tools. It:
* <ul>
* <li>Connects to an MCP server through an async client</li>
* <li>Lists and retrieves available tools from the server</li>
* <li>Connects to MCP servers through async clients</li>
* <li>Lists and retrieves available tools from each server asynchronously</li>
* <li>Creates {@link AsyncMcpToolCallback} instances for each discovered tool</li>
* <li>Validates tool names to prevent duplicates</li>
* <li>Validates tool names to prevent duplicates across all servers</li>
* </ul>
* <p>
* Example usage: <pre>{@code
* Example usage with a single client:
*
* <pre>{@code
* McpAsyncClient mcpClient = // obtain MCP client
* ToolCallbackProvider provider = new AsyncMcpToolCallbackProvider(mcpClient);
*
* // Get all available tools
* ToolCallback[] tools = provider.getToolCallbacks();
* }</pre>
*
* Example usage with multiple clients:
*
* <pre>{@code
* List<McpAsyncClient> mcpClients = // obtain multiple MCP clients
* ToolCallbackProvider provider = new AsyncMcpToolCallbackProvider(mcpClients);
*
* // Get tools from all clients
* ToolCallback[] tools = provider.getToolCallbacks();
*
* // Or use the reactive API
* Flux<ToolCallback> toolsFlux = AsyncMcpToolCallbackProvider.asyncToolCallbacks(mcpClients);
* }</pre>
*
* @author Christian Tzolov
* @since 1.0.0
* @see ToolCallbackProvider
Expand All @@ -55,40 +71,61 @@
*/
public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {

private final McpAsyncClient mcpClient;
private final List<McpAsyncClient> mcpClients;

/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance.
* @param mcpClient the MCP client to use for discovering tools
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
* clients.
* @param mcpClients the list of MCP clients to use for discovering tools. Each client
* typically connects to a different MCP server, allowing tool discovery from multiple
* sources.
* @throws IllegalArgumentException if mcpClients is null
*/
public AsyncMcpToolCallbackProvider(McpAsyncClient mcpClient) {
this.mcpClient = mcpClient;
public AsyncMcpToolCallbackProvider(List<McpAsyncClient> mcpClients) {
Assert.notNull(mcpClients, "McpClients must not be null");
this.mcpClients = mcpClients;
}

public AsyncMcpToolCallbackProvider(McpAsyncClient... mcpClients) {
Assert.notNull(mcpClients, "McpClients must not be null");
this.mcpClients = List.of(mcpClients);
}

/**
* Discovers and returns all available tools from the MCP server asynchronously.
* Discovers and returns all available tools from the configured MCP servers.
* <p>
* This method:
* <ol>
* <li>Retrieves the list of tools from the MCP server</li>
* <li>Creates a {@link AsyncMcpToolCallback} for each tool</li>
* <li>Validates that there are no duplicate tool names</li>
* <li>Retrieves the list of tools from each MCP server asynchronously</li>
* <li>Creates a {@link AsyncMcpToolCallback} for each discovered tool</li>
* <li>Validates that there are no duplicate tool names across all servers</li>
* </ol>
* <p>
* Note: While the underlying tool discovery is asynchronous, this method blocks until
* all tools are discovered from all servers.
* @return an array of tool callbacks, one for each discovered tool
* @throws IllegalStateException if duplicate tool names are found
*/
@Override
public ToolCallback[] getToolCallbacks() {
var toolCallbacks = this.mcpClient.listTools()
.map(response -> response.tools()
.stream()
.map(tool -> new AsyncMcpToolCallback(this.mcpClient, tool))
.toArray(ToolCallback[]::new))
.block();

validateToolCallbacks(toolCallbacks);
List<ToolCallback> toolCallbackList = new ArrayList<>();

for (McpAsyncClient mcpClient : this.mcpClients) {

ToolCallback[] toolCallbacks = mcpClient.listTools()
.map(response -> response.tools()
.stream()
.map(tool -> new AsyncMcpToolCallback(mcpClient, tool))
.toArray(ToolCallback[]::new))
.block();

return toolCallbacks;
validateToolCallbacks(toolCallbacks);

toolCallbackList.addAll(List.of(toolCallbacks));
}

return toolCallbackList.toArray(new ToolCallback[0]);
}

/**
Expand All @@ -110,12 +147,19 @@ private void validateToolCallbacks(ToolCallback[] toolCallbacks) {
/**
* Creates a reactive stream of tool callbacks from multiple MCP clients.
* <p>
* This utility method:
* This utility method provides a reactive way to work with tool callbacks from
* multiple MCP clients in a single operation. It:
* <ol>
* <li>Takes a list of MCP clients</li>
* <li>Creates a provider for each client</li>
* <li>Retrieves and flattens all tool callbacks into a single stream</li>
* <li>Takes a list of MCP clients as input</li>
* <li>Creates a provider instance to manage all clients</li>
* <li>Retrieves tools from all clients asynchronously</li>
* <li>Combines them into a single reactive stream</li>
* <li>Ensures there are no naming conflicts between tools from different clients</li>
* </ol>
* <p>
* Unlike {@link #getToolCallbacks()}, this method provides a fully reactive way to
* work with tool callbacks, making it suitable for non-blocking applications. Any
* errors during tool discovery will be propagated through the returned Flux.
* @param mcpClients the list of MCP clients to create callbacks from
* @return a Flux of tool callbacks from all provided clients
*/
Expand All @@ -124,9 +168,7 @@ public static Flux<ToolCallback> asyncToolCallbacks(List<McpAsyncClient> mcpClie
return Flux.empty();
}

return Flux.fromIterable(mcpClients)
.flatMap(mcpClient -> Mono.just(new AsyncMcpToolCallbackProvider(mcpClient).getToolCallbacks()))
.flatMap(callbacks -> Flux.fromArray(callbacks));
return Flux.fromArray(new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,7 @@ public static List<ToolCallback> getToolCallbacksFromSyncClients(List<McpSyncCli
if (CollectionUtils.isEmpty(mcpClients)) {
return List.of();
}
return mcpClients.stream()
.map(mcpClient -> List.of((new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks())))
.flatMap(List::stream)
.toList();
return List.of((new SyncMcpToolCallbackProvider(mcpClients).getToolCallbacks()));
}

/**
Expand Down Expand Up @@ -247,10 +244,7 @@ public static List<ToolCallback> getToolCallbacksFromAsyncClinents(List<McpAsync
if (CollectionUtils.isEmpty(asynMcpClients)) {
return List.of();
}
return asynMcpClients.stream()
.map(mcpClient -> List.of((new AsyncMcpToolCallbackProvider(mcpClient).getToolCallbacks())))
.flatMap(List::stream)
.toList();
return List.of((new AsyncMcpToolCallbackProvider(asynMcpClients).getToolCallbacks()));
}

}
Loading