Skip to content

Commit 97a4476

Browse files
committed
more test conversion
1 parent 8f5d7df commit 97a4476

File tree

14 files changed

+127
-58
lines changed

14 files changed

+127
-58
lines changed

FUNCTIONS-TO-TOOLS-API-MIGRATION-GUIDE.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,22 @@ MethodToolCallback.builder()
8888
.build()
8989
```
9090

91+
And you can use the same `ChatClient#tools()` API to register method-based tool callbackes:
92+
93+
```java
94+
String response = ChatClient.create(chatModel)
95+
.prompt()
96+
.user("What's the weather like in San Francisco?")
97+
.tools(MethodToolCallback.builder()
98+
.toolDefinition(ToolDefinition.builder(toolMethod)
99+
.description("Get the weather in location")
100+
.build())
101+
.toolMethod(toolMethod)
102+
.build())
103+
.call()
104+
.content();
105+
```
106+
91107
### 4. Options Configuration
92108

93109
Before:
@@ -181,6 +197,31 @@ The following methods are deprecated and will be removed in a future release:
181197

182198
Use their `tools` counterparts instead.
183199

200+
## @Tool tool definition path.
201+
202+
Now you can use the method-level annothation (`@Tool`) to register tools with Spring AI
203+
204+
```java
205+
public class Home {
206+
207+
@Tool(description = "Turn light On or Off in a room.")
208+
public void turnLight(String roomName, boolean on) {
209+
// ...
210+
logger.info("Turn light in room: {} to: {}", roomName, on);
211+
}
212+
}
213+
214+
Home homeAutomation = new HomeAutomation();
215+
216+
String response = ChatClient.create(this.chatModel).prompt()
217+
.user("Turn the light in the living room On.")
218+
.tools(homeAutomation)
219+
.call()
220+
.content();
221+
222+
```
223+
224+
184225
## Additional Notes
185226

