Skip to content

Commit 1b5cd09

Browse files
authored
Wrap Exception from toolFunction in ToolExecutionException and rethrow (#3918)
Fixes #2857 Auto-cherry-pick to 1.0.x - When `toolFunction` throws a `ToolExecutionException`, rethrow it directly. - When `toolFunction` throws a `Exception`, wrap it in a `ToolExecutionException` and rethrow it. - Add related tests Signed-off-by: YunKui Lu <[email protected]>
1 parent c5a9568 commit 1b5cd09

File tree

2 files changed

+202
-1
lines changed

2 files changed

+202
-1
lines changed

spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.tool.definition.ToolDefinition;
3232
import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
3333
import org.springframework.ai.tool.execution.ToolCallResultConverter;
34+
import org.springframework.ai.tool.execution.ToolExecutionException;
3435
import org.springframework.ai.tool.metadata.ToolMetadata;
3536
import org.springframework.ai.tool.support.ToolUtils;
3637
import org.springframework.ai.util.json.JsonParser;
@@ -44,6 +45,7 @@
4445
* A {@link ToolCallback} implementation to invoke functions as tools.
4546
*
4647
* @author Thomas Vitale
48+
* @author YunKui Lu
4749
* @since 1.0.0
4850
*/
4951
public class FunctionToolCallback<I, O> implements ToolCallback {
@@ -99,13 +101,25 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {
99101
logger.debug("Starting execution of tool: {}", this.toolDefinition.name());
100102

101103
I request = JsonParser.fromJson(toolInput, this.toolInputType);
102-
O response = this.toolFunction.apply(request, toolContext);
104+
O response = callMethod(request, toolContext);
103105

104106
logger.debug("Successful execution of tool: {}", this.toolDefinition.name());
105107

106108
return this.toolCallResultConverter.convert(response, null);
107109
}
108110

111+
private O callMethod(I request, @Nullable ToolContext toolContext) {
112+
try {
113+
return this.toolFunction.apply(request, toolContext);
114+
}
115+
catch (ToolExecutionException ex) {
116+
throw ex;
117+
}
118+
catch (Exception ex) {
119+
throw new ToolExecutionException(this.toolDefinition, ex);
120+
}
121+
}
122+
109123
@Override
110124
public String toString() {
111125
return "FunctionToolCallback{" + "toolDefinition=" + this.toolDefinition + ", toolMetadata=" + this.toolMetadata
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.tool.function;
18+
19+
import java.util.Map;
20+
import java.util.concurrent.atomic.AtomicReference;
21+
import java.util.function.BiFunction;
22+
import java.util.function.Consumer;
23+
import java.util.function.Function;
24+
import java.util.function.Supplier;
25+
26+
import org.junit.jupiter.api.Test;
27+
28+
import org.springframework.ai.chat.model.ToolContext;
29+
import org.springframework.ai.tool.execution.ToolExecutionException;
30+
import org.springframework.ai.tool.metadata.ToolMetadata;
31+
32+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
33+
import static org.assertj.core.api.InstanceOfAssertFactories.type;
34+
import static org.junit.jupiter.api.Assertions.assertEquals;
35+
36+
/**
37+
* @author YunKui Lu
38+
*/
39+
class FunctionToolCallbackTest {
40+
41+
@Test
42+
void testConsumerToolCall() {
43+
TestFunctionTool tool = new TestFunctionTool();
44+
FunctionToolCallback<String, Void> callback = FunctionToolCallback.builder("testTool", tool.stringConsumer())
45+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
46+
.description("test description")
47+
.inputType(String.class)
48+
.build();
49+
50+
callback.call("\"test string param\"");
51+
52+
assertEquals("test string param", tool.calledValue.get());
53+
}
54+
55+
@Test
56+
void testBiFunctionToolCall() {
57+
TestFunctionTool tool = new TestFunctionTool();
58+
FunctionToolCallback<String, String> callback = FunctionToolCallback
59+
.builder("testTool", tool.stringBiFunction())
60+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
61+
.description("test description")
62+
.inputType(String.class)
63+
.build();
64+
65+
ToolContext toolContext = new ToolContext(Map.of("foo", "bar"));
66+
67+
String callResult = callback.call("\"test string param\"", toolContext);
68+
69+
assertEquals("test string param", tool.calledValue.get());
70+
assertEquals("\"return value = test string param\"", callResult);
71+
assertEquals(toolContext, tool.calledToolContext.get());
72+
}
73+
74+
@Test
75+
void testFunctionToolCall() {
76+
TestFunctionTool tool = new TestFunctionTool();
77+
FunctionToolCallback<String, String> callback = FunctionToolCallback.builder("testTool", tool.stringFunction())
78+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
79+
.description("test description")
80+
.inputType(String.class)
81+
.build();
82+
83+
ToolContext toolContext = new ToolContext(Map.of());
84+
85+
String callResult = callback.call("\"test string param\"", toolContext);
86+
87+
assertEquals("test string param", tool.calledValue.get());
88+
assertEquals("\"return value = test string param\"", callResult);
89+
}
90+
91+
@Test
92+
void testSupplierToolCall() {
93+
TestFunctionTool tool = new TestFunctionTool();
94+
95+
FunctionToolCallback<Void, String> callback = FunctionToolCallback.builder("testTool", tool.stringSupplier())
96+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
97+
.description("test description")
98+
.inputType(Void.class)
99+
.build();
100+
101+
ToolContext toolContext = new ToolContext(Map.of());
102+
103+
String callResult = callback.call("\"test string param\"", toolContext);
104+
105+
assertEquals("not params", tool.calledValue.get());
106+
assertEquals("\"return value = \"", callResult);
107+
}
108+
109+
@Test
110+
void testThrowRuntimeException() {
111+
TestFunctionTool tool = new TestFunctionTool();
112+
FunctionToolCallback<String, Void> callback = FunctionToolCallback
113+
.builder("testTool", tool.throwRuntimeException())
114+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
115+
.description("test description")
116+
.inputType(String.class)
117+
.build();
118+
119+
assertThatThrownBy(() -> callback.call("\"test string param\"")).hasMessage("test exception")
120+
.hasCauseInstanceOf(RuntimeException.class)
121+
.asInstanceOf(type(ToolExecutionException.class))
122+
.extracting(ToolExecutionException::getToolDefinition)
123+
.isEqualTo(callback.getToolDefinition());
124+
}
125+
126+
@Test
127+
void testThrowToolExecutionException() {
128+
TestFunctionTool tool = new TestFunctionTool();
129+
FunctionToolCallback<String, Void> callback = FunctionToolCallback
130+
.builder("testTool", tool.throwToolExecutionException())
131+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
132+
.description("test description")
133+
.inputType(String.class)
134+
.build();
135+
136+
assertThatThrownBy(() -> callback.call("\"test string param\"")).hasMessage("test exception")
137+
.hasCauseInstanceOf(RuntimeException.class)
138+
.isInstanceOf(ToolExecutionException.class);
139+
}
140+
141+
static class TestFunctionTool {
142+
143+
AtomicReference<Object> calledValue = new AtomicReference<>();
144+
145+
AtomicReference<ToolContext> calledToolContext = new AtomicReference<>();
146+
147+
public Consumer<String> stringConsumer() {
148+
return s -> {
149+
calledValue.set(s);
150+
};
151+
}
152+
153+
public BiFunction<String, ToolContext, String> stringBiFunction() {
154+
return (s, context) -> {
155+
calledValue.set(s);
156+
calledToolContext.set(context);
157+
return "return value = " + s;
158+
};
159+
}
160+
161+
public Function<String, String> stringFunction() {
162+
return s -> {
163+
calledValue.set(s);
164+
return "return value = " + s;
165+
};
166+
}
167+
168+
public Supplier<String> stringSupplier() {
169+
calledValue.set("not params");
170+
return () -> "return value = ";
171+
}
172+
173+
public Consumer<String> throwRuntimeException() {
174+
return s -> {
175+
throw new RuntimeException("test exception");
176+
};
177+
}
178+
179+
public Consumer<String> throwToolExecutionException() {
180+
return s -> {
181+
throw new ToolExecutionException(null, new RuntimeException("test exception"));
182+
};
183+
}
184+
185+
}
186+
187+
}

0 commit comments

Comments
 (0)