Skip to content

Commit c945741

Browse files
committed
feat(autoconfigure): Support both FunctionCallback and ToolCallback in ToolCallingAutoConfiguration
- Extends the ToolCallingAutoConfiguration to support both FunctionCallback and ToolCallback types. - The toolCallbackResolver bean now handles both callback types through ObjectProvider injection. - Added comprehensive tests to verify the resolution of multiple function and tool callbacks. Signed-off-by: Christian Tzolov <[email protected]>
1 parent a528253 commit c945741

File tree

2 files changed

+136
-2
lines changed

2 files changed

+136
-2
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.springframework.ai.chat.model.ChatModel;
2121
import org.springframework.ai.model.function.FunctionCallback;
2222
import org.springframework.ai.model.tool.ToolCallingManager;
23+
import org.springframework.ai.tool.ToolCallback;
2324
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
2425
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
2526
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
@@ -33,12 +34,14 @@
3334
import org.springframework.context.annotation.Bean;
3435
import org.springframework.context.support.GenericApplicationContext;
3536

37+
import java.util.ArrayList;
3638
import java.util.List;
3739

3840
/**
3941
* Auto-configuration for common tool calling features of {@link ChatModel}.
4042
*
4143
* @author Thomas Vitale
44+
* @author Christian Tzolov
4245
* @since 1.0.0
4346
*/
4447
@AutoConfiguration
@@ -48,8 +51,15 @@ public class ToolCallingAutoConfiguration {
4851
@Bean
4952
@ConditionalOnMissingBean
5053
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
51-
List<FunctionCallback> toolCallbacks) {
52-
var staticToolCallbackResolver = new StaticToolCallbackResolver(toolCallbacks);
54+
ObjectProvider<List<FunctionCallback>> functionCallbacksProvider,
55+
ObjectProvider<List<ToolCallback>> toolCallbacksProvider) {
56+
57+
List<FunctionCallback> allFunctionAndToolCallbacks = new ArrayList<>(
58+
functionCallbacksProvider.stream().flatMap(List::stream).toList());
59+
allFunctionAndToolCallbacks.addAll(toolCallbacksProvider.stream().flatMap(List::stream).toList());
60+
61+
var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks);
62+
5363
var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder()
5464
.applicationContext(applicationContext)
5565
.build();

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,38 @@
1616

1717
package org.springframework.ai.autoconfigure.chat.model;
1818

19+
import java.util.List;
20+
import java.util.function.Function;
21+
1922
import org.junit.jupiter.api.Test;
23+
24+
import org.springframework.ai.model.function.FunctionCallback;
2025
import org.springframework.ai.model.tool.DefaultToolCallingManager;
2126
import org.springframework.ai.model.tool.ToolCallingManager;
27+
import org.springframework.ai.tool.ToolCallback;
28+
import org.springframework.ai.tool.ToolCallbacks;
29+
import org.springframework.ai.tool.annotation.Tool;
30+
import org.springframework.ai.tool.definition.ToolDefinition;
2231
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
2332
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
33+
import org.springframework.ai.tool.function.FunctionToolCallback;
34+
import org.springframework.ai.tool.method.MethodToolCallback;
2435
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
2536
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
2637
import org.springframework.boot.autoconfigure.AutoConfigurations;
2738
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
39+
import org.springframework.context.annotation.Bean;
40+
import org.springframework.context.annotation.Configuration;
41+
import org.springframework.context.annotation.Description;
42+
import org.springframework.util.ReflectionUtils;
2843

2944
import static org.assertj.core.api.Assertions.assertThat;
3045

3146
/**
3247
* Unit tests for {@link ToolCallingAutoConfiguration}.
3348
*
3449
* @author Thomas Vitale
50+
* @author Christian Tzolov
3551
*/
3652
class ToolCallingAutoConfigurationTests {
3753

@@ -50,4 +66,112 @@ void beansAreCreated() {
5066
});
5167
}
5268