186227
1. The new API provides better separation between tool definition and implementation.

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.chat.messages.Message;
3232
import org.springframework.ai.chat.model.ChatModel;
3333
import org.springframework.ai.chat.model.ToolContext;
34+
import org.springframework.ai.tool.annotation.Tool;
3435
import org.springframework.ai.tool.definition.ToolDefinition;
3536
import org.springframework.ai.tool.method.MethodToolCallback;
3637
import org.springframework.beans.factory.annotation.Autowired;
@@ -227,7 +228,7 @@ void methodNoParameters() {
227228

228229
String response = ChatClient.create(this.chatModel).prompt()
229230
.user("Turn light on in the living room.")
230-
.functions(MethodToolCallback.builder()
231+
.tools(MethodToolCallback.builder()
231232
.toolMethod(toolMethod)
232233
.toolDefinition(ToolDefinition.builder(toolMethod)
233234
.description("Can turn lights on in the Living Room")
@@ -243,6 +244,25 @@ void methodNoParameters() {
243244
assertThat(arguments).containsEntry("turnLivingRoomLightOn", true);
244245
}
245246

247+
@Test
248+
void toolAnnotation() {
249+
250+
TestFunctionClass targetObject = new TestFunctionClass();
251+
252+
// @formatter:off
253+
String response = ChatClient.create(this.chatModel).prompt()
254+
.user("Turn light red in the living room.")
255+
.tools(targetObject)
256+
.call()
257+
.content();
258+
// @formatter:on
259+
260+
logger.info("Response: {}", response);
261+
262+
assertThat(arguments).containsEntry("roomName", "living room")
263+
.containsEntry("color", TestFunctionClass.LightColor.RED);
264+
}
265+
246266
@Autowired
247267
ChatModel chatModel;
248268

@@ -306,6 +326,19 @@ public void turnLivingRoomLightOn() {
306326
arguments.put("turnLivingRoomLightOn", true);
307327
}
308328

329+
enum LightColor {
330+
331+
RED, GREEN, BLUE
332+
333+
}
334+
335+
@Tool(description = "Change the lamp color in a room.")
336+
public void changeRoomLightColor(String roomName, LightColor color) {
337+
arguments.put("roomName", roomName);
338+
arguments.put("color", color);
339+
logger.info("Change light colur in room: {} to color: {}", roomName, color);
340+
}
341+
309342
}
310343

311344
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import org.springframework.ai.chat.model.ChatResponse;
4141
import org.springframework.ai.chat.model.Generation;
4242
import org.springframework.ai.chat.prompt.Prompt;
43-
import org.springframework.ai.model.function.FunctionCallback;
43+
import org.springframework.ai.tool.function.FunctionToolCallback;
4444
import org.springframework.beans.factory.annotation.Autowired;
4545
import org.springframework.boot.SpringBootConfiguration;
4646
import org.springframework.boot.test.context.SpringBootTest;
@@ -70,8 +70,7 @@ void functionCallTest() {
7070

7171
var promptOptions = AzureOpenAiChatOptions.builder()
7272
.deploymentName(this.selectedModel)
73-
.functionCallbacks(List.of(FunctionCallback.builder()
74-
.function("getCurrentWeather", new MockWeatherService())
73+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
7574
.description("Get the current weather in a given location")
7675
.inputType(MockWeatherService.Request.class)
7776
.build()))
@@ -99,8 +98,7 @@ void functionCallSequentialTest() {
9998

10099
var promptOptions = AzureOpenAiChatOptions.builder()
101100
.deploymentName(this.selectedModel)
102-
.functionCallbacks(List.of(FunctionCallback.builder()
103-
.function("getCurrentWeather", new MockWeatherService())
101+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
104102
.description("Get the current weather in a given location")
105103
.inputType(MockWeatherService.Request.class)
106104
.build()))
@@ -121,8 +119,7 @@ void streamFunctionCallTest() {
121119

122120
var promptOptions = AzureOpenAiChatOptions.builder()
123121
.deploymentName(this.selectedModel)
124-
.functionCallbacks(List.of(FunctionCallback.builder()
125-
.function("getCurrentWeather", new MockWeatherService())
122+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
126123
.description("Get the current weather in a given location")
127124
.inputType(MockWeatherService.Request.class)
128125
.build()))
@@ -159,8 +156,7 @@ void streamFunctionCallUsageTest() {
159156

160157
var promptOptions = AzureOpenAiChatOptions.builder()
161158
.deploymentName(this.selectedModel)
162-
.functionCallbacks(List.of(FunctionCallback.builder()
163-
.function("getCurrentWeather", new MockWeatherService())
159+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
164160
.description("Get the current weather in a given location")
165161
.inputType(MockWeatherService.Request.class)
166162
.build()))
@@ -186,8 +182,7 @@ void functionCallSequentialAndStreamTest() {
186182

187183
var promptOptions = AzureOpenAiChatOptions.builder()
188184
.deploymentName(this.selectedModel)
189-
.functionCallbacks(List.of(FunctionCallback.builder()
190-
.function("getCurrentWeather", new MockWeatherService())
185+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
191186
.description("Get the current weather in a given location")
192187
.inputType(MockWeatherService.Request.class)
193188
.build()))

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import org.springframework.ai.converter.ListOutputConverter;
3535
import org.springframework.ai.mistralai.api.MistralAiApi;
3636
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice;
37-
import org.springframework.ai.model.function.FunctionCallback;
37+
import org.springframework.ai.tool.function.FunctionToolCallback;
3838
import org.springframework.beans.factory.annotation.Autowired;
3939
import org.springframework.beans.factory.annotation.Value;
4040
import org.springframework.boot.test.context.SpringBootTest;
@@ -225,8 +225,7 @@ void functionCallTest() {
225225
String response = ChatClient.create(this.chatModel).prompt()
226226
.options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).toolChoice(ToolChoice.AUTO).build())
227227
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
228-
.functions(FunctionCallback.builder()
229-
.function("getCurrentWeather", new MockWeatherService())
228+
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
230229
.description("Get the weather in location")
231230
.inputType(MockWeatherService.Request.class)
232231
.build())
@@ -247,8 +246,7 @@ void defaultFunctionCallTest() {
247246
// @formatter:off
248247
String response = ChatClient.builder(this.chatModel)
249248
.defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).build())
250-
.defaultFunctions(FunctionCallback.builder()
251-
.function("getCurrentWeather", new MockWeatherService())
249+
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
252250
.description("Get the weather in location")
253251
.inputType(MockWeatherService.Request.class)
254252
.build())
@@ -271,8 +269,7 @@ void streamFunctionCallTest() {
271269
Flux<String> response = ChatClient.create(this.chatModel).prompt()
272270
.options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).build())
273271
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")
274-
.functions(FunctionCallback.builder()
275-
.function("getCurrentWeather", new MockWeatherService())
272+
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
276273
.description("Get the weather in location")
277274
.inputType(MockWeatherService.Request.class)
278275
.build())

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.springframework.ai.mistralai.api.MistralAiApi;
5050
import org.springframework.ai.model.Media;
5151
import org.springframework.ai.model.function.FunctionCallback;
52+
import org.springframework.ai.tool.function.FunctionToolCallback;
5253
import org.springframework.beans.factory.annotation.Autowired;
5354
import org.springframework.beans.factory.annotation.Value;
5455
import org.springframework.boot.test.context.SpringBootTest;
@@ -98,7 +99,8 @@ void roleTest() {
9899
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
99100
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
100101
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
101-
// NOTE: Mistral expects the system message to be before the user message or will
102+
// NOTE: Mistral expects the system message to be before the user message or
103+
// will
102104
// fail with 400 error.
103105
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
104106
ChatResponse response = this.chatModel.call(prompt);
@@ -203,8 +205,7 @@ void functionCallTest() {
203205

204206
var promptOptions = MistralAiChatOptions.builder()
205207
.model(MistralAiApi.ChatModel.SMALL.getValue())
206-
.functionCallbacks(List.of(FunctionCallback.builder()
207-
.function("getCurrentWeather", new MockWeatherService())
208+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
208209
.description("Get the weather in location")
209210
.inputType(MockWeatherService.Request.class)
210211
.build()))
@@ -229,8 +230,7 @@ void streamFunctionCallTest() {
229230

230231
var promptOptions = MistralAiChatOptions.builder()
231232
.model(MistralAiApi.ChatModel.SMALL.getValue())
232-
.functionCallbacks(List.of(FunctionCallback.builder()
233-
.function("getCurrentWeather", new MockWeatherService())
233+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
234234
.description("Get the weather in location")
235235
.inputType(MockWeatherService.Request.class)
236236
.build()))
@@ -317,8 +317,7 @@ void streamFunctionCallUsageTest() {
317317

318318
var promptOptions = MistralAiChatOptions.builder()
319319
.model(MistralAiApi.ChatModel.SMALL.getValue())
320-
.functionCallbacks(List.of(FunctionCallback.builder()
321-
.function("getCurrentWeather", new MockWeatherService())
320+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
322321
.description("Get the weather in location")
323322
.inputType(MockWeatherService.Request.class)
324323
.build()))

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
import org.springframework.ai.chat.model.ChatResponse;
3333
import org.springframework.ai.chat.model.Generation;
3434
import org.springframework.ai.chat.prompt.Prompt;
35-
import org.springframework.ai.model.function.FunctionCallback;
3635
import org.springframework.ai.ollama.api.OllamaApi;
3736
import org.springframework.ai.ollama.api.OllamaOptions;
3837
import org.springframework.ai.ollama.api.tool.MockWeatherService;
38+
import org.springframework.ai.tool.function.FunctionToolCallback;
3939
import org.springframework.beans.factory.annotation.Autowired;
4040
import org.springframework.boot.SpringBootConfiguration;
4141
import org.springframework.boot.test.context.SpringBootTest;
@@ -62,8 +62,7 @@ void functionCallTest() {
6262

6363
var promptOptions = OllamaOptions.builder()
6464
.model(MODEL)
65-
.functionCallbacks(List.of(FunctionCallback.builder()
66-
.function("getCurrentWeather", new MockWeatherService())
65+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
6766
.description(
6867
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
6968
.inputType(MockWeatherService.Request.class)
@@ -86,8 +85,7 @@ void streamFunctionCallTest() {
8685

8786
var promptOptions = OllamaOptions.builder()
8887
.model(MODEL)
89-
.functionCallbacks(List.of(FunctionCallback.builder()
90-
.function("getCurrentWeather", new MockWeatherService())
88+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
9189
.description(
9290
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
9391
.inputType(MockWeatherService.Request.class)

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.chat.prompt.Prompt;
3838
import org.springframework.ai.model.function.FunctionCallback;
3939
import org.springframework.ai.model.function.FunctionCallback.SchemaType;
40+
import org.springframework.ai.tool.function.FunctionToolCallback;
4041
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
4142
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
4243
import org.springframework.beans.factory.annotation.Autowired;
@@ -83,10 +84,9 @@ public void functionCallExplicitOpenApiSchema() {
8384
""";
8485

8586
var promptOptions = VertexAiGeminiChatOptions.builder()
86-
.functionCallbacks(List.of(FunctionCallback.builder()
87-
.function("get_current_weather", new MockWeatherService())
87+
.functionCallbacks(List.of(FunctionToolCallback.builder("get_current_weather", new MockWeatherService())
8888
.description("Get the current weather in a given location")
89-
.inputTypeSchema(openApiSchema)
89+
.inputSchema(openApiSchema)
9090
.inputType(MockWeatherService.Request.class)
9191
.build()))
9292
.build();

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ interface ChatClientRequestSpec {
218218

219219
ChatClientRequestSpec tools(Object... toolObjects);
220220

221-
ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks);
221+
// ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks);
222222

223223
<I, O> ChatClientRequestSpec functions(FunctionCallback... functionCallbacks);
224224

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -861,22 +861,26 @@ public ChatClientRequestSpec tools(Object... toolObjects) {
861861
}
862862
}
863863
this.functionCallbacks.addAll(functionCallbacks);
864-
this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(nonFunctinCallbacks)));
864+
this.functionCallbacks.addAll(Arrays
865+
.asList(ToolCallbacks.from(nonFunctinCallbacks.toArray(new Object[nonFunctinCallbacks.size()]))));
865866
return this;
866867
}
867868

868-
@Override
869-
public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) {
870-
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
871-
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
872-
this.functionCallbacks.addAll(Arrays.asList(toolCallbacks));
873-
return this;
874-
}
869+
// @Override
870+
// public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) {
871+
// Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
872+
// Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null
873+
// elements");
874+
// this.functionCallbacks.addAll(Arrays.asList(toolCallbacks));
875+
// return this;
876+
// }
875877

878+
@Deprecated
876879
public ChatClientRequestSpec functions(String... functionBeanNames) {
877880
return tools(functionBeanNames);
878881
}
879882

883+
@Deprecated
880884
public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
881885
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
882886
Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements");

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ void addMessages(List<Message> messages) {
201201

202202
void addToolCallbacks(List<FunctionCallback> toolCallbacks) {
203203
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
204-
this.defaultRequest.toolCallbacks(toolCallbacks.toArray(FunctionCallback[]::new));
204+
this.defaultRequest.tools(toolCallbacks.toArray(FunctionCallback[]::new));
205205
}
206206

207207
void addToolContext(Map<String, Object> toolContext) {

0 commit comments

Comments
 (0)