Skip to content

Commit cc9a5ec

Browse files
committed
feat: Support both FunctionCallback and ToolCallback in ToolCallingAutoConfiguration
Enhances the ToolCallingAutoConfiguration to handle both FunctionCallback and ToolCallback types by: Modifying toolCallbackResolver bean to aggregate both callback types Adding comprehensive test coverage for different callback implementations Signed-off-by: Christian Tzolov <[email protected]>
1 parent c19287d commit cc9a5ec

File tree

2 files changed

+138
-2
lines changed

2 files changed

+138
-2
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.springframework.ai.chat.model.ChatModel;
2121
import org.springframework.ai.model.function.FunctionCallback;
2222
import org.springframework.ai.model.tool.ToolCallingManager;
23+
import org.springframework.ai.tool.ToolCallback;
2324
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
2425
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
2526
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
@@ -33,12 +34,14 @@
3334
import org.springframework.context.annotation.Bean;
3435
import org.springframework.context.support.GenericApplicationContext;
3536

37+
import java.util.ArrayList;
3638
import java.util.List;
3739

3840
/**
3941
* Auto-configuration for common tool calling features of {@link ChatModel}.
4042
*
4143
* @author Thomas Vitale
44+
* @author Christian Tzolov
4245
* @since 1.0.0
4346
*/
4447
@AutoConfiguration
@@ -48,8 +51,16 @@ public class ToolCallingAutoConfiguration {
4851
@Bean
4952
@ConditionalOnMissingBean
5053
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
51-
List<FunctionCallback> toolCallbacks) {
52-
var staticToolCallbackResolver = new StaticToolCallbackResolver(toolCallbacks);
54+
ObjectProvider<List<FunctionCallback>> functionCallbacksProvider,
55+
ObjectProvider<List<ToolCallback>> toolCallbacksProvider) {
56+
57+
List<FunctionCallback> functionCallbacks = functionCallbacksProvider.stream().flatMap(List::stream).toList();
58+
59+
List<FunctionCallback> allFunctionAndToolCallbacks = new ArrayList<>(functionCallbacks);
60+
allFunctionAndToolCallbacks.addAll(toolCallbacksProvider.stream().flatMap(List::stream).toList());
61+
62+
var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks);
63+
5364
var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder()
5465
.applicationContext(applicationContext)
5566
.build();

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,39 @@
1616

1717
package org.springframework.ai.autoconfigure.chat.model;
1818

19+
import java.util.List;
20+
import java.util.function.Function;
21+
1922
import org.junit.jupiter.api.Test;
23+
24+
import org.springframework.ai.model.function.FunctionCallback;
2025
import org.springframework.ai.model.tool.DefaultToolCallingManager;
26+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
2127
import org.springframework.ai.model.tool.ToolCallingManager;
28+
import org.springframework.ai.tool.ToolCallback;
29+
import org.springframework.ai.tool.ToolCallbacks;
30+
import org.springframework.ai.tool.annotation.Tool;
31+
import org.springframework.ai.tool.definition.ToolDefinition;
2232
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
2333
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
34+
import org.springframework.ai.tool.function.FunctionToolCallback;
35+
import org.springframework.ai.tool.method.MethodToolCallback;
2436
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
2537
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
2638
import org.springframework.boot.autoconfigure.AutoConfigurations;
2739
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
40+
import org.springframework.context.annotation.Bean;
41+
import org.springframework.context.annotation.Configuration;
42+
import org.springframework.context.annotation.Description;
43+
import org.springframework.util.ReflectionUtils;
2844

2945
import static org.assertj.core.api.Assertions.assertThat;
3046

3147
/**
3248
* Unit tests for {@link ToolCallingAutoConfiguration}.
3349
*
3450
* @author Thomas Vitale
51+
* @author Christian Tzolov
3552
*/
3653
class ToolCallingAutoConfigurationTests {
3754

@@ -50,4 +67,112 @@ void beansAreCreated() {
5067
});
5168
}
5269