69+
@Test
70+
void resolveMultipleFuncitonAndToolCallbacks() {
71+
new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class))
72+
.withUserConfiguration(Config.class)
73+
.run(context -> {
74+
var toolCallbackResolver = context.getBean(ToolCallbackResolver.class);
75+
assertThat(toolCallbackResolver).isInstanceOf(DelegatingToolCallbackResolver.class);
76+
77+
assertThat(toolCallbackResolver.resolve("getForecast")).isNotNull();
78+
assertThat(toolCallbackResolver.resolve("getForecast").getName()).isEqualTo("getForecast");
79+
80+
assertThat(toolCallbackResolver.resolve("getAlert")).isNotNull();
81+
assertThat(toolCallbackResolver.resolve("getAlert").getName()).isEqualTo("getAlert");
82+
83+
assertThat(toolCallbackResolver.resolve("weatherFunction1")).isNotNull();
84+
assertThat(toolCallbackResolver.resolve("weatherFunction1").getName()).isEqualTo("weatherFunction1");
85+
86+
assertThat(toolCallbackResolver.resolve("getCurrentWeather3")).isNotNull();
87+
assertThat(toolCallbackResolver.resolve("getCurrentWeather3").getName())
88+
.isEqualTo("getCurrentWeather3");
89+
90+
assertThat(toolCallbackResolver.resolve("getCurrentWeather4")).isNotNull();
91+
assertThat(toolCallbackResolver.resolve("getCurrentWeather4").getName())
92+
.isEqualTo("getCurrentWeather4");
93+
94+
assertThat(toolCallbackResolver.resolve("getCurrentWeather5")).isNotNull();
95+
assertThat(toolCallbackResolver.resolve("getCurrentWeather5").getName())
96+
.isEqualTo("getCurrentWeather5");
97+
});
98+
}
99+
100+
static class WeatherService {
101+
102+
@Tool(description = "Get the weather in location. Return temperature in 36°F or 36°C format.")
103+
public String getForecast(String location) {
104+
return "30";
105+
}
106+
107+
public String getAlert(String usState) {
108+
return "Alergt";
109+
}
110+
111+
}
112+
113+
@Configuration
114+
static class Config {
115+
116+
// Note: Currently we do not have ToolCallbackResolver implementation that can
117+
// resolve the ToolCallback from the Tool annotation.
118+
// Therefore we need to provide the ToolCallback instances explicitly using the
119+
// ToolCallbacks.from(...) utility method.
120+
@Bean
121+
public List<ToolCallback> toolCallbacks() {
122+
return List.of(ToolCallbacks.from(new WeatherService()));
123+
}
124+
125+
public record Request(String location) {
126+
}
127+
128+
public record Response(String temperature) {
129+
}
130+
131+
@Bean
132+
@Description("Get the weather in location. Return temperature in 36°F or 36°C format.")
133+
public Function<Request, Response> weatherFunction1() {
134+
return request -> new Response("30");
135+
}
136+
137+
@Bean
138+
public List<FunctionCallback> functionCallbacks3() {
139+
return List.of(FunctionCallback.builder()
140+
.function("getCurrentWeather3", (Request request) -> "15.0°C")
141+
.description("Gets the weather in location")
142+
.inputType(Request.class)
143+
.build());
144+
}
145+
146+
@Bean
147+
public List<FunctionCallback> functionCallbacks4() {
148+
return List.of(FunctionCallback.builder()
149+
.function("getCurrentWeather4", (Request request) -> "15.0°C")
150+
.description("Gets the weather in location")
151+
.inputType(Request.class)
152+
.build());
153+
154+
}
155+
156+
@Bean
157+
public List<ToolCallback> toolCallbacks5() {
158+
return List.of(FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C")
159+
.description("Gets the weather in location")
160+
.inputType(Request.class)
161+
.build());
162+
163+
}
164+
165+
@Bean
166+
public List<ToolCallback> toolCallbacks6() {
167+
var toolMethod = ReflectionUtils.findMethod(WeatherService.class, "getAlert", String.class);
168+
return List.of(MethodToolCallback.builder()
169+
.toolDefinition(ToolDefinition.builder(toolMethod).build())
170+
.toolMethod(toolMethod)
171+
.toolObject(new WeatherService())
172+
.build());
173+
}
174+
175+
}
176+
53177
}

0 commit comments

Comments
 (0)