Skip to content

Commit d5412a7

Browse files
committed
refactor: introduce new StaticToolCallbackProvider implementation
- Update ToolCallbackProvider to return FunctionCallback[] - Migrate from List to ToolCallbackProvider in configurations - Update tests to use new provider pattern Signed-off-by: Christian Tzolov <[email protected]>
1 parent c945741 commit d5412a7

File tree

7 files changed

+134
-42
lines changed

7 files changed

+134
-42
lines changed

auto-configurations/spring-ai-mcp-client/src/main/java/org/springframework/ai/autoconfigure/mcp/client/McpClientAutoConfiguration.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import org.springframework.ai.mcp.McpToolUtils;
3131
import org.springframework.ai.mcp.customizer.McpAsyncClientCustomizer;
3232
import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer;
33-
import org.springframework.ai.tool.ToolCallback;
33+
import org.springframework.ai.tool.ToolCallbackProvider;
3434
import org.springframework.beans.factory.ObjectProvider;
3535
import org.springframework.boot.autoconfigure.AutoConfiguration;
3636
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
@@ -176,9 +176,9 @@ public List<McpSyncClient> mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC
176176
@Bean
177177
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
178178
matchIfMissing = true)
179-
public List<ToolCallback> toolCallbacks(ObjectProvider<List<McpSyncClient>> mcpClientsProvider) {
179+
public ToolCallbackProvider toolCallbacks(ObjectProvider<List<McpSyncClient>> mcpClientsProvider) {
180180
List<McpSyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
181-
return McpToolUtils.getToolCallbacksFromSyncClients(mcpClients);
181+
return ToolCallbackProvider.from(McpToolUtils.getToolCallbacksFromSyncClients(mcpClients));
182182
}
183183

184184
/**
@@ -265,9 +265,9 @@ public List<McpAsyncClient> mcpAsyncClients(McpAsyncClientConfigurer mcpSyncClie
265265

266266
@Bean
267267
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
268-
public List<ToolCallback> asyncToolCallbacks(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
268+
public ToolCallbackProvider asyncToolCallbacks(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
269269
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
270-
return McpToolUtils.getToolCallbacksFromAsyncClinents(mcpClients);
270+
return ToolCallbackProvider.from(McpToolUtils.getToolCallbacksFromAsyncClinents(mcpClients));
271271
}
272272

273273
public record ClosebleMcpAsyncClients(List<McpAsyncClient> clients) implements AutoCloseable {

auto-configurations/spring-ai-mcp-server/src/main/java/org/springframework/ai/autoconfigure/mcp/server/MpcServerAutoConfiguration.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
package org.springframework.ai.autoconfigure.mcp.server;
1818

19+
import java.util.ArrayList;
1920
import java.util.List;
2021
import java.util.function.Consumer;
2122
import java.util.function.Function;
23+
import java.util.stream.Stream;
2224

2325
import io.modelcontextprotocol.server.McpAsyncServer;
2426
import io.modelcontextprotocol.server.McpServer;
@@ -39,7 +41,9 @@
3941
import reactor.core.publisher.Mono;
4042

4143
import org.springframework.ai.mcp.McpToolUtils;
44+
import org.springframework.ai.model.function.FunctionCallback;
4245
import org.springframework.ai.tool.ToolCallback;
46+
import org.springframework.ai.tool.ToolCallbackProvider;
4347
import org.springframework.beans.factory.ObjectProvider;
4448
import org.springframework.boot.autoconfigure.AutoConfiguration;
4549
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
@@ -135,15 +139,23 @@ public McpSyncServer mcpSyncServer(ServerMcpTransport transport,
135139
McpSchema.ServerCapabilities.Builder capabilitiesBuilder, McpServerProperties serverProperties,
136140
ObjectProvider<List<SyncToolRegistration>> tools, ObjectProvider<List<SyncResourceRegistration>> resources,
137141
ObjectProvider<List<SyncPromptRegistration>> prompts,
138-
ObjectProvider<Consumer<List<McpSchema.Root>>> rootsChangeConsumers) {
142+
ObjectProvider<Consumer<List<McpSchema.Root>>> rootsChangeConsumers,
143+
List<ToolCallbackProvider> toolCallbackProvider) {
139144

140145
McpSchema.Implementation serverInfo = new Implementation(serverProperties.getName(),
141146
serverProperties.getVersion());
142147

143148
// Create the server with both tool and resource capabilities
144149
SyncSpec serverBuilder = McpServer.sync(transport).serverInfo(serverInfo);
145150

146-
List<SyncToolRegistration> toolResgistrations = tools.stream().flatMap(List::stream).toList();
151+
List<SyncToolRegistration> toolResgistrations = new ArrayList<>(tools.stream().flatMap(List::stream).toList());
152+
List<ToolCallback> providerToolCallbacks = toolCallbackProvider.stream()
153+
.map(pr -> List.of(pr.getToolCallbacks()))
154+
.flatMap(List::stream)
155+
.filter(fc -> fc instanceof ToolCallback)
156+
.map(fc -> (ToolCallback) fc)
157+
.toList();
158+
toolResgistrations.addAll(McpToolUtils.toSyncToolRegistration(providerToolCallbacks));
147159
if (!CollectionUtils.isEmpty(toolResgistrations)) {
148160
serverBuilder.tools(toolResgistrations);
149161
capabilitiesBuilder.tools(serverProperties.isToolChangeNotification());
@@ -191,15 +203,23 @@ public McpAsyncServer mcpAsyncServer(ServerMcpTransport transport,
191203
ObjectProvider<List<AsyncToolRegistration>> tools,
192204
ObjectProvider<List<AsyncResourceRegistration>> resources,
193205
ObjectProvider<List<AsyncPromptRegistration>> prompts,
194-
ObjectProvider<Consumer<List<McpSchema.Root>>> rootsChangeConsumer) {
206+
ObjectProvider<Consumer<List<McpSchema.Root>>> rootsChangeConsumer,
207+
List<ToolCallbackProvider> toolCallbackProvider) {
195208

196209
McpSchema.Implementation serverInfo = new Implementation(serverProperties.getName(),
197210
serverProperties.getVersion());
198211

199212
// Create the server with both tool and resource capabilities
200213
AsyncSpec serverBilder = McpServer.async(transport).serverInfo(serverInfo);
201214

202-
List<AsyncToolRegistration> toolResgistrations = tools.stream().flatMap(List::stream).toList();
215+
List<AsyncToolRegistration> toolResgistrations = new ArrayList<>(tools.stream().flatMap(List::stream).toList());
216+
List<ToolCallback> providerToolCallbacks = toolCallbackProvider.stream()
217+
.map(pr -> List.of(pr.getToolCallbacks()))
218+
.flatMap(List::stream)
219+
.filter(fc -> fc instanceof ToolCallback)
220+
.map(fc -> (ToolCallback) fc)
221+
.toList();
222+
toolResgistrations.addAll(McpToolUtils.toAsyncToolRegistration(providerToolCallbacks));
203223
if (!CollectionUtils.isEmpty(toolResgistrations)) {
204224
serverBilder.tools(toolResgistrations);
205225
capabilitiesBuilder.tools(serverProperties.isToolChangeNotification());

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
3939
import org.springframework.ai.model.function.FunctionCallback;
4040
import org.springframework.ai.model.tool.ToolCallingManager;
41-
import org.springframework.ai.tool.ToolCallback;
41+
import org.springframework.ai.tool.ToolCallbackProvider;
4242
import org.springframework.ai.tool.ToolCallbacks;
4343
import org.springframework.ai.tool.annotation.Tool;
4444
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
@@ -183,9 +183,8 @@ public List<Status> statusespaymentStatuses(List<Transaction> transactions) {
183183
public static class TestConfiguration {
184184

185185
@Bean
186-
public List<ToolCallback> paymentServiceTools() {
187-
var tools = List.of(ToolCallbacks.from(new PaymentService()));
188-
return tools;
186+
public ToolCallbackProvider paymentServiceTools() {
187+
return ToolCallbackProvider.from(List.of(ToolCallbacks.from(new PaymentService())));
189188
}
190189

191190
@Bean
@@ -221,11 +220,11 @@ public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ToolCallingM
221220

222221
@Bean
223222
ToolCallingManager toolCallingManager(GenericApplicationContext applicationContext,
224-
List<ToolCallback> toolCallbacks, List<FunctionCallback> functionCallbacks,
223+
List<ToolCallbackProvider> tcps, List<FunctionCallback> functionCallbacks,
225224
ObjectProvider<ObservationRegistry> observationRegistry) {
226225

227226
List<FunctionCallback> allFunctionCallbacks = new ArrayList(functionCallbacks);
228-
allFunctionCallbacks.addAll(toolCallbacks.stream().map(tc -> (FunctionCallback) tc).toList());
227+
tcps.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionCallbacks::addAll);
229228

230229
var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionCallbacks);
231230

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright 2025 - 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.tool;
17+
18+
import java.util.List;
19+
20+
import org.springframework.ai.model.function.FunctionCallback;
21+
import org.springframework.util.Assert;
22+
23+
/**
24+
* @author Christian Tzolov
25+
* @since 1.0.0
26+
*/
27+
public class StaticToolCallbackProvider implements ToolCallbackProvider {
28+
29+
private final FunctionCallback[] toolCallbacks;
30+
31+
public StaticToolCallbackProvider(FunctionCallback... toolCallbacks) {
32+
Assert.notNull(toolCallbacks, "ToolCallbacks must not be null");
33+
this.toolCallbacks = toolCallbacks;
34+
}
35+
36+
public StaticToolCallbackProvider(List<? extends FunctionCallback> toolCallbacks) {
37+
Assert.notNull(toolCallbacks, "ToolCallbacks must not be null");
38+
this.toolCallbacks = toolCallbacks.toArray(new FunctionCallback[0]);
39+
}
40+
41+
@Override
42+
public FunctionCallback[] getToolCallbacks() {
43+
return this.toolCallbacks;
44+
}
45+
46+
}

spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
package org.springframework.ai.tool;
1818

19+
import java.util.List;
20+
21+
import org.springframework.ai.model.function.FunctionCallback;
22+
1923
/**
2024
* Provides {@link ToolCallback} instances for tools defined in different sources.
2125
*
@@ -24,6 +28,14 @@
2428
*/
2529
public interface ToolCallbackProvider {
2630

27-
ToolCallback[] getToolCallbacks();
31+
FunctionCallback[] getToolCallbacks();
32+
33+
public static ToolCallbackProvider from(List<? extends FunctionCallback> toolCallbacks) {
34+
return new StaticToolCallbackProvider(toolCallbacks);
35+
}
36+
37+
public static ToolCallbackProvider from(FunctionCallback... toolCallbacks) {
38+
return new StaticToolCallbackProvider(toolCallbacks);
39+
}
2840

2941
}

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616

1717
package org.springframework.ai.autoconfigure.chat.model;
1818

19+
import java.util.ArrayList;
20+
import java.util.List;
21+
1922
import io.micrometer.observation.ObservationRegistry;
23+
2024
import org.springframework.ai.chat.model.ChatModel;
2125
import org.springframework.ai.model.function.FunctionCallback;
2226
import org.springframework.ai.model.tool.ToolCallingManager;
23-
import org.springframework.ai.tool.ToolCallback;
27+
import org.springframework.ai.tool.ToolCallbackProvider;
2428
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
2529
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
2630
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
@@ -34,9 +38,6 @@
3438
import org.springframework.context.annotation.Bean;
3539
import org.springframework.context.support.GenericApplicationContext;
3640

37-
import java.util.ArrayList;
38-
import java.util.List;
39-
4041
/**
4142
* Auto-configuration for common tool calling features of {@link ChatModel}.
4243
*
@@ -51,12 +52,10 @@ public class ToolCallingAutoConfiguration {
5152
@Bean
5253
@ConditionalOnMissingBean
5354
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
54-
ObjectProvider<List<FunctionCallback>> functionCallbacksProvider,
55-
ObjectProvider<List<ToolCallback>> toolCallbacksProvider) {
55+
List<FunctionCallback> functionCallbacks, List<ToolCallbackProvider> tcbProviders) {
5656

57-
List<FunctionCallback> allFunctionAndToolCallbacks = new ArrayList<>(
58-
functionCallbacksProvider.stream().flatMap(List::stream).toList());
59-
allFunctionAndToolCallbacks.addAll(toolCallbacksProvider.stream().flatMap(List::stream).toList());
57+
List<FunctionCallback> allFunctionAndToolCallbacks = new ArrayList<>(functionCallbacks);
58+
tcbProviders.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionAndToolCallbacks::addAll);
6059

6160
var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks);
6261

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,23 @@
1616

1717
package org.springframework.ai.autoconfigure.chat.model;
1818

19-
import java.util.List;
2019
import java.util.function.Function;
2120

2221
import org.junit.jupiter.api.Test;
2322

2423
import org.springframework.ai.model.function.FunctionCallback;
2524
import org.springframework.ai.model.tool.DefaultToolCallingManager;
2625
import org.springframework.ai.model.tool.ToolCallingManager;
26+
import org.springframework.ai.tool.StaticToolCallbackProvider;
2727
import org.springframework.ai.tool.ToolCallback;
28-
import org.springframework.ai.tool.ToolCallbacks;
28+
import org.springframework.ai.tool.ToolCallbackProvider;
2929
import org.springframework.ai.tool.annotation.Tool;
3030
import org.springframework.ai.tool.definition.ToolDefinition;
3131
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
3232
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
3333
import org.springframework.ai.tool.function.FunctionToolCallback;
3434
import org.springframework.ai.tool.method.MethodToolCallback;
35+
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
3536
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
3637
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
3738
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -104,8 +105,13 @@ public String getForecast(String location) {
104105
return "30";
105106
}
106107

108+
@Tool(description = "Get the weather in location. Return temperature in 36°F or 36°C format.")
109+
public String getForecast2(String location) {
110+
return "30";
111+
}
112+
107113
public String getAlert(String usState) {
108-
return "Alergt";
114+
return "Alert";
109115
}
110116

111117
}
@@ -118,8 +124,8 @@ static class Config {
118124
// Therefore we need to provide the ToolCallback instances explicitly using the
119125
// ToolCallbacks.from(...) utility method.
120126
@Bean
121-
public List<ToolCallback> toolCallbacks() {
122-
return List.of(ToolCallbacks.from(new WeatherService()));
127+
public ToolCallbackProvider toolCallbacks() {
128+
return MethodToolCallbackProvider.builder().toolObjects(new WeatherService()).build();
123129
}
124130

125131
public record Request(String location) {
@@ -135,41 +141,51 @@ public Function<Request, Response> weatherFunction1() {
135141
}
136142

137143
@Bean
138-
public List<FunctionCallback> functionCallbacks3() {
139-
return List.of(FunctionCallback.builder()
144+
public FunctionCallback functionCallbacks3() {
145+
return FunctionCallback.builder()
140146
.function("getCurrentWeather3", (Request request) -> "15.0°C")
141147
.description("Gets the weather in location")
142148
.inputType(Request.class)
143-
.build());
149+
.build();
144150
}
145151

146152
@Bean
147-
public List<FunctionCallback> functionCallbacks4() {
148-
return List.of(FunctionCallback.builder()
153+
public FunctionCallback functionCallbacks4() {
154+
return FunctionCallback.builder()
149155
.function("getCurrentWeather4", (Request request) -> "15.0°C")
150156
.description("Gets the weather in location")
151157
.inputType(Request.class)
152-
.build());
158+
.build();
153159

154160
}
155161

156162
@Bean
157-
public List<ToolCallback> toolCallbacks5() {
158-
return List.of(FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C")
163+
public ToolCallback toolCallbacks5() {
164+
return FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C")
159165
.description("Gets the weather in location")
160166
.inputType(Request.class)
161-
.build());
167+
.build();
168+
169+
}
170+
171+
@Bean
172+
public ToolCallbackProvider blabla() {
173+
return new StaticToolCallbackProvider(
174+
FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C")
175+
.description("Gets the weather in location")
176+
.inputType(Request.class)
177+
.build());
162178

163179
}
164180

165181
@Bean
166-
public List<ToolCallback> toolCallbacks6() {
182+
public ToolCallback toolCallbacks6() {
167183
var toolMethod = ReflectionUtils.findMethod(WeatherService.class, "getAlert", String.class);
168-
return List.of(MethodToolCallback.builder()
184+
return MethodToolCallback.builder()
169185
.toolDefinition(ToolDefinition.builder(toolMethod).build())
170186
.toolMethod(toolMethod)
171187
.toolObject(new WeatherService())
172-
.build());
188+
.build();
173189
}
174190

175191
}

0 commit comments

Comments
 (0)