Skip to content

Commit 3c4b6dc

Browse files
tzolovilayaperumalg
authored andcommitted
Add Tool Argument Augmentation for dynamic tool schema enhancement
Introduce utilities to augment tool input schemas with additional arguments (e.g., reasoning, metadata) without modifying tool implementations. Includes AugmentedToolCallbackProvider, AugmentedToolCallback, and ToolInputSchemaAugmenter components. Signed-off-by: Christian Tzolov <[email protected]>
1 parent 08fadd0 commit 3c4b6dc

File tree

10 files changed

+1810
-6
lines changed

10 files changed

+1810
-6
lines changed

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,74 @@ ToolCallbackResolver toolCallbackResolver(List<FunctionCallback> toolCallbacks)
13281328

13291329
The `ToolCallbackResolver` is used internally by the `ToolCallingManager` to resolve tools dynamically at runtime, supporting both xref:_framework_controlled_tool_execution[] and xref:_user_controlled_tool_execution[].
13301330

1331+
[[tool-argument-augmentation]]
1332+
== Tool Argument Augmentation
1333+
1334+
Spring AI provides a utility for **dynamic augmentation of tool input schemas** with additional arguments. This allows capturing extra information from the model—such as reasoning or metadata—without modifying the underlying tool implementation.
1335+
1336+
Common use cases include:
1337+
1338+
* **Inner Thinking/Reasoning**: Capture the model's step-by-step reasoning before executing a tool
1339+
* **Memory Enhancement**: Extract insights to store in long-term memory
1340+
* **Analytics & Tracking**: Collect metadata, user intent, or usage patterns
1341+
* **Multi-Agent Coordination**: Pass agent identifiers or coordination signals
1342+
1343+
=== Quick Start
1344+
1345+
1. **Define augmented arguments** as a Java Record:
1346+
1347+
[source,java]
1348+
----
1349+
public record AgentThinking(
1350+
@ToolParam(description = "Your reasoning for calling this tool", required = true)
1351+
String innerThought,
1352+
1353+
@ToolParam(description = "Confidence level (low, medium, high)", required = false)
1354+
String confidence
1355+
) {}
1356+
----
1357+
1358+
2. **Wrap your tool** with `AugmentedToolCallbackProvider`:
1359+
1360+
[source,java]
1361+
----
1362+
AugmentedToolCallbackProvider<AgentThinking> provider = AugmentedToolCallbackProvider
1363+
.<AgentThinking>builder()
1364+
.toolObject(new MyTools()) // Your @Tool annotated class
1365+
.argumentType(AgentThinking.class)
1366+
.argumentConsumer(event -> {
1367+
AgentThinking thinking = event.arguments();
1368+
log.info("Tool: {} | Reasoning: {}", event.toolDefinition().name(), thinking.innerThought());
1369+
})
1370+
.removeExtraArgumentsAfterProcessing(true)
1371+
.build();
1372+
----
1373+
1374+
3. **Use with ChatClient**:
1375+
1376+
[source,java]
1377+
----
1378+
ChatClient chatClient = ChatClient.builder(chatModel)
1379+
.defaultToolCallbacks(provider)
1380+
.build();
1381+
----
1382+
1383+
The LLM sees the augmented schema with your additional fields. Your consumer receives the `AgentThinking` record, while the original tool receives only its expected arguments.
1384+
1385+
=== Core Components
1386+
1387+
* `AugmentedToolCallbackProvider<T>` - Wraps tool objects or providers, augmenting all tools with the specified Record type
1388+
* `AugmentedToolCallback<T>` - Wraps individual `ToolCallback` instances
1389+
* `AugmentedArgumentEvent<T>` - Contains `toolDefinition()`, `rawInput()`, and `arguments()` for consumers
1390+
* `ToolInputSchemaAugmenter` - Low-level utility for schema manipulation
1391+
1392+
=== Configuration
1393+
1394+
The `removeExtraArgumentsAfterProcessing` option controls whether augmented arguments are passed to the original tool:
1395+
1396+
* `true` (default) - Remove augmented arguments before calling the tool
1397+
* `false` - Preserve augmented arguments in the input (if the tool can ignore extra fields)
1398+
13311399
== Observability
13321400

