Skip to content

Commit 560baaa

Browse files
ThomasVitaletzolov
authored andcommitted
Advancing Tool Support - Part 5
* Introduced new ToolParam annotation for defining a description for tool parameters and marking them as (non)required. * Improved the JSON Schema generation for tools, solving inconsistencies between methods and functions, and ensuring a predictable outcome. * Added support for returning tool results directly to the user instead of passing them back to the model. Introduced new ToolExecutionResult API to propagate this information. * Consolidated naming of tool-related options in ToolCallingChatOptions. * Fixed varargs issue in ChatClient when passing ToolCallback[]. * Introduced new documentation for the tool calling capabilities in Spring AI, and deprecated the old one. * Bumped jsonschema dependency to 4.37.0. Relates to gh-2049 Signed-off-by: Thomas Vitale <[email protected]>
1 parent 854e545 commit 560baaa

File tree

43 files changed

+1746
-382
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1746
-382
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.model.tool.LegacyToolCallingManager;
3232
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3333
import org.springframework.ai.model.tool.ToolCallingManager;
34+
import org.springframework.ai.model.tool.ToolExecutionResult;
3435
import org.springframework.ai.tool.definition.ToolDefinition;
3536
import org.springframework.ai.util.json.JsonParser;
3637
import reactor.core.publisher.Flux;
@@ -271,10 +272,19 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
271272

272273
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null
273274
&& response.hasToolCalls()) {
274-
var toolCallConversation = this.toolCallingManager.executeToolCalls(prompt, response);
275-
// Recursively call the call method with the tool call message
276-
// conversation that contains the call responses.
277-
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
275+
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
276+
if (toolExecutionResult.returnDirect()) {
277+
// Return tool execution result directly to the client.
278+
return ChatResponse.builder()
279+
.from(response)
280+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
281+
.build();
282+
}
283+
else {
284+
// Send the tool execution result back to the model.
285+
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
286+
response);
287+
}
278288
}
279289

280290
return response;
@@ -335,10 +345,17 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
335345
// @formatter:off
336346
Flux<ChatResponse> chatResponseFlux = chatResponse.flatMap(response -> {
337347
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) {
338-
var toolCallConversation = this.toolCallingManager.executeToolCalls(prompt, response);
339-
// Recursively call the stream method with the tool call message
340-
// conversation that contains the call responses.
341-
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
348+
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
349+
if (toolExecutionResult.returnDirect()) {
350+
// Return tool execution result directly to the client.
351+
return Flux.just(ChatResponse.builder().from(response)
352+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
353+
.build());
354+
} else {
355+
// Send the tool execution result back to the model.
356+
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
357+
response);
358+
}
342359
}
343360
else {
344361
return Flux.just(response);
@@ -379,13 +396,13 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
379396
// Merge tool names and tool callbacks explicitly since they are ignored by
380397
// Jackson, used by ModelOptionsUtils.
381398
if (runtimeOptions != null) {
382-
requestOptions.setTools(
383-
ToolCallingChatOptions.mergeToolNames(runtimeOptions.getTools(), this.defaultOptions.getTools()));
399+
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
400+
this.defaultOptions.getToolNames()));
384401
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
385402
this.defaultOptions.getToolCallbacks()));
386403
}
387404
else {
388-
requestOptions.setTools(this.defaultOptions.getTools());
405+
requestOptions.setToolNames(this.defaultOptions.getToolNames());
389406
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
390407
}
391408

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
383383
.mirostatEta(fromOptions.getMirostatEta())
384384
.penalizeNewline(fromOptions.getPenalizeNewline())
385385
.stop(fromOptions.getStop())
386-
.tools(fromOptions.getTools())
386+
.toolNames(fromOptions.getToolNames())
387387
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
388388
.toolCallbacks(fromOptions.getToolCallbacks())
389389
.toolContext(fromOptions.getToolContext()).build();
@@ -700,13 +700,13 @@ public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
700700

