From 892403893190117f096beaf4c932de5706f031ae Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 13 Feb 2025 13:11:09 +0100 Subject: [PATCH 1/4] feat(autoconfigure): Support both FunctionCallback and ToolCallback in ToolCallingAutoConfiguration - Extends the ToolCallingAutoConfiguration to support both FunctionCallback and ToolCallback types. - The toolCallbackResolver bean now handles both callback types through ObjectProvider injection. - Added comprehensive tests to verify the resolution of multiple function and tool callbacks. Signed-off-by: Christian Tzolov --- .../model/ToolCallingAutoConfiguration.java | 14 +- .../ToolCallingAutoConfigurationTests.java | 124 ++++++++++++++++++ 2 files changed, 136 insertions(+), 2 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java index 4a3dbd60951..9c6de9ca419 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java @@ -20,6 +20,7 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; @@ -33,12 +34,14 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.support.GenericApplicationContext; +import java.util.ArrayList; import java.util.List; /** * Auto-configuration for common tool calling features of {@link ChatModel}. * * @author Thomas Vitale + * @author Christian Tzolov * @since 1.0.0 */ @AutoConfiguration @@ -48,8 +51,15 @@ public class ToolCallingAutoConfiguration { @Bean @ConditionalOnMissingBean ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, - List toolCallbacks) { - var staticToolCallbackResolver = new StaticToolCallbackResolver(toolCallbacks); + ObjectProvider> functionCallbacksProvider, + ObjectProvider> toolCallbacksProvider) { + + List allFunctionAndToolCallbacks = new ArrayList<>( + functionCallbacksProvider.stream().flatMap(List::stream).toList()); + allFunctionAndToolCallbacks.addAll(toolCallbacksProvider.stream().flatMap(List::stream).toList()); + + var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks); + var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder() .applicationContext(applicationContext) .build(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java index dd38a7b15a2..75538615e5c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java @@ -16,15 +16,30 @@ package org.springframework.ai.autoconfigure.chat.model; +import java.util.List; +import java.util.function.Function; + import org.junit.jupiter.api.Test; + +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.DefaultToolCallingManager; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; +import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -32,6 +47,7 @@ * Unit tests for {@link ToolCallingAutoConfiguration}. * * @author Thomas Vitale + * @author Christian Tzolov */ class ToolCallingAutoConfigurationTests { @@ -50,4 +66,112 @@ void beansAreCreated() { }); } + @Test + void resolveMultipleFuncitonAndToolCallbacks() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .run(context -> { + var toolCallbackResolver = context.getBean(ToolCallbackResolver.class); + assertThat(toolCallbackResolver).isInstanceOf(DelegatingToolCallbackResolver.class); + + assertThat(toolCallbackResolver.resolve("getForecast")).isNotNull(); + assertThat(toolCallbackResolver.resolve("getForecast").getName()).isEqualTo("getForecast"); + + assertThat(toolCallbackResolver.resolve("getAlert")).isNotNull(); + assertThat(toolCallbackResolver.resolve("getAlert").getName()).isEqualTo("getAlert"); + + assertThat(toolCallbackResolver.resolve("weatherFunction1")).isNotNull(); + assertThat(toolCallbackResolver.resolve("weatherFunction1").getName()).isEqualTo("weatherFunction1"); + + assertThat(toolCallbackResolver.resolve("getCurrentWeather3")).isNotNull(); + assertThat(toolCallbackResolver.resolve("getCurrentWeather3").getName()) + .isEqualTo("getCurrentWeather3"); + + assertThat(toolCallbackResolver.resolve("getCurrentWeather4")).isNotNull(); + assertThat(toolCallbackResolver.resolve("getCurrentWeather4").getName()) + .isEqualTo("getCurrentWeather4"); + + assertThat(toolCallbackResolver.resolve("getCurrentWeather5")).isNotNull(); + assertThat(toolCallbackResolver.resolve("getCurrentWeather5").getName()) + .isEqualTo("getCurrentWeather5"); + }); + } + + static class WeatherService { + + @Tool(description = "Get the weather in location. Return temperature in 36°F or 36°C format.") + public String getForecast(String location) { + return "30"; + } + + public String getAlert(String usState) { + return "Alergt"; + } + + } + + @Configuration + static class Config { + + // Note: Currently we do not have ToolCallbackResolver implementation that can + // resolve the ToolCallback from the Tool annotation. + // Therefore we need to provide the ToolCallback instances explicitly using the + // ToolCallbacks.from(...) utility method. + @Bean + public List toolCallbacks() { + return List.of(ToolCallbacks.from(new WeatherService())); + } + + public record Request(String location) { + } + + public record Response(String temperature) { + } + + @Bean + @Description("Get the weather in location. Return temperature in 36°F or 36°C format.") + public Function weatherFunction1() { + return request -> new Response("30"); + } + + @Bean + public List functionCallbacks3() { + return List.of(FunctionCallback.builder() + .function("getCurrentWeather3", (Request request) -> "15.0°C") + .description("Gets the weather in location") + .inputType(Request.class) + .build()); + } + + @Bean + public List functionCallbacks4() { + return List.of(FunctionCallback.builder() + .function("getCurrentWeather4", (Request request) -> "15.0°C") + .description("Gets the weather in location") + .inputType(Request.class) + .build()); + + } + + @Bean + public List toolCallbacks5() { + return List.of(FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C") + .description("Gets the weather in location") + .inputType(Request.class) + .build()); + + } + + @Bean + public List toolCallbacks6() { + var toolMethod = ReflectionUtils.findMethod(WeatherService.class, "getAlert", String.class); + return List.of(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod).build()) + .toolMethod(toolMethod) + .toolObject(new WeatherService()) + .build()); + } + + } + } From 0c1abecfc028e04911a6e750cefe8ced11479e12 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 13 Feb 2025 17:11:59 +0100 Subject: [PATCH 2/4] refactor: introduce new StaticToolCallbackProvider implementation - Update ToolCallbackProvider to return FunctionCallback[] - Migrate from List to ToolCallbackProvider in configurations - Update tests to use new provider pattern Signed-off-by: Christian Tzolov --- .../client/McpClientAutoConfiguration.java | 10 ++-- .../server/MpcServerAutoConfiguration.java | 28 +++++++++-- ...texAiGeminiPaymentTransactionMethodIT.java | 11 ++-- .../ai/tool/StaticToolCallbackProvider.java | 46 +++++++++++++++++ .../ai/tool/ToolCallbackProvider.java | 14 +++++- .../model/ToolCallingAutoConfiguration.java | 17 +++---- .../ToolCallingAutoConfigurationTests.java | 50 ++++++++++++------- 7 files changed, 134 insertions(+), 42 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java diff --git a/auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java b/auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java index 23a3d10130c..c427feef1c1 100644 --- a/auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java +++ b/auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java @@ -30,7 +30,7 @@ import org.springframework.ai.mcp.McpToolUtils; import org.springframework.ai.mcp.customizer.McpAsyncClientCustomizer; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; -import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -176,9 +176,9 @@ public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public List toolCallbacks(ObjectProvider> mcpClientsProvider) { + public ToolCallbackProvider toolCallbacks(ObjectProvider> mcpClientsProvider) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); - return McpToolUtils.getToolCallbacksFromSyncClients(mcpClients); + return ToolCallbackProvider.from(McpToolUtils.getToolCallbacksFromSyncClients(mcpClients)); } /** @@ -265,9 +265,9 @@ public List mcpAsyncClients(McpAsyncClientConfigurer mcpSyncClie @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - public List asyncToolCallbacks(ObjectProvider> mcpClientsProvider) { + public ToolCallbackProvider asyncToolCallbacks(ObjectProvider> mcpClientsProvider) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); - return McpToolUtils.getToolCallbacksFromAsyncClinents(mcpClients); + return ToolCallbackProvider.from(McpToolUtils.getToolCallbacksFromAsyncClinents(mcpClients)); } public record ClosebleMcpAsyncClients(List clients) implements AutoCloseable { diff --git a/auto-configurations/spring-ai-mcp-server/src/main/java/org/springframework/ai/autoconfigure/mcp/server/MpcServerAutoConfiguration.java b/auto-configurations/spring-ai-mcp-server/src/main/java/org/springframework/ai/autoconfigure/mcp/server/MpcServerAutoConfiguration.java index dfb8f9760f7..d7303cc9ba1 100644 --- a/auto-configurations/spring-ai-mcp-server/src/main/java/org/springframework/ai/autoconfigure/mcp/server/MpcServerAutoConfiguration.java +++ b/auto-configurations/spring-ai-mcp-server/src/main/java/org/springframework/ai/autoconfigure/mcp/server/MpcServerAutoConfiguration.java @@ -16,9 +16,11 @@ package org.springframework.ai.autoconfigure.mcp.server; +import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Stream; import io.modelcontextprotocol.server.McpAsyncServer; import io.modelcontextprotocol.server.McpServer; @@ -39,7 +41,9 @@ import reactor.core.publisher.Mono; import org.springframework.ai.mcp.McpToolUtils; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -135,7 +139,8 @@ public McpSyncServer mcpSyncServer(ServerMcpTransport transport, McpSchema.ServerCapabilities.Builder capabilitiesBuilder, McpServerProperties serverProperties, ObjectProvider> tools, ObjectProvider> resources, ObjectProvider> prompts, - ObjectProvider>> rootsChangeConsumers) { + ObjectProvider>> rootsChangeConsumers, + List toolCallbackProvider) { McpSchema.Implementation serverInfo = new Implementation(serverProperties.getName(), serverProperties.getVersion()); @@ -143,7 +148,14 @@ public McpSyncServer mcpSyncServer(ServerMcpTransport transport, // Create the server with both tool and resource capabilities SyncSpec serverBuilder = McpServer.sync(transport).serverInfo(serverInfo); - List toolResgistrations = tools.stream().flatMap(List::stream).toList(); + List toolResgistrations = new ArrayList<>(tools.stream().flatMap(List::stream).toList()); + List providerToolCallbacks = toolCallbackProvider.stream() + .map(pr -> List.of(pr.getToolCallbacks())) + .flatMap(List::stream) + .filter(fc -> fc instanceof ToolCallback) + .map(fc -> (ToolCallback) fc) + .toList(); + toolResgistrations.addAll(McpToolUtils.toSyncToolRegistration(providerToolCallbacks)); if (!CollectionUtils.isEmpty(toolResgistrations)) { serverBuilder.tools(toolResgistrations); capabilitiesBuilder.tools(serverProperties.isToolChangeNotification()); @@ -191,7 +203,8 @@ public McpAsyncServer mcpAsyncServer(ServerMcpTransport transport, ObjectProvider> tools, ObjectProvider> resources, ObjectProvider> prompts, - ObjectProvider>> rootsChangeConsumer) { + ObjectProvider>> rootsChangeConsumer, + List toolCallbackProvider) { McpSchema.Implementation serverInfo = new Implementation(serverProperties.getName(), serverProperties.getVersion()); @@ -199,7 +212,14 @@ public McpAsyncServer mcpAsyncServer(ServerMcpTransport transport, // Create the server with both tool and resource capabilities AsyncSpec serverBilder = McpServer.async(transport).serverInfo(serverInfo); - List toolResgistrations = tools.stream().flatMap(List::stream).toList(); + List toolResgistrations = new ArrayList<>(tools.stream().flatMap(List::stream).toList()); + List providerToolCallbacks = toolCallbackProvider.stream() + .map(pr -> List.of(pr.getToolCallbacks())) + .flatMap(List::stream) + .filter(fc -> fc instanceof ToolCallback) + .map(fc -> (ToolCallback) fc) + .toList(); + toolResgistrations.addAll(McpToolUtils.toAsyncToolRegistration(providerToolCallbacks)); if (!CollectionUtils.isEmpty(toolResgistrations)) { serverBilder.tools(toolResgistrations); capabilitiesBuilder.tools(serverProperties.isToolChangeNotification()); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java index 85c9c5218f8..5d5fe556a4d 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java @@ -38,7 +38,7 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.ToolCallbacks; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; @@ -183,9 +183,8 @@ public List statusespaymentStatuses(List transactions) { public static class TestConfiguration { @Bean - public List paymentServiceTools() { - var tools = List.of(ToolCallbacks.from(new PaymentService())); - return tools; + public ToolCallbackProvider paymentServiceTools() { + return ToolCallbackProvider.from(List.of(ToolCallbacks.from(new PaymentService()))); } @Bean @@ -221,11 +220,11 @@ public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ToolCallingM @Bean ToolCallingManager toolCallingManager(GenericApplicationContext applicationContext, - List toolCallbacks, List functionCallbacks, + List tcps, List functionCallbacks, ObjectProvider observationRegistry) { List allFunctionCallbacks = new ArrayList(functionCallbacks); - allFunctionCallbacks.addAll(toolCallbacks.stream().map(tc -> (FunctionCallback) tc).toList()); + tcps.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionCallbacks::addAll); var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionCallbacks); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java new file mode 100644 index 00000000000..51e326d8299 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java @@ -0,0 +1,46 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.tool; + +import java.util.List; + +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.util.Assert; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public class StaticToolCallbackProvider implements ToolCallbackProvider { + + private final FunctionCallback[] toolCallbacks; + + public StaticToolCallbackProvider(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "ToolCallbacks must not be null"); + this.toolCallbacks = toolCallbacks; + } + + public StaticToolCallbackProvider(List toolCallbacks) { + Assert.notNull(toolCallbacks, "ToolCallbacks must not be null"); + this.toolCallbacks = toolCallbacks.toArray(new FunctionCallback[0]); + } + + @Override + public FunctionCallback[] getToolCallbacks() { + return this.toolCallbacks; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java index e5e4d01319c..df17efffa46 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java @@ -16,6 +16,10 @@ package org.springframework.ai.tool; +import java.util.List; + +import org.springframework.ai.model.function.FunctionCallback; + /** * Provides {@link ToolCallback} instances for tools defined in different sources. * @@ -24,6 +28,14 @@ */ public interface ToolCallbackProvider { - ToolCallback[] getToolCallbacks(); + FunctionCallback[] getToolCallbacks(); + + public static ToolCallbackProvider from(List toolCallbacks) { + return new StaticToolCallbackProvider(toolCallbacks); + } + + public static ToolCallbackProvider from(FunctionCallback... toolCallbacks) { + return new StaticToolCallbackProvider(toolCallbacks); + } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java index 9c6de9ca419..a530bf7dc5a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java @@ -16,11 +16,15 @@ package org.springframework.ai.autoconfigure.chat.model; +import java.util.ArrayList; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; @@ -34,9 +38,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.support.GenericApplicationContext; -import java.util.ArrayList; -import java.util.List; - /** * Auto-configuration for common tool calling features of {@link ChatModel}. * @@ -51,12 +52,10 @@ public class ToolCallingAutoConfiguration { @Bean @ConditionalOnMissingBean ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, - ObjectProvider> functionCallbacksProvider, - ObjectProvider> toolCallbacksProvider) { + List functionCallbacks, List tcbProviders) { - List allFunctionAndToolCallbacks = new ArrayList<>( - functionCallbacksProvider.stream().flatMap(List::stream).toList()); - allFunctionAndToolCallbacks.addAll(toolCallbacksProvider.stream().flatMap(List::stream).toList()); + List allFunctionAndToolCallbacks = new ArrayList<>(functionCallbacks); + tcbProviders.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionAndToolCallbacks::addAll); var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java index 75538615e5c..3c6cca70d08 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java @@ -16,7 +16,6 @@ package org.springframework.ai.autoconfigure.chat.model; -import java.util.List; import java.util.function.Function; import org.junit.jupiter.api.Test; @@ -24,14 +23,16 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.DefaultToolCallingManager; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.StaticToolCallbackProvider; import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.ai.tool.method.MethodToolCallback; +import org.springframework.ai.tool.method.MethodToolCallbackProvider; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -104,8 +105,13 @@ public String getForecast(String location) { return "30"; } + @Tool(description = "Get the weather in location. Return temperature in 36°F or 36°C format.") + public String getForecast2(String location) { + return "30"; + } + public String getAlert(String usState) { - return "Alergt"; + return "Alert"; } } @@ -118,8 +124,8 @@ static class Config { // Therefore we need to provide the ToolCallback instances explicitly using the // ToolCallbacks.from(...) utility method. @Bean - public List toolCallbacks() { - return List.of(ToolCallbacks.from(new WeatherService())); + public ToolCallbackProvider toolCallbacks() { + return MethodToolCallbackProvider.builder().toolObjects(new WeatherService()).build(); } public record Request(String location) { @@ -135,41 +141,51 @@ public Function weatherFunction1() { } @Bean - public List functionCallbacks3() { - return List.of(FunctionCallback.builder() + public FunctionCallback functionCallbacks3() { + return FunctionCallback.builder() .function("getCurrentWeather3", (Request request) -> "15.0°C") .description("Gets the weather in location") .inputType(Request.class) - .build()); + .build(); } @Bean - public List functionCallbacks4() { - return List.of(FunctionCallback.builder() + public FunctionCallback functionCallbacks4() { + return FunctionCallback.builder() .function("getCurrentWeather4", (Request request) -> "15.0°C") .description("Gets the weather in location") .inputType(Request.class) - .build()); + .build(); } @Bean - public List toolCallbacks5() { - return List.of(FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C") + public ToolCallback toolCallbacks5() { + return FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C") .description("Gets the weather in location") .inputType(Request.class) - .build()); + .build(); + + } + + @Bean + public ToolCallbackProvider blabla() { + return new StaticToolCallbackProvider( + FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C") + .description("Gets the weather in location") + .inputType(Request.class) + .build()); } @Bean - public List toolCallbacks6() { + public ToolCallback toolCallbacks6() { var toolMethod = ReflectionUtils.findMethod(WeatherService.class, "getAlert", String.class); - return List.of(MethodToolCallback.builder() + return MethodToolCallback.builder() .toolDefinition(ToolDefinition.builder(toolMethod).build()) .toolMethod(toolMethod) .toolObject(new WeatherService()) - .build()); + .build(); } } From d18262e8473378a45f42cfef99886d09d03aba19 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 13 Feb 2025 20:06:20 +0100 Subject: [PATCH 3/4] refactor(mcp): enhance tool callback providers to support multiple clients Refactor AsyncMcpToolCallbackProvider and SyncMcpToolCallbackProvider to handle multiple MCP clients Add ToolCallbackProvider support to ChatClient API Deprecate direct tool callback list methods in favor of providers Fix typos in Closeable class names Update MCP documentation with new examples and usage patterns Signed-off-by: Christian Tzolov --- .../client/McpClientAutoConfiguration.java | 45 ++++++-- .../client/McpClientAutoConfigurationIT.java | 2 +- .../ai/mcp/AsyncMcpToolCallbackProvider.java | 104 ++++++++++++------ .../springframework/ai/mcp/McpToolUtils.java | 10 +- .../ai/mcp/SyncMcpToolCallbackProvider.java | 88 +++++++++------ .../ai/chat/client/ChatClient.java | 5 + .../ai/chat/client/DefaultChatClient.java | 11 ++ .../chat/client/DefaultChatClientBuilder.java | 7 ++ .../ai/tool/StaticToolCallbackProvider.java | 2 +- .../api/mcp/mcp-client-boot-starter-docs.adoc | 6 +- .../api/mcp/mcp-server-boot-starter-docs.adoc | 14 +-- 11 files changed, 201 insertions(+), 93 deletions(-) diff --git a/auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java b/auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java index c427feef1c1..c6471ed3a25 100644 --- a/auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java +++ b/auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java @@ -27,9 +27,11 @@ import org.springframework.ai.autoconfigure.mcp.client.configurer.McpAsyncClientConfigurer; import org.springframework.ai.autoconfigure.mcp.client.configurer.McpSyncClientConfigurer; import org.springframework.ai.autoconfigure.mcp.client.properties.McpClientCommonProperties; -import org.springframework.ai.mcp.McpToolUtils; +import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; +import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.customizer.McpAsyncClientCustomizer; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -178,7 +180,20 @@ public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC matchIfMissing = true) public ToolCallbackProvider toolCallbacks(ObjectProvider> mcpClientsProvider) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); - return ToolCallbackProvider.from(McpToolUtils.getToolCallbacksFromSyncClients(mcpClients)); + return new SyncMcpToolCallbackProvider(mcpClients); + } + + /** + * @deprecated replaced by {@link #toolCallbacks(ObjectProvider)} that returns a + * {@link ToolCallbackProvider} instead of a list of {@link ToolCallback} + */ + @Deprecated + @Bean + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", + matchIfMissing = true) + public List toolCallbacksDeprecated(ObjectProvider> mcpClientsProvider) { + List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); + return List.of(new SyncMcpToolCallbackProvider(mcpClients).getToolCallbacks()); } /** @@ -189,7 +204,7 @@ public ToolCallbackProvider toolCallbacks(ObjectProvider> mc * This class is responsible for closing all MCP sync clients when the application * context is closed, preventing resource leaks. */ - public record ClosebleMcpSyncClients(List clients) implements AutoCloseable { + public record CloseableMcpSyncClients(List clients) implements AutoCloseable { @Override public void close() { @@ -205,8 +220,8 @@ public void close() { @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public ClosebleMcpSyncClients makeSyncClientsClosable(List clients) { - return new ClosebleMcpSyncClients(clients); + public CloseableMcpSyncClients makeSyncClientsClosable(List clients) { + return new CloseableMcpSyncClients(clients); } /** @@ -263,14 +278,26 @@ public List mcpAsyncClients(McpAsyncClientConfigurer mcpSyncClie return mcpSyncClients; } + /** + * @deprecated replaced by {@link #asyncToolCallbacks(ObjectProvider)} that returns a + * {@link ToolCallbackProvider} instead of a list of {@link ToolCallback} + */ + @Deprecated + @Bean + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") + public List asyncToolCallbacksDeprecated(ObjectProvider> mcpClientsProvider) { + List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); + return List.of(new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks()); + } + @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") public ToolCallbackProvider asyncToolCallbacks(ObjectProvider> mcpClientsProvider) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); - return ToolCallbackProvider.from(McpToolUtils.getToolCallbacksFromAsyncClinents(mcpClients)); + return new AsyncMcpToolCallbackProvider(mcpClients); } - public record ClosebleMcpAsyncClients(List clients) implements AutoCloseable { + public record CloseableMcpAsyncClients(List clients) implements AutoCloseable { @Override public void close() { this.clients.forEach(McpAsyncClient::close); @@ -279,8 +306,8 @@ public void close() { @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - public ClosebleMcpAsyncClients makeAsynClientsClosable(List clients) { - return new ClosebleMcpAsyncClients(clients); + public CloseableMcpAsyncClients makeAsynClientsClosable(List clients) { + return new CloseableMcpAsyncClients(clients); } @Bean diff --git a/auto-configurations/spring-ai-mcp-client/src/test/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfigurationIT.java b/auto-configurations/spring-ai-mcp-client/src/test/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfigurationIT.java index 780d105d973..dd08fd85916 100644 --- a/auto-configurations/spring-ai-mcp-client/src/test/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfigurationIT.java +++ b/auto-configurations/spring-ai-mcp-client/src/test/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfigurationIT.java @@ -122,7 +122,7 @@ void toolCallbacksCreation() { @Test void closeableWrappersCreation() { this.contextRunner.withUserConfiguration(TestTransportConfiguration.class).run(context -> { - assertThat(context).hasSingleBean(McpClientAutoConfiguration.ClosebleMcpSyncClients.class); + assertThat(context).hasSingleBean(McpClientAutoConfiguration.CloseableMcpSyncClients.class); }); } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index f8f1bb1ebc0..6b645952de8 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -15,11 +15,12 @@ */ package org.springframework.ai.mcp; +import java.util.ArrayList; import java.util.List; import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; @@ -28,18 +29,20 @@ /** * Implementation of {@link ToolCallbackProvider} that discovers and provides MCP tools - * asynchronously. + * asynchronously from one or more MCP servers. *

* This class acts as a tool provider for Spring AI, automatically discovering tools from - * an MCP server and making them available as Spring AI tools. It: + * multiple MCP servers and making them available as Spring AI tools. It: *

    - *
  • Connects to an MCP server through an async client
  • - *
  • Lists and retrieves available tools from the server
  • + *
  • Connects to MCP servers through async clients
  • + *
  • Lists and retrieves available tools from each server asynchronously
  • *
  • Creates {@link AsyncMcpToolCallback} instances for each discovered tool
  • - *
  • Validates tool names to prevent duplicates
  • + *
  • Validates tool names to prevent duplicates across all servers
  • *
*

- * Example usage:

{@code
+ * Example usage with a single client:
+ *
+ * 
{@code
  * McpAsyncClient mcpClient = // obtain MCP client
  * ToolCallbackProvider provider = new AsyncMcpToolCallbackProvider(mcpClient);
  *
@@ -47,6 +50,19 @@
  * ToolCallback[] tools = provider.getToolCallbacks();
  * }
* + * Example usage with multiple clients: + * + *
{@code
+ * List mcpClients = // obtain multiple MCP clients
+ * ToolCallbackProvider provider = new AsyncMcpToolCallbackProvider(mcpClients);
+ *
+ * // Get tools from all clients
+ * ToolCallback[] tools = provider.getToolCallbacks();
+ *
+ * // Or use the reactive API
+ * Flux toolsFlux = AsyncMcpToolCallbackProvider.asyncToolCallbacks(mcpClients);
+ * }
+ * * @author Christian Tzolov * @since 1.0.0 * @see ToolCallbackProvider @@ -55,40 +71,61 @@ */ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { - private final McpAsyncClient mcpClient; + private final List mcpClients; /** - * Creates a new {@code AsyncMcpToolCallbackProvider} instance. - * @param mcpClient the MCP client to use for discovering tools + * Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP + * clients. + * @param mcpClients the list of MCP clients to use for discovering tools. Each client + * typically connects to a different MCP server, allowing tool discovery from multiple + * sources. + * @throws IllegalArgumentException if mcpClients is null */ - public AsyncMcpToolCallbackProvider(McpAsyncClient mcpClient) { - this.mcpClient = mcpClient; + public AsyncMcpToolCallbackProvider(List mcpClients) { + Assert.notNull(mcpClients, "McpClients must not be null"); + this.mcpClients = mcpClients; + } + + public AsyncMcpToolCallbackProvider(McpAsyncClient... mcpClients) { + Assert.notNull(mcpClients, "McpClients must not be null"); + this.mcpClients = List.of(mcpClients); } /** - * Discovers and returns all available tools from the MCP server asynchronously. + * Discovers and returns all available tools from the configured MCP servers. *

* This method: *

    - *
  1. Retrieves the list of tools from the MCP server
  2. - *
  3. Creates a {@link AsyncMcpToolCallback} for each tool
  4. - *
  5. Validates that there are no duplicate tool names
  6. + *
  7. Retrieves the list of tools from each MCP server asynchronously
  8. + *
  9. Creates a {@link AsyncMcpToolCallback} for each discovered tool
  10. + *
  11. Validates that there are no duplicate tool names across all servers
  12. *
+ *

+ * Note: While the underlying tool discovery is asynchronous, this method blocks until + * all tools are discovered from all servers. * @return an array of tool callbacks, one for each discovered tool * @throws IllegalStateException if duplicate tool names are found */ @Override public ToolCallback[] getToolCallbacks() { - var toolCallbacks = this.mcpClient.listTools() - .map(response -> response.tools() - .stream() - .map(tool -> new AsyncMcpToolCallback(this.mcpClient, tool)) - .toArray(ToolCallback[]::new)) - .block(); - validateToolCallbacks(toolCallbacks); + List toolCallbackList = new ArrayList<>(); + + for (McpAsyncClient mcpClient : this.mcpClients) { + + ToolCallback[] toolCallbacks = mcpClient.listTools() + .map(response -> response.tools() + .stream() + .map(tool -> new AsyncMcpToolCallback(mcpClient, tool)) + .toArray(ToolCallback[]::new)) + .block(); - return toolCallbacks; + validateToolCallbacks(toolCallbacks); + + toolCallbackList.addAll(List.of(toolCallbacks)); + } + + return toolCallbackList.toArray(new ToolCallback[0]); } /** @@ -110,12 +147,19 @@ private void validateToolCallbacks(ToolCallback[] toolCallbacks) { /** * Creates a reactive stream of tool callbacks from multiple MCP clients. *

- * This utility method: + * This utility method provides a reactive way to work with tool callbacks from + * multiple MCP clients in a single operation. It: *

    - *
  1. Takes a list of MCP clients
  2. - *
  3. Creates a provider for each client
  4. - *
  5. Retrieves and flattens all tool callbacks into a single stream
  6. + *
  7. Takes a list of MCP clients as input
  8. + *
  9. Creates a provider instance to manage all clients
  10. + *
  11. Retrieves tools from all clients asynchronously
  12. + *
  13. Combines them into a single reactive stream
  14. + *
  15. Ensures there are no naming conflicts between tools from different clients
  16. *
+ *

+ * Unlike {@link #getToolCallbacks()}, this method provides a fully reactive way to + * work with tool callbacks, making it suitable for non-blocking applications. Any + * errors during tool discovery will be propagated through the returned Flux. * @param mcpClients the list of MCP clients to create callbacks from * @return a Flux of tool callbacks from all provided clients */ @@ -124,9 +168,7 @@ public static Flux asyncToolCallbacks(List mcpClie return Flux.empty(); } - return Flux.fromIterable(mcpClients) - .flatMap(mcpClient -> Mono.just(new AsyncMcpToolCallbackProvider(mcpClient).getToolCallbacks())) - .flatMap(callbacks -> Flux.fromArray(callbacks)); + return Flux.fromArray(new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks()); } } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index 5a69b92b8e5..3a58b9d2737 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -211,10 +211,7 @@ public static List getToolCallbacksFromSyncClients(List List.of((new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks()))) - .flatMap(List::stream) - .toList(); + return List.of((new SyncMcpToolCallbackProvider(mcpClients).getToolCallbacks())); } /** @@ -247,10 +244,7 @@ public static List getToolCallbacksFromAsyncClinents(List List.of((new AsyncMcpToolCallbackProvider(mcpClient).getToolCallbacks()))) - .flatMap(List::stream) - .toList(); + return List.of((new AsyncMcpToolCallbackProvider(asynMcpClients).getToolCallbacks())); } } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java index ea18baf42e3..945dc4c1f9f 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.mcp; +import java.util.ArrayList; import java.util.List; import io.modelcontextprotocol.client.McpSyncClient; @@ -25,18 +26,21 @@ import org.springframework.util.CollectionUtils; /** - * Implementation of {@link ToolCallbackProvider} that discovers and provides MCP tools. + * Implementation of {@link ToolCallbackProvider} that discovers and provides MCP tools + * from one or more MCP servers. *

* This class acts as a tool provider for Spring AI, automatically discovering tools from - * an MCP server and making them available as Spring AI tools. It: + * multiple MCP servers and making them available as Spring AI tools. It: *

    - *
  • Connects to an MCP server through a sync client
  • - *
  • Lists and retrieves available tools from the server
  • + *
  • Connects to one or more MCP servers through sync clients
  • + *
  • Lists and retrieves available tools from all connected servers
  • *
  • Creates {@link SyncMcpToolCallback} instances for each discovered tool
  • - *
  • Validates tool names to prevent duplicates
  • + *
  • Validates tool names to prevent duplicates across all servers
  • *
*

- * Example usage:

{@code
+ * Example usage with a single client:
+ *
+ * 
{@code
  * McpSyncClient mcpClient = // obtain MCP client
  * ToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClient);
  *
@@ -44,6 +48,16 @@
  * ToolCallback[] tools = provider.getToolCallbacks();
  * }
* + * Example usage with multiple clients: + * + *
{@code
+ * List mcpClients = // obtain multiple MCP clients
+ * ToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClients);
+ *
+ * // Get tools from all clients
+ * ToolCallback[] tools = provider.getToolCallbacks();
+ * }
+ * * @author Christian Tzolov * @since 1.0.0 * @see ToolCallbackProvider @@ -53,24 +67,29 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { - private final McpSyncClient mcpClient; + private final List mcpClients; /** - * Creates a new {@code SyncMcpToolCallbackProvider} instance. - * @param mcpClient the MCP client to use for discovering tools + * Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP + * clients. + * @param mcpClients the list of MCP clients to use for discovering tools */ - public SyncMcpToolCallbackProvider(McpSyncClient mcpClient) { - this.mcpClient = mcpClient; + public SyncMcpToolCallbackProvider(List mcpClients) { + this.mcpClients = mcpClients; + } + + public SyncMcpToolCallbackProvider(McpSyncClient... mcpClients) { + this.mcpClients = List.of(mcpClients); } /** - * Discovers and returns all available tools from the MCP server. + * Discovers and returns all available tools from all connected MCP servers. *

* This method: *

    - *
  1. Retrieves the list of tools from the MCP server
  2. - *
  3. Creates a {@link SyncMcpToolCallback} for each tool
  4. - *
  5. Validates that there are no duplicate tool names
  6. + *
  7. Retrieves the list of tools from each connected MCP server
  8. + *
  9. Creates a {@link SyncMcpToolCallback} for each discovered tool
  10. + *
  11. Validates that there are no duplicate tool names across all servers
  12. *
* @return an array of tool callbacks, one for each discovered tool * @throws IllegalStateException if duplicate tool names are found @@ -78,16 +97,18 @@ public SyncMcpToolCallbackProvider(McpSyncClient mcpClient) { @Override public ToolCallback[] getToolCallbacks() { - var toolCallbacks = this.mcpClient.listTools() - .tools() - .stream() - .map(tool -> new SyncMcpToolCallback(this.mcpClient, tool)) - .toArray(ToolCallback[]::new); - - validateToolCallbacks(toolCallbacks); - - return toolCallbacks; - + var toolCallbacks = new ArrayList<>(); + + mcpClients.stream().forEach(mcpClient -> { + toolCallbacks.addAll(mcpClient.listTools() + .tools() + .stream() + .map(tool -> new SyncMcpToolCallback(mcpClient, tool)) + .toList()); + }); + var array = toolCallbacks.toArray(new ToolCallback[0]); + validateToolCallbacks(array); + return array; } /** @@ -107,13 +128,15 @@ private void validateToolCallbacks(ToolCallback[] toolCallbacks) { } /** - * Creates a list of tool callbacks from multiple MCP clients. + * Creates a consolidated list of tool callbacks from multiple MCP clients. *

- * This utility method: + * This utility method provides a convenient way to create tool callbacks from + * multiple MCP clients in a single operation. It: *

    - *
  1. Takes a list of MCP clients
  2. - *
  3. Creates a provider for each client
  4. - *
  5. Retrieves and combines all tool callbacks into a single list
  6. + *
  7. Takes a list of MCP clients as input
  8. + *
  9. Creates a provider instance to manage all clients
  10. + *
  11. Retrieves tools from all clients and combines them into a single list
  12. + *
  13. Ensures there are no naming conflicts between tools from different clients
  14. *
* @param mcpClients the list of MCP clients to create callbacks from * @return a list of tool callbacks from all provided clients @@ -123,10 +146,7 @@ public static List syncToolCallbacks(List mcpClient if (CollectionUtils.isEmpty(mcpClients)) { return List.of(); } - return mcpClients.stream() - .map(mcpClient -> List.of((new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks()))) - .flatMap(List::stream) - .toList(); + return List.of((new SyncMcpToolCallbackProvider(mcpClients).getToolCallbacks())); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 07bae3c30cc..6c1b99c65c5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -36,6 +36,7 @@ import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; @@ -223,6 +224,8 @@ interface ChatClientRequestSpec { ChatClientRequestSpec tools(Object... toolObjects); + ChatClientRequestSpec tools(ToolCallbackProvider... toolCallbackProvider); + @Deprecated ChatClientRequestSpec functions(FunctionCallback... functionCallbacks); @@ -290,6 +293,8 @@ interface Builder { Builder defaultTools(Object... toolObjects); + Builder defaultTools(ToolCallbackProvider... toolCallbackProvider); + /** * @deprecated in favor of {@link #defaultTools(String...)} */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index cb7599f0b45..e0b68d75de3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -35,6 +35,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.ToolCallbacks; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -871,6 +872,16 @@ public ChatClientRequestSpec tools(Object... toolObjects) { return this; } + @Override + public ChatClientRequestSpec tools(ToolCallbackProvider... toolCallbackProviders) { + Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null"); + Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements"); + for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) { + this.functionCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks())); + } + return this; + } + @Deprecated // Use tools() public ChatClientRequestSpec functions(String... functionBeanNames) { return tools(functionBeanNames); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 2349cd61945..594fd355222 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -36,6 +36,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -173,6 +174,12 @@ public Builder defaultTools(Object... toolObjects) { return this; } + @Override + public Builder defaultTools(ToolCallbackProvider... toolCallbackProviders) { + this.defaultRequest.tools(toolCallbackProviders); + return this; + } + @Deprecated // Use defaultTools() public Builder defaultFunction(String name, String description, java.util.function.Function function) { this.defaultRequest diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java index 51e326d8299..d405790208d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java @@ -34,7 +34,7 @@ public StaticToolCallbackProvider(FunctionCallback... toolCallbacks) { } public StaticToolCallbackProvider(List toolCallbacks) { - Assert.notNull(toolCallbacks, "ToolCallbacks must not be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = toolCallbacks.toArray(new FunctionCallback[0]); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-client-boot-starter-docs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-client-boot-starter-docs.adoc index 9d2c241e61a..f34e08d7d0d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-client-boot-starter-docs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-client-boot-starter-docs.adoc @@ -351,12 +351,14 @@ private List mcpSyncClients; // For sync client private List mcpAsyncClients; // For async client ---- -Additionally, the registered MCP Tools with all MCP clients are provided as a list of ToolCallback instances: +Additionally, the registered MCP Tools with all MCP clients are provided as a list of ToolCallback +throurgh a ToolCallbackProvider instance: [source,java] ---- @Autowired -private List toolCallbacks; +private SyncMcpToolCallbackProvider toolCallbackProvider; +ToolCallback[] toolCallbacks = toolCallbackProvider.getToolCallbacks(); ---- == Example Applications diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-server-boot-starter-docs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-server-boot-starter-docs.adoc index 758859bd719..5d2598b8673 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-server-boot-starter-docs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-server-boot-starter-docs.adoc @@ -126,9 +126,9 @@ Allows servers to expose tools that can be invoked by language models. The MCP S [source,java] ---- @Bean -public List myTools(...) { +public ToolCallbackProvier myTools(...) { List tools = ... - return tools; + return ToolCallbackProvier.from(tools); } ---- @@ -284,15 +284,15 @@ public class McpServerApplication { SpringApplication.run(McpServerApplication.class, args); } - @Bean - public List tools(WeatherService weatherService) { - return ToolCallbacks.from(weatherService); - } + @Bean + public ToolCallbackProvider weatherTools(WeatherService weatherService) { + return MethodToolCallbackProvider.builder().toolObjects(weatherService).build(); + } } ---- The auto-configuration will automatically register the tool callbacks as MCP tools. -You can have multiple beans producing lists of ToolCallbacks. The auto-configuration will merge them. +You can have multiple beans producing ToolCallbacks. The auto-configuration will merge them. == Example Applications * link:https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/weather/starter-webflux-server[Weather Server (WebFlux)] - Spring AI MCP Server Boot Starter with WebFlux transport. From fdc8a21f3203c95027e19c26e92ca554ca0ddf5e Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 13 Feb 2025 20:30:25 +0100 Subject: [PATCH 4/4] addres review comments Signed-off-by: Christian Tzolov --- .../ai/chat/client/ChatClient.java | 4 +- .../ai/tool/StaticToolCallbackProvider.java | 44 +++++++++++++++++++ .../api/mcp/mcp-client-boot-starter-docs.adoc | 2 +- .../api/mcp/mcp-server-boot-starter-docs.adoc | 4 +- 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 6c1b99c65c5..1af0a20ea77 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -224,7 +224,7 @@ interface ChatClientRequestSpec { ChatClientRequestSpec tools(Object... toolObjects); - ChatClientRequestSpec tools(ToolCallbackProvider... toolCallbackProvider); + ChatClientRequestSpec tools(ToolCallbackProvider... toolCallbackProviders); @Deprecated ChatClientRequestSpec functions(FunctionCallback... functionCallbacks); @@ -293,7 +293,7 @@ interface Builder { Builder defaultTools(Object... toolObjects); - Builder defaultTools(ToolCallbackProvider... toolCallbackProvider); + Builder defaultTools(ToolCallbackProvider... toolCallbackProviders); /** * @deprecated in favor of {@link #defaultTools(String...)} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java index d405790208d..4ab2fc87069 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java @@ -21,23 +21,67 @@ import org.springframework.util.Assert; /** + * A simple implementation of {@link ToolCallbackProvider} that maintains a static array + * of {@link FunctionCallback} objects. This provider is immutable after construction and + * provides a straightforward way to supply a fixed set of tool callbacks to AI models. + * + *

