|
46 | 46 | import org.springframework.ai.converter.BeanOutputConverter; |
47 | 47 | import org.springframework.ai.converter.ListOutputConverter; |
48 | 48 | import org.springframework.ai.converter.MapOutputConverter; |
| 49 | +import org.springframework.ai.model.tool.ToolCallingChatOptions; |
49 | 50 | import org.springframework.ai.tool.function.FunctionToolCallback; |
50 | 51 | import org.springframework.beans.factory.annotation.Autowired; |
51 | 52 | import org.springframework.beans.factory.annotation.Value; |
@@ -279,6 +280,29 @@ void functionCallTest() { |
279 | 280 | assertThat(generation.getOutput().getText()).contains("30", "10", "15"); |
280 | 281 | } |
281 | 282 |
|
| 283 | + @Test |
| 284 | + void functionCallTestWithToolCallingOptions() { |
| 285 | + |
| 286 | + UserMessage userMessage = new UserMessage( |
| 287 | + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); |
| 288 | + |
| 289 | + List<Message> messages = new ArrayList<>(List.of(userMessage)); |
| 290 | + |
| 291 | + var promptOptions = ToolCallingChatOptions.builder() |
| 292 | + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) |
| 293 | + .description("Get the weather in location. Return in 36°C format") |
| 294 | + .inputType(MockWeatherService.Request.class) |
| 295 | + .build())) |
| 296 | + .build(); |
| 297 | + |
| 298 | + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); |
| 299 | + |
| 300 | + logger.info("Response: {}", response); |
| 301 | + |
| 302 | + Generation generation = response.getResult(); |
| 303 | + assertThat(generation.getOutput().getText()).contains("30", "10", "15"); |
| 304 | + } |
| 305 | + |
282 | 306 | @Test |
283 | 307 | void streamFunctionCallTest() { |
284 | 308 |
|
|
0 commit comments