13331401
Tool calling includes observability support with spring.ai.tool observations that measure completion time and propagate tracing information. See xref:observability/index.adoc#_tool_calling[Tool Calling Observability].

spring-ai-model/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,18 @@ private static String toGetName(String name) {
378378
*/
379379
public static String getJsonSchema(Type inputType, boolean toUpperCaseTypeValues) {
380380

381+
ObjectNode node = getJsonSchema(inputType);
382+
383+
if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
384+
// version of it).
385+
toUpperCaseTypeValues(node);
386+
}
387+
388+
return node.toPrettyString();
389+
}
390+
391+
public static ObjectNode getJsonSchema(Type inputType) {
392+
381393
if (SCHEMA_GENERATOR_CACHE.get() == null) {
382394

383395
JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED);
@@ -405,12 +417,7 @@ public static String getJsonSchema(Type inputType, boolean toUpperCaseTypeValues
405417
node.putObject("properties");
406418
}
407419

408-
if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
409-
// version of it).
410-
toUpperCaseTypeValues(node);
411-
}
412-
413-
return node.toPrettyString();
420+
return node;
414421
}
415422

416423
public static void toUpperCaseTypeValues(ObjectNode node) {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.tool.augment;
18+
19+
import org.springframework.ai.tool.definition.ToolDefinition;
20+
21+
/**
22+
* An event that encapsulates the augmented arguments extracted from a tool input, along
23+
* with the associated tool definition and raw input data.
24+
*
25+
* @param <T> The type of the augmented arguments record.
26+
* @param toolDefinition The tool definition associated with the event.
27+
* @param rawInput The raw input data as a string.
28+
* @param arguments The augmented arguments extracted from the input.
29+
* @author Christian Tzolov
30+
*/
31+
public record AugmentedArgumentEvent<T>(ToolDefinition toolDefinition, String rawInput, T arguments) {
32+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.tool.augment;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.function.Consumer;
22+
23+
import com.fasterxml.jackson.core.type.TypeReference;
24+
25+
import org.springframework.ai.chat.model.ToolContext;
26+
import org.springframework.ai.tool.ToolCallback;
27+
import org.springframework.ai.tool.augment.ToolInputSchemaAugmenter.AugmentedArgumentType;
28+
import org.springframework.ai.tool.definition.ToolDefinition;
29+
import org.springframework.ai.util.json.JsonParser;
30+
import org.springframework.lang.Nullable;
31+
import org.springframework.util.Assert;
32+
33+
/**
34+
* This class wraps an existing {@link ToolCallback} and modifies its input schema to
35+
* include additional fields defined in the provided Record type. It also provides a
36+
* mechanism to handle these extended arguments, either by consuming them via a provided
37+
* {@link Consumer} or by removing them from the input after processing.
38+
*
39+
* @author Christian Tzolov
40+
*/
41+
public class AugmentedToolCallback<T extends Record> implements ToolCallback {
42+
43+
/**
44+
* The delegate ToolCallback that this class extends.
45+
*/
46+
private final ToolCallback delegate;
47+
48+
/**
49+
* The augmented ToolDefinition that includes the augmented input schema.
50+
*/
51+
private ToolDefinition augmentedToolDefinition;
52+
53+
/**
54+
* The record class type that defines the structure of the augmented arguments.
55+
*/
56+
private Class<T> augmentedArgumentsClass;
57+
58+
/**
59+
* A consumer that processes the augmented arguments extracted from the tool input.
60+
*/
61+
private Consumer<AugmentedArgumentEvent<T>> augmentedArgumentsConsumer;
62+
63+
/**
64+
* The list of tool argument types that have been added to the tool input schema.
65+
*/
66+
private List<AugmentedArgumentType> augmentedArgumentTypes;
67+
68+
/**
69+
* A flag indicating whether to remove the augmented arguments from the tool input
70+
* after they have been processed. If the arguments are not removed, they will remain
71+
* in the tool input for the delegate to process. In many cases this could be useful.
72+
*/
73+
private boolean removeAugmentedArgumentsAfterProcessing = false;
74+
75+
public AugmentedToolCallback(ToolCallback delegate, Class<T> augmentedArgumentsClass,
76+
Consumer<AugmentedArgumentEvent<T>> augmentedArgumentsConsumer,
77+
boolean removeExtraArgumentsAfterProcessing) {
78+
Assert.notNull(delegate, "Delegate ToolCallback must not be null");
79+
Assert.notNull(augmentedArgumentsClass, "Argument types must not be null");
80+
Assert.isTrue(augmentedArgumentsClass.isRecord(), "Argument types must be a Record type");
81+
Assert.isTrue(augmentedArgumentsClass.getRecordComponents().length > 0,
82+
"Argument types must have at least one field");
83+
84+
this.delegate = delegate;
85+
this.augmentedArgumentTypes = ToolInputSchemaAugmenter.toAugmentedArgumentTypes(augmentedArgumentsClass);
86+
String originalSchema = this.delegate.getToolDefinition().inputSchema();
87+
String augmentedSchema = ToolInputSchemaAugmenter.augmentToolInputSchema(originalSchema,
88+
this.augmentedArgumentTypes);
89+
this.augmentedToolDefinition = ToolDefinition.builder()
90+
.name(this.delegate.getToolDefinition().name())
91+
.description(this.delegate.getToolDefinition().description())
92+
.inputSchema(augmentedSchema)
93+
.build();
94+
95+
this.augmentedArgumentsClass = augmentedArgumentsClass;
96+
this.augmentedArgumentsConsumer = augmentedArgumentsConsumer;
97+
this.removeAugmentedArgumentsAfterProcessing = removeExtraArgumentsAfterProcessing;
98+
}
99+
100+
@Override
101+
public ToolDefinition getToolDefinition() {
102+
return this.augmentedToolDefinition;
103+
}
104+
105+
@Override
106+
public String call(String toolInput) {
107+
return this.delegate.call(this.handleAugmentedArguments(toolInput));
108+
}
109+
110+
@Override
111+
public String call(String toolInput, @Nullable ToolContext tooContext) {
112+
return this.delegate.call(this.handleAugmentedArguments(toolInput), tooContext);
113+
}
114+
115+
/**
116+
* Handles the augmented arguments in the tool input. It extracts the augmented
117+
* arguments from the tool input, processes them using the provided consumer, and
118+
* optionally removes them from the tool input.
119+
* @param toolInput the input as received from the LLM.
120+
* @return the input to send to the delegate ToolCallback
121+
*/
122+
private String handleAugmentedArguments(String toolInput) {
123+
124+
// Extract the augmented arguments from the toolInput and send them to the
125+
// consumer if provided.
126+
if (this.augmentedArgumentsConsumer != null) {
127+
T augmentedArguments = JsonParser.fromJson(toolInput, this.augmentedArgumentsClass);
128+
this.augmentedArgumentsConsumer
129+
.accept(new AugmentedArgumentEvent<>(this.augmentedToolDefinition, toolInput, augmentedArguments));
130+
}
131+
132+
// Optionally remove the extra arguments from the toolInput
133+
if (this.removeAugmentedArgumentsAfterProcessing) {
134+
var args = JsonParser.fromJson(toolInput, new TypeReference<Map<String, Object>>() {
135+
});
136+
137+
for (AugmentedArgumentType newFieldType : this.augmentedArgumentTypes) {
138+
args.remove(newFieldType.name());
139+
}
140+
toolInput = JsonParser.toJson(args);
141+
}
142+
143+
return toolInput;
144+
}
145+
146+
}

0 commit comments

Comments
 (0)