|
20 | 20 |
|
21 | 21 | import org.junit.jupiter.api.Test; |
22 | 22 |
|
| 23 | +import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; |
| 24 | +import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; |
23 | 25 | import org.springframework.ai.model.tool.DefaultToolCallingManager; |
24 | 26 | import org.springframework.ai.model.tool.ToolCallingManager; |
25 | 27 | import org.springframework.ai.tool.StaticToolCallbackProvider; |
|
45 | 47 | import org.springframework.util.ReflectionUtils; |
46 | 48 |
|
47 | 49 | import static org.assertj.core.api.Assertions.assertThat; |
| 50 | +import static org.mockito.Mockito.mock; |
| 51 | +import static org.mockito.Mockito.never; |
| 52 | +import static org.mockito.Mockito.verify; |
| 53 | +import static org.mockito.Mockito.when; |
48 | 54 |
|
49 | 55 | /** |
50 | 56 | * Unit tests for {@link ToolCallingAutoConfiguration}. |
51 | 57 | * |
52 | 58 | * @author Thomas Vitale |
53 | 59 | * @author Christian Tzolov |
| 60 | + * @author Yanming Zhou |
54 | 61 | */ |
55 | 62 | class ToolCallingAutoConfigurationTests { |
56 | 63 |
|
@@ -185,6 +192,27 @@ void throwExceptionOnErrorEnabled() { |
185 | 192 | }); |
186 | 193 | } |
187 | 194 |
|
| 195 | + @Test |
| 196 | + void toolCallbackResolverDoesNotUseMcpToolCallbackProviders() { |
| 197 | + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) |
| 198 | + .withUserConfiguration(Config.class) |
| 199 | + .run(context -> { |
| 200 | + var syncMcpToolCallbackProvider = context.getBean("syncMcpToolCallbackProvider", |
| 201 | + ToolCallbackProvider.class); |
| 202 | + var asyncMcpToolCallbackProvider = context.getBean("asyncMcpToolCallbackProvider", |
| 203 | + ToolCallbackProvider.class); |
| 204 | + |
| 205 | + verify(syncMcpToolCallbackProvider, never()).getToolCallbacks(); |
| 206 | + verify(asyncMcpToolCallbackProvider, never()).getToolCallbacks(); |
| 207 | + |
| 208 | + var toolCallbackResolver = context.getBean(ToolCallbackResolver.class); |
| 209 | + assertThat(toolCallbackResolver.resolve("getForecast")).isNotNull(); |
| 210 | + |
| 211 | + verify(syncMcpToolCallbackProvider, never()).getToolCallbacks(); |
| 212 | + verify(asyncMcpToolCallbackProvider, never()).getToolCallbacks(); |
| 213 | + }); |
| 214 | + } |
| 215 | + |
188 | 216 | static class WeatherService { |
189 | 217 |
|
190 | 218 | @Tool(description = "Get the weather in location. Return temperature in 36°F or 36°C format.") |
@@ -267,6 +295,20 @@ public ToolCallback toolCallbacks6() { |
267 | 295 | .build(); |
268 | 296 | } |
269 | 297 |
|
| 298 | + @Bean |
| 299 | + public SyncMcpToolCallbackProvider syncMcpToolCallbackProvider() { |
| 300 | + SyncMcpToolCallbackProvider provider = mock(SyncMcpToolCallbackProvider.class); |
| 301 | + when(provider.getToolCallbacks()).thenReturn(new ToolCallback[0]); |
| 302 | + return provider; |
| 303 | + } |
| 304 | + |
| 305 | + @Bean |
| 306 | + public AsyncMcpToolCallbackProvider asyncMcpToolCallbackProvider() { |
| 307 | + AsyncMcpToolCallbackProvider provider = mock(AsyncMcpToolCallbackProvider.class); |
| 308 | + when(provider.getToolCallbacks()).thenReturn(new ToolCallback[0]); |
| 309 | + return provider; |
| 310 | + } |
| 311 | + |
270 | 312 | public record Request(String location) { |
271 | 313 | } |
272 | 314 |
|
|
0 commit comments