70+
@Test
71+
void beansAreCreated2() {
72+
new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class))
73+
.withUserConfiguration(Config.class)
74+
.run(context -> {
75+
var toolCallbackResolver = context.getBean(ToolCallbackResolver.class);
76+
assertThat(toolCallbackResolver).isInstanceOf(DelegatingToolCallbackResolver.class);
77+
78+
var toolExecutionExceptionProcessor = context.getBean(ToolExecutionExceptionProcessor.class);
79+
assertThat(toolExecutionExceptionProcessor).isInstanceOf(DefaultToolExecutionExceptionProcessor.class);
80+
81+
var toolCallingManager = context.getBean(ToolCallingManager.class);
82+
assertThat(toolCallingManager).isInstanceOf(DefaultToolCallingManager.class);
83+
84+
assertThat(toolCallbackResolver.resolve("getForecast")).isNotNull();
85+
assertThat(toolCallbackResolver.resolve("getForecast").getName()).isEqualTo("getForecast");
86+
87+
assertThat(toolCallbackResolver.resolve("getAlert")).isNotNull();
88+
assertThat(toolCallbackResolver.resolve("getAlert").getName()).isEqualTo("getAlert");
89+
90+
assertThat(toolCallbackResolver.resolve("weatherFunction1")).isNotNull();
91+
assertThat(toolCallbackResolver.resolve("weatherFunction1").getName()).isEqualTo("weatherFunction1");
92+
93+
assertThat(toolCallbackResolver.resolve("toolCallbacks5")).isNotNull();
94+
assertThat(toolCallbackResolver.resolve("toolCallbacks5").getName()).isEqualTo("toolCallbacks5");
95+
96+
// assertThat(toolCallbackResolver.resolve("functionCallbacks3")).isNotNull();
97+
// assertThat(toolCallbackResolver.resolve("functionCallbacks3").getName()).isEqualTo("functionCallbacks3");
98+
});
99+
}
100+
101+
static class WeatherService {
102+
103+
@Tool(description = "Get the weather in location. Return temperature in 36°F or 36°C format.")
104+
public String getForecast(String location) {
105+
return "30";
106+
}
107+
108+
public String getAlert(String usState) {
109+
return "Alergt";
110+
}
111+
112+
}
113+
114+
@Configuration
115+
static class Config {
116+
117+
public record Request(String location) {
118+
}
119+
120+
public record Response(String temperature) {
121+
}
122+
123+
@Bean
124+
@Description("Get the weather in location. Return temperature in 36°F or 36°C format.")
125+
public Function<Request, Response> weatherFunction1() {
126+
return request -> new Response("30");
127+
}
128+
129+
@Bean
130+
public List<ToolCallback> toolCallbacks() {
131+
return List.of(ToolCallbacks.from(new WeatherService()));
132+
}
133+
134+
@Bean
135+
public List<FunctionCallback> functionCallbacks3() {
136+
return List.of(FunctionCallback.builder()
137+
.function("getCurrentWeather3", (Request request) -> "15.0°C")
138+
.description("Gets the weather in location")
139+
.inputType(Request.class)
140+
.build());
141+
142+
}
143+
144+
@Bean
145+
public List<FunctionCallback> functionCallbacks4() {
146+
return List.of(FunctionCallback.builder()
147+
.function("getCurrentWeather4", (Request request) -> "15.0°C")
148+
.description("Gets the weather in location")
149+
.inputType(Request.class)
150+
.build());
151+
152+
}
153+
154+
@Bean
155+
public List<ToolCallback> toolCallbacks5() {
156+
return List.of(FunctionToolCallback.builder("getCurrentWeather5", (Request request) -> "15.0°C")
157+
.description("Gets the weather in location")
158+
.inputType(Request.class)
159+
.build());
160+
161+
}
162+
163+
@Bean
164+
public List<ToolCallback> toolCallbacks6() {
165+
166+
var toolMethod = ReflectionUtils.findMethod(WeatherService.class, "getAlert", String.class);
167+
168+
return List.of(MethodToolCallback.builder()
169+
.toolDefinition(ToolDefinition.builder(toolMethod).build())
170+
.toolMethod(toolMethod)
171+
.toolObject(new WeatherService())
172+
.build());
173+
174+
}
175+
176+
}
177+
53178
}

0 commit comments

Comments
 (0)