701701
@Override
702702
@JsonIgnore
703-
public Set<String> getTools() {
703+
public Set<String> getToolNames() {
704704
return this.toolNames;
705705
}
706706

707707
@Override
708708
@JsonIgnore
709-
public void setTools(Set<String> toolNames) {
709+
public void setToolNames(Set<String> toolNames) {
710710
Assert.notNull(toolNames, "toolNames cannot be null");
711711
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
712712
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
@@ -744,14 +744,14 @@ public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
744744
@Deprecated
745745
@JsonIgnore
746746
public Set<String> getFunctions() {
747-
return this.getTools();
747+
return this.getToolNames();
748748
}
749749

750750
@Override
751751
@Deprecated
752752
@JsonIgnore
753753
public void setFunctions(Set<String> functions) {
754-
this.setTools(functions);
754+
this.setToolNames(functions);
755755
}
756756

757757
@Override
@@ -1028,12 +1028,12 @@ public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
10281028
return this;
10291029
}
10301030

1031-
public Builder tools(Set<String> toolNames) {
1032-
this.options.setTools(toolNames);
1031+
public Builder toolNames(Set<String> toolNames) {
1032+
this.options.setToolNames(toolNames);
10331033
return this;
10341034
}
10351035

1036-
public Builder tools(String... toolNames) {
1036+
public Builder toolNames(String... toolNames) {
10371037
Assert.notNull(toolNames, "toolNames cannot be null");
10381038
this.options.toolNames.addAll(Set.of(toolNames));
10391039
return this;
@@ -1051,12 +1051,12 @@ public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
10511051

10521052
@Deprecated
10531053
public Builder functions(Set<String> functions) {
1054-
return tools(functions);
1054+
return toolNames(functions);
10551055
}
10561056

10571057
@Deprecated
10581058
public Builder function(String functionName) {
1059-
return tools(functionName);
1059+
return toolNames(functionName);
10601060
}
10611061

10621062
@Deprecated

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@
3636
import org.springframework.ai.chat.model.Generation;
3737
import org.springframework.ai.chat.prompt.Prompt;
3838
import org.springframework.ai.tool.function.FunctionToolCallback;
39-
import org.springframework.ai.util.json.JsonSchemaGenerator;
39+
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
4040
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
4141
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
4242
import org.springframework.beans.factory.annotation.Autowired;

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@
189189
<com.google.cloud.version>26.48.0</com.google.cloud.version>
190190
<qdrant.version>1.9.1</qdrant.version>
191191
<ibm.sdk.version>9.20.0</ibm.sdk.version>
192-
<jsonschema.version>4.35.0</jsonschema.version>
192+
<jsonschema.version>4.37.0</jsonschema.version>
193193
<swagger-annotations.version>2.2.25</swagger-annotations.version>
194194
<spring-cloud-bindings.version>2.0.3</spring-cloud-bindings.version>
195195

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ interface ChatClientRequestSpec {
216216

217217
ChatClientRequestSpec tools(String... toolNames);
218218

219-
ChatClientRequestSpec tools(Object... toolObjects);
219+
ChatClientRequestSpec tools(FunctionCallback... toolCallbacks);
220220

221-
// ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks);
221+
ChatClientRequestSpec tools(Object... toolObjects);
222222

223223
@Deprecated
224224
<I, O> ChatClientRequestSpec functions(FunctionCallback... functionCallbacks);
@@ -281,6 +281,8 @@ interface Builder {
281281

282282
Builder defaultTools(String... toolNames);
283283

284+
Builder defaultTools(FunctionCallback... toolCallbacks);
285+
284286
Builder defaultTools(Object... toolObjects);
285287

286288
/**

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -845,42 +845,28 @@ public ChatClientRequestSpec tools(String... toolNames) {
845845
return this;
846846
}
847847

848+
@Override
849+
public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) {
850+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
851+
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
852+
this.functionCallbacks.addAll(List.of(toolCallbacks));
853+
return this;
854+
}
855+
848856
@Override
849857
public ChatClientRequestSpec tools(Object... toolObjects) {
850858
Assert.notNull(toolObjects, "toolObjects cannot be null");
851859
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
852-
853-
List<FunctionCallback> functionCallbacks = new ArrayList<>();
854-
List<Object> nonFunctinCallbacks = new ArrayList<>();
855-
for (Object toolObject : toolObjects) {
856-
if (toolObject instanceof FunctionCallback) {
857-
functionCallbacks.add((FunctionCallback) toolObject);
858-
}
859-
else {
860-
nonFunctinCallbacks.add(toolObject);
861-
}
862-
}
863-
this.functionCallbacks.addAll(functionCallbacks);
864-
this.functionCallbacks.addAll(Arrays
865-
.asList(ToolCallbacks.from(nonFunctinCallbacks.toArray(new Object[nonFunctinCallbacks.size()]))));
860+
this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
866861
return this;
867862
}
868863

869-
// @Override
870-
// public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) {
871-
// Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
872-
// Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null
873-
// elements");
874-
// this.functionCallbacks.addAll(Arrays.asList(toolCallbacks));
875-
// return this;
876-
// }
877-
878-
@Deprecated
864+
@Deprecated // Use tools()
879865
public ChatClientRequestSpec functions(String... functionBeanNames) {
880866
return tools(functionBeanNames);
881867
}
882868

883-
@Deprecated
869+
@Deprecated // Use tools()
884870
public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
885871
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
886872
Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements");

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.springframework.ai.chat.model.ToolContext;
3636
import org.springframework.ai.chat.prompt.ChatOptions;
3737
import org.springframework.ai.model.function.FunctionCallback;
38-
import org.springframework.ai.tool.ToolCallbacks;
3938
import org.springframework.core.io.Resource;
4039
import org.springframework.lang.Nullable;
4140
import org.springframework.util.Assert;
@@ -151,7 +150,13 @@ public Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer) {
151150

152151
@Override
153152
public Builder defaultTools(String... toolNames) {
154-
this.defaultRequest.functions(toolNames);
153+
this.defaultRequest.tools(toolNames);
154+
return this;
155+
}
156+
157+
@Override
158+
public Builder defaultTools(FunctionCallback... toolCallbacks) {
159+
this.defaultRequest.tools(toolCallbacks);
155160
return this;
156161
}
157162

@@ -161,24 +166,28 @@ public Builder defaultTools(Object... toolObjects) {
161166
return this;
162167
}
163168

169+
@Deprecated // Use defaultTools()
164170
public <I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function) {
165171
this.defaultRequest
166172
.functions(FunctionCallback.builder().function(name, function).description(description).build());
167173
return this;
168174
}
169175

176+
@Deprecated // Use defaultTools()
170177
public <I, O> Builder defaultFunction(String name, String description,
171178
java.util.function.BiFunction<I, ToolContext, O> biFunction) {
172179
this.defaultRequest
173180
.functions(FunctionCallback.builder().function(name, biFunction).description(description).build());
174181
return this;
175182
}
176183

184+
@Deprecated // Use defaultTools()
177185
public Builder defaultFunctions(String... functionNames) {
178186
this.defaultRequest.functions(functionNames);
179187
return this;
180188
}
181189

190+
@Deprecated // Use defaultTools()
182191
public Builder defaultFunctions(FunctionCallback... functionCallbacks) {
183192
this.defaultRequest.functions(functionCallbacks);
184193
return this;

spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {
3939

4040
private List<FunctionCallback> toolCallbacks = new ArrayList<>();
4141

42-
private Set<String> tools = new HashSet<>();
42+
private Set<String> toolNames = new HashSet<>();
4343

4444
private Map<String, Object> toolContext = new HashMap<>();
4545

@@ -83,16 +83,16 @@ public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
8383
}
8484

8585
@Override
86-
public Set<String> getTools() {
87-
return Set.copyOf(this.tools);
86+
public Set<String> getToolNames() {
87+
return Set.copyOf(this.toolNames);
8888
}
8989

9090
@Override
91-
public void setTools(Set<String> tools) {
92-
Assert.notNull(tools, "tools cannot be null");
93-
Assert.noNullElements(tools, "tools cannot contain null elements");
94-
tools.forEach(tool -> Assert.hasText(tool, "tools cannot contain empty elements"));
95-
this.tools = new HashSet<>(tools);
91+
public void setToolNames(Set<String> toolNames) {
92+
Assert.notNull(toolNames, "toolNames cannot be null");
93+
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
94+
toolNames.forEach(toolName -> Assert.hasText(toolName, "toolNames cannot contain empty elements"));
95+
this.toolNames = new HashSet<>(toolNames);
9696
}
9797

9898
@Override
@@ -130,12 +130,12 @@ public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
130130

131131
@Override
132132
public Set<String> getFunctions() {
133-
return getTools();
133+
return getToolNames();
134134
}
135135

136136
@Override
137137
public void setFunctions(Set<String> functions) {
138-
setTools(functions);
138+
setToolNames(functions);
139139
}
140140

141141
@Override
@@ -234,7 +234,7 @@ public void setTopP(@Nullable Double topP) {
234234
public <T extends ChatOptions> T copy() {
235235
DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions();
236236
options.setToolCallbacks(getToolCallbacks());
237-
options.setTools(getTools());
237+
options.setToolNames(getToolNames());
238238
options.setToolContext(getToolContext());
239239
options.setInternalToolExecutionEnabled(isInternalToolExecutionEnabled());
240240
options.setModel(getModel());
@@ -273,15 +273,15 @@ public ToolCallingChatOptions.Builder toolCallbacks(FunctionCallback... toolCall
273273
}
274274

275275
@Override
276-
public ToolCallingChatOptions.Builder tools(Set<String> toolNames) {
277-
this.options.setTools(toolNames);
276+
public ToolCallingChatOptions.Builder toolNames(Set<String> toolNames) {
277+
this.options.setToolNames(toolNames);
278278
return this;
279279
}
280280

281281
@Override
282-
public ToolCallingChatOptions.Builder tools(String... toolNames) {
282+
public ToolCallingChatOptions.Builder toolNames(String... toolNames) {
283283
Assert.notNull(toolNames, "toolNames cannot be null");
284-
this.options.setTools(Set.of(toolNames));
284+
this.options.setToolNames(Set.of(toolNames));
285285
return this;
286286
}
287287

@@ -322,15 +322,15 @@ public ToolCallingChatOptions.Builder functionCallbacks(FunctionCallback... func
322322
}
323323

324324
@Override
325-
@Deprecated // Use tools() instead
325+
@Deprecated // Use toolNames() instead
326326
public ToolCallingChatOptions.Builder functions(Set<String> functions) {
327-
return tools(functions);
327+
return toolNames(functions);
328328
}
329329

330330
@Override
331-
@Deprecated // Use tools() instead
331+
@Deprecated // Use toolNames() instead
332332
public ToolCallingChatOptions.Builder function(String function) {
333-
return tools(function);
333+
return toolNames(function);
334334
}
335335

336336
@Override

0 commit comments

Comments
 (0)