+ * This implementation is thread-safe as it maintains an immutable array of callbacks that + * is set during construction and cannot be modified afterwards. + * + *

+ * Example usage:

{@code
+ * FunctionCallback callback1 = new MyFunctionCallback();
+ * FunctionCallback callback2 = new AnotherFunctionCallback();
+ *
+ * // Create provider with varargs constructor
+ * ToolCallbackProvider provider1 = new StaticToolCallbackProvider(callback1, callback2);
+ *
+ * // Or create provider with List constructor
+ * List callbacks = Arrays.asList(callback1, callback2);
+ * ToolCallbackProvider provider2 = new StaticToolCallbackProvider(callbacks);
+ * }
+ * * @author Christian Tzolov * @since 1.0.0 + * @see ToolCallbackProvider + * @see FunctionCallback */ public class StaticToolCallbackProvider implements ToolCallbackProvider { private final FunctionCallback[] toolCallbacks; + /** + * Constructs a new StaticToolCallbackProvider with the specified array of function + * callbacks. + * @param toolCallbacks the array of function callbacks to be provided by this + * provider. Must not be null, though an empty array is permitted. + * @throws IllegalArgumentException if the toolCallbacks array is null + */ public StaticToolCallbackProvider(FunctionCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "ToolCallbacks must not be null"); this.toolCallbacks = toolCallbacks; } + /** + * Constructs a new StaticToolCallbackProvider with the specified list of function + * callbacks. The list is converted to an array internally. + * @param toolCallbacks the list of function callbacks to be provided by this + * provider. Must not be null and must not contain null elements. + * @throws IllegalArgumentException if the toolCallbacks list is null or contains null + * elements + */ public StaticToolCallbackProvider(List toolCallbacks) { Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = toolCallbacks.toArray(new FunctionCallback[0]); } + /** + * Returns the array of function callbacks held by this provider. + * @return an array containing all function callbacks provided during construction. + * The returned array is a direct reference to the internal array, as the callbacks + * are expected to be immutable. + */ @Override public FunctionCallback[] getToolCallbacks() { return this.toolCallbacks; diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-client-boot-starter-docs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-client-boot-starter-docs.adoc index f34e08d7d0d..76cd532971d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-client-boot-starter-docs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-client-boot-starter-docs.adoc @@ -352,7 +352,7 @@ private List mcpAsyncClients; // For async client ---- Additionally, the registered MCP Tools with all MCP clients are provided as a list of ToolCallback -throurgh a ToolCallbackProvider instance: +through a ToolCallbackProvider instance: [source,java] ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-server-boot-starter-docs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-server-boot-starter-docs.adoc index 5d2598b8673..0f9929a3678 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-server-boot-starter-docs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-server-boot-starter-docs.adoc @@ -126,9 +126,9 @@ Allows servers to expose tools that can be invoked by language models. The MCP S [source,java] ---- @Bean -public ToolCallbackProvier myTools(...) { +public ToolCallbackProvider myTools(...) { List tools = ... - return ToolCallbackProvier.from(tools); + return ToolCallbackProvider.from(tools